#   0 enviorment and data setup

In [None]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
import os
import glob
import cv2
import tqdm

In [None]:
'''location for data'''
from google.colab import drive
drive.mount('/content/drive')
BASE_PATH = '/content/drive/My Drive/high10'

Mounted at /content/drive


In [None]:
import os
if not os.path.exists('/content/train'):
  !tar --exclude='._*' -xvf /content/drive/My\ Drive/high10/hightrain.tar
if not os.path.exists('/content/test'):
  !tar --exclude='._*' -xvf /content/drive/My\ Drive/high10/hightest.tar
if os.path.exists('/content/test') & os.path.exists('/content/train'):
  print("all done")


In [None]:
import os
data_path = '/content'
traindir = os.path.join(data_path + '/train')
print(traindir)
testdir = os.path.join(data_path + '/test')
print(testdir)

/content/train
/content/test


In [None]:
'''data agumentaiton'''
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.Resize((256,256)),             
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

test_dataset = datasets.ImageFolder(
    testdir,
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize,
    ]))

In [None]:
trainloader = torch.utils.data.DataLoader(train_dataset,batch_size=16, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset,batch_size=16, shuffle=True)

# 1 utils setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

def trainStep(network, criterion, optimizer, X, y):
	"""
	One training step of the network: forward prop + backprop + update parameters
	Return: (loss, accuracy) of current batch
	"""
	optimizer.zero_grad()
	outputs = network(X)
	loss = criterion(outputs, y)
	loss.backward()
	optimizer.step()
	accuracy = float(torch.sum(torch.argmax(outputs, dim=1) == y).item()) / y.shape[0]
	return np.mean(loss), accuracy

def getLossAccuracyOnDataset(network, dataset_loader, fast_device, criterion=None):
	"""
	Returns (loss, accuracy) of network on given dataset
	"""
	network.is_training = False
	accuracy = 0.0
	loss = 0.0
	dataset_size = 0
	for j, D in enumerate(dataset_loader, 0):
		X, y = D
		X = X.to(fast_device)
		y = y.to(fast_device)
		with torch.no_grad():
			pred = network(X)
			if criterion is not None:
				loss += criterion(pred, y) * y.shape[0]
			accuracy += torch.sum(torch.argmax(pred, dim=1) == y).item()
		dataset_size += y.shape[0]
	loss, accuracy = loss / dataset_size, accuracy / dataset_size
	network.is_training = True
	return loss, accuracy

def trainTeacherOnHparam(teacher_net, hparam, num_epochs, 
						train_loader, val_loader, 
						print_every=0, 
						fast_device=torch.device('cpu')):
	"""
	Trains teacher on given hyperparameters for given number of epochs; Pass val_loader=None when not required to validate for every epoch 
	Return: List of training loss, accuracy for each update calculated only on the batch; List of validation loss, accuracy for each epoch
	"""
	train_loss_list, train_acc_list, val_loss_list, val_acc_list = [], [], [], []
	train_losses = []
	train_acces = []
	teacher_net.dropout_input = hparam['dropout_input']
	teacher_net.dropout_hidden = hparam['dropout_hidden']
	criterion = nn.CrossEntropyLoss()
	optimizer = optim.SGD(teacher_net.parameters(), lr=hparam['lr'], momentum=hparam['momentum'], weight_decay=hparam['weight_decay'])
	lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=hparam['lr_decay'])
	for epoch in range(num_epochs):
		lr_scheduler.step()
		if epoch == 0:
			if val_loader is not None:
				val_loss, val_acc = getLossAccuracyOnDataset(teacher_net, val_loader, fast_device, criterion)
				val_loss_list.append(val_loss)
				val_acc_list.append(val_acc)
				print('epoch: %d validation loss: %.3f validation accuracy: %.3f' %(epoch, val_loss, val_acc))
		for i, data in enumerate(train_loader, 0):
			X, y = data
			X, y = X.to(fast_device), y.to(fast_device)
			loss, acc = trainStep(teacher_net, criterion, optimizer, X, y)
			train_loss_list.append(loss)
			train_acc_list.append(acc)

			if print_every > 0 and i % print_every == print_every - 1:
				print('[%d, %5d/%5d] train loss: %.3f train accuracy: %.3f' %
					  (epoch + 1, i + 1, len(train_loader), loss, acc))
		train_acces.append(np.mean(train_acc_list))
		train_losses.append(np.mean(train_acc_list))	
		if val_loader is not None:
			val_loss, val_acc = getLossAccuracyOnDataset(teacher_net, val_loader, fast_device, criterion)
			val_loss_list.append(val_loss)
			val_acc_list.append(val_acc)
			print('epoch: %d validation loss: %.3f validation accuracy: %.3f' %(epoch + 1, val_loss, val_acc))
	return {'train_loss': train_losses, 
			'train_acc': train_acces}

