# Import Libraries

In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
fast_device = device

device: cuda


In [None]:
checkpoints_path_teacher = 'checkpoints_teacher/'
checkpoints_path_student = 'checkpoints_student/'
if not os.path.exists(checkpoints_path_student):
    os.makedirs(checkpoints_path_student)
if not os.path.exists(checkpoints_path_teacher):
    os.makedirs(checkpoints_path_teacher)

# Utils Functions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

def trainStep(network, criterion, optimizer, X, y):
	"""
	One training step of the network: forward prop + backprop + update parameters
	Return: (loss, accuracy) of current batch
	"""
	optimizer.zero_grad()
	outputs = network(X)
	loss = criterion(outputs, y)
	loss.backward()
	optimizer.step()
	accuracy = float(torch.sum(torch.argmax(outputs, dim=1) == y).item()) / y.shape[0]
	return loss, accuracy

def getLossAccuracyOnDataset(network, dataset_loader, fast_device, criterion=None):
	"""
	Returns (loss, accuracy) of network on given dataset
	"""
	network.is_training = False
	accuracy = 0.0
	loss = 0.0
	dataset_size = 0
	for j, D in enumerate(dataset_loader, 0):
		X, y = D
		X = X.to(fast_device)
		y = y.to(fast_device)
		with torch.no_grad():
			pred = network(X)
			if criterion is not None:
				loss += criterion(pred, y) * y.shape[0]
			accuracy += torch.sum(torch.argmax(pred, dim=1) == y).item()
		dataset_size += y.shape[0]
	loss, accuracy = loss / dataset_size, accuracy / dataset_size
	network.is_training = True
	return loss, accuracy

def trainTeacherOnHparam(teacher_net, hparam, num_epochs, 
						train_loader, val_loader, 
						print_every=0, 
						fast_device='cuda:0'):
	"""
	Trains teacher on given hyperparameters for given number of epochs; Pass val_loader=None when not required to validate for every epoch 
	Return: List of training loss, accuracy for each update calculated only on the batch; List of validation loss, accuracy for each epoch
	"""
	train_loss_list, train_acc_list, val_loss_list, val_acc_list = [], [], [], []
	teacher_net.dropout_input = hparam['dropout_input']
	teacher_net.dropout_hidden = hparam['dropout_hidden']
	criterion = nn.CrossEntropyLoss()
	optimizer = optim.SGD(teacher_net.parameters(), lr=hparam['lr'], momentum=hparam['momentum'], weight_decay=hparam['weight_decay'])
	lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=hparam['lr_decay'])
	for epoch in range(num_epochs):
		lr_scheduler.step()
		if epoch == 0:
			if val_loader is not None:
				val_loss, val_acc = getLossAccuracyOnDataset(teacher_net, val_loader, fast_device, criterion)
				val_loss_list.append(val_loss)
				val_acc_list.append(val_acc)
				print('epoch: %d validation loss: %.3f validation accuracy: %.3f' %(epoch, val_loss, val_acc))
		for i, data in enumerate(train_loader, 0):
			X, y = data
			X, y = X.to(fast_device), y.to(fast_device)
			loss, acc = trainStep(teacher_net, criterion, optimizer, X, y)
			train_loss_list.append(loss)
			train_acc_list.append(acc)
		
			if print_every > 0 and i % print_every == print_every - 1:
				print('[%d, %5d/%5d] train loss: %.3f train accuracy: %.3f' %
					  (epoch + 1, i + 1, len(train_loader), loss, acc))
		
		if val_loader is not None:
			val_loss, val_acc = getLossAccuracyOnDataset(teacher_net, val_loader, fast_device, criterion)
			val_loss_list.append(val_loss)
			val_acc_list.append(val_acc)
			print('epoch: %d validation loss: %.3f validation accuracy: %.3f' %(epoch + 1, val_loss, val_acc))
	return {'train_loss': train_loss_list, 
			'train_acc': train_acc_list, 
			'val_loss': val_loss_list, 
			'val_acc': val_acc_list}

