## ResNet training notebook

In [1]:
from src.args import ResNetExpArgs, TrainingArgs, to_exp_name
from src.data_loading import get_image_data_loader, transforms_image_net
from src.models import MLP, ResNet
from src.test import test_loop
from src.train import training_loop
from src.utils import accuracy, get_optimizer

In [2]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary

### Params

In [3]:
args = ResNetExpArgs(
    batch_size=128,
    dataset_name="mnist",
    learning_rate=0.1,
    num_epochs=1,
    momentum=0.9,
    weight_decay=0.0001,
    cosine_lr=True,
    warmup_epochs=0,
    decay_interval=2,
    decay_gamma=0.1,
    mixup_alpha=0.5,
    lean_stem=True,
    smart_downsampling=True,
    use_gpu=False,
)

training_args = TrainingArgs(
    args.batch_size,
    10,  # num classes
    args.num_epochs,
    args.cosine_lr,
    args.warmup_epochs,
    args.decay_interval,
    args.decay_gamma,
    args.mixup_alpha,
    print_every=1,
    write_every=1,
    plot_every=10,
    check_every=0,
)

dataset_to_n_classes = {
    "mnist": 10,
    "cifar10": 10,
    "fmnist": 10,
}

exp_name = to_exp_name(args)
device = "cuda" if args.use_gpu else "cpu"
if args.use_gpu:
    assert torch.cuda.is_available()
device, exp_name

('cpu', 'mnist_128_0.1_1_0.9_0.0001_True_0_2_0.1_0.5_True_True_False')

### Data loading

In [4]:
transform = transforms_image_net(
    crop=True,
    crop_size=28,
    flip=True,
    colors=True,
    standardize=False,
    is_image=True,
)

train_data, eval_data = get_image_data_loader(
    args.dataset_name,
    train=True,
    val_share=0.1,
    shuffle=True,
    batch_size=args.batch_size,
    single_batch=False,
    transform=transform,
)

Dataset lengths: train-54000, val-6000


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


### Model selection

In [5]:
# ImageNet version
# resnet_imagenet = Resnet(
# img_channels=3,
# n_classes=10,
# extra_blocks_per_layer=[1, 3, 5, 2,],
# resnet_channels=[64, 128, 256, 512],
# stem_channels=64,
# stem_downsample=True,
# )
# Cifar10 version
# resnet_cifar = ResNet(
#     img_channels=3,
#     n_classes=10,
#     extra_blocks_per_layer=[5, 5, 5],
#     resnet_channels=[16, 32, 64],
#     stem_channels=16,
#     stem_conv_size=7 if not args.lean_stem else 3,
#     stem_downsample=False,
#     slender_stem=args.lean_stem,
#     better_downsampling=args.smart_downsampling,
# )
# print(summary(resnet_cifar.to(device), (3, 32, 32)))
baby_resnet = ResNet(
    img_channels=1,
    n_classes=10,
    extra_blocks_per_layer=[1, 1],
    resnet_channels=[16, 32,],
    stem_channels=16,
    stem_downsample=False,
    slender_stem=True,
    better_downsampling=True,
)
model = baby_resnet
model = model.to(device)

### Training

In [6]:
no_decay, decay = model.get_params()
optimizer = get_optimizer(
    decay_params=decay,
    no_decay_params=no_decay,
    lr=args.learning_rate,
    momentum=args.momentum,
    weight_decay=args.weight_decay,
)

if args.mixup_alpha is not None:
    loss_fn = lambda inputs, targets: nn.KLDivLoss(reduction="batchmean")(
        nn.LogSoftmax(dim=1)(inputs), targets
    )
else:
    loss_fn = nn.CrossEntropyLoss(reduction="mean")
training_loop(
    "mnist_mixup",
    args=training_args,
    model=model,
    opt=optimizer,
    train_loader=train_data,
    eval_loader=eval_data,
    loss_fn=loss_fn,
    device=device,
    metric_fn=accuracy,
)

Training starts for mnist_mixup
Post-warmup begins
Start of epoch 1
Step: 1 | Training Loss: 1.87148
Step: 1 | Training Metric: 0.52344
Step: 2 | Training Loss: 1.90643
Step: 2 | Training Metric: 0.54688
Step: 3 | Training Loss: 1.87435
Step: 3 | Training Metric: 0.54688
Step: 4 | Training Loss: 1.85099
Step: 4 | Training Metric: 0.60156
Step: 5 | Training Loss: 1.72477
Step: 5 | Training Metric: 0.63281
Step: 6 | Training Loss: 1.72915
Step: 6 | Training Metric: 0.67969
Step: 7 | Training Loss: 1.99270
Step: 7 | Training Metric: 0.79688
Step: 8 | Training Loss: 1.80596
Step: 8 | Training Metric: 0.59375
Step: 9 | Training Loss: 2.06831
Step: 9 | Training Metric: 0.74219
Step: 10 | Training Loss: 1.99760
Step: 10 | Training Metric: 0.75781
Step: 11 | Training Loss: 1.73777
Step: 11 | Training Metric: 0.75781
Step: 12 | Training Loss: 1.92045
Step: 12 | Training Metric: 0.75781
Step: 13 | Training Loss: 1.54518
Step: 13 | Training Metric: 0.75000
Step: 14 | Training Loss: 1.98052
Step: 

KeyboardInterrupt: 

### Load from checkpoint

In [None]:
loaded = torch.load(
    os.path.join(os.path.join("data", "checkpoints"), f"{exp_name}-70000.pt"),
    map_location=device,
)
model.load_state_dict(loaded["model_state"])

### Load test data

In [None]:
test_loader, should_be_none = get_image_data_loader(
    args.dataset_name,
    train=False,
    val_share=0.1,
    shuffle=True,
    batch_size=args.batch_size,
    single_batch=False,
)
assert should_be_none is None

### Evaluate the loaded model

In [None]:
test_loop(
    test_loader=test_loader,
    model=model,
    device=device,
    metric_fn=accuracy,
    plot=True,
    loss_fn=F.cross_entropy,
)