def studentTrainStep(teacher_net, student_net, studentLossFn, optimizer, X, y, T, alpha):
	"""
	One training step of student network: forward prop + backprop + update parameters
	Return: (loss, accuracy) of current batch
	"""
	optimizer.zero_grad()
	teacher_pred = None
	if (alpha > 0):
		with torch.no_grad():
			teacher_pred = teacher_net(X) 
	student_pred = student_net(X)
	# print(student_pred)
	loss = studentLossFn(teacher_pred, student_pred, y, T, alpha)
	loss.backward()
	torch.nn.utils.clip_grad_norm_(student_net.parameters(), 20)
	optimizer.step()
	accuracy = float(torch.sum(torch.argmax(student_pred, dim=1) == y).item()) / y.shape[0]
	return loss, accuracy

def trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs, 
						train_loader, val_loader, 
						print_every=0, 
						fast_device=torch.device('cpu')):
	"""
	Trains teacher on given hyperparameters for given number of epochs; Pass val_loader=None when not required to validate for every epoch
	Return: List of training loss, accuracy for each update calculated only on the batch; List of validation loss, accuracy for each epoch
	"""
	train_loss_list, train_acc_list, val_acc_list = [], [], []
	T = hparam['T']
	alpha = hparam['alpha']
	student_net.dropout_input = hparam['dropout_input']
	student_net.dropout_hidden = hparam['dropout_hidden']
	optimizer = optim.SGD(student_net.parameters(), lr=hparam['lr'], momentum=hparam['momentum'], weight_decay=hparam['weight_decay'])
	lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=hparam['lr_decay'])
	BASE_PATH = '/gdrive/My Drive/colab_files/caifar10_alexnet/'

	def studentLossFn(teacher_pred, student_pred, y, T, alpha):
		"""
		Loss function for student network: Loss = alpha * (distillation loss with soft-target) + (1 - alpha) * (cross-entropy loss with true label)
		Return: loss
		"""
		if (alpha > 0):
			loss = F.kl_div(F.log_softmax(student_pred / T, dim=1), F.softmax(teacher_pred / T, dim=1), reduction='batchmean') * (T ** 2) * alpha + F.cross_entropy(student_pred, y) * (1 - alpha)
		else:
			loss = F.cross_entropy(student_pred, y)
		return loss

	for epoch in range(num_epochs):
		lr_scheduler.step()
		epoch_loss = 0

		if epoch == 0:
			if val_loader is not None:
				_, val_acc = getLossAccuracyOnDataset(student_net, val_loader, fast_device)
				val_acc_list.append(val_acc)
				print('epoch: %d validation accuracy: %.3f' %(epoch, val_acc))
		for i, batch in enumerate(train_loader, 0):
			imgs = batch['image']
			true_masks = batch['mask']
			assert imgs.shape[1] == net.n_channels, \
					f'Network has been defined with {net.n_channels} input channels, ' \
					f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
					'the images are loaded correctly.'

			imgs = imgs.to(device=device, dtype=torch.float32)
			mask_type = torch.float32 if net.n_classes == 1 else torch.long
			true_masks = true_masks.to(device=device, dtype=mask_type)

			masks_pred = net(imgs)
			loss = criterion(masks_pred, true_masks)
			epoch_loss += loss.item()
			optimizer.zero_grad()
			loss.backward()
			nn.utils.clip_grad_value_(net.parameters(), 20)
			optimizer.step()
			# X, y = data
			# X, y = X.to(fast_device), y.to(fast_device)
			# loss, acc = studentTrainStep(teacher_net, student_net, studentLossFn, optimizer, X, y, T, alpha)
			train_loss_list.append(loss)
			train_acc_list.append(acc)
			if print_every > 0 and i % print_every == print_every - 1:
			  print('[%d, %5d/%5d] train loss: %.3f train accuracy: %.3f' %
			    (epoch + 1, i + 1, len(train_loader), loss, acc))
	
		if val_loader is not None:
			_, val_acc = getLossAccuracyOnDataset(student_net, val_loader, fast_device)
			val_acc_list.append(val_acc)
			print('epoch: %d validation accuracy: %.3f' %(epoch + 1, val_acc))
	
	return {'train_loss': train_loss_list, 
			'train_acc': train_acc_list, 
			'val_acc': val_acc_list}

