### Load Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision                import datasets, transforms
from torchvision.datasets.utils import download_url
from torchvision.datasets       import ImageFolder
from torch.optim.lr_scheduler   import StepLR
from pytorch_model_summary      import summary
import torch.utils.data as data
import time
import random
import numpy as np
import os

### Define params

In [2]:
batch_size = 64
epochs     = 50
lr         = 0.01
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_path = "./CIFAR10"
print(device)

cuda


### Some utility functions

In [3]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [5]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim = True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

### define network

In [6]:
class CNN_Student(nn.Module):
    def __init__(self):
        super(CNN_Student, self).__init__()
        self.conv_layers = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5),
        nn.ReLU(inplace=True),

        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
        nn.ReLU(inplace=True),
            
        nn.MaxPool2d(kernel_size=3, stride=3),
            
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
            
        nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.fc_layers = nn.Sequential(
            nn.Linear(1152, 10))
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        logits = self.fc_layers(x)
        y_pred = F.log_softmax(logits, dim=1)
        return y_pred, logits

def train(model, train_loader, optimizer, device):    
    epoch_loss = 0
    epoch_acc = 0    
    model.train()    
    for (x, y) in train_loader:        
        x = x.to(device)
        y = y.to(device)        
        optimizer.zero_grad()                
        y_pred, _ = model(x)        
        loss = F.nll_loss(y_pred, y, reduction='sum')        
        acc = calculate_accuracy(y_pred, y)        
        loss.backward()        
        optimizer.step()            
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    
    return epoch_loss / len(train_loader), epoch_acc / len(train_loader)

def evaluate(model, iterator, device):
    
    epoch_loss = 0
    epoch_acc = 0
    model.eval()    
    with torch.no_grad():        
        for (x, y) in iterator:
            x = x.to(device)
            y = y.to(device)
            y_pred, _ = model(x)
            loss = F.nll_loss(y_pred, y, reduction='sum')
            acc = calculate_accuracy(y_pred, y)
            epoch_loss += loss.item()
            epoch_acc += acc.item()        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

### Initialize & Define loaders

In [7]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding = 5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [8]:
train_data = datasets.CIFAR10("CIFAR10", train=True, download=True, transform=transform_train)
valid_data = datasets.CIFAR10("CIFAR10", train=False, transform=transform_test)
    
#print summary
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')

Files already downloaded and verified
Number of training examples: 50000
Number of validation examples: 10000


In [9]:
train_loader  = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle= True, pin_memory=True)
valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size = batch_size)

### Create model

In [10]:
student_model = CNN_Student().to(device)
optimizer = optim.SGD(student_model.parameters(), lr = lr)

In [11]:
print(f'The model has {count_params(student_model):,} trainable parameters')

The model has 139,338 trainable parameters


### Train Student

In [12]:
best_valid_acc  = 0

In [13]:
for epoch in range(1, epochs + 1):    
    start_time = time.monotonic()
    
    train_loss, train_acc = train(student_model, train_loader, optimizer, device)
    valid_loss, valid_acc = evaluate(student_model, valid_loader, device)
    
    if valid_acc > best_valid_acc:
        best_valid_acc  = valid_acc
        torch.save(student_model.state_dict(), './models/CIFAR10_CNN_student.pt')
    
    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}% |  Best Val. Acc: {best_valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 19s
	Train Loss: 138.320 | Train Acc: 22.91%
	 Val. Loss: 113.878 |  Val. Acc: 35.10% |  Best Val. Acc: 35.10%
Epoch: 02 | Epoch Time: 0m 17s
	Train Loss: 117.081 | Train Acc: 32.43%
	 Val. Loss: 105.936 |  Val. Acc: 38.38% |  Best Val. Acc: 38.38%
Epoch: 03 | Epoch Time: 0m 17s
	Train Loss: 111.135 | Train Acc: 35.19%
	 Val. Loss: 97.364 |  Val. Acc: 43.09% |  Best Val. Acc: 43.09%
Epoch: 04 | Epoch Time: 0m 17s
	Train Loss: 105.785 | Train Acc: 38.44%
	 Val. Loss: 118.375 |  Val. Acc: 36.82% |  Best Val. Acc: 43.09%
Epoch: 05 | Epoch Time: 0m 17s
	Train Loss: 99.577 | Train Acc: 42.65%
	 Val. Loss: 88.782 |  Val. Acc: 48.02% |  Best Val. Acc: 48.02%
Epoch: 06 | Epoch Time: 0m 17s
	Train Loss: 94.694 | Train Acc: 46.18%
	 Val. Loss: 92.545 |  Val. Acc: 47.93% |  Best Val. Acc: 48.02%
Epoch: 07 | Epoch Time: 0m 17s
	Train Loss: 90.437 | Train Acc: 49.02%
	 Val. Loss: 92.123 |  Val. Acc: 49.61% |  Best Val. Acc: 49.61%
Epoch: 08 | Epoch Time: 0m 17s
	Train Los

### Load and re-run tests

In [14]:
student_model.load_state_dict(torch.load('./models/CIFAR10_CNN_student.pt'))
valid_loss, valid_acc = evaluate(student_model, valid_loader, device)
print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}% |  Best Val. Acc: {best_valid_acc*100:.2f}%')

	 Val. Loss: 47.003 |  Val. Acc: 74.78% |  Best Val. Acc: 74.78%
