In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import os
import torchvision.models as tmodels
from termcolor import colored
from torch.utils.tensorboard import SummaryWriter
from custom_transforms import QuantizeBatch
%matplotlib inline
torch.__version__

'1.7.1'

In [2]:
#data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,), (.5,))
])

BATCH_SIZE=32

trainset = torchvision.datasets.KMNIST(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.KMNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

In [3]:
from model import Net
lmbda = 100.
net = Net(reg_param=lmbda, device='cuda:2')

net.load(name=f'kmnist_lambda100.0')

print(net)

def train(epochs, defense=True):
    min_acc = -np.inf
    epoch_losses = []
    epoch_accs = []
    writer = SummaryWriter(comment=f'_q_kmnist_lambda_{lmbda}')
    for epoch in range(epochs):  # loop over the dataset multiple times

        epoch_loss = 0.0
        epoch_acc = 0.
        val_acc = -np.inf
        val_loss = np.inf
        total = 0
        pbar = tqdm(enumerate(trainloader, 0))
        for i, data in pbar:
            # get the inputs; data is a list of [inputs, labels]
            if defense:
                data[0] = QuantizeBatch(data[0])
            
            inputs, labels = data[0].to(net.device), data[1].to(net.device)
            inputs.requires_grad_(requires_grad=True)

            # zero the parameter gradients
            net.optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = net.loss_fn(outputs, labels, inputs, regularize_grads=True)
            loss.backward()
            net.optimizer.step()
            pred = torch.argmax(outputs, axis=1)
            running_acc = (pred == labels).float().sum()
            running_loss = loss.item()

            epoch_acc += running_acc
            epoch_loss += running_loss
            
            total += labels.size(0)

            # print statistics
            pbar.set_description(f'Epoch: {epoch}, Done {(i + 1)/len(trainloader) * 100}%, Loss: {running_loss}, Accuracy: {running_acc  / BATCH_SIZE * 100}%')
            pbar.update(BATCH_SIZE)
        pbar.close()

        val_acc, val_loss = test(testloader)
        
        epoch_loss /= len(trainloader)
        epoch_acc /= total
        epoch_acc *= 100
        
        net.scheduler.step(val_loss)
        
        epoch_losses.append(epoch_loss)
        epoch_accs.append(epoch_acc)
        
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_acc, epoch)
        writer.add_scalar('Loss/test', val_loss, epoch)
        writer.add_scalar('Accuracy/test', val_acc, epoch)

        #print(f"Avg loss: {colored(str(epoch_loss), 'green')}, Avg accuracy: {colored(str(epoch_acc.item()), 'red')}, Val loss: {val_loss}, Val accuracy: {val_acc}")

        if val_acc > min_acc:
            #print(f'Improved val acc from {min_acc} to {val_acc}, saving model')
            min_acc = val_acc
            #net.save(name=f'kmnist_lambda{net.loss_fn.hp}')
        else:
            print(f'Val acc did not improve from {min_acc}')

    print('Finished Training')
    writer.flush()
    return epoch_accs, epoch_losses

def test(dataloader):
    correct = 0.
    total = 0.
    loss = 0.
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(net.device), data[1].to(net.device)
            outputs = net(images)
            run_loss = net.loss_fn(outputs, labels, images, regularize_grads=False).item()
            loss += run_loss
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            run_acc = (predicted == labels).sum().item()
            correct += run_acc

    return correct / total * 100, loss / len(dataloader)
    

re_param: 100.0
Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=10, bias=True)
  (act): Softmax(dim=-1)
)


In [4]:
accs, losses = train(10)
acc, loss = test(testloader)
print(f'Test acc: {acc}%, test loss: {loss}')

Epoch: 0, Done 100.0%, Loss: 0.04418201372027397, Accuracy: 100.0%: : 1875it [06:37,  4.72it/s]               
Epoch: 1, Done 100.0%, Loss: 0.06483405083417892, Accuracy: 100.0%: : 1875it [06:39,  4.69it/s]               


Val acc did not improve from 95.03


Epoch: 2, Done 100.0%, Loss: 0.14885002374649048, Accuracy: 96.875%: : 1875it [06:55,  4.51it/s]              
Epoch: 3, Done 100.0%, Loss: 0.20340050756931305, Accuracy: 96.875%: : 1875it [06:40,  4.68it/s]              
Epoch: 4, Done 100.0%, Loss: 0.3872072100639343, Accuracy: 93.75%: : 1875it [06:37,  4.71it/s]                


Val acc did not improve from 95.65


Epoch: 5, Done 100.0%, Loss: 0.07207611203193665, Accuracy: 100.0%: : 1875it [07:00,  4.46it/s]               


Val acc did not improve from 95.65


Epoch: 6, Done 100.0%, Loss: 0.36319953203201294, Accuracy: 96.875%: : 1875it [06:34,  4.76it/s]              


Val acc did not improve from 95.65


Epoch: 7, Done 100.0%, Loss: 0.080357126891613, Accuracy: 100.0%: : 1875it [06:46,  4.62it/s]                 


Val acc did not improve from 95.65


Epoch: 8, Done 100.0%, Loss: 0.09410664439201355, Accuracy: 100.0%: : 1875it [07:27,  4.19it/s]               


Val acc did not improve from 95.65


Epoch: 9, Done 100.0%, Loss: 0.20304983854293823, Accuracy: 100.0%: : 1875it [06:53,  4.54it/s]               


Val acc did not improve from 95.65
Finished Training
Test acc: 94.69999999999999%, test loss: 0.19378279700589637


In [6]:
net.save(name=f'q_kmnist_lambda{net.loss_fn.hp}')