def hparamToString(hparam):
	"""
	Convert hparam dictionary to string with deterministic order of attribute of hparam in output string
	"""
	hparam_str = ''
	for k, v in sorted(hparam.items()):
		hparam_str += k + '=' + str(v) + ', '
	return hparam_str[:-2]

def hparamDictToTuple(hparam):
	"""
	Convert hparam dictionary to tuple with deterministic order of attribute of hparam in output tuple
	"""
	hparam_tuple = [v for k, v in sorted(hparam.items())]
	return tuple(hparam_tuple)

def getTrainMetricPerEpoch(train_metric, updates_per_epoch):
	"""
	Smooth the training metric calculated for each batch of training set by averaging over batches in an epoch
	Input: List of training metric calculated for each batch
	Output: List of training matric averaged over each epoch
	"""
	train_metric_per_epoch = []
	temp_sum = 0.0
	for i in range(len(train_metric)):
		temp_sum += train_metric[i]
		if (i % updates_per_epoch == updates_per_epoch - 1):
			train_metric_per_epoch.append(temp_sum / updates_per_epoch)
			temp_sum = 0.0

	return train_metric_per_epoch

# 2 network 

In [None]:
import torch
from torch import nn

class spectral_pool_layer(nn.Module):
  def __init__(self,filter_size=3,freq_dropout_lower_bound=None,freq_dropout_upper_bound=None,train_phase = False ):
    super(spectral_pool_layer,self).__init__()
    # assert only 1 dimension passed for filter size
    assert isinstance(filter_size, int)
    # input_shape = x.shape
    # assert len(input_shape) == 4
    # _, _, H, W = input_shape
    # assert H == W

    self.filter_size = filter_size
    self.freq_dropout_lower_bound = freq_dropout_lower_bound
    self.freq_dropout_upper_bound = freq_dropout_upper_bound
    self.activation = F
    self.train_phase = train_phase 
  
  def forward(self,x):
    # Compute the Fourier transform of the image
    im_fft = torch.rfft(x,2,onesided = False)

    # Truncate the spectrum
    im_transformed = self._common_spectral_pool(im_fft, self.filter_size)

    if ( self.freq_dropout_lower_bound is not None and self.freq_dropout_upper_bound is not None):
      def true_fn():
      		tf_random_cutoff = tf.random_uniform(
						[],
						freq_dropout_lower_bound,
						freq_dropout_upper_bound
					)
      		dropout_mask = _frequency_dropout_mask(
						filter_size,
						tf_random_cutoff
					)
      		return im_transformed * dropout_mask

				# In the testing phase, return the truncated frequency
				# matrix unchanged.
      def false_fn():
      		return im_transformed
      im_downsampled = tf.cond(
					self.train_phase,
					true_fn=true_fn,
					false_fn=false_fn
				)
      im_out = torch.irfft(im_downsampled,2, onesided=False)
		
    else:
      im_out = torch.irfft(im_transformed,2, onesided=False)
    
    if self.activation is not None:
      		cell_out = self.activation.relu(im_out)
    else:
      cell_out = im_out
    return cell_out

  def _common_spectral_pool(self,images, filter_size):
    assert len(images.shape) == 5
    assert filter_size >= 3
	
    if filter_size % 2 == 1:
      n = int((filter_size-1)/2)
      top_left = images[:, :, :n+1, :n+1]
      top_right = images[:, :, :n+1, -n:]
      bottom_left = images[:, :, -n:, :n+1]
      bottom_right = images[:, :, -n:, -n:]
      top_combined = torch.cat([top_left, top_right], axis=-2)
      # print(top_combined.shape)
      bottom_combined = torch.cat([bottom_left, bottom_right], axis=-2)
      # print(bottom_combined.shape)
      all_together = torch.cat([top_combined, bottom_combined], axis=-3)
      return all_together

