In [None]:
import torch
import math
import torch.nn as nn
from torch.utils.data import Dataset
import os
import datetime
from torch.utils.data import DataLoader
import argparse
from PIL import Image
import torch.optim as optim
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from timm.models.registry import register_model
from einops import rearrange
from einops.layers.torch import Rearrange

train_path = os.path.join("input", "plant-seedlings-classification", "train")
test_path = os.path.join("input", "plant-seedlings-classification", "test")
train_txt_path = os.path.join("working", "train.txt")
dev_txt_path = os.path.join("working", "dev.txt")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_dev_split(img_dir):
    img_list, label = [], []
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):
        for sub_dir in s_dirs:
            label.append(sub_dir)
            i_dir = os.path.join(root, sub_dir)
            list = os.listdir(i_dir)
            img_list.extend([os.path.join(i_dir, i) for i in list])
    train, dev = train_test_split(img_list, train_size=0.9, random_state=2)
    return train, dev, label

def gen_txt(txt_path, data, label):
    f = open(txt_path, 'w')
    for i in data:
        img_name = i.split('/')[3]
        line = i + ' ' + str(label.index(img_name)) + '\n'
        f.write(line)

train, dev, img_label = train_dev_split(train_path)
gen_txt(train_txt_path, train, img_label)
gen_txt(dev_txt_path, dev, img_label)

In [None]:
parser = argparse.ArgumentParser(description='caltech')
parser.add_argument('--gpu', type=str, default='0', help='gpu')
parser.add_argument('--data_path', type=str, default='../../datasets/caltech-101/101_ObjectCategories',
                    help='path to train set')

parser.add_argument('--save_dir', type=str, default='working/checkpoint', help='save dir')
parser.add_argument('--log_dir', type=str, default='working/log', help='log dir')
parser.add_argument('--save_prefix', type=str, default='working/ResNet18_SGD_e', help='save prefix')

parser.add_argument('--lr_initial', type=float, default=1e-4, help='initial learning rate')
parser.add_argument('--weight_decay', type=float, default=2e-5, help='weight decay')

parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--epoch_num', type=int, default=60, help='epoch num')

parser.add_argument('--checkpoint_frequency', type=int, default=3, help='checkpoint frequency')
parser.add_argument('--test_num', type=int, default=2, help='test num')
opt = parser.parse_args(args=[])

In [None]:
class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            idx = line.rfind(' ')
            imgs.append((line[:idx], int(line[idx:])))
            self.imgs = imgs
            self.transform = transform
            self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)

In [None]:
def vgg_block(nums_conv2d, in_channels, out_channels):
    blocks = []
    for _ in range(nums_conv2d):
        blocks.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1))
        blocks.append(nn.BatchNorm2d(out_channels, affine=True))
        blocks.append(nn.ReLU())
        in_channels = out_channels
    blocks.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*blocks)

class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.conv_arch = ((1, 3, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))
        conv_layers = []
        for (nums_conv2d, in_channels, out_channels) in self.conv_arch:
            conv_layers.append(vgg_block(nums_conv2d, in_channels, out_channels))
        self.features = nn.Sequential(*conv_layers)
        self.F = nn.Flatten()
        self.Linear = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(0.5))
        self.Linear2 = nn.Sequential(
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5))
        self.Linear3 = nn.Linear(4096, 12)

    def forward(self, x):
        x = self.features(x)
        x = self.F(x)
        x = self.Linear(x)
        x = self.Linear2(x)
        x = self.Linear3(x)
        return x

In [None]:
class S_SE(nn.Module):
    def __init__(self, channel, reduction=4):
        super(S_SE, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(channel, channel // reduction)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(channel // reduction, channel)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: B, C, H, W
        out = self.avg_pool(x).flatten(1)
        out = self.sigmoid(self.fc2(self.relu(self.fc1(out))))
        out = out.unsqueeze(2).unsqueeze(2)
        return out * x

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1, se=False, reduction=4):
        super(ResBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channel)
        )
        if se:
            self.se_layer = S_SE(channel=out_channel, reduction=reduction)
        else:
            self.se_layer = nn.Identity()
        if in_channel != out_channel:
            self.short_cut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=2),
                nn.BatchNorm2d(out_channel)
            )
        else:
            self.short_cut = nn.Sequential()

    def forward(self, x):
        # print(self.conv(x).shape)
        # print(self.short_cut(x).shape)
        return self.se_layer(self.conv(x)) + self.short_cut(x)


