## ResNet training notebook

In [7]:
from src.args import ResNetExpArgs, TrainingArgs, to_exp_name
from src.data_loading import get_image_data_loader, transforms_image_net
from src.models import MLP, ResNet
from src.test import test_loop
from src.train import training_loop
from src.utils import accuracy, get_optimizer

In [8]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary

### Params

In [9]:
args = ResNetExpArgs(
    batch_size=128,
    dataset_name="cifar10",
    learning_rate=0.1,
    num_epochs=300,
    momentum=0.9,
    weight_decay=0.00001,
    cosine_lr=True,
    warmup_epochs=5,
    decay_interval=2,
    decay_gamma=0.1,
    mixup_alpha=0.5,
    lean_stem=True,
    smart_downsampling=True,
    use_gpu=True,
)

training_args = TrainingArgs(
    batch_size=args.batch_size,
    num_classes=10,
    num_epochs=args.num_epochs,
    learning_rate=args.learning_rate,
    cosine_lr=args.cosine_lr,
    warmup_epochs=args.warmup_epochs,
    decay_interval=args.decay_interval,
    decay_gamma=args.decay_gamma,
    mixup_alpha=args.mixup_alpha,
    print_every=1000,
    write_every=1000,
    plot_every=10000,
    check_every=10000,
)

dataset_to_n_classes = {
    "mnist": 10,
    "cifar10": 10,
    "fmnist": 10,
}

exp_name = to_exp_name(args)
device = "cuda" if args.use_gpu else "cpu"
if args.use_gpu:
    assert torch.cuda.is_available()
device, exp_name

('cuda', 'cifar10_128_0.1_300_0.9_1e-05_True_5_2_0.1_0.5_True_True_True')

### Data loading

In [10]:
transform = transforms_image_net(
    crop=True,
    crop_size=28,
    flip=True,
    colors=True,
    standardize=False,
    is_image=True,
)

train_data, eval_data = get_image_data_loader(
    args.dataset_name,
    train=True,
    val_share=0.1,
    shuffle=True,
    batch_size=args.batch_size,
    single_batch=False,
    transform=transform,
)

Files already downloaded and verified
Dataset lengths: train-45000, val-5000


### Model selection

In [12]:
# ImageNet version
# resnet_imagenet = Resnet(
# img_channels=3,
# n_classes=10,
# extra_blocks_per_layer=[1, 3, 5, 2,],
# resnet_channels=[64, 128, 256, 512],
# stem_channels=64,
# stem_downsample=True,
# )
# Cifar10 version
resnet_cifar = ResNet(
    img_channels=3,
    n_classes=10,
    extra_blocks_per_layer=[5, 5, 5],
    resnet_channels=[16, 32, 64],
    stem_channels=16,
    stem_conv_size=7 if not args.lean_stem else 3,
    stem_downsample=False,
    slender_stem=args.lean_stem,
    better_downsampling=args.smart_downsampling,
)
print(summary(resnet_cifar.to(device), (3, 32, 32)))
model = resnet_cifar
# baby_resnet = ResNet(
#     img_channels=1,
#     n_classes=10,
#     extra_blocks_per_layer=[1, 1],
#     resnet_channels=[16, 32,],
#     stem_channels=16,
#     stem_downsample=False,
#     slender_stem=True,
#     better_downsampling=True,
# )
# model = baby_resnet
model = model.to(device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             448
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 32, 32]           2,320
       BatchNorm2d-4           [-1, 16, 32, 32]              32
            Conv2d-5           [-1, 16, 32, 32]           2,320
       BatchNorm2d-6           [-1, 16, 32, 32]              32
         MaxPool2d-7           [-1, 16, 32, 32]               0
        ResnetStem-8           [-1, 16, 32, 32]               0
            Conv2d-9           [-1, 16, 32, 32]             272
      BatchNorm2d-10           [-1, 16, 32, 32]              32
           Conv2d-11           [-1, 16, 32, 32]           2,320
      BatchNorm2d-12           [-1, 16, 32, 32]              32
           Conv2d-13           [-1, 64, 32, 32]           1,088
      BatchNorm2d-14           [-1, 64,

### Training

In [None]:
no_decay, decay = model.get_params()
optimizer = get_optimizer(
    decay_params=decay,
    no_decay_params=no_decay,
    lr=args.learning_rate,
    momentum=args.momentum,
    weight_decay=args.weight_decay,
)

if args.mixup_alpha is not None:
    loss_fn = lambda inputs, targets: nn.KLDivLoss(reduction="batchmean")(
        nn.LogSoftmax(dim=1)(inputs), targets
    )
else:
    loss_fn = nn.CrossEntropyLoss(reduction="mean")
training_loop(
    exp_name,
    args=training_args,
    model=model,
    opt=optimizer,
    train_loader=train_data,
    eval_loader=eval_data,
    loss_fn=loss_fn,
    device=device,
    metric_fn=accuracy,
)

Training starts for cifar10_128_0.1_300_0.9_1e-05_True_5_2_0.1_0.5_True_True_True
Start of epoch 1


### Load from checkpoint

In [None]:
loaded = torch.load(
    os.path.join(os.path.join("data", "checkpoints"), f"{exp_name}-70000.pt"),
    map_location=device,
)
model.load_state_dict(loaded["model_state"])

### Load test data

In [None]:
test_loader, should_be_none = get_image_data_loader(
    args.dataset_name,
    train=False,
    val_share=0.1,
    shuffle=True,
    batch_size=args.batch_size,
    single_batch=False,
)
assert should_be_none is None

### Evaluate the loaded model

In [None]:
test_loop(
    test_loader=test_loader,
    model=model,
    device=device,
    metric_fn=accuracy,
    plot=True,
    loss_fn=F.cross_entropy,
)