In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import os
import torchvision.models as tmodels
from termcolor import colored
from torch.utils.tensorboard import SummaryWriter
%matplotlib inline
torch.__version__

In [None]:
#data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

BATCH_SIZE=32

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

In [None]:
from model import Net
net = Net(reg_param=3.)

print(net)

def train(epochs):
    min_acc = -np.inf
    epoch_losses = []
    epoch_accs = []
    writer = SummaryWriter(comment='_cifar_lambda_3')
    for epoch in range(epochs):  # loop over the dataset multiple times

        epoch_loss = 0.0
        epoch_acc = 0.
        val_acc = -np.inf
        val_loss = np.inf
        total = 0
        pbar = tqdm(enumerate(trainloader, 0))
        for i, data in pbar:
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(net.device), data[1].to(net.device)
            inputs.requires_grad_(requires_grad=True)

            # zero the parameter gradients
            net.optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = net.loss_fn(outputs, labels, inputs, regularize_grads=True)
            loss.backward()
            net.optimizer.step()
            pred = torch.argmax(outputs, axis=1)
            running_acc = (pred == labels).float().sum()
            running_loss = loss.item()

            epoch_acc += running_acc
            epoch_loss += running_loss
            
            total += labels.size(0)

            # print statistics
            pbar.set_description(f'Epoch: {epoch}, Done {(i + 1)/len(trainloader) * 100}%, Loss: {running_loss}, Accuracy: {running_acc  / BATCH_SIZE * 100}%')
            pbar.update(BATCH_SIZE)
        pbar.close()

        val_acc, val_loss = test(testloader)
        
        epoch_loss /= len(trainloader)
        epoch_acc /= total
        epoch_acc *= 100
        
        net.scheduler.step(val_loss)
        
        epoch_losses.append(epoch_loss)
        epoch_accs.append(epoch_acc)
        
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_acc, epoch)
        writer.add_scalar('Loss/test', val_loss, epoch)
        writer.add_scalar('Accuracy/test', val_acc, epoch)

        #print(f"Avg loss: {colored(str(epoch_loss), 'green')}, Avg accuracy: {colored(str(epoch_acc.item()), 'red')}, Val loss: {val_loss}, Val accuracy: {val_acc}")

        if val_acc > min_acc:
            print(f'Improved val acc from {min_acc} to {val_acc}, saving model')
            min_acc = val_acc
            net.save(name=f'cifar_lambda{net.loss_fn.hp}')
        else:
            print(f'Val acc did not improve from {min_acc}')

    print('Finished Training')
    writer.flush()
    return epoch_accs, epoch_losses

def test(dataloader):
    correct = 0.
    total = 0.
    loss = 0.
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(net.device), data[1].to(net.device)
            outputs = net(images)
            run_loss = net.loss_fn(outputs, labels, images, regularize_grads=False).item()
            loss += run_loss
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            run_acc = (predicted == labels).sum().item()
            correct += run_acc

    return correct / total * 100, loss / len(dataloader)
    

In [None]:
accs, losses = train(150)
acc, loss = test(testloader)
print(f'Test acc: {acc}%, test loss: {loss}')

In [None]:
net.save(name=f'cifar_lambda{net.loss_fn.hp}')