class ResNet18(nn.Module):
    def __init__(self, class_num, se=False):
        super(ResNet18, self).__init__()
        self.residual_layer_nums = [2, 2, 2, 2]
        self.out_channels = [64, 64, 128, 256, 512]
        self.se = se
        self.in_proj = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.residual_layer = []
        self.residual_layer.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        for i in range(len(self.residual_layer_nums)):
            layer_num = self.residual_layer_nums[i]
            in_channel = self.out_channels[i]
            out_channel = self.out_channels[i + 1]
            self.residual_layer.append(self.make_layer(ResBlock, layer_num, in_channel, out_channel))
        self.residual_layer = nn.Sequential(*self.residual_layer)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, class_num)

    def make_layer(self, block, layer_num, in_channel, out_channel):
        layer = []
        for i in range(layer_num):
            layer.append(block(in_channel, out_channel, stride=2 if in_channel != out_channel and i == 0 else 1,
                               se=self.se))
            in_channel = out_channel
        return nn.Sequential(*layer)

    def forward(self, x):
        out = self.in_proj(x)  # B, 64, 112, 112
        out = self.residual_layer(out)  # B, 512, 7, 7
        out = self.avg_pool(out).flatten(1)  # B, 512
        out = self.fc(out)  # B, class_num
        return out


In [None]:
class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch

    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """

    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x

In [None]:
class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf
    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """

    def __init__(self, in_chans=3, num_classes=12,
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
                 layer_scale_init_value=1e-6, head_init_scale=1.,
                 ):
        super().__init__()

        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
                        layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)

        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-2, -1]))  # global average pooling, (N, C, H, W) -> (N, C)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

In [None]:
class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

In [None]:
class S_MBConv(nn.Module):
    def __init__(self, channel, reduction):
        super(S_MBConv, self).__init__()
        self.bone = nn.Sequential(
            nn.Conv2d(channel, channel, 1),
            nn.BatchNorm2d(channel),
            nn.ReLU(),
            nn.Conv2d(channel, channel, kernel_size=3, padding=1, groups=channel),
            SE(channel, reduction),
            nn.Conv2d(channel, channel, 1),
            nn.BatchNorm2d(channel)
        )

    def forward(self, x):
        return x + self.bone(x)


