# Task2 - Classifier
This jupyter notebook contains the code for training a CNN based classifier on CIFAR10.

In [7]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import transforms as trafos

from einops import rearrange

import time

import matplotlib.pyplot as plt
%matplotlib inline

In [8]:
#dataset
batch_size=64
trafo_train = trafos.Compose([trafos.RandomHorizontalFlip(),
                        trafos.ToTensor(), trafos.Resize(32), trafos.Normalize((0.5), (0.5))])
trafo_test = trafos.Compose([trafos.ToTensor(), trafos.Resize(32), trafos.Normalize((0.5), (0.5))])

trainset = torchvision.datasets.FashionMNIST(root="datasets", train=True, download=True, transform=trafo_train)
testset = torchvision.datasets.FashionMNIST(root="datasets", train=False, download=True, transform=trafo_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=16, drop_last=False, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=16, drop_last=False, shuffle=False)

In [9]:
#Simple CNN for classification
class Classifier(nn.Module):
    def __init__(self, nf=64, num_classes=10):
        super().__init__()

        block = [nn.Conv2d(1, nf, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(nf), nn.LeakyReLU(0.2, True),
                nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(nf*2), nn.LeakyReLU(0.2, True),
                nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(nf*4), nn.LeakyReLU(0.2, True),
                nn.Conv2d(nf*4, nf*4, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(nf*4), nn.LeakyReLU(0.2, True)]
        self.block = nn.Sequential(*block)

        feature = [nn.Linear(nf*4, 100), nn.LeakyReLU(0.2, True)]
        self.feature = nn.Sequential(*feature)

        self.final = nn.Linear(100, num_classes)

    #use our classifer weights as feature extractor
    def get_features(self, x):
        x = self.block(x)
        x = torch.mean(x, dim=(2,3))
        return self.feature(x)

    def forward(self, x):
        x = self.get_features(x)
        return self.final(x)

In [10]:
def print_parameters(net, name):
    num_params = 0
    for p in net.parameters():
        num_params += p.numel()
    
    print("{} has {:.3f}M parameters".format(name, num_params/1e6))

device = torch.device("cuda:1")

classifier = Classifier().to(device)
print_parameters(classifier, "Classifier")

lr = 1e-4
optim = torch.optim.Adam(classifier.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.CrossEntropyLoss()

Classifier has 1.733M parameters


In [11]:
epoch_decay = 10
num_epochs = 30

losses_avg_train = []
losses_avg_test = []
acc_train = []
acc_test = []

best_acc = 0.0

for epoch in range(1,num_epochs+1):
    start = time.time()
    #train
    classifier.train()
    cur_losses = []
    num_total = 0
    num_right = 0
    for it,data in enumerate(trainloader):
        images, targets = data
        images = images.to(device)
        targets = targets.to(device)
        
        optim.zero_grad()
        out = classifier(images)
        loss = criterion(out, targets)
        loss.backward()
        optim.step()
        
        _, pred = out.max(1)
        num_total += images.size(0)
        num_right += torch.sum(torch.eq(pred, targets)).cpu().numpy()
        
        cur_losses.append(loss.item())
        
    losses_avg_train.append(np.mean(cur_losses))
    acc_train.append(num_right/num_total*100.0)
    
    
    #test
    classifier.eval()
    cur_losses = []
    num_total = 0
    num_right = 0
    for it,data in enumerate(testloader):
        images, targets = data
        images = images.to(device)
        targets = targets.to(device)
        
        out = classifier(images)
        loss = criterion(out, targets)
        
        cur_losses.append(loss.item())
        
        _, pred = out.max(1)
        num_total += images.size(0)
        num_right += torch.sum(torch.eq(pred, targets)).cpu().numpy()
        
    losses_avg_test.append(np.mean(cur_losses))
    acc_test.append(num_right/num_total*100.0)
    
    #plot losses
    plt.figure(dpi=300)
    plt.title("Classification Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.plot(np.arange(1, epoch+1), losses_avg_train, label="Train")
    plt.plot(np.arange(1, epoch+1), losses_avg_test, label="Test")
    plt.legend()
    plt.savefig("losses_classifier.png")
    plt.close()
    
    plt.figure(dpi=300)
    plt.title("Classification Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy [%]")
    plt.plot(np.arange(1, epoch+1), acc_train, label="Train")
    plt.plot(np.arange(1, epoch+1), acc_test, label="Test")
    plt.legend()
    plt.savefig("accuracy_classifier.png")
    plt.close()
    
    #save model
    if acc_test[-1]>best_acc:
        best_acc = acc_test[-1]
        torch.save(classifier.state_dict(), "classifier.ckpt")
    
    #anneal learning rate
    nlrg = lr*(1.0-max(0,epoch-epoch_decay)/float(num_epochs-epoch_decay+1)) #new learning rate
    if epoch>epoch_decay: #decay learning rate linearly
        for param_group in optim.param_groups:
            param_group['lr'] = nlrg
    
    print("End of epoch {}. (Time taken: {:.3f}s)".format(epoch, time.time()-start))

End of epoch 1. (Time taken: 9.944s)
End of epoch 2. (Time taken: 14.720s)
End of epoch 3. (Time taken: 11.332s)
End of epoch 4. (Time taken: 11.246s)
End of epoch 5. (Time taken: 10.890s)
End of epoch 6. (Time taken: 10.895s)
End of epoch 7. (Time taken: 12.518s)
End of epoch 8. (Time taken: 12.881s)
End of epoch 9. (Time taken: 15.645s)
End of epoch 10. (Time taken: 18.293s)
End of epoch 11. (Time taken: 16.362s)
End of epoch 12. (Time taken: 13.164s)
End of epoch 13. (Time taken: 10.388s)
End of epoch 14. (Time taken: 10.899s)
End of epoch 15. (Time taken: 12.339s)
End of epoch 16. (Time taken: 11.186s)
End of epoch 17. (Time taken: 12.709s)
End of epoch 18. (Time taken: 8.556s)
End of epoch 19. (Time taken: 12.482s)
End of epoch 20. (Time taken: 12.003s)
End of epoch 21. (Time taken: 12.062s)
End of epoch 22. (Time taken: 12.527s)
End of epoch 23. (Time taken: 11.788s)
End of epoch 24. (Time taken: 11.405s)
End of epoch 25. (Time taken: 12.820s)
End of epoch 26. (Time taken: 5.981s

: 