In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import os
import torchvision.models as tmodels
from termcolor import colored
from torch.utils.tensorboard import SummaryWriter
%matplotlib inline
torch.__version__

'1.7.1'

In [2]:
#data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,), (.5,))
])

BATCH_SIZE=32

trainset = torchvision.datasets.KMNIST(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.KMNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

In [3]:
from model import Net
lmbda = 100.
net = Net(reg_param=lmbda)

print(net)

def train(epochs):
    min_acc = -np.inf
    epoch_losses = []
    epoch_accs = []
    writer = SummaryWriter(comment=f'_kmnist_lambda_{lmbda}')
    for epoch in range(epochs):  # loop over the dataset multiple times

        epoch_loss = 0.0
        epoch_acc = 0.
        val_acc = -np.inf
        val_loss = np.inf
        total = 0
        pbar = tqdm(enumerate(trainloader, 0))
        for i, data in pbar:
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(net.device), data[1].to(net.device)
            inputs.requires_grad_(requires_grad=True)

            # zero the parameter gradients
            net.optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = net.loss_fn(outputs, labels, inputs, regularize_grads=True)
            loss.backward()
            net.optimizer.step()
            pred = torch.argmax(outputs, axis=1)
            running_acc = (pred == labels).float().sum()
            running_loss = loss.item()

            epoch_acc += running_acc
            epoch_loss += running_loss
            
            total += labels.size(0)

            # print statistics
            pbar.set_description(f'Epoch: {epoch}, Done {(i + 1)/len(trainloader) * 100}%, Loss: {running_loss}, Accuracy: {running_acc  / BATCH_SIZE * 100}%')
            pbar.update(BATCH_SIZE)
        pbar.close()

        val_acc, val_loss = test(testloader)
        
        epoch_loss /= len(trainloader)
        epoch_acc /= total
        epoch_acc *= 100
        
        net.scheduler.step(val_loss)
        
        epoch_losses.append(epoch_loss)
        epoch_accs.append(epoch_acc)
        
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_acc, epoch)
        writer.add_scalar('Loss/test', val_loss, epoch)
        writer.add_scalar('Accuracy/test', val_acc, epoch)

        #print(f"Avg loss: {colored(str(epoch_loss), 'green')}, Avg accuracy: {colored(str(epoch_acc.item()), 'red')}, Val loss: {val_loss}, Val accuracy: {val_acc}")

        if val_acc > min_acc:
            print(f'Improved val acc from {min_acc} to {val_acc}, saving model')
            min_acc = val_acc
            net.save(name=f'kmnist_lambda{net.loss_fn.hp}')
        else:
            print(f'Val acc did not improve from {min_acc}')

    print('Finished Training')
    writer.flush()
    return epoch_accs, epoch_losses

def test(dataloader):
    correct = 0.
    total = 0.
    loss = 0.
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(net.device), data[1].to(net.device)
            outputs = net(images)
            run_loss = net.loss_fn(outputs, labels, images, regularize_grads=False).item()
            loss += run_loss
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            run_acc = (predicted == labels).sum().item()
            correct += run_acc

    return correct / total * 100, loss / len(dataloader)
    

re_param: 100.0
Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=10, bias=True)
  (act): Softmax(dim=-1)
)


In [None]:
accs, losses = train(100)
acc, loss = test(testloader)
print(f'Test acc: {acc}%, test loss: {loss}')

Epoch: 0, Done 100.0%, Loss: 0.548210620880127, Accuracy: 93.75%: : 1875it [00:26, 69.68it/s]                  


Improved val acc from -inf to 85.99, saving model


Epoch: 1, Done 100.0%, Loss: 0.1493290811777115, Accuracy: 96.875%: : 1875it [00:27, 67.88it/s]                


Improved val acc from 85.99 to 88.25, saving model


Epoch: 2, Done 100.0%, Loss: 0.39601752161979675, Accuracy: 90.625%: : 1875it [00:28, 65.80it/s]               