class My_Improved(nn.Module):
    def __init__(self, in_channel=3, depths=[3, 3, 9, 3], out_channel=[96, 192, 384, 768], reduction=4, num_classes=12):
        super(My_Improved, self).__init__()
        self.downsample_layers = nn.ModuleList()
        stem = nn.Sequential(
            nn.Conv2d(in_channel, out_channel[0], kernel_size=4, stride=4),
            nn.BatchNorm2d(out_channel[0])
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                nn.BatchNorm2d(out_channel[i]),
                nn.Conv2d(out_channel[i], out_channel[i + 1], kernel_size=2, stride=2)
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()
        for i in range(4):
            stage = nn.Sequential(
                *[S_MBConv(out_channel[i], reduction) for _ in range(depths[i])]
            )
            self.stages.append(stage)

        self.head = nn.Linear(out_channel[-1], num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
            nn.init.constant_(m.bias, 0)

    def extract_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return x.mean([-2, -1])
    
    def forward(self, x):
        feature = self.extract_features(x)
        out = self.head(feature)
        return out

In [None]:
def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.GELU()
    )


class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class MBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return x + self.conv(x)

In [None]:
class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

In [None]:
class Transformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)

        self.attn = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(inp, self.attn, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        if self.downsample:
            x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
        else:
            x = x + self.attn(x)
        x = x + self.ff(x)
        return x


class CoAtNet(nn.Module):
    def __init__(self, image_size, in_channels, num_blocks, channels, num_classes=1000, block_types=['C', 'C', 'T', 'T']):
        super().__init__()
        ih, iw = image_size
        block = {'C': MBConv, 'T': Transformer}

        self.s0 = self._make_layer(
            conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        self.s1 = self._make_layer(
            block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s2 = self._make_layer(
            block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s3 = self._make_layer(
            block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s4 = self._make_layer(
            block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        self.fc = nn.Linear(channels[-1], num_classes, bias=False)

    def forward(self, x):
        x = self.s0(x)
        x = self.s1(x)
        x = self.s2(x)
        x = self.s3(x)
        x = self.s4(x)

        x = self.pool(x).view(-1, x.shape[1])
        x = self.fc(x)
        return x

    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, image_size, downsample=True))
            else:
                layers.append(block(oup, oup, image_size))
        return nn.Sequential(*layers)

In [None]:
log_dir = os.path.join(opt.log_dir)
datetime_now = datetime.datetime.now().isoformat()[:-7].replace(':', '-')
log_txt_name = os.path.join(log_dir, datetime_now + '.txt')
print('Now time is : ', datetime_now)

writer = SummaryWriter(log_dir)
num_blocks = [2, 2, 3, 5, 2] 
channels = [64, 96, 192, 384, 768]
model = ResNet18(class_num=12)
# from Lab3.model.resnet18 import ResNet18
# model = ResNet18(class_num=101, se=True)
# from torchvision.models import resnet18
# model = resnet18(num_classes=101)

if torch.cuda.device_count() > 1:
    print('Use', torch.cuda.device_count(), 'GPUs!')
model = torch.nn.DataParallel(model).to(device)

# optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr_initial,
#                        betas=(0.9, 0.999), eps=1e-8, weight_decay=opt.weight_decay)

optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr_initial,
                       weight_decay=opt.weight_decay)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 15], gamma=0.1)

start_epoch = 1
criterion = nn.CrossEntropyLoss().to(device)

train_transform = transforms.Compose([
    # transforms.CenterCrop([224, 224]),
    transforms.Resize([224, 224]),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
test_transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_dataset = MyDataset(txt_path=train_txt_path, transform=train_transform)
dev_dataset = MyDataset(txt_path=dev_txt_path, transform=test_transform)
train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=opt.batch_size)
dev_dataloader = DataLoader(dataset=dev_dataset, shuffle=False, batch_size=16)

total_step = len(train_dataloader)
print('Total step: {}'.format(total_step))

epoch_num = opt.epoch_num
print('===> Start Epoch {} End Epoch {}'.format(start_epoch, epoch_num))

step = 0
dev_acc = 0

for epoch in range(start_epoch, epoch_num + 1):
    model.train()
    epoch_loss = 0.
    for i, (image, label) in enumerate(tqdm(train_dataloader, desc='Train'), 1):
        optimizer.zero_grad()
        image = image.cuda()
        label = label.cuda().long()
        pred = model(image)
        loss = criterion(pred, label.detach())
        acc = torch.sum(torch.argmax(pred, dim=1) == label) / opt.batch_size
        epoch_loss += loss.detach().item()
        batch_log = 'Epoch:[{}/{}] Batch: [{}/{}] loss = {:.4f} lr = {:.7f} acc = {:.3f}'.\
            format(epoch, epoch_num, i, total_step, loss.detach().item(), scheduler.get_last_lr()[0], acc)
        writer.add_scalar('Training_loss', epoch_loss, epoch*total_step+i)
        writer.add_scalar('Acc in the train dataset', acc, epoch*total_step+i)
        # print(batch_log)
        loss.backward()
        optimizer.step()
        
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in dev_dataloader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy of the network on the dev images: {} %'.format(100 * correct / total))
        writer.add_scalar('Acc in the dev dataset', 100 * correct / total, epoch)

        if (100 * correct / total) > dev_acc:  # 寻找最高准确率的模型参数
            torch.save(model.state_dict(), opt.save_prefix + str(epoch) + '.pth')
            dev_acc = 100 * correct / total

In [None]:
import pandas as pd
class TestDataset(Dataset):
    def __init__(self, test_dir, transform=None):
        img_list = os.listdir(test_dir)
        img_list = [os.path.join(test_dir, img_list[i]) for i in range(len(img_list))]
        self.imgs = img_list
        self.transform = transform

    def __getitem__(self, index):
        fn = self.imgs[index]
        img = Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.imgs)
    
def predict_class():
    pth = os.path.join('working', 'ResNet18_SGD_e57.pth')
    test_dataset = TestDataset(test_dir=test_path,transform=test_transform)
    test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=False)
    model = ResNet18(class_num=12)
    model = torch.nn.DataParallel(model).to(device)
    model.load_state_dict(torch.load(pth))
    model.eval()
    with torch.no_grad():
        predict_list = []
        for img in test_dataloader:
            img = img.to(device)
            predict = model(img)
            max_index = torch.argmax(predict.data, dim=1)
            predict_list.extend(max_index.cpu().data.numpy().tolist())
    
    file_predict_table = [[os.listdir(test_path)[i], img_label[predict_list[i]]] for i in range(len(predict_list))]
    df = pd.DataFrame(file_predict_table, columns=['file','species'])
    df.to_csv("working/ResNet18_SGD.csv", index=False)

predict_class()