In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from scipy import signal
import scipy
from torch import nn
import time

def dft_conv(imgR,imgIm,kernelR,kernelIm):

    # Fast complex multiplication
    ac = torch.mul(kernelR, imgR)
    bd = torch.mul(kernelIm, imgIm)
    
    ab_cd = torch.mul(torch.add(kernelR, kernelIm), torch.add(imgR, imgIm))
    # print(ab_cd.sum(1)[0,0,:,:])
    imgsR = ac - bd
    imgsIm = ab_cd - ac - bd

    # Sum over in channels
    imgsR = imgsR.sum(1)
    imgsIm = imgsIm.sum(1)


    return imgsR,imgsIm

class FFT_Conv_Layer(nn.Module):

    def __init__(self,imgSize,inCs,outCs,imagDim,filtSize,cuda=False):

        super(FFT_Conv_Layer, self).__init__()
        self.filts = np.random.normal(0,0.01,(1,inCs,outCs,filtSize,filtSize,imagDim))
        self.imgSize = imgSize
        self.filtSize = np.size(self.filts,4)

        if cuda:
            self.filts = torch.from_numpy(self.filts).type(torch.float32).cuda()
            self.filts = Parameter(self.filts)
        

    def forward(self,imgs):

        # Pad and transform the image
        # Pad arg = (last dim pad left side, last dim pad right side, 2nd last dim left side, etc..)
        # imgs = torch.randn(batchSize,inCs,1,imgSize, imgSize,imagDim).cuda()
        imgs = imgs.unsqueeze(2)
        imgs = imgs.unsqueeze(5)

        imgs = F.pad(imgs, (0, 0, 0, self.filtSize - 1, 0,self.filtSize - 1))
        imgs = imgs.squeeze(5)

        imgs = torch.rfft(imgs,2,onesided= False)
        # print(imgs.shape)

        # Extract the real and imaginary parts
        imgsR = imgs[:, :, :, :, :, 0]
        imgsIm = imgs[:, :, :, :, :, 1]
        

        # Pad and transform the filters
        filts = F.pad(self.filts, (0, 0, 0, self.imgSize - 1, 0, self.imgSize - 1))

        filts = torch.fft(filts, 2)

        # Extract the real and imaginary parts
        filtR = filts[:, :, :, :, :, 0]
        filtIm = filts[:, :, :, :, :, 1]

        # Do element wise complex multiplication
        imgsR, imgsIm = dft_conv(imgsR,imgsIm,filtR,filtIm)

        # Add dim to concat over
        imgsR = imgsR.unsqueeze(4)
        imgsIm = imgsIm.unsqueeze(4)

        # Concat the real and imaginary again then IFFT
        imgs = torch.cat((imgsR,imgsIm),-1)
        #print("1",imgs.shape)
        imgs = torch.ifft(imgs,2)
        #print("2",imgs.shape)

        # Filter and imgs were real so imag should be ~0
        imgs = imgs[:,:,1:-1,1:-1,0]
        #print("3",imgs.shape)

        return imgs