Improved val acc from 88.25 to 90.03, saving model


Epoch: 3, Done 100.0%, Loss: 0.10974828898906708, Accuracy: 100.0%: : 1875it [00:27, 68.83it/s]                


Improved val acc from 90.03 to 91.16, saving model


Epoch: 4, Done 100.0%, Loss: 0.42359286546707153, Accuracy: 93.75%: : 1875it [00:26, 70.33it/s]                


Improved val acc from 91.16 to 91.92, saving model


Epoch: 5, Done 100.0%, Loss: 0.05182027816772461, Accuracy: 100.0%: : 1875it [00:27, 69.38it/s]                


Improved val acc from 91.92 to 92.69, saving model


Epoch: 6, Done 100.0%, Loss: 0.33457332849502563, Accuracy: 96.875%: : 1875it [00:27, 68.36it/s]               


Improved val acc from 92.69 to 93.38, saving model


Epoch: 7, Done 100.0%, Loss: 0.03191111981868744, Accuracy: 100.0%: : 1875it [00:27, 68.74it/s]                


Improved val acc from 93.38 to 93.57, saving model


Epoch: 8, Done 100.0%, Loss: 0.11894914507865906, Accuracy: 100.0%: : 1875it [00:27, 67.52it/s]                


Improved val acc from 93.57 to 94.19, saving model


Epoch: 9, Done 100.0%, Loss: 0.0524582639336586, Accuracy: 100.0%: : 1875it [00:26, 69.58it/s]                 


Val acc did not improve from 94.19


Epoch: 10, Done 100.0%, Loss: 0.17203441262245178, Accuracy: 96.875%: : 1875it [00:27, 68.35it/s]               


Val acc did not improve from 94.19


Epoch: 11, Done 100.0%, Loss: 0.04992722347378731, Accuracy: 100.0%: : 1875it [00:28, 66.57it/s]                


Improved val acc from 94.19 to 94.28999999999999, saving model


Epoch: 12, Done 100.0%, Loss: 0.14768120646476746, Accuracy: 96.875%: : 1875it [00:27, 68.50it/s]               


Val acc did not improve from 94.28999999999999


Epoch: 13, Done 100.0%, Loss: 0.4688013195991516, Accuracy: 93.75%: : 1875it [00:27, 68.79it/s]                 


Val acc did not improve from 94.28999999999999


Epoch: 14, Done 100.0%, Loss: 0.1132582426071167, Accuracy: 100.0%: : 1875it [00:27, 69.40it/s]                 


Val acc did not improve from 94.28999999999999


Epoch: 15, Done 100.0%, Loss: 0.21664747595787048, Accuracy: 96.875%: : 1875it [00:27, 68.86it/s]               


Val acc did not improve from 94.28999999999999


Epoch: 16, Done 100.0%, Loss: 0.10986211150884628, Accuracy: 100.0%: : 1875it [00:28, 66.94it/s]                


Improved val acc from 94.28999999999999 to 94.67999999999999, saving model


Epoch: 17, Done 100.0%, Loss: 0.022847823798656464, Accuracy: 100.0%: : 1875it [00:27, 69.04it/s]               


Val acc did not improve from 94.67999999999999


Epoch: 18, Done 100.0%, Loss: 0.10567933320999146, Accuracy: 100.0%: : 1875it [00:25, 74.10it/s]                


Improved val acc from 94.67999999999999 to 94.86, saving model


Epoch: 19, Done 100.0%, Loss: 0.2771374583244324, Accuracy: 96.875%: : 1875it [00:26, 71.92it/s]                


Val acc did not improve from 94.86


Epoch: 20, Done 100.0%, Loss: 0.7745958566665649, Accuracy: 96.875%: : 1875it [00:26, 70.96it/s]                


Val acc did not improve from 94.86


