### Load Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision                import datasets, transforms
from torchvision.datasets.utils import download_url
from torchvision.datasets       import ImageFolder
from torch.optim.lr_scheduler   import StepLR
from pytorch_model_summary      import summary
import torch.utils.data as data
import time
import random
import numpy as np
import os

### Define params

In [2]:
batch_size = 64
epochs     = 50
lr         = 0.01
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_path = "./CIFAR10"
print(device)

cuda


### Some utility functions

In [3]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [5]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim = True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

### define network

In [6]:
class CNN_Net(nn.Module):
    def __init__(self):
        super(CNN_Net, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(1024, 512),
            nn.Dropout(p=0.2),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, 10)
        )
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        output = F.log_softmax(x, dim=1)
        return output, x


def train(model, train_loader, optimizer, device):    
    epoch_loss = 0
    epoch_acc = 0    
    model.train()    
    for (x, y) in train_loader:        
        x = x.to(device)
        y = y.to(device)        
        optimizer.zero_grad()                
        y_pred, _ = model(x)        
        loss = F.nll_loss(y_pred, y, reduction='mean')        
        acc = calculate_accuracy(y_pred, y)        
        loss.backward()        
        optimizer.step()     
       
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    
    return epoch_loss / len(train_loader), epoch_acc / len(train_loader)


def evaluate(model, test_loader, device):    
    epoch_loss = 0
    epoch_acc = 0
    model.eval()    
    with torch.no_grad():        
        for (x, y) in test_loader:
            x = x.to(device)
            y = y.to(device)
            y_pred, _ = model(x)
            loss = F.nll_loss(y_pred, y, reduction='sum')
            acc = calculate_accuracy(y_pred, y)
            epoch_loss += loss.item()
            epoch_acc += acc.item()        
    return epoch_loss / len(test_loader), epoch_acc / len(test_loader)


### Initialize & Define loaders

In [7]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding = 5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [8]:
train_data = datasets.CIFAR10("CIFAR10", train=True, download=True, transform=transform_train)
valid_data = datasets.CIFAR10("CIFAR10", train=False, transform=transform_test)
    
#print summary
print(f'Number of training examples:   {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')

Files already downloaded and verified
Number of training examples:   50000
Number of validation examples: 10000


In [9]:
train_loader  = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle= True, pin_memory=True)
valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size = batch_size)

### Create model

In [10]:
teacher_model = CNN_Net().to(device)
optimizer = optim.SGD(teacher_model.parameters(), lr = lr)

In [11]:
print(f'The model has {count_params(teacher_model):,} trainable parameters')

The model has 6,822,154 trainable parameters


### Train Teacher

In [12]:
best_valid_acc  = 0.0

In [13]:
for epoch in range(1, epochs + 1):    
    start_time = time.monotonic()
    
    train_loss, train_acc = train(teacher_model, train_loader, optimizer, device)
    valid_loss, valid_acc = evaluate(teacher_model, valid_loader, device)
    
    if valid_acc > best_valid_acc:
        best_valid_acc  = valid_acc
        torch.save(teacher_model.state_dict(), './models/CIFAR10_CNN_teacher.pt')
    
    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}% |  Best Val. Acc: {best_valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 30s
	Train Loss: 1.471 | Train Acc: 46.13%
	 Val. Loss: 81.776 |  Val. Acc: 53.98% |  Best Val. Acc: 53.98%
Epoch: 02 | Epoch Time: 0m 27s
	Train Loss: 1.095 | Train Acc: 60.62%
	 Val. Loss: 87.091 |  Val. Acc: 53.29% |  Best Val. Acc: 53.98%
Epoch: 03 | Epoch Time: 0m 27s
	Train Loss: 0.919 | Train Acc: 67.33%
	 Val. Loss: 67.519 |  Val. Acc: 63.82% |  Best Val. Acc: 63.82%
Epoch: 04 | Epoch Time: 0m 27s
	Train Loss: 0.819 | Train Acc: 70.98%
	 Val. Loss: 53.725 |  Val. Acc: 70.52% |  Best Val. Acc: 70.52%
Epoch: 05 | Epoch Time: 0m 27s
	Train Loss: 0.746 | Train Acc: 73.61%
	 Val. Loss: 56.125 |  Val. Acc: 69.52% |  Best Val. Acc: 70.52%
Epoch: 06 | Epoch Time: 0m 27s
	Train Loss: 0.689 | Train Acc: 75.96%
	 Val. Loss: 64.995 |  Val. Acc: 67.66% |  Best Val. Acc: 70.52%
Epoch: 07 | Epoch Time: 0m 27s
	Train Loss: 0.648 | Train Acc: 77.42%
	 Val. Loss: 50.372 |  Val. Acc: 72.97% |  Best Val. Acc: 72.97%
Epoch: 08 | Epoch Time: 0m 27s
	Train Loss: 0.611 | Tra

### Load and re-run tests

In [14]:
teacher_model.load_state_dict(torch.load('./models/CIFAR10_CNN_teacher.pt'))
valid_loss, valid_acc = evaluate(teacher_model, valid_loader, device)
print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

	 Val. Loss: 26.294 |  Val. Acc: 87.31%
