## Resnet on Tiny ImageNet

在[resnet.ipynb](./resnet.ipynb)中我们测试了Resnet在Imagenette上的分类任务, Imagenette是只包含10类的ImageNet的子集,这里我们在TinyImagenet上测试对比不同的Resnet网络效果.

Tiny ImageNet 包含 100000 张图片，涵盖了200个类别（每个类别有500张图片），图片大小为64×64，并且为彩色图片。每个类别有500张训练图片，50张验证图片和50张测试图片。

随机挑选的15个类别:

![alt text](resources/tinyimagenet.png "Title")

In [1]:
# 自动重新加载外部module，使得修改代码之后无需重新import
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

from hdd.device.utils import get_device

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 设置训练数据的路径
DATA_ROOT = "~/workspace/hands-dirty-on-dl/dataset"
# 设置TensorBoard的路径
TENSORBOARD_ROOT = "~/workspace/hands-dirty-on-dl/dataset"
# 设置预训练模型参数路径
TORCH_HUB_PATH = "~/workspace/hands-dirty-on-dl/pretrained_models"
torch.hub.set_dir(TORCH_HUB_PATH)
# 挑选最合适的训练设备
DEVICE = get_device(["cuda", "cpu"])
print("Use device: ", DEVICE)

Use device:  cuda


In [2]:
from hdd.dataset.tiny_imagenet import TinyImagenet
import torchvision.transforms.v2

# 提前计算好了均值和方差
TRAIN_MEAN = [0.4802, 0.4481, 0.3975]
TRAIN_STD = [0.2302, 0.2265, 0.2262]


train_dataset_transforms = transforms.Compose(
    [
        torchvision.transforms.v2.RandomResize(66, 72),
        transforms.RandomRotation(3),
        transforms.RandomCrop(64),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=TRAIN_MEAN, std=TRAIN_STD),
    ]
)
train_dataset = TinyImagenet(
    root=DATA_ROOT, split="train", download=True, transform=train_dataset_transforms
)
val_dataset = TinyImagenet(
    root=DATA_ROOT,
    split="val",
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(TRAIN_MEAN, TRAIN_STD)]
    ),
)
print("Basic Info of train dataaset: \n", train_dataset)
print("Basic Info of test dataset: \n", val_dataset)

Basic Info of train dataaset: 
 Dataset TinyImagenet
    Number of datapoints: 100000
    Root location: /home/tf/workspace/hands-dirty-on-dl/dataset
    StandardTransform
Transform: Compose(
               RandomResize(min_size=66, max_size=72, interpolation=InterpolationMode.BILINEAR, antialias=True)
               RandomRotation(degrees=[-3.0, 3.0], interpolation=nearest, expand=False, fill=0)
               RandomCrop(size=(64, 64), padding=None)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
           )
Basic Info of test dataset: 
 Dataset TinyImagenet
    Number of datapoints: 10000
    Root location: /home/tf/workspace/hands-dirty-on-dl/dataset
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
           )


In [3]:
BATCH_SIZE = 64
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

## 测试比较不同的Resnet架构

In [4]:
from spacy import training
from hdd.models.cnn.resnet import ResnetSmall, resnet18_config
from hdd.train.early_stopping import EarlyStoppingInMem
from hdd.train.classification_utils import (
    naive_train_classification_model,
    eval_image_classifier,
)
from hdd.models.nn_utils import count_trainable_parameter


def train_net(
    resnet_config,
    train_dataloader,
    val_dataloader,
    dropout,
    lr,
    weight_decay,
    step_size=30,
    gamma=0.1,
    patience=40,
    max_epochs=120,
) -> tuple[ResnetSmall, dict[str, list[float]]]:
    net = ResnetSmall(resnet_config, num_classes=200, dropout=dropout).to(DEVICE)
    criteria = nn.CrossEntropyLoss()
    # SGD的收敛速度远不如Adam好
    # optimizer = torch.optim.SGD(
    #     net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay
    # )
    optimizer = optim.AdamW(
        net.parameters(), lr=lr, eps=1e-6, weight_decay=weight_decay
    )
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=step_size, gamma=gamma, last_epoch=-1
    )
    early_stopper = EarlyStoppingInMem(patience=patience, verbose=False)
    training_stats = naive_train_classification_model(
        net,
        criteria,
        max_epochs,
        train_dataloader,
        val_dataloader,
        DEVICE,
        optimizer,
        scheduler,
        early_stopper,
        verbose=True,
    )
    return net, training_stats


net, resnet18_stats = train_net(
    resnet18_config,
    train_dataloader,
    val_dataloader,
    dropout=0.5,
    lr=0.005,
    weight_decay=1e-2,
)

eval_result = eval_image_classifier(net, val_dataloader.dataset, DEVICE)
ss = [result.gt_label == result.predicted_label for result in eval_result]
print(f"#Parameter: {count_trainable_parameter(net)} Accuracy: {sum(ss) / len(ss)}")