Epoch: 21, Done 100.0%, Loss: 0.21532246470451355, Accuracy: 96.875%: : 1875it [00:26, 71.42it/s]               


Val acc did not improve from 94.86


Epoch: 22, Done 100.0%, Loss: 0.03561604395508766, Accuracy: 100.0%: : 1875it [00:27, 67.69it/s]                


Val acc did not improve from 94.86


Epoch: 23, Done 100.0%, Loss: 0.15868420898914337, Accuracy: 96.875%: : 1875it [00:27, 67.39it/s]               


Val acc did not improve from 94.86


Epoch: 24, Done 100.0%, Loss: 0.45108938217163086, Accuracy: 93.75%: : 1875it [00:26, 71.91it/s]                


Improved val acc from 94.86 to 95.08, saving model


Epoch: 25, Done 100.0%, Loss: 0.15739774703979492, Accuracy: 100.0%: : 1875it [00:25, 72.22it/s]                


Val acc did not improve from 95.08


Epoch: 26, Done 100.0%, Loss: 0.043070584535598755, Accuracy: 100.0%: : 1875it [00:25, 73.78it/s]               


Val acc did not improve from 95.08


Epoch: 27, Done 100.0%, Loss: 0.16497516632080078, Accuracy: 100.0%: : 1875it [00:25, 73.70it/s]                


Val acc did not improve from 95.08


Epoch: 28, Done 100.0%, Loss: 0.028757117688655853, Accuracy: 100.0%: : 1875it [00:25, 72.85it/s]               


Val acc did not improve from 95.08


Epoch: 29, Done 100.0%, Loss: 0.1720074713230133, Accuracy: 96.875%: : 1875it [00:25, 73.75it/s]                


Improved val acc from 95.08 to 95.28, saving model


Epoch: 30, Done 100.0%, Loss: 0.041887640953063965, Accuracy: 100.0%: : 1875it [00:25, 73.36it/s]               


Val acc did not improve from 95.28


Epoch: 31, Done 100.0%, Loss: 0.3341827094554901, Accuracy: 96.875%: : 1875it [00:26, 70.30it/s]                


Val acc did not improve from 95.28


Epoch: 32, Done 100.0%, Loss: 0.4016231298446655, Accuracy: 96.875%: : 1875it [00:26, 71.43it/s]                


Improved val acc from 95.28 to 95.53, saving model


Epoch: 33, Done 100.0%, Loss: 0.2292618453502655, Accuracy: 96.875%: : 1875it [00:25, 74.96it/s]                


Val acc did not improve from 95.53


Epoch: 34, Done 100.0%, Loss: 0.14417888224124908, Accuracy: 96.875%: : 1875it [00:25, 73.38it/s]                


Val acc did not improve from 95.53


Epoch: 35, Done 100.0%, Loss: 0.03730227053165436, Accuracy: 100.0%: : 1875it [00:25, 73.44it/s]                


Val acc did not improve from 95.53


Epoch: 36, Done 100.0%, Loss: 0.1804284155368805, Accuracy: 100.0%: : 1875it [00:25, 72.73it/s]                 


Val acc did not improve from 95.53


Epoch: 37, Done 100.0%, Loss: 0.06435190141201019, Accuracy: 100.0%: : 1875it [00:24, 76.59it/s]                


Val acc did not improve from 95.53


Epoch: 38, Done 100.0%, Loss: 0.4678618609905243, Accuracy: 96.875%: : 1875it [00:25, 73.71it/s]                


Val acc did not improve from 95.53


Epoch: 39, Done 100.0%, Loss: 0.23847696185112, Accuracy: 96.875%: : 1875it [00:25, 74.75it/s]                  


Val acc did not improve from 95.53


Epoch: 40, Done 100.0%, Loss: 0.03789467364549637, Accuracy: 100.0%: : 1875it [00:25, 73.56it/s]                


Val acc did not improve from 95.53


