In [2]:
import time
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import datasets, transforms
from torchsummary import summary
from torch.optim import lr_scheduler
import torch.nn.functional as F
import torch.nn as nn
import torchvision.models as models
from torch import nn, optim
from google.colab import files


In [3]:
BATCH_SIZE = 64
DIR = 'content'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
ALPHA = 0.5
TEMP = 1

In [4]:
transform = transforms.Compose([transforms.Resize((224,224)), #resizing images to match ResNet architecture
                                transforms.ToTensor(),
                                transforms.Normalize([0.485,0.456,  #using mean and std from ImageNet to match ResNet input data scaling
                                0.406], [0.229, 0.224, 0.225])])
trainset = datasets.CIFAR10(f'/DIR/train/', download=True, train=True, transform=transform)
valset = datasets.CIFAR10(f'/DIR/val/', download=True, train=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=True)
len_trainset = len(trainset)
len_valset = len(valset)
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
dataiter = iter(trainloader)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
resnet = models.resnet50(pretrained=True) #setting up ResNet50 the teacher model
for param in resnet.parameters():
   param.requires_grad = False
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 10)
resnet = resnet.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.fc.parameters())

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 67.5MB/s]


In [28]:
def train_and_evaluate(model, trainloader, valloader, criterion, optimizer, len_trainset, len_valset, num_epochs=25):
	model.train()
	best_model_wts = copy.deepcopy(model.state_dict())
	best_acc = 0.0
	for epoch in range(num_epochs):
		model.train()
		print(f'Epoch {epoch}/{num_epochs-1}')
		print('-' * 10)
		running_loss = 0.0
		running_corrects = 0
		for inputs, labels in trainloader:
			inputs = inputs.to(device)
			labels = labels.to(device)
			optimizer.zero_grad()
			outputs = model(inputs)
			loss = criterion(outputs, labels)
			_, preds = torch.max(outputs, 1)
			loss.backward()
			optimizer.step()
			running_loss += loss.item() * inputs.size(0)
			running_corrects += torch.sum(preds == labels.data)
		epoch_loss = running_loss / len_trainset
		epoch_acc = running_corrects.double() / len_trainset
		print(' Train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss,
				 epoch_acc))

		model.eval()
		running_loss_val = 0.0
		running_corrects_val = 0
		for inputs, labels in valloader:
			inputs = inputs.to(device)
			labels = labels.to(device)
			outputs = model(inputs)
			loss = criterion(outputs,labels)
			_, preds = torch.max(outputs, 1)
			running_loss_val += loss.item() * inputs.size(0)
			running_corrects_val += torch.sum(preds == labels.data)

		epoch_loss_val = running_loss_val / len_valset
		epoch_acc_val = running_corrects_val.double() / len_valset

		if epoch_acc_val > best_acc:
			best_acc = epoch_acc_val
			best_model_wts = copy.deepcopy(model.state_dict())

		print(' Val Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss_val,
				 epoch_acc_val))

		print()
		print('Best val Acc: {:4f}'.format(best_acc))
		model.load_state_dict(best_model_wts)
	return model

In [30]:
resnet_teacher = train_and_evaluate(resnet,trainloader,
                                   valloader,criterion,optimizer,
                                   len_trainset,len_valset,10)

Epoch 0/9
----------
 Train Loss: 0.7513 Acc: 0.7537
 Val Loss: 0.6128 Acc: 0.7914

Best val Acc: 0.791400
Epoch 1/9
----------
 Train Loss: 0.5895 Acc: 0.7943
 Val Loss: 0.5586 Acc: 0.8082

Best val Acc: 0.808200
Epoch 2/9
----------
 Train Loss: 0.5631 Acc: 0.8044
 Val Loss: 0.5537 Acc: 0.8093

Best val Acc: 0.809300
Epoch 3/9
----------
 Train Loss: 0.5494 Acc: 0.8100
 Val Loss: 0.5410 Acc: 0.8147

Best val Acc: 0.814700
Epoch 4/9
----------
 Train Loss: 0.5296 Acc: 0.8154
 Val Loss: 0.5221 Acc: 0.8192

Best val Acc: 0.819200
Epoch 5/9
----------
 Train Loss: 0.5240 Acc: 0.8189
 Val Loss: 0.6081 Acc: 0.7998

Best val Acc: 0.819200
Epoch 6/9
----------
 Train Loss: 0.5149 Acc: 0.8210
 Val Loss: 0.5523 Acc: 0.8133

Best val Acc: 0.819200
Epoch 7/9
----------
 Train Loss: 0.5151 Acc: 0.8218
 Val Loss: 0.5031 Acc: 0.8269

Best val Acc: 0.826900
Epoch 8/9
----------
 Train Loss: 0.5114 Acc: 0.8230
 Val Loss: 0.5341 Acc: 0.8201

Best val Acc: 0.826900
Epoch 9/9
----------
 Train Loss: 0.5