In [None]:
class StudentNetwork_noRelu(nn.Module):
  def __init__(self):
    super(StudentNetwork_noRelu,self).__init__()
    self.conv1 = FFT_Conv_Layer(imgSize = 224 ,inCs = 3,outCs = 32,imagDim =2,filtSize = 3,cuda=True)
    # self.conv1 = nn.Conv2d(3,64, kernel_size=11,stride = 4,padding = 2)
    self.conv2 = FFT_Conv_Layer(imgSize = 113 ,inCs = 32,outCs =64,imagDim =2,filtSize = 3,cuda=True)
    self.conv3 = FFT_Conv_Layer(imgSize = 27 ,inCs = 64,outCs = 256,imagDim =2,filtSize = 3,cuda=True)
    self.fc1 = nn.Linear(9216, 512)
    self.fc2 = nn.Linear(512,256)
    self.fc3 = nn.Linear(256,10)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    self.dropout_input = 0.5
    self.dropout_hidden = 0.5
    self.is_training = True
    self.avepool = nn.AdaptiveAvgPool2d((6,6))
    self.m = nn.LogSoftmax(dim=1)
    self.max_113 = spectral_pool_layer(113)
    self.max_27 = spectral_pool_layer(27)

    
  
  def forward(self,x):
    #print(x.shape)
    forw = self.conv1(x)
    # print(forw.shape)
    #forw = self.max_113(torch.square(self.conv1(x)))
    #forw = self.max_27(torch.square(self.conv2(forw)))
    forw = self.max_113(self.conv1(x))
    forw = self.max_27(self.conv2(forw))
    # # print(forw.shape)
    forw = self.conv3(forw)
    # # print(forw.shape)
    forw = self.maxpool(forw)
    forw = self.avepool(forw)
    forw = forw.view(-1,9216)
    forw = F.dropout(forw, p=self.dropout_input, training=self.is_training)
    forw = F.dropout(self.fc1(forw), p=self.dropout_hidden, training=self.is_training)
    forw = F.relu(forw)
    forw = self.fc2(forw)
    forw = F.relu(forw)
    forw = self.fc3(forw)
    return self.m(forw)

time test

In [None]:
inputs = torch.rand([1,3,224,224],dtype= torch.float)
print(inputs.shape)
inputs = inputs.to("cuda")
model = StudentNetwork_noRelu()
model.to('cuda')


import datetime
starttime = datetime.datetime.now()
#long running
for i in range(100):
  model.forward(inputs)

endtime = datetime.datetime.now()
print (endtime - starttime)


torch.Size([1, 3, 224, 224])
0:00:04.546426


# 3 pre-training for students 

In [None]:
device = 'cuda'
model = StudentNetwork_noRelu()
# model = teacher_net
criterion = nn.NLLLoss()
#optimizer = optim.Adam(model.parameters(), lr=0.003)
optimizer = optim.SGD(model.parameters(), lr=0.003)
model.to(device)

from tqdm import tqdm

traininglosses = []
testinglosses = []
testaccuracy = []
totalsteps = []
epochs = 50
steps = 0
running_loss = 0
print_every = 100
for epoch in range(epochs):
    for inputs, labels in trainloader:
        steps += 1
        # Move input and label tensors to the default device
        inputs, labels = inputs.to(device), labels.to(device)
        print(inputs.shape)
        
        optimizer.zero_grad()
        
        logps = model.forward(inputs)
        loss = criterion(logps, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0
            model.eval()
            with torch.no_grad():
                for inputs, labels in testloader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    logps = model.forward(inputs)
                    batch_loss = criterion(logps, labels)
                    
                    test_loss += batch_loss.item()
                    
                    # Calculate accuracy
                    ps = torch.exp(logps)
                    top_p, top_class = ps.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
            
            traininglosses.append(running_loss/print_every)
            testinglosses.append(test_loss/len(testloader))
            testaccuracy.append(accuracy/len(testloader))
            totalsteps.append(steps)
            print(f"Device {device}.."
                  f"Epoch {epoch+1}/{epochs}.. "
                  f"Step {steps}.. "
                  f"Train loss: {running_loss/print_every:.3f}.. "
                  f"Test loss: {test_loss/len(testloader):.3f}.. "
                  f"Test accuracy: {accuracy/len(testloader):.3f}")
            running_loss = 0
            model.train()

# 4 load the pretrain student model

In [None]:
import torch
pretrained_dict = torch.load(BASE_PATH+'/net_params.pkl')
net = StudentNetwork_noRelu()
net.load_state_dict(pretrained_dict) 
device = "cuda"
net.to(device)

In [None]:
_, test_accuracy = getLossAccuracyOnDataset(net, testloader, device)
print('teacher test accuracy (w distillation): ', test_accuracy)

teacher test accuracy (w distillation):  0.7272727272727273


In [None]:
for name, param in net.named_parameters():
  if ("conv" in name):
    param.requires_grad = False
  else:
    param.requires_grad = True


In [None]:
for name, param in net.named_parameters():
  print(name,param.requires_grad)

conv1.filts False
conv2.filts False
conv3.filts False
fc1.weight True
fc1.bias True
fc2.weight True
fc2.bias True
fc3.weight True
fc3.bias True


# 5 load the teacher model

In [None]:
model = torch.load(BASE_PATH +"/teacher.pkl")
model

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (fc1): Linear(in_features=9216, out_features=1024, bias=True)
    (relu1): ReLU()
    (fc2): Linear(

In [None]:
x = torch.zeros([10,3,224,224])
print(x.shape)
model.features[0:11](x.cuda()).shape

torch.Size([10, 3, 224, 224])


torch.Size([10, 256, 13, 13])

In [None]:
_, test_accuracy = getLossAccuracyOnDataset(model, testloader, "cuda")
print('teacher test accuracy (w distillation): ', test_accuracy)

teacher test accuracy (w distillation):  0.939572192513369


# 6 knowledge distillation

In [None]:
import numpy as np
import math

%matplotlib inline
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
    
%load_ext autoreload
%autoreload 2
num_epochs = 50
print_every = 200

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
temperatures = [10]
# trade-off between soft-target (st) cross-entropy and true-target (tt) cross-entropy;
# loss = alpha * st + (1 - alpha) * tt
alphas = [0.5]
learning_rates = [0.0001]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0,0)]
hparams_list = []
for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, 
                                        momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)
