In [None]:
import torch
from torchvision import transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm.notebook import tqdm

from resnet import ResNet34

mean_std = {
    'svhn':((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    'cifar10':((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    'cifar100':((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
}

train_dataset = 'cifar10_train'
test_dataset = 'cifar10'


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*mean_std[test_dataset]),
])
transform_train = transforms.Compose([
    transforms.RandomCrop((32,32), padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*mean_std[test_dataset]),
])

dataset_classes = {
    'svhn':torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test),
    'svhn_train':torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train),
    'cifar10':torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test),
    'cifar10_train':torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train),
    'cifar100':torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test),
    'cifar100_train':torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
}

trainset = dataset_classes[train_dataset]
testset = dataset_classes[test_dataset]

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,shuffle=False,  num_workers=2)

In [None]:
num_classes = {
    'svhn':10,
    'cifar10':10,
    'cifar100':100
}

model = ResNet34(num_classes[test_dataset]).cuda()

# Hyperparamaters are chosen in accordance with the paper
net = model
n_epochs = 200
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

def eval_acc(model):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].cuda(), data[1].cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    model.train()
    print(f'Validation accuracy: {100 * correct / total:.2f}')
    return(correct / total)

def adjust_lr(optimizer, epoch):
    """
    Adjusting the learning rate during training, in accordance with the paper
    """
    if epoch > n_epochs*0.50:
        lr = 0.01
        if epoch > n_epochs*0.75:
            lr = 0.001
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

last_best = 0.0
acc_thresholds = [0,5,10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90] # Validation accuracy thresholds
                                                                          # to use for saving model snapshots
last_thres = 0
loss_t = [] # Training loss time evolution

for epoch in range(n_epochs):

    adjust_lr(optimizer, epoch)
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        
        inputs, labels = data[0].cuda(), data[1].cuda()

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 50 == 49:    # print every 50 mini-batches
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 50:.4f}')
            loss_t.append(running_loss)
            running_loss = 0.0
            
    accuracy = eval_acc(model)
    if accuracy > last_best:
        last_best = accuracy
        if accuracy > 0.01*acc_thresholds[last_thres]:
            while acc_thresholds[last_thres] <= last_best:
                last_thres += 1
            torch.save(model, f'./models/resnet34/{test_dataset}_{accuracy*10000:.0f}')
            
torch.save(model, f'./models/resnet34/{test_dataset}_{accuracy*10000:.0f}_final')
print('Finished Training')