Epoch: 1/120 Train Loss: 4.7006 Accuracy: 0.0528 Time: 63.13437  | Val Loss: 4.3960 Accuracy: 0.0854
Epoch: 2/120 Train Loss: 3.9700 Accuracy: 0.1361 Time: 62.87105  | Val Loss: 3.6732 Accuracy: 0.1883
Epoch: 3/120 Train Loss: 3.5430 Accuracy: 0.2017 Time: 62.97417  | Val Loss: 3.5985 Accuracy: 0.1993
Epoch: 4/120 Train Loss: 3.2475 Accuracy: 0.2553 Time: 62.57284  | Val Loss: 3.2399 Accuracy: 0.2616
Epoch: 5/120 Train Loss: 3.0202 Accuracy: 0.2968 Time: 62.38601  | Val Loss: 3.4862 Accuracy: 0.2313
Epoch: 6/120 Train Loss: 2.8270 Accuracy: 0.3346 Time: 60.49896  | Val Loss: 3.1019 Accuracy: 0.2943
Epoch: 7/120 Train Loss: 2.6774 Accuracy: 0.3674 Time: 60.54692  | Val Loss: 3.2807 Accuracy: 0.2776
Epoch: 8/120 Train Loss: 2.5473 Accuracy: 0.3917 Time: 60.56805  | Val Loss: 2.7279 Accuracy: 0.3597
Epoch: 9/120 Train Loss: 2.4343 Accuracy: 0.4171 Time: 60.57106  | Val Loss: 3.1570 Accuracy: 0.3223
Epoch: 10/120 Train Loss: 2.3302 Accuracy: 0.4373 Time: 60.54537  | Val Loss: 2.7059 Accura

In [5]:
from hdd.models.cnn.resnet import resnet34_config

net, resnet34_stats = train_net(
    resnet34_config,
    train_dataloader,
    val_dataloader,
    dropout=0.5,
    lr=0.05,
    weight_decay=1e-2,
)

eval_result = eval_image_classifier(net, val_dataloader.dataset, DEVICE)
ss = [result.gt_label == result.predicted_label for result in eval_result]
print(f"#Parameter: {count_trainable_parameter(net)} Accuracy: {sum(ss) / len(ss)}")

Epoch: 1/120 Train Loss: 5.2847 Accuracy: 0.0075 Time: 102.45276  | Val Loss: 5.3287 Accuracy: 0.0053
Epoch: 2/120 Train Loss: 5.2336 Accuracy: 0.0088 Time: 102.45402  | Val Loss: 5.2376 Accuracy: 0.0111
Epoch: 3/120 Train Loss: 5.2131 Accuracy: 0.0094 Time: 102.43996  | Val Loss: 5.1594 Accuracy: 0.0126
Epoch: 4/120 Train Loss: 5.1762 Accuracy: 0.0105 Time: 102.43808  | Val Loss: 5.1073 Accuracy: 0.0130
Epoch: 5/120 Train Loss: 5.1691 Accuracy: 0.0113 Time: 102.41869  | Val Loss: 5.1693 Accuracy: 0.0105
Epoch: 6/120 Train Loss: 5.1633 Accuracy: 0.0114 Time: 102.38357  | Val Loss: 5.6657 Accuracy: 0.0097
Epoch: 7/120 Train Loss: 5.1570 Accuracy: 0.0117 Time: 102.37989  | Val Loss: 6.8334 Accuracy: 0.0103
Epoch: 8/120 Train Loss: 5.1298 Accuracy: 0.0119 Time: 102.36425  | Val Loss: 5.0727 Accuracy: 0.0151
Epoch: 9/120 Train Loss: 5.0950 Accuracy: 0.0134 Time: 102.37211  | Val Loss: 5.3892 Accuracy: 0.0111
Epoch: 10/120 Train Loss: 5.0718 Accuracy: 0.0137 Time: 102.35102  | Val Loss: 5.9

In [6]:
import torchvision

In [7]:
import torchvision.transforms.v2


torchvision.transforms.v2.

SyntaxError: invalid syntax (4170558160.py, line 4)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure(figsize=(16,5))
images_per_class = 5
num_classes = 15

for target_class in range(num_classes):
    idxs = [idx  for idx, target  in enumerate(train_dataset.targets) if target == target_class]
    idxs = idxs[:images_per_class]
    for i, idx in enumerate(idxs):
        plt_idx = i * num_classes + target_class + 1
        plt.subplot(images_per_class, num_classes, plt_idx)
        image,_ = train_dataset[idx]
        plt.imshow(image)
        plt.axis("off")
        if i == 0:
            plt.title(f"{train_dataset.classes[target_class][0]}", fontsize = "x-small")
plt.show()