In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import os
import time
import pandas as pd
from torch.utils.data import Dataset, DataLoader

## prepare the datasets

In [2]:
batch_size = 16
lr = 0.001
weight_decay = 0.0001
momentum = 0.9
num_classes = 12

In [3]:
from torchvision.transforms import transforms as T
from PIL import Image

class Cat12Dataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(annotations_file, sep='\t', header=None)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        # image = torchvision.io.read_image(img_path) / 255.0
        image = Image.open(img_path).convert('RGB')
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        return image, label


dataset_path = 'D:\\code\\python\\datasets'
cat12_path = os.path.join(dataset_path, 'cat12_classification')
anno_file_path = os.path.join(cat12_path, 'train_list.txt')
img_path = os.path.join(cat12_path, 'cat_12_train')
transforms = [
    T.Resize([224, 224]),
    T.RandomHorizontalFlip(0.3),
    T.RandomVerticalFlip(0.3),
    T.ToTensor(),
]
transforms = T.Compose(transforms)
cat_12_dataset = Cat12Dataset(anno_file_path, img_path, transform=transforms)
train_loader = DataLoader(cat_12_dataset, batch_size=batch_size, shuffle=True)
next(iter(train_loader))

[tensor([[[[0.1255, 0.1216, 0.1137,  ..., 0.0784, 0.1176, 0.1412],
           [0.1255, 0.1216, 0.1137,  ..., 0.0745, 0.1137, 0.1373],
           [0.1216, 0.1176, 0.1098,  ..., 0.0667, 0.1020, 0.1255],
           ...,
           [0.1804, 0.1804, 0.1804,  ..., 0.2039, 0.1843, 0.1765],
           [0.1647, 0.1686, 0.1725,  ..., 0.2000, 0.1804, 0.1725],
           [0.1608, 0.1647, 0.1686,  ..., 0.2000, 0.1804, 0.1686]],
 
          [[0.1412, 0.1373, 0.1294,  ..., 0.0902, 0.1294, 0.1529],
           [0.1412, 0.1373, 0.1294,  ..., 0.0863, 0.1255, 0.1490],
           [0.1373, 0.1333, 0.1255,  ..., 0.0784, 0.1137, 0.1373],
           ...,
           [0.1216, 0.1216, 0.1216,  ..., 0.2471, 0.2314, 0.2275],
           [0.1059, 0.1098, 0.1137,  ..., 0.2392, 0.2196, 0.2118],
           [0.1020, 0.1059, 0.1098,  ..., 0.2353, 0.2157, 0.2039]],
 
          [[0.0824, 0.0784, 0.0706,  ..., 0.0549, 0.0941, 0.1176],
           [0.0824, 0.0784, 0.0706,  ..., 0.0510, 0.0902, 0.1137],
           [0.0824, 0.07

## prepare the network

In [13]:
class PretrainedResnet50(nn.Module):
    def __init__(self, num_classes):
        super(PretrainedResnet50, self).__init__()
        self.num_classes = num_classes
        self.backbone = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
        self.backbone = torch.nn.Sequential(*(list(self.backbone.children())[:-1]))
        self.flatten = nn.Flatten()
        self.classifer = nn.Linear(2048, self.num_classes, bias=True)
        nn.init.xavier_uniform(self.classifer.weight)

    def forward(self, x):
        x = self.backbone(x)
        x = self.flatten(x)
        x = self.classifer(x)
        return x

simple_net = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(16),
    nn.MaxPool2d(kernel_size=3),

    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.MaxPool2d(kernel_size=3),

    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.MaxPool2d(kernel_size=3),

    nn.Conv2d(64, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.MaxPool2d(kernel_size=3),

    nn.Conv2d(32, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(32),

    nn.AdaptiveAvgPool2d(output_size=(1, 1)),
    nn.Flatten(),
    nn.Linear(32, num_classes)
)

net = PretrainedResnet50(num_classes=num_classes)
# net = simple_net
net(torch.rand((1,3,224,224)))

  nn.init.xavier_uniform(self.classifer.weight)


tensor([[-0.0378, -0.9124,  0.2419, -0.4415, -0.3957,  0.7750,  0.0911,  1.4446,
         -0.3541, -0.2201,  0.2108,  1.0582]], grad_fn=<AddmmBackward0>)

## start training

In [14]:
from torch.optim import SGD
from torch.nn import CrossEntropyLoss

optimizer = SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
criterion = CrossEntropyLoss(reduction='none')
epoches = 20
weight_path = './cat_12_classification_resnet50.pth'
# weight_path = './cat_12_classification_simple.pth'
if os.path.exists(weight_path):
    net.load_state_dict(torch.load(weight_path))

for epoch in range(1, epoches + 1):
    print(f'start training {epoch}')
    start_time = time.time()
    total_loss = 0
    total_image_num = 0
    for batch in train_loader:
        imgs, targets = batch
        img_num = imgs.shape[0]
        optimizer.zero_grad()
        preds = net(imgs)
        l = criterion(preds, targets).mean(-1).sum()
        l.backward()
        optimizer.step()

        total_loss += l.item()
        total_image_num += img_num
    end_time = time.time()
    print(f'epoch:{epoch}, mean loss: {total_loss / total_image_num}, time cost: {end_time - start_time} seconds, imgnum: {total_image_num}')
    torch.save(net.state_dict(), weight_path)
    print('save weights successfully.')

start training 1
epoch:1, mean loss: 0.08034447794435201, time cost: 2165.7837524414062 seconds, imgnum: 2160
save weights successfully.
start training 2
epoch:2, mean loss: 0.026703220323004106, time cost: 2157.387532711029 seconds, imgnum: 2160
save weights successfully.
start training 3
epoch:3, mean loss: 0.017371856171154866, time cost: 2152.631463766098 seconds, imgnum: 2160
save weights successfully.
start training 4
epoch:4, mean loss: 0.013771994341233815, time cost: 2224.7634887695312 seconds, imgnum: 2160
save weights successfully.
start training 5
epoch:5, mean loss: 0.008497150270785722, time cost: 2185.9317770004272 seconds, imgnum: 2160
save weights successfully.
start training 6
epoch:6, mean loss: 0.008411633007711283, time cost: 2171.4820413589478 seconds, imgnum: 2160
save weights successfully.
start training 7
epoch:7, mean loss: 0.0069087716298074356, time cost: 2180.850682258606 seconds, imgnum: 2160
save weights successfully.
start training 8
epoch:8, mean loss: 

KeyboardInterrupt: 