In [33]:
torch.save(resnet_teacher.state_dict(), 'resnet_teacher.pt')
files.download('resnet_teacher.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
class Net(nn.Module):
	def __init__(self):
		super(Net, self).__init__()
		self.layer1 = nn.Sequential(
			nn.Conv2d(3, 64, kernel_size = (3,3), stride = (1,1),
			padding = (1,1)),
			nn.ReLU(inplace=True),
			nn.Conv2d(64, 64, kernel_size = (3,3), stride = (1,1),
			padding = (1,1)),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(kernel_size=2, stride=2, padding=0,
			dilation=1, ceil_mode=False)
		)
		self.layer2 = nn.Sequential(
			nn.Conv2d(64, 128, kernel_size = (3,3), stride = (1,1),
			padding = (1,1)),
			nn.ReLU(inplace=True),
			nn.Conv2d(128, 128, kernel_size = (3,3), stride = (1,1),
			padding = (1,1)),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(kernel_size=2, stride=2, padding=0,
			dilation=1, ceil_mode=False)
		)
		self.pool1 = nn.AdaptiveAvgPool2d(output_size=(1,1))
		self.fc1 = nn.Linear(128, 32)
		self.fc2 = nn.Linear(32, 10)
		self.dropout_rate = 0.5

	def forward(self, x):
		x = self.layer1(x)
		x = self.layer2(x)
		x = self.pool1(x)
		x = x.view(x.size(0), -1)
		x = self.fc1(x)
		x = self.fc2(x)
		return x
net = Net().to(device)

In [8]:
class SmallNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SmallNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

small_net = SmallNet().to(device)


In [12]:
def loss_kd(outputs, labels, teacher_outputs, temp, alpha):
	KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/temp,
				 dim=1),F.softmax(teacher_outputs/temp,dim=1)) * (alpha * temp * temp) + F.cross_entropy(outputs, labels) * (1 - alpha)
    #( torch.sum(F.softmax(teacher_outputs/temp,dim=-1) * (F.softmax(teacher_outputs/temp,dim=-1).log() - F.log_softmax(outputs/temp, dim=-1))) / soft_prob.size()[0] * (temp**2 * alpha) ) + F.cross_entropy(outputs, labels) * (1 - alpha)
	return KD_loss

def get_outputs(model, dataloader):
	outputs = []
	for inputs, labels in dataloader:
		inputs_batch, labels_batch = inputs.to(device), labels.to(device)
		output_batch = model(inputs_batch).data.cpu().numpy()
		outputs.append(output_batch)
	return outputs

In [10]:
def train_kd(model,teacher_out, optimizer, use_kd, loss_fn, dataloader, temp, alpha):
	model.train()
	running_loss = 0.0
	running_corrects = 0
	for i,(images, labels) in enumerate(dataloader):
		inputs = images.to(device)
		labels = labels.to(device)
		optimizer.zero_grad()
		outputs = model(inputs)
		if use_kd:
			outputs_teacher = torch.from_numpy(teacher_out[i]).to(device)
			loss = loss_fn(outputs,labels,outputs_teacher,temp,
							alpha)
		else:
			loss = loss_fn(outputs, labels)
		_, preds = torch.max(outputs, 1)
		loss.backward()
		optimizer.step()
		running_loss += loss.item() * inputs.size(0)
		running_corrects += torch.sum(preds == labels.data)

	epoch_loss = running_loss / len(trainset)
	epoch_acc = running_corrects.double() / len(trainset)
	print(' Train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss,
			 epoch_acc))

def eval_kd(model,teacher_out, optimizer, use_kd, loss_fn, dataloader, temp, alpha):
	model.eval()
	running_loss = 0.0
	running_corrects = 0
	for i,(images, labels) in enumerate(dataloader):
		inputs = images.to(device)
		labels = labels.to(device)
		outputs = model(inputs)
		if use_kd:
			outputs_teacher = torch.from_numpy(teacher_out[i]).cuda()
			loss = loss_fn(outputs,labels,outputs_teacher,temp,
							alpha)
		else:
			loss = loss_fn(outputs, labels)
		_, preds = torch.max(outputs, 1)
		running_loss += loss.item() * inputs.size(0)
		running_corrects += torch.sum(preds == labels.data)
	epoch_loss = running_loss / len(valset)
	epoch_acc = running_corrects.double() / len(valset)
	print(' Val Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss,
			 epoch_acc))
	return epoch_acc

