### Load Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision                import datasets, transforms
from torchvision.datasets.utils import download_url
from torchvision.datasets       import ImageFolder
from torch.optim.lr_scheduler   import StepLR
from pytorch_model_summary      import summary
import torch.utils.data as data
import time
import random
import numpy as np
import os

In [2]:
batch_size = 64
epochs     = 50
lr         = 0.01
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_path = "./CIFAR10"
temperature = 4
alpha       = 0.95
print(device)

cuda


In [3]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [5]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim = True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

### Teacher Model

In [6]:
class CNN_Net(nn.Module):
    def __init__(self):
        super(CNN_Net, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(1024, 512),
            nn.Dropout(p=0.2),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, 10)
        )
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        output = F.log_softmax(x, dim=1)
        return output, x

In [7]:
teacher = CNN_Net().to(device)
teacher.load_state_dict(torch.load('./models/CIFAR10_CNN_teacher.pt'))
teacher = teacher.eval()

### Student Model

In [8]:
class CNN_Student(nn.Module):
    def __init__(self):
        super(CNN_Student, self).__init__()
        self.conv_layers = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5),
        nn.ReLU(inplace=True),

        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
        nn.ReLU(inplace=True),
            
        nn.MaxPool2d(kernel_size=3, stride=3),
            
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
            
        nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.fc_layers = nn.Sequential(
            nn.Linear(1152, 10))
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        logits = self.fc_layers(x)
        y_pred = F.log_softmax(logits, dim=1)
        return y_pred, logits
    
def distillation_loss(logits_stu, logits_base, T):
    logits_stu  = logits_stu  / T
    logits_base = logits_base / T
    pred_stu    = F.log_softmax(logits_stu,  dim=1)
    prop_base   = nn.Softmax(dim=1) (logits_base)
    pred_base   = torch.argmax(prop_base, dim=1)
    loss        = F.nll_loss(pred_stu, pred_base, reduction='sum')
    return loss
    
def train(model, iterator , optimizer, teacher, device, T, alpha):    
    epoch_loss = 0
    epoch_acc = 0    
    model = model.train()
    teacher = teacher.eval()
    for (x, y) in iterator:        
        x = x.to(device)
        y = y.to(device)        
        optimizer.zero_grad()         
        if teacher is not None:
            y_pred, logits_pred = model(x)        
            y_teacher, logits_teacher = teacher(x)
            dist_loss = distillation_loss(logits_pred,logits_teacher,T)            
            stu_loss  = F.nll_loss(y_pred, y, reduction='sum')     
            loss      = alpha * dist_loss + (1-alpha) * stu_loss
        else:
            y_pred, _ = model(x)        
            loss = F.nll_loss(y_pred, y, reduction='sum')   
            
        acc = calculate_accuracy(y_pred, y)        
        loss.backward()        
        optimizer.step()        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, device):    
    epoch_loss = 0
    epoch_acc = 0
    model.eval()    
    with torch.no_grad():        
        for (x, y) in iterator:
            x = x.to(device)
            y = y.to(device)
            y_pred, _ = model(x)
            loss = F.nll_loss(y_pred, y, reduction='sum')
            acc = calculate_accuracy(y_pred, y)
            epoch_loss += loss.item()
            epoch_acc += acc.item()        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

### Initialize & Define loaders

In [9]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [10]:
train_data = datasets.CIFAR10("CIFAR10", train=True, download=True, transform=transform_train)
valid_data = datasets.CIFAR10("CIFAR10", train=False, transform=transform_test)
    
#print summary
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')

Files already downloaded and verified
Number of training examples: 50000
Number of validation examples: 10000


In [11]:
train_loader  = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle= True, pin_memory=True)
valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size = batch_size)

### Create model

In [12]:
#dist_model = CNN_Student().to(device)
#optimizer = optim.SGD(dist_model.parameters(), lr = lr)

In [13]:
#print(f'The model has {count_params(dist_model):,} trainable parameters')

In [14]:
def run (T):
    best_valid_acc  = 0
    dist_model = CNN_Student().to(device)
    optimizer = optim.SGD(dist_model.parameters(), lr = lr)

    for epoch in range(1, epochs + 1):    
        start_time = time.monotonic()
        
        train_loss, train_acc = train(dist_model, train_loader, optimizer,teacher, device, T, alpha)
        valid_loss, valid_acc = evaluate(dist_model, valid_loader, device)

        if valid_acc > best_valid_acc:
            best_valid_acc  = valid_acc
            torch.save(dist_model.state_dict(), './models/CIFAR10_CNN_dist T = ' + str(T)+ '_.pt')

        end_time = time.monotonic()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        print(f'Epoch: {epoch:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}% |  Best Val. Acc: {best_valid_acc*100:.2f}%')

    return best_valid_acc

In [15]:
Ts  = [7.0] #[4.0, 5.0, 6.0, 7.0, 8.0, 10.0, 15.0, 20.0]
acc = []

for temperature in Ts:
    print (f"-------------------- Processing with T = {temperature} --------------------")
    acc.append(run(temperature))


-------------------- Processing with T = 7.0 --------------------
Epoch: 01 | Epoch Time: 0m 21s
	Train Loss: 107.964 | Train Acc: 41.55%
	 Val. Loss: 197.922 |  Val. Acc: 56.48% |  Best Val. Acc: 56.48%
Epoch: 02 | Epoch Time: 0m 19s
	Train Loss: 84.786 | Train Acc: 56.34%
	 Val. Loss: 189.551 |  Val. Acc: 63.61% |  Best Val. Acc: 63.61%
Epoch: 03 | Epoch Time: 0m 19s
	Train Loss: 74.090 | Train Acc: 62.38%
	 Val. Loss: 176.124 |  Val. Acc: 67.46% |  Best Val. Acc: 67.46%
Epoch: 04 | Epoch Time: 0m 19s
	Train Loss: 67.289 | Train Acc: 65.87%
	 Val. Loss: 194.634 |  Val. Acc: 68.20% |  Best Val. Acc: 68.20%
Epoch: 05 | Epoch Time: 0m 19s
	Train Loss: 63.371 | Train Acc: 68.05%
	 Val. Loss: 233.073 |  Val. Acc: 67.76% |  Best Val. Acc: 68.20%
Epoch: 06 | Epoch Time: 0m 19s
	Train Loss: 59.693 | Train Acc: 70.04%
	 Val. Loss: 207.873 |  Val. Acc: 69.79% |  Best Val. Acc: 69.79%
Epoch: 07 | Epoch Time: 0m 19s
	Train Loss: 57.035 | Train Acc: 71.31%
	 Val. Loss: 174.706 |  Val. Acc: 73.66%

### Load and re-run tests

In [17]:
dist_model = CNN_Student().to(device)
dist_model.load_state_dict(torch.load('./models/CIFAR10_CNN_dist T = 7.0_.pt'))
valid_loss, valid_acc = evaluate(dist_model, valid_loader, device)
print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

	 Val. Loss: 137.401 |  Val. Acc: 82.08%