def studentTrainStep(teacher_net, student_net, studentLossFn, optimizer, X, y, T, alpha):
	"""
	One training step of student network: forward prop + backprop + update parameters
	Return: (loss, accuracy) of current batch
	"""
	optimizer.zero_grad()
	teacher_pred = None
	if (alpha > 0):
		with torch.no_grad():
			teacher_pred = teacher_net(X)
	student_pred = student_net(X)
	loss = studentLossFn(teacher_pred, student_pred, y, T, alpha)
	loss.backward()
	optimizer.step()
	accuracy = float(torch.sum(torch.argmax(student_pred, dim=1) == y).item()) / y.shape[0]
	return loss, accuracy

def trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
						train_loader, val_loader, 
						print_every=0, 
						fast_device=torch.device('cpu')):
	"""
	Trains teacher on given hyperparameters for given number of epochs; Pass val_loader=None when not required to validate for every epoch
	Return: List of training loss, accuracy for each update calculated only on the batch; List of validation loss, accuracy for each epoch
	"""
	train_loss_list, train_acc_list, val_acc_list = [], [], []
	T = hparam['T']
	alpha = hparam['alpha']
	student_net.dropout_input = hparam['dropout_input']
	student_net.dropout_hidden = hparam['dropout_hidden']
	optimizer = optim.SGD(student_net.parameters(), lr=hparam['lr'], momentum=hparam['momentum'], weight_decay=hparam['weight_decay'])
	lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=hparam['lr_decay'])

	def studentLossFn(teacher_pred, student_pred, y, T, alpha):
		"""
		Loss function for student network: Loss = alpha * (distillation loss with soft-target) + (1 - alpha) * (cross-entropy loss with true label)
		Return: loss
		"""
		if (alpha > 0):
			loss = F.kl_div(F.log_softmax(student_pred / T, dim=1), F.softmax(teacher_pred / T, dim=1), reduction='batchmean') * (T ** 2) * alpha + F.cross_entropy(student_pred, y) * (1 - alpha)
		else:
			loss = F.cross_entropy(student_pred, y)
		return loss

	for epoch in range(num_epochs):
		lr_scheduler.step()
		if epoch == 0:
			if val_loader is not None:
				_, val_acc = getLossAccuracyOnDataset(student_net, val_loader, fast_device)
				val_acc_list.append(val_acc)
				print('epoch: %d validation accuracy: %.3f' %(epoch, val_acc))
		for i, data in enumerate(train_loader, 0):
			X, y = data
			X, y = X.to(fast_device), y.to(fast_device)
			loss, acc = studentTrainStep(teacher_net, student_net, studentLossFn, optimizer, X, y, T, alpha)
			train_loss_list.append(loss)
			train_acc_list.append(acc)
		
			if print_every > 0 and i % print_every == print_every - 1:
				print('[%d, %5d/%5d] train loss: %.3f train accuracy: %.3f' %
					  (epoch + 1, i + 1, len(train_loader), loss, acc))
	
		if val_loader is not None:
			_, val_acc = getLossAccuracyOnDataset(student_net, val_loader, fast_device)
			val_acc_list.append(val_acc)
			print('epoch: %d validation accuracy: %.3f' %(epoch + 1, val_acc))
	return {'train_loss': train_loss_list, 
			'train_acc': train_acc_list, 
			'val_acc': val_acc_list}

def hparamToString(hparam):
	"""
	Convert hparam dictionary to string with deterministic order of attribute of hparam in output string
	"""
	hparam_str = ''
	for k, v in sorted(hparam.items()):
		hparam_str += k + '=' + str(v) + ', '
	return hparam_str[:-2]

def hparamDictToTuple(hparam):
	"""
	Convert hparam dictionary to tuple with deterministic order of attribute of hparam in output tuple
	"""
	hparam_tuple = [v for k, v in sorted(hparam.items())]
	return tuple(hparam_tuple)

def getTrainMetricPerEpoch(train_metric, updates_per_epoch):
	"""
	Smooth the training metric calculated for each batch of training set by averaging over batches in an epoch
	Input: List of training metric calculated for each batch
	Output: List of training matric averaged over each epoch
	"""
	train_metric_per_epoch = []
	temp_sum = 0.0
	for i in range(len(train_metric)):
		temp_sum += train_metric[i]
		if (i % updates_per_epoch == updates_per_epoch - 1):
			train_metric_per_epoch.append(temp_sum / updates_per_epoch)
			temp_sum = 0.0

	return train_metric_per_epoch