Epoch: 41, Done 100.0%, Loss: 0.22776785492897034, Accuracy: 100.0%: : 1875it [00:25, 73.62it/s]                


Val acc did not improve from 95.53


Epoch: 42, Done 100.0%, Loss: 0.13958171010017395, Accuracy: 100.0%: : 1875it [00:25, 73.63it/s]                 


Val acc did not improve from 95.53


Epoch: 43, Done 100.0%, Loss: 0.5995777249336243, Accuracy: 90.625%: : 1875it [00:25, 74.31it/s]                


Val acc did not improve from 95.53


Epoch: 44, Done 100.0%, Loss: 0.15091736614704132, Accuracy: 100.0%: : 1875it [00:25, 73.37it/s]                


Improved val acc from 95.53 to 96.14, saving model


Epoch: 45, Done 100.0%, Loss: 0.03728713095188141, Accuracy: 100.0%: : 1875it [00:25, 73.18it/s]                 


Improved val acc from 96.14 to 96.19, saving model


Epoch: 46, Done 100.0%, Loss: 0.16799254715442657, Accuracy: 100.0%: : 1875it [00:24, 75.67it/s]                


Improved val acc from 96.19 to 96.34, saving model


Epoch: 47, Done 100.0%, Loss: 0.11970114707946777, Accuracy: 100.0%: : 1875it [00:24, 75.51it/s]                


Val acc did not improve from 96.34


Epoch: 48, Done 100.0%, Loss: 0.014541162177920341, Accuracy: 100.0%: : 1875it [00:25, 74.16it/s]                


Improved val acc from 96.34 to 96.41999999999999, saving model


Epoch: 49, Done 100.0%, Loss: 0.029268622398376465, Accuracy: 100.0%: : 1875it [00:25, 73.36it/s]                


Val acc did not improve from 96.41999999999999


Epoch: 50, Done 100.0%, Loss: 0.23249542713165283, Accuracy: 96.875%: : 1875it [00:26, 71.78it/s]                


Val acc did not improve from 96.41999999999999


Epoch: 51, Done 100.0%, Loss: 0.3086022734642029, Accuracy: 96.875%: : 1875it [00:25, 73.71it/s]                 


Val acc did not improve from 96.41999999999999


Epoch: 52, Done 100.0%, Loss: 0.04623791575431824, Accuracy: 100.0%: : 1875it [00:25, 73.73it/s]                 


Val acc did not improve from 96.41999999999999


Epoch: 53, Done 100.0%, Loss: 0.041052334010601044, Accuracy: 100.0%: : 1875it [00:25, 72.78it/s]                


Val acc did not improve from 96.41999999999999


Epoch: 54, Done 100.0%, Loss: 0.04241614043712616, Accuracy: 100.0%: : 1875it [00:26, 71.35it/s]                


Val acc did not improve from 96.41999999999999


Epoch: 55, Done 100.0%, Loss: 0.2792566418647766, Accuracy: 100.0%: : 1875it [00:25, 73.28it/s]                  


Val acc did not improve from 96.41999999999999


Epoch: 56, Done 100.0%, Loss: 0.017643947154283524, Accuracy: 100.0%: : 1875it [00:25, 72.52it/s]                


Val acc did not improve from 96.41999999999999


Epoch: 57, Done 100.0%, Loss: 0.1790192872285843, Accuracy: 96.875%: : 1875it [00:25, 73.29it/s]                 


Val acc did not improve from 96.41999999999999


Epoch: 58, Done 100.0%, Loss: 0.00627024844288826, Accuracy: 100.0%: : 1875it [00:25, 72.86it/s]                 


Val acc did not improve from 96.41999999999999


Epoch: 59, Done 100.0%, Loss: 0.027344048023223877, Accuracy: 100.0%: : 1875it [00:25, 74.45it/s]                


Val acc did not improve from 96.41999999999999