def train_and_evaluate_kd(model, teacher_model, optimizer, loss_fn, trainloader, valloader, temp, alpha, num_epochs=25):
	teacher_model.eval()
	best_model_wts = copy.deepcopy(model.state_dict())
	outputs_teacher_train = get_outputs(teacher_model, trainloader)
	outputs_teacher_val = get_outputs(teacher_model, valloader)
	print("Starting the training process")
	best_acc = 0.0
	for epoch in range(num_epochs):
		print('Epoch {}/{}'.format(epoch, num_epochs - 1))
		print('-' * 10)

		train_kd(model, outputs_teacher_train,
					optim.Adam(net.parameters()),True, loss_fn,trainloader,
					temp, alpha)

		# Evaluating the student network
		epoch_acc_val = eval_kd(model, outputs_teacher_val,
									optim.Adam(net.parameters()), True, loss_fn,
									valloader, temp, alpha)
		if epoch_acc_val > best_acc:
			best_acc = epoch_acc_val
			best_model_wts = copy.deepcopy(model.state_dict())
			print('Best val Acc: {:4f}'.format(best_acc))
			model.load_state_dict(best_model_wts)
	return model

def train_and_evaluate_no_kd(model, optimizer, loss_fn, trainloader, valloader, temp, alpha, num_epochs=25):
	best_model_wts = copy.deepcopy(model.state_dict())
	outputs_teacher_train = None
	outputs_teacher_val = None
	print("Starting the training process")
	best_acc = 0.0
	for epoch in range(num_epochs):
		print('Epoch {}/{}'.format(epoch, num_epochs - 1))
		print('-' * 10)

		train_kd(model, outputs_teacher_train,
					optim.Adam(net.parameters()),False, loss_fn,trainloader,
					temp, alpha)

		# Evaluating the student network
		epoch_acc_val = eval_kd(model, outputs_teacher_val,
									optim.Adam(net.parameters()), False, loss_fn,
									valloader, temp, alpha)
		if epoch_acc_val > best_acc:
			best_acc = epoch_acc_val
			best_model_wts = copy.deepcopy(model.state_dict())
			print('Best val Acc: {:4f}'.format(best_acc))
			model.load_state_dict(best_model_wts)
	return model


In [48]:
student_kd = train_and_evaluate_kd(net,resnet_teacher,optim.Adam(net.parameters()),loss_kd,trainloader,valloader,TEMP,ALPHA,10)
torch.save(student_kd.state_dict(), 'student_kd.pt')
files.download('student_kd.pt')

Starting the training process
Epoch 0/9
----------
 Train Loss: 0.9521 Acc: 0.4017
 Val Loss: 0.9186 Acc: 0.4436
Best val Acc: 0.443600
Epoch 1/9
----------
 Train Loss: 0.8998 Acc: 0.4533
 Val Loss: 0.9029 Acc: 0.4613
Best val Acc: 0.461300
Epoch 2/9
----------
 Train Loss: 0.8618 Acc: 0.4900
 Val Loss: 0.8489 Acc: 0.5084
Best val Acc: 0.508400
Epoch 3/9
----------
 Train Loss: 0.8292 Acc: 0.5192
 Val Loss: 0.8243 Acc: 0.5216
Best val Acc: 0.521600
Epoch 4/9
----------
 Train Loss: 0.8075 Acc: 0.5419
 Val Loss: 0.7901 Acc: 0.5546
Best val Acc: 0.554600
Epoch 5/9
----------
 Train Loss: 0.7902 Acc: 0.5560
 Val Loss: 0.7923 Acc: 0.5511
Epoch 6/9
----------
 Train Loss: 0.7788 Acc: 0.5649
 Val Loss: 0.7827 Acc: 0.5668
Best val Acc: 0.566800
Epoch 7/9
----------
 Train Loss: 0.7658 Acc: 0.5788
 Val Loss: 0.7532 Acc: 0.5848
Best val Acc: 0.584800
Epoch 8/9
----------
 Train Loss: 0.7567 Acc: 0.5867
 Val Loss: 0.7493 Acc: 0.5901
Best val Acc: 0.590100
Epoch 9/9
----------
 Train Loss: 0.748

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
student_no_kd = train_and_evaluate_no_kd(net,optim.Adam(net.parameters()),nn.CrossEntropyLoss(),trainloader,valloader,TEMP,ALPHA,10)
torch.save(student_no_kd.state_dict(), 'student_no_kd.pt')
files.download('student_no_kd.pt')

Starting the training process
Epoch 0/9
----------


In [None]:
!pip install gdown
!gdown https://drive.google.com/uc?id=1fvCetBJguJMyS8a66Uv2jvwW8bDgbl4X

In [None]:
resnet_teacher = models.resnet50(pretrained=True)
for param in resnet_teacher.parameters():
   param.requires_grad = False
resnet_teacher.to(device)
resnet_teacher.load_state_dict(torch.load('/content/resnet_teacher.pt'))
resnet_teacher.eval()

In [None]:
student_kd_small = train_and_evaluate_kd(small_net,resnet_teacher,optim.Adam(small_net.parameters()),loss_kd,trainloader,valloader,TEMP,ALPHA,10)
torch.save(student_kd_small.state_dict(), 'student_kd_small.pt')
files.download('student_kd_small.pt')