In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import os
import torchvision.models as tmodels
from termcolor import colored
from torch.utils.tensorboard import SummaryWriter
%matplotlib inline
torch.__version__

'1.7.1'

In [2]:
#data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

BATCH_SIZE=32

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
from model import Net
net = Net(reg_param=0.)

print(net)

def train(epochs):
    min_acc = -np.inf
    epoch_losses = []
    epoch_accs = []
    writer = SummaryWriter(comment='_lambda_0')
    for epoch in range(epochs):  # loop over the dataset multiple times

        epoch_loss = 0.0
        epoch_acc = 0.
        val_acc = -np.inf
        val_loss = np.inf
        total = 0
        pbar = tqdm(enumerate(trainloader, 0))
        for i, data in pbar:
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(net.device), data[1].to(net.device)
            inputs.requires_grad_(requires_grad=True)

            # zero the parameter gradients
            net.optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = net.loss_fn(outputs, labels, inputs, regularize_grads=False)
            loss.backward()
            net.optimizer.step()
            pred = torch.argmax(outputs, axis=1)
            running_acc = (pred == labels).float().sum()
            running_loss = loss.item()

            epoch_acc += running_acc
            epoch_loss += running_loss
            
            total += labels.size(0)

            # print statistics
            pbar.set_description(f'Epoch: {epoch}, Done {(i + 1)/len(trainloader) * 100}%, Loss: {running_loss}, Accuracy: {running_acc  / BATCH_SIZE * 100}%')
            pbar.update(BATCH_SIZE)
        pbar.close()

        val_acc, val_loss = test(testloader)
        
        epoch_loss /= len(trainloader)
        epoch_acc /= total
        epoch_acc *= 100
        
        net.scheduler.step(val_loss)
        
        epoch_losses.append(epoch_loss)
        epoch_accs.append(epoch_acc)
        
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_acc, epoch)
        writer.add_scalar('Loss/test', val_loss, epoch)
        writer.add_scalar('Accuracy/test', val_acc, epoch)

        print(f"Avg loss: {colored(str(epoch_loss), 'green')}, Avg accuracy: {colored(str(epoch_acc.item()), 'red')}, Val loss: {val_loss}, Val accuracy: {val_acc}")

        if val_acc > min_acc:
            print(f'Improved val acc from {min_acc} to {val_acc}, saving model')
            min_acc = val_acc
            net.save(name=f'cifar_lambda{net.loss_fn.hp}')
        else:
            print(f'Val acc did not improve from {min_acc}')

    print('Finished Training')
    writer.flush()
    return epoch_accs, epoch_losses

def test(dataloader):
    correct = 0.
    total = 0.
    loss = 0.
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(net.device), data[1].to(net.device)
            outputs = net(images)
            run_loss = net.loss_fn(outputs, labels, images, regularize_grads=False).item()
            loss += run_loss
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            run_acc = (predicted == labels).sum().item()
            correct += run_acc

    return correct / total * 100, loss / len(dataloader)
    

re_param: 0.0
Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=10, bias=True)
  (act): Softmax(dim=-1)
)


In [None]:
accs, losses = train(150)
acc, loss = test(testloader)
print(f'Test acc: {acc}%, test loss: {loss}')

Epoch: 0, Done 100.0%, Loss: 0.9258631467819214, Accuracy: 34.375%: : 1563it [00:20, 76.55it/s]               


Avg loss: [32m1.4373399600796568[0m, Avg accuracy: [31m48.38999938964844[0m, Val loss: 1.2374058491505755, Val accuracy: 55.58
Improved val acc from -inf to 55.58, saving model


Epoch: 1, Done 100.0%, Loss: 1.1482802629470825, Accuracy: 34.375%: : 1563it [00:19, 78.60it/s]               


Avg loss: [32m1.1827745335802236[0m, Avg accuracy: [31m57.90999984741211[0m, Val loss: 1.1485459067570134, Val accuracy: 59.519999999999996
Improved val acc from 55.58 to 59.519999999999996, saving model


Epoch: 2, Done 100.0%, Loss: 0.7744807600975037, Accuracy: 40.625%: : 1563it [00:20, 75.96it/s]               


Avg loss: [32m1.0877392071031715[0m, Avg accuracy: [31m61.61199951171875[0m, Val loss: 1.0918042794964946, Val accuracy: 61.160000000000004
Improved val acc from 59.519999999999996 to 61.160000000000004, saving model


Epoch: 3, Done 100.0%, Loss: 1.194390892982483, Accuracy: 31.25%: : 1563it [00:18, 83.64it/s]                 


Avg loss: [32m1.0292355630615928[0m, Avg accuracy: [31m63.74799728393555[0m, Val loss: 1.0941345449834585, Val accuracy: 61.19
Improved val acc from 61.160000000000004 to 61.19, saving model


Epoch: 4, Done 100.0%, Loss: 1.0768628120422363, Accuracy: 34.375%: : 1563it [00:19, 79.80it/s]                


Avg loss: [32m0.9792178829968624[0m, Avg accuracy: [31m65.6659927368164[0m, Val loss: 1.050448401096149, Val accuracy: 63.41
Improved val acc from 61.19 to 63.41, saving model


Epoch: 5, Done 100.0%, Loss: 0.7844680547714233, Accuracy: 34.375%: : 1563it [00:18, 82.48it/s]               


Avg loss: [32m0.9463099978211135[0m, Avg accuracy: [31m66.94200134277344[0m, Val loss: 1.0264987252390803, Val accuracy: 63.88
Improved val acc from 63.41 to 63.88, saving model


Epoch: 6, Done 100.0%, Loss: 1.1219488382339478, Accuracy: 31.25%: : 1563it [00:20, 77.37it/s]                 


Avg loss: [32m0.9129055019799365[0m, Avg accuracy: [31m68.18999481201172[0m, Val loss: 1.030639633203086, Val accuracy: 63.85999999999999
Val acc did not improve from 63.88


Epoch: 7, Done 100.0%, Loss: 1.2511104345321655, Accuracy: 28.125%: : 1563it [00:19, 81.02it/s]                


Avg loss: [32m0.8850360805989838[0m, Avg accuracy: [31m69.0[0m, Val loss: 0.9883254170417786, Val accuracy: 65.68
Improved val acc from 63.88 to 65.68, saving model


Epoch: 8, Done 100.0%, Loss: 0.9473572969436646, Accuracy: 28.125%: : 1563it [00:19, 79.22it/s]                


Avg loss: [32m0.8592227272932451[0m, Avg accuracy: [31m69.96399688720703[0m, Val loss: 1.0062914869655817, Val accuracy: 65.07
Val acc did not improve from 65.68


Epoch: 9, Done 100.0%, Loss: 0.8846527934074402, Accuracy: 31.25%: : 1563it [00:19, 81.15it/s]                 


Avg loss: [32m0.8359925532790994[0m, Avg accuracy: [31m70.85199737548828[0m, Val loss: 1.0046435160377918, Val accuracy: 65.61
Val acc did not improve from 65.68


Epoch: 10, Done 10.300703774792067%, Loss: 0.5823954343795776, Accuracy: 81.25%: : 5120it [00:02, 2575.01it/s] 