# Observe FLOPS and Number of Parameter function

In [None]:
def print_the_model_out(net):
  for name, module in net.named_modules():
    # print(name, module)
    if isinstance(module, nn.Linear):
      # Get the input feature map of the module as a NumPy array
      input = module.input.cpu().detach().numpy()     #Your code here
      # Get the output feature map of the module as a NumPy array
      output = module.output.cpu().detach().numpy()     #Your code here
      # Get the weight of the module as a NumPy array
      weight = module.weight     #Your code here
      num_Param = torch.numel(weight)
      num_MAC = input.shape[1]*output.shape[1]
      print(f'{name:10} {str(input.shape):20} {str(output.shape):20} {str(weight.shape):20} {str(num_Param):10} {str(num_MAC):10}')


# ReproducibilitySeed

In [None]:
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

reproducibilitySeed()

# Student & Teacher Network Implementation

In [None]:
class FC(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        self.input = None
        self.output = None
        super(FC, self).__init__(in_features, out_features, bias)
    
    def forward(self, input):
        self.input = input
        self.output = F.linear(input, self.weight, self.bias)
        return self.output

class TeacherNetwork(nn.Module):
    def __init__(self):
        super(TeacherNetwork, self).__init__()
        self.fc1 = FC(28 * 28, 1200)
        self.fc2 = FC(1200, 1200)
        self.fc3 = FC(1200, 10)
        self.dropout_input = 0.2
        self.dropout_hidden = 0.5
        self.is_training = True
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.dropout(x, p=self.dropout_input, training=self.is_training)
        x = F.dropout(F.relu(self.fc1(x)), p=self.dropout_hidden, training=self.is_training)
        x = F.dropout(F.relu(self.fc2(x)), p=self.dropout_hidden, training=self.is_training)
        x = self.fc3(x)
        return x

class StudentNetwork(nn.Module):
    def __init__(self):
        super(StudentNetwork, self).__init__()
        self.fc1 = FC(28 * 28, 400)
        self.fc2 = FC(400, 10)
        self.dropout_input = 0.0
        self.dropout_hidden = 0.0
        self.is_training = True
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.dropout(x, p=self.dropout_input, training=self.is_training)
        x = F.dropout(F.relu(self.fc1(x)), p=self.dropout_hidden, training=self.is_training)
        x = self.fc2(x)
        return x

class StudentNetworkSmall(nn.Module):
    def __init__(self):
        super(StudentNetworkSmall, self).__init__()
        self.fc1 = FC(28 * 28, 30)
        self.fc2 = FC(30, 10)
        self.dropout_input = 0.0
        self.dropout_hidden = 0.0
        self.is_training = True
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.dropout(x, p=self.dropout_input, training=self.is_training)
        x = F.dropout(F.relu(self.fc1(x)), p=self.dropout_hidden, training=self.is_training)
        x = self.fc2(x)
        return x

# Teacher dataloader with Augmentation

In [None]:
import torchvision
import torchvision.transforms as transforms

mnist_image_shape = (28, 28)
random_pad_size = 2
# Training images augmented by randomly shifting images by at max. 2 pixels in any of 4 directions
transform_train = transforms.Compose(
                [
                    transforms.RandomCrop(mnist_image_shape, random_pad_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )


transform_test = transforms.Compose(
                [
                    transforms.ToTensor(),
                    # transforms.Normalize((0.5, 0.5), (0.5, 0.5))
                    transforms.Normalize([0.5], [0.5])
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform_train)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform_test)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST_dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST_dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST_dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST_dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST_dataset/MNIST/raw



# Peek at Teacher & Student Network Shape

In [None]:
teacher_net = TeacherNetwork().to(device)
print(teacher_net)

TeacherNetwork(
  (fc1): FC(in_features=784, out_features=1200, bias=True)
  (fc2): FC(in_features=1200, out_features=1200, bias=True)
  (fc3): FC(in_features=1200, out_features=10, bias=True)
)


In [None]:
student_net = StudentNetwork().to(device)
print(student_net)

StudentNetwork(
  (fc1): FC(in_features=784, out_features=400, bias=True)
  (fc2): FC(in_features=400, out_features=10, bias=True)
)


# Train Teacher

In [None]:
num_epochs = 60
print_every = 100    # Interval size for which to print statistics of training

In [None]:
# Hyperparamters can be tuned by setting required range below
# learning_rates = list(np.logspace(-4, -2, 3))
learning_rates = [1e-2]
learning_rate_decays = [0.95]    # learning rate decays at every epoch
# weight_decays = [0.0] + list(np.logspace(-5, -1, 5))
weight_decays = [1e-5]           # regularization weight
momentums = [0.9]
# dropout_probabilities = [(0.2, 0.5), (0.0, 0.0)]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['dropout_input'] = hparam_tuple[0][0]
    hparam['dropout_hidden'] = hparam_tuple[0][1]
    hparam['weight_decay'] = hparam_tuple[1]
    hparam['lr_decay'] = hparam_tuple[2]
    hparam['momentum'] = hparam_tuple[3]
    hparam['lr'] = hparam_tuple[4]
    hparams_list.append(hparam)

results = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    teacher_net = TeacherNetwork()
    teacher_net = teacher_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results[hparam_tuple] = trainTeacherOnHparam(teacher_net, hparam, num_epochs, 
                                                        train_val_loader, None, 
                                                        print_every=print_every, 
                                                        fast_device=fast_device)
    save_path = checkpoints_path_teacher + hparamToString(hparam) + '_final.tar'
    torch.save({'results' : results[hparam_tuple], 
                'model_state_dict' : teacher_net.state_dict(), 
                'epoch' : num_epochs}, save_path)

Training with hparamsdropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05




[1,   100/  469] train loss: 0.979 train accuracy: 0.672
[1,   200/  469] train loss: 0.686 train accuracy: 0.781
[1,   300/  469] train loss: 0.771 train accuracy: 0.781
[1,   400/  469] train loss: 0.437 train accuracy: 0.867
[2,   100/  469] train loss: 0.246 train accuracy: 0.914
[2,   200/  469] train loss: 0.260 train accuracy: 0.914
[2,   300/  469] train loss: 0.409 train accuracy: 0.883
[2,   400/  469] train loss: 0.264 train accuracy: 0.930
[3,   100/  469] train loss: 0.196 train accuracy: 0.922
[3,   200/  469] train loss: 0.125 train accuracy: 0.977
[3,   300/  469] train loss: 0.187 train accuracy: 0.945
[3,   400/  469] train loss: 0.148 train accuracy: 0.969
[4,   100/  469] train loss: 0.210 train accuracy: 0.945
[4,   200/  469] train loss: 0.155 train accuracy: 0.938
[4,   300/  469] train loss: 0.080 train accuracy: 0.977
[4,   400/  469] train loss: 0.179 train accuracy: 0.938
[5,   100/  469] train loss: 0.181 train accuracy: 0.938
[5,   200/  469] train loss: 0.

# Load Teacher Network

In [None]:
# set the hparams used for training teacher to load the teacher network
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
# keeping dropout input = dropout hidden
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['dropout_input'] = hparam_tuple[0][0]
    hparam['dropout_hidden'] = hparam_tuple[0][1]
    hparam['weight_decay'] = hparam_tuple[1]
    hparam['lr_decay'] = hparam_tuple[2]
    hparam['momentum'] = hparam_tuple[3]
    hparam['lr'] = hparam_tuple[4]
    hparams_list.append(hparam)
    
load_path = checkpoints_path_teacher + hparamToString(hparams_list[0]) + '_final.tar'
teacher_net = TeacherNetwork()
teacher_net.load_state_dict(torch.load(load_path, map_location=fast_device)['model_state_dict'])
teacher_net = teacher_net.to(fast_device)

## Calculate the Teacher accuracy

In [None]:
# Calculate teacher test accuracy
_, test_accuracy = getLossAccuracyOnDataset(teacher_net, test_loader, fast_device)
print('teacher test accuracy: ', test_accuracy)

# Student dataloader without data augmentation

In [None]:
import torchvision
import torchvision.transforms as transforms

# Student trained without data augmentation
transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Train student network without distillation

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [1]    # temperature for distillation loss
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.0]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
# No dropout used
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_no_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_no_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                    train_val_loader, None, 
                                                                    print_every=print_every, 
                                                                    fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_no_distillation_final.tar'
    torch.save({'results' : results_no_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)

Training with hparamsT=1, alpha=0.0, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  469] train loss: 0.336 train accuracy: 0.906
[1,   200/  469] train loss: 0.495 train accuracy: 0.859
[1,   300/  469] train loss: 0.259 train accuracy: 0.891
[1,   400/  469] train loss: 0.249 train accuracy: 0.906
[2,   100/  469] train loss: 0.180 train accuracy: 0.938
[2,   200/  469] train loss: 0.204 train accuracy: 0.969
[2,   300/  469] train loss: 0.278 train accuracy: 0.891
[2,   400/  469] train loss: 0.286 train accuracy: 0.938
[3,   100/  469] train loss: 0.074 train accuracy: 0.984
[3,   200/  469] train loss: 0.110 train accuracy: 0.961
[3,   300/  469] train loss: 0.251 train accuracy: 0.922
[3,   400/  469] train loss: 0.175 train accuracy: 0.953
[4,   100/  469] train loss: 0.187 train accuracy: 0.945
[4,   200/  469] train loss: 0.218 train accuracy: 0.914
[4,   300/  469] train loss: 0.086 train accuracy: 0.977
[4,   400/  

## Calculate student test accuracy

In [None]:
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w/o distillation): ', test_accuracy)

student test accuracy (w/o distillation):  0.9806


# View loaded weight FLOPS and parmater on Teacher and Student (Vanilla)

In [None]:
teacher_net = TeacherNetwork()
teacher_net.load_state_dict(torch.load('checkpoints_teacher/dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05_final.tar', map_location=fast_device)['model_state_dict'])
teacher_net = teacher_net.to(fast_device)

In [None]:
data = train_dataset[1][0].to(fast_device)
out = teacher_net.forward(data)

In [None]:
print_the_model_out(teacher_net)

fc1        (1, 784)             (1, 1200)            torch.Size([1200, 784]) 940800     940800    
fc2        (1, 1200)            (1, 1200)            torch.Size([1200, 1200]) 1440000    1440000   
fc3        (1, 1200)            (1, 10)              torch.Size([10, 1200]) 12000      12000     


In [None]:
student_net = StudentNetwork()
student_net.load_state_dict(torch.load('checkpoints_student/T=1, alpha=0.0, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05_no_distillation_final.tar', map_location=fast_device)['model_state_dict'])
student_net = student_net.to(fast_device)

In [None]:
data = test_dataset[1][0].to(fast_device)
out = student_net.forward(data)

In [None]:
print_the_model_out(student_net)

fc1        (1, 784)             (1, 400)             torch.Size([400, 784]) 313600     313600    
fc2        (1, 400)             (1, 10)              torch.Size([10, 400]) 4000       4000      


# Student Training with Distillation (Temperature = 10)

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [10]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


Training with hparamsT=10, alpha=0.5, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  469] train loss: 4.576 train accuracy: 0.859
[1,   200/  469] train loss: 4.671 train accuracy: 0.859
[1,   300/  469] train loss: 3.365 train accuracy: 0.953
[1,   400/  469] train loss: 3.962 train accuracy: 0.945
[2,   100/  469] train loss: 3.606 train accuracy: 0.953
[2,   200/  469] train loss: 3.654 train accuracy: 0.961
[2,   300/  469] train loss: 3.580 train accuracy: 0.969
[2,   400/  469] train loss: 3.577 train accuracy: 0.953
[3,   100/  469] train loss: 3.372 train accuracy: 1.000
[3,   200/  469] train loss: 3.675 train accuracy: 0.969
[3,   300/  469] train loss: 3.607 train accuracy: 0.945
[3,   400/  469] train loss: 3.646 train accuracy: 0.953
[4,   100/  469] train loss: 3.477 train accuracy: 0.977
[4,   200/  469] train loss: 3.378 train accuracy: 0.922
[4,   300/  469] train loss: 3.405 train accuracy: 0.969
[4,   400/ 

## Calculate student test accuracy

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.985


# Student Training with Distillation (Temperature = 20)

In [None]:
temperatures = [20]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


## Calculate student test accuracy

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

# Student Training with Distillation (Temperature = 5)

In [None]:
temperatures = [5]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


## Calculate student test accuracy

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

# Student Training with Distillation (Temperature = 25)

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [25]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


## Calculate student test accuracy

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

# Student Training with Distillation (Temperature = 30)

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [30]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


## Calculate student test accuracy

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

# Student Training with Distillation (Temperature = 15)

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [15]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


## Calculate student test accuracy

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

# Student training without digit one ground truth

## Dataloader without digit one

In [None]:
import torchvision
import torchvision.transforms as transforms

# Student trained without data augmentation
transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
missing_digit = 1
i = -1
indices_list = []
for img, index in train_val_dataset:
    i += 1
    if  index != missing_digit:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
train_val_no_one_dataset = data_utils.Subset(train_val_dataset, indices_list)

In [None]:
missing_digit = 1
i = -1
indices_list = []
for img, index in test_dataset:
    i += 1
    if  index != missing_digit:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
test_no_one_dataset = data_utils.Subset(test_dataset, indices_list)

In [None]:
train_val_no_one_loader = torch.utils.data.DataLoader(train_val_no_one_dataset, batch_size=128, shuffle=True, num_workers=2)
test_no_one_loader = torch.utils.data.DataLoader(test_no_one_dataset, batch_size=128, shuffle=False, num_workers=2)

## Train student without distillation without one (T=1)

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [1]    # temperature for distillation loss
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.0]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
# No dropout used
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_no_one_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_no_one_no_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


Training with hparamsT=1, alpha=0.0, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  417] train loss: 0.306 train accuracy: 0.922
[1,   200/  417] train loss: 0.438 train accuracy: 0.859
[1,   300/  417] train loss: 0.314 train accuracy: 0.930
[1,   400/  417] train loss: 0.340 train accuracy: 0.906
[2,   100/  417] train loss: 0.302 train accuracy: 0.898
[2,   200/  417] train loss: 0.147 train accuracy: 0.969
[2,   300/  417] train loss: 0.326 train accuracy: 0.906
[2,   400/  417] train loss: 0.206 train accuracy: 0.945
[3,   100/  417] train loss: 0.258 train accuracy: 0.930
[3,   200/  417] train loss: 0.185 train accuracy: 0.961
[3,   300/  417] train loss: 0.130 train accuracy: 0.969
[3,   400/  417] train loss: 0.136 train accuracy: 0.953
[4,   100/  417] train loss: 0.096 train accuracy: 0.969
[4,   200/  417] train loss: 0.178 train accuracy: 0.953
[4,   300/  417] train loss: 0.224 train accuracy: 0.922
[4,   400/  

## DataLoader with only one in it

In [None]:
import torchvision
import torchvision.transforms as transforms

# Student trained without data augmentation
transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
missing_digit = 1
i = -1
indices_list = []
for img, index in train_val_dataset:
    i += 1
    if  index == missing_digit:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
train_val_only_one_dataset = data_utils.Subset(train_val_dataset, indices_list)

In [None]:
missing_digit = 1
i = -1
indices_list = []
for img, index in test_dataset:
    i += 1
    if  index == missing_digit:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
test_only_one_dataset = data_utils.Subset(test_dataset, indices_list)

In [None]:
train_val_only_one_loader = torch.utils.data.DataLoader(train_val_only_one_dataset, batch_size=128, shuffle=True, num_workers=2)
test_only_one_loader = torch.utils.data.DataLoader(test_only_one_dataset, batch_size=128, shuffle=False, num_workers=2)

## Calculate student test accuracy on one

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_only_one_loader, fast_device)
print('student test accuracy (w/o distillation) on one digit: ', test_accuracy)