print(hparams_list)

[{'alpha': 0.5, 'T': 10, 'dropout_input': 0, 'dropout_hidden': 0, 'weight_decay': 1e-05, 'lr_decay': 0.95, 'momentum': 0.9, 'lr': 0.0001}]


In [None]:
device = torch.device("cuda")
use_gpu = True
# Ensure reproducibility
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

In [None]:
model.to(device)
results_distill = {}

for hparam in hparams_list:
    print('Training with hparams' + hparamToString(hparam))
    reproducibilitySeed()
    student_net = student_net
    hparam_tuple = hparamDictToTuple(hparam)
    results_distill[hparam_tuple] = trainStudentOnHparam(model, student_net, hparam, num_epochs, 
                                                                trainloader, testloader, 
                                                                print_every=print_every, 
                                                                fast_device = device)

Training with hparamsT=10, alpha=0.5, dropout_hidden=0, dropout_input=0, lr=0.0001, lr_decay=0.95, momentum=0.9, weight_decay=1e-05




epoch: 0 validation accuracy: 0.780
[1,   200/  625] train loss: 2.211 train accuracy: 0.688
[1,   400/  625] train loss: 0.879 train accuracy: 0.812
[1,   600/  625] train loss: 1.973 train accuracy: 0.688
epoch: 1 validation accuracy: 0.788
[2,   200/  625] train loss: 1.973 train accuracy: 0.938
[2,   400/  625] train loss: 3.238 train accuracy: 0.750
[2,   600/  625] train loss: 2.112 train accuracy: 0.875
epoch: 2 validation accuracy: 0.787
[3,   200/  625] train loss: 1.870 train accuracy: 0.938
[3,   400/  625] train loss: 3.349 train accuracy: 1.000
[3,   600/  625] train loss: 2.072 train accuracy: 0.875
epoch: 3 validation accuracy: 0.789
[4,   200/  625] train loss: 4.293 train accuracy: 0.875
[4,   400/  625] train loss: 3.130 train accuracy: 0.750
[4,   600/  625] train loss: 1.857 train accuracy: 0.688
epoch: 4 validation accuracy: 0.787
[5,   200/  625] train loss: 3.130 train accuracy: 0.750
[5,   400/  625] train loss: 1.783 train accuracy: 0.750
[5,   600/  625] train

In [None]:
torch.save(student_net.state_dict(), BASE_PATH+'/net_params_new.pkl')

In [None]:
_, test_accuracy = getLossAccuracyOnDataset(student_net, testloader, device)
print('teacher test accuracy (w distillation): ', test_accuracy)

teacher test accuracy (w distillation):  0.7983957219251336