Epoch: 60, Done 100.0%, Loss: 0.014766641892492771, Accuracy: 100.0%: : 1875it [00:25, 73.95it/s]                


Improved val acc from 96.41999999999999 to 96.44, saving model


Epoch: 61, Done 100.0%, Loss: 0.007684311829507351, Accuracy: 100.0%: : 1875it [00:24, 75.41it/s]                


Val acc did not improve from 96.44


Epoch: 62, Done 100.0%, Loss: 0.04735754430294037, Accuracy: 100.0%: : 1875it [00:26, 70.94it/s]                 


Val acc did not improve from 96.44


Epoch: 63, Done 100.0%, Loss: 0.03592117875814438, Accuracy: 100.0%: : 1875it [00:24, 75.60it/s]                 


Val acc did not improve from 96.44


Epoch: 64, Done 100.0%, Loss: 0.015884799882769585, Accuracy: 100.0%: : 1875it [00:25, 72.14it/s]                


Val acc did not improve from 96.44


Epoch: 65, Done 100.0%, Loss: 0.0958961471915245, Accuracy: 100.0%: : 1875it [00:25, 74.20it/s]                  


Val acc did not improve from 96.44


Epoch: 66, Done 100.0%, Loss: 0.010744275525212288, Accuracy: 100.0%: : 1875it [00:25, 74.47it/s]                


Val acc did not improve from 96.44


Epoch: 67, Done 100.0%, Loss: 0.004159553907811642, Accuracy: 100.0%: : 1875it [00:25, 73.68it/s]                


Val acc did not improve from 96.44


Epoch: 68, Done 100.0%, Loss: 0.023417770862579346, Accuracy: 100.0%: : 1875it [00:26, 69.65it/s]                


Improved val acc from 96.44 to 96.5, saving model


Epoch: 69, Done 100.0%, Loss: 0.008520716801285744, Accuracy: 100.0%: : 1875it [00:25, 73.41it/s]                


Val acc did not improve from 96.5


Epoch: 70, Done 100.0%, Loss: 0.08505982160568237, Accuracy: 100.0%: : 1875it [00:25, 74.34it/s]                 


Val acc did not improve from 96.5


Epoch: 71, Done 100.0%, Loss: 0.03035121038556099, Accuracy: 100.0%: : 1875it [00:26, 71.77it/s]                 


Val acc did not improve from 96.5


Epoch: 72, Done 100.0%, Loss: 0.15197530388832092, Accuracy: 100.0%: : 1875it [00:26, 71.43it/s]                 


Val acc did not improve from 96.5


Epoch: 73, Done 100.0%, Loss: 0.0030148569494485855, Accuracy: 100.0%: : 1875it [00:26, 71.45it/s]               


Val acc did not improve from 96.5


Epoch: 74, Done 100.0%, Loss: 0.012720312923192978, Accuracy: 100.0%: : 1875it [00:28, 66.70it/s]                


Val acc did not improve from 96.5


Epoch: 75, Done 100.0%, Loss: 0.05142361670732498, Accuracy: 100.0%: : 1875it [00:28, 66.63it/s]                 


Val acc did not improve from 96.5


Epoch: 76, Done 100.0%, Loss: 0.056902721524238586, Accuracy: 100.0%: : 1875it [00:27, 68.10it/s]                


Val acc did not improve from 96.5


Epoch: 77, Done 100.0%, Loss: 0.02758907899260521, Accuracy: 100.0%: : 1875it [00:27, 68.60it/s]                 


Val acc did not improve from 96.5


Epoch: 78, Done 100.0%, Loss: 0.028505492955446243, Accuracy: 100.0%: : 1875it [00:25, 72.24it/s]                


Val acc did not improve from 96.5


Epoch: 79, Done 100.0%, Loss: 0.011104840785264969, Accuracy: 100.0%: : 1875it [00:25, 74.52it/s]                


Val acc did not improve from 96.5