student test accuracy (w/o distillation) on one digit:  0.0


## Calculate student test accuracy without one in it

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_no_one_loader, fast_device)
print('student test accuracy (w/o distillation): ', test_accuracy)

student test accuracy (w/o distillation):  0.9803722504230118


# Vanilla Distillation with one digit missing (Digit One)

## Dataloader without digit one

In [None]:
import torchvision
import torchvision.transforms as transforms

# Student trained without data augmentation
transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
i = -1
indices_list = []
for img, index in train_val_dataset:
    i += 1
    if  index != 1:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
train_val_no_one_dataset = data_utils.Subset(train_val_dataset, indices_list)

In [None]:
i = -1
indices_list = []
for img, index in test_dataset:
    i += 1
    if  index != 1:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
test_no_one_dataset = data_utils.Subset(test_dataset, indices_list)

In [None]:
train_val_no_one_loader = torch.utils.data.DataLoader(train_val_no_one_dataset, batch_size=128, shuffle=True, num_workers=2)
test_no_one_loader = torch.utils.data.DataLoader(test_no_one_dataset, batch_size=128, shuffle=False, num_workers=2)

## Train student without digit one with distillation (T=10)

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [10]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_no_one_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_no_one_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
            'model_state_dict' : student_net.state_dict(), 
            'epoch' : num_epochs}, save_path)