Epoch: 80, Done 100.0%, Loss: 0.01098649576306343, Accuracy: 100.0%: : 1875it [00:25, 74.20it/s]                 


Val acc did not improve from 96.5


Epoch: 81, Done 100.0%, Loss: 0.011050987057387829, Accuracy: 100.0%: : 1875it [00:26, 69.87it/s]                


Val acc did not improve from 96.5


Epoch: 82, Done 100.0%, Loss: 0.2390151023864746, Accuracy: 96.875%: : 1875it [00:25, 74.37it/s]                 


Val acc did not improve from 96.5


Epoch: 83, Done 100.0%, Loss: 0.023331299424171448, Accuracy: 100.0%: : 1875it [00:25, 73.52it/s]                


Val acc did not improve from 96.5


Epoch: 84, Done 100.0%, Loss: 0.06000722199678421, Accuracy: 100.0%: : 1875it [00:25, 72.69it/s]                 


Val acc did not improve from 96.5


Epoch: 85, Done 100.0%, Loss: 0.0814594030380249, Accuracy: 100.0%: : 1875it [00:26, 69.93it/s]                  


Val acc did not improve from 96.5


Epoch: 86, Done 100.0%, Loss: 0.03459516912698746, Accuracy: 100.0%: : 1875it [00:25, 73.06it/s]                 


Val acc did not improve from 96.5


Epoch: 87, Done 100.0%, Loss: 0.004360952414572239, Accuracy: 100.0%: : 1875it [00:25, 73.39it/s]                


Val acc did not improve from 96.5


Epoch: 88, Done 100.0%, Loss: 0.009726028889417648, Accuracy: 100.0%: : 1875it [00:25, 73.15it/s]                


Val acc did not improve from 96.5


Epoch: 89, Done 100.0%, Loss: 0.0027813464403152466, Accuracy: 100.0%: : 1875it [00:25, 73.59it/s]               


Val acc did not improve from 96.5


Epoch: 90, Done 100.0%, Loss: 0.045122724026441574, Accuracy: 100.0%: : 1875it [00:27, 68.26it/s]                


Val acc did not improve from 96.5


Epoch: 91, Done 100.0%, Loss: 0.06363168358802795, Accuracy: 100.0%: : 1875it [00:25, 73.32it/s]                 


Val acc did not improve from 96.5


Epoch: 92, Done 100.0%, Loss: 0.0040525770746171474, Accuracy: 100.0%: : 1875it [00:25, 74.71it/s]               


Val acc did not improve from 96.5


Epoch: 93, Done 100.0%, Loss: 0.017768487334251404, Accuracy: 100.0%: : 1875it [00:25, 73.73it/s]                


Val acc did not improve from 96.5


Epoch: 94, Done 100.0%, Loss: 0.004585486836731434, Accuracy: 100.0%: : 1875it [00:25, 73.32it/s]                


Val acc did not improve from 96.5


Epoch: 95, Done 100.0%, Loss: 0.014138059690594673, Accuracy: 100.0%: : 1875it [00:26, 71.46it/s]                


Val acc did not improve from 96.5


Epoch: 96, Done 100.0%, Loss: 0.0071483189240098, Accuracy: 100.0%: : 1875it [00:25, 72.62it/s]                  


Val acc did not improve from 96.5


Epoch: 97, Done 100.0%, Loss: 0.06560099124908447, Accuracy: 100.0%: : 1875it [00:25, 73.47it/s]                 


Val acc did not improve from 96.5


Epoch: 98, Done 100.0%, Loss: 0.07437742501497269, Accuracy: 100.0%: : 1875it [00:25, 74.45it/s]                 


Val acc did not improve from 96.5


Epoch: 99, Done 45.22666666666667%, Loss: 0.008779189549386501, Accuracy: 100.0%: : 5342it [00:11, 2010.01it/s]  

In [None]:
net.save(name=f'kmnist_lambda{net.loss_fn.hp}')