Training with hparamsT=10, alpha=0.5, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  417] train loss: 4.465 train accuracy: 0.914
[1,   200/  417] train loss: 4.336 train accuracy: 0.891
[1,   300/  417] train loss: 3.743 train accuracy: 0.953
[1,   400/  417] train loss: 4.042 train accuracy: 0.906
[2,   100/  417] train loss: 3.512 train accuracy: 0.914
[2,   200/  417] train loss: 3.846 train accuracy: 0.977
[2,   300/  417] train loss: 3.553 train accuracy: 0.922
[2,   400/  417] train loss: 3.441 train accuracy: 0.953
[3,   100/  417] train loss: 3.568 train accuracy: 0.938
[3,   200/  417] train loss: 3.365 train accuracy: 0.938
[3,   300/  417] train loss: 3.516 train accuracy: 0.977
[3,   400/  417] train loss: 3.395 train accuracy: 0.977
[4,   100/  417] train loss: 3.748 train accuracy: 0.953
[4,   200/  417] train loss: 3.125 train accuracy: 0.961
[4,   300/  417] train loss: 3.170 train accuracy: 0.945
[4,   400/ 

## Calculate student test accuracy with only one in it

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_only_one_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.9850220264317181


## Calculate student test accuracy without one in it

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_no_one_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.9825155104342922


# Student training without digit two ground truth

## Dataloader without digit two

In [None]:
import torchvision
import torchvision.transforms as transforms

# Student trained without data augmentation
transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
i = -1
indices_list = []
for img, index in train_val_dataset:
    i += 1
    if  index != 2:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
train_val_no_two_dataset = data_utils.Subset(train_val_dataset, indices_list)

In [None]:
i = -1
indices_list = []
for img, index in train_val_dataset:
    i += 1
    if  index == 2:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
train_val_only_two_dataset = data_utils.Subset(train_val_dataset, indices_list)

In [None]:
i = -1
indices_list = []
for img, index in test_dataset:
    i += 1
    if  index != 2:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
test_no_two_dataset = data_utils.Subset(test_dataset, indices_list)

In [None]:
i = -1
indices_list = []
for img, index in test_dataset:
    i += 1
    if  index == 2:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
test_only_two_dataset = data_utils.Subset(test_dataset, indices_list)

In [None]:
train_val_no_two_loader = torch.utils.data.DataLoader(train_val_no_two_dataset, batch_size=128, shuffle=True, num_workers=2)
test_no_two_loader = torch.utils.data.DataLoader(test_no_two_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
train_val_only_two_loader = torch.utils.data.DataLoader(train_val_only_two_dataset, batch_size=128, shuffle=True, num_workers=2)
test_only_two_loader = torch.utils.data.DataLoader(test_only_two_dataset, batch_size=128, shuffle=False, num_workers=2)

## Train student without distillation without two (T=1)

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [1]    # temperature for distillation loss
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.0]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
# No dropout used
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_no_two_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_no_two_no_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


Training with hparamsT=1, alpha=0.0, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  423] train loss: 0.361 train accuracy: 0.867
[1,   200/  423] train loss: 0.323 train accuracy: 0.898
[1,   300/  423] train loss: 0.171 train accuracy: 0.961
[1,   400/  423] train loss: 0.302 train accuracy: 0.914
[2,   100/  423] train loss: 0.177 train accuracy: 0.961
[2,   200/  423] train loss: 0.271 train accuracy: 0.938
[2,   300/  423] train loss: 0.163 train accuracy: 0.945
[2,   400/  423] train loss: 0.217 train accuracy: 0.922
[3,   100/  423] train loss: 0.200 train accuracy: 0.930
[3,   200/  423] train loss: 0.146 train accuracy: 0.961
[3,   300/  423] train loss: 0.089 train accuracy: 0.977
[3,   400/  423] train loss: 0.106 train accuracy: 0.977
[4,   100/  423] train loss: 0.171 train accuracy: 0.953
[4,   200/  423] train loss: 0.109 train accuracy: 0.953
[4,   300/  423] train loss: 0.166 train accuracy: 0.953
[4,   400/  

## Calculate student test accuracy with only two in it

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_only_two_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.0


## Calculate student test accuracy without two in it

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_no_two_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.9811552185548618


# Vanilla Distillation with one digit missing (Digit Two)

## Dataloader without digit two

In [None]:
import torchvision
import torchvision.transforms as transforms

# Student trained without data augmentation
transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
i = -1
indices_list = []
for img, index in train_val_dataset:
    i += 1
    if  index != 2:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
train_val_no_two_dataset = data_utils.Subset(train_val_dataset, indices_list)

In [None]:
i = -1
indices_list = []
for img, index in test_dataset:
    i += 1
    if  index != 2:
        # print(i, " index is not one")
        indices_list.append(i)

In [None]:
import torch.utils.data as data_utils
test_no_two_dataset = data_utils.Subset(test_dataset, indices_list)

In [None]:
train_val_no_two_loader = torch.utils.data.DataLoader(train_val_no_two_dataset, batch_size=128, shuffle=True, num_workers=2)
test_no_two_loader = torch.utils.data.DataLoader(test_no_two_dataset, batch_size=128, shuffle=False, num_workers=2)

## Train student without digit two with distillation (T=10)

In [None]:
num_epochs = 60
print_every = 100

In [None]:
temperatures = [10]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = StudentNetwork()
    student_net = student_net.to(fast_device)
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
                                                                train_val_no_two_loader, None, 
                                                                print_every=print_every, 
                                                                fast_device=fast_device)
    save_path = checkpoints_path_student + hparamToString(hparam) + '_no_two_distillation_final.tar'
    torch.save({'results' : results_distill[hparam_tuple], 
                'model_state_dict' : student_net.state_dict(), 
                'epoch' : num_epochs}, save_path)


Training with hparamsT=10, alpha=0.5, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05
[1,   100/  423] train loss: 4.568 train accuracy: 0.836
[1,   200/  423] train loss: 4.346 train accuracy: 0.930
[1,   300/  423] train loss: 3.925 train accuracy: 0.977
[1,   400/  423] train loss: 3.970 train accuracy: 0.914
[2,   100/  423] train loss: 3.889 train accuracy: 0.969
[2,   200/  423] train loss: 3.863 train accuracy: 0.977
[2,   300/  423] train loss: 3.961 train accuracy: 0.977
[2,   400/  423] train loss: 3.233 train accuracy: 0.953
[3,   100/  423] train loss: 3.918 train accuracy: 0.953
[3,   200/  423] train loss: 3.750 train accuracy: 0.969
[3,   300/  423] train loss: 3.446 train accuracy: 0.984
[3,   400/  423] train loss: 3.341 train accuracy: 0.992
[4,   100/  423] train loss: 3.753 train accuracy: 0.953
[4,   200/  423] train loss: 4.013 train accuracy: 0.961
[4,   300/  423] train loss: 3.438 train accuracy: 0.969
[4,   400/ 

## Calculate student test accuracy with only two in it

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_only_two_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.9273255813953488


## Calculate student test accuracy without two in it

In [None]:
# Calculate student test accuracy
_, test_accuracy = getLossAccuracyOnDataset(student_net, test_no_two_loader, fast_device)
print('student test accuracy (w distillation): ', test_accuracy)

student test accuracy (w distillation):  0.9850579839429081
