In [1]:
import time
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import datasets, transforms
from torchsummary import summary
from torch.optim import lr_scheduler
import torch.nn.functional as F
import torch.nn as nn
import torchvision.models as models
from torch import nn, optim

In [4]:
# # Dowload the dataset
# from torchvision.datasets.utils import download_url
# dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz"a
# download_url(dataset_url, '.')

Downloading https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz to ./cifar10.tgz


  0%|          | 0/135107811 [00:00<?, ?it/s]

In [7]:
# # Extract from archive
# import tarfile
# with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
#     tar.extractall(path='./data')

In [15]:
transform = transforms.Compose([transforms.Resize((224,224)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485,0.456,  
                                0.406], [0.229, 0.224, 0.225])])
trainset = datasets.CIFAR10(root='/content/data/cifar10/train', download=True, train=True, transform=transform)
valset = datasets.CIFAR10(root='/content/data/cifar10/test', download=True, train=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)
len_trainset = len(trainset)
len_valset = len(valset)
classes = ("plane", "car", "bird", "cat","deer", "dog", "frog", "horse", "ship", "truck")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar10/train/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /content/data/cifar10/train/cifar-10-python.tar.gz to /content/data/cifar10/train
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar10/test/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /content/data/cifar10/test/cifar-10-python.tar.gz to /content/data/cifar10/test


In [16]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
print(labels.shape)

torch.Size([64, 3, 224, 224])
torch.Size([64])


In [17]:
resnet = models.resnet50(pretrained=True)
for param in resnet.parameters():
   param.requires_grad = False
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 10)
resnet = resnet.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.fc.parameters())

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [30]:
def train_and_evaluate(model, trainloader, valloader, criterion, optimizer, len_trainset, len_valset, num_epochs=25):

   model.train()
   best_model_wts = copy.deepcopy(model.state_dict())
   best_acc = 0.0
   for epoch in range(num_epochs):
      model.train()
      print("Epoch {}/{}".format(epoch, num_epochs-1))
      print('-' * 10)
      running_loss = 0.0
      running_corrects = 0
      for inputs, labels in trainloader:
         inputs = inputs.to(device)
         labels = labels.to(device)
         optimizer.zero_grad()
         outputs = model(inputs)
         loss = criterion(outputs, labels)
         _, preds = torch.max(outputs, 1)
         loss.backward() 
         optimizer.step()  
         running_loss += loss.item() * inputs.size(0)
         running_corrects += torch.sum(preds == labels.data)
      epoch_loss = running_loss / len_trainset
      epoch_acc = running_corrects.double() / len_trainset
      print("Train Loss: {:.4f} Acc: {:.4f}".format(epoch_loss,epoch_acc)) 
         
      model.eval()
      running_loss_val = 0.0 
      running_corrects_val = 0
      for inputs, labels in valloader:
         inputs = inputs.to(device)
         labels = labels.to(device)
         outputs = model(inputs) 
         loss = criterion(outputs,labels)
         _, preds = torch.max(outputs, 1)
         running_loss_val += loss.item() * inputs.size(0)
         running_corrects_val += torch.sum(preds == labels.data)
      
      epoch_loss_val = running_loss_val / len_valset
      epoch_acc_val = running_corrects_val.double() / len_valset
      
      if epoch_acc_val > best_acc:
         best_acc = epoch_acc_val
         best_model_wts = copy.deepcopy(model.state_dict())
      
      print('Val Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss_val,
             epoch_acc_val))
      
      print()
      print('Best val Acc: {:4f}'.format(best_acc))
      model.load_state_dict(best_model_wts)
   return model

In [32]:
resnet_teacher = train_and_evaluate(resnet,trainloader,
                                   valloader,criterion,optimizer,
                                   len_trainset,len_valset,10)

Epoch 0/9
----------
Train Loss: 0.7527 Acc: 0.7525
Val Loss: 0.5969 Acc: 0.7962

Best val Acc: 0.796200
Epoch 1/9
----------
Train Loss: 0.5909 Acc: 0.7934
Val Loss: 0.5518 Acc: 0.8100

Best val Acc: 0.810000
Epoch 2/9
----------
Train Loss: 0.5650 Acc: 0.8041
Val Loss: 0.5458 Acc: 0.8116

Best val Acc: 0.811600
Epoch 3/9
----------
Train Loss: 0.5485 Acc: 0.8100
Val Loss: 0.5283 Acc: 0.8183

Best val Acc: 0.818300
Epoch 4/9
----------
Train Loss: 0.5302 Acc: 0.8153
Val Loss: 0.5244 Acc: 0.8189

Best val Acc: 0.818900
Epoch 5/9
----------
Train Loss: 0.5178 Acc: 0.8203
Val Loss: 0.4985 Acc: 0.8305

Best val Acc: 0.830500
Epoch 6/9
----------
Train Loss: 0.5089 Acc: 0.8215
Val Loss: 0.5340 Acc: 0.8154

Best val Acc: 0.830500
Epoch 7/9
----------
Train Loss: 0.5144 Acc: 0.8199
Val Loss: 0.5681 Acc: 0.8078

Best val Acc: 0.830500
Epoch 8/9
----------
Train Loss: 0.5140 Acc: 0.8205
Val Loss: 0.5198 Acc: 0.8216

Best val Acc: 0.830500
Epoch 9/9
----------
Train Loss: 0.5091 Acc: 0.8229
Val

In [34]:
torch.save(resnet_teacher.state_dict(),'/content/resnet_teacher_weight.pth')

In [36]:
class Net(nn.Module):
  """
    This will be your student network that will learn from the 
    teacher network in our case resnet50.
    """
  def __init__(self):
    super(Net, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size = (3,3), stride = (1,1), 
        padding = (1,1)),
        nn.ReLU(inplace=True),
        nn.Conv2d(64, 64, kernel_size = (3,3), stride = (1,1), 
        padding = (1,1)),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2, stride=2, padding=0, 
        dilation=1, ceil_mode=False)
    )
    self.layer2 = nn.Sequential(
        nn.Conv2d(64, 128, kernel_size = (3,3), stride = (1,1), 
        padding = (1,1)),
        nn.ReLU(inplace=True),
        nn.Conv2d(128, 128, kernel_size = (3,3), stride = (1,1), 
        padding = (1,1)),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2, stride=2, padding=0, 
        dilation=1, ceil_mode=False)
    )
    self.pool1 = nn.AdaptiveAvgPool2d(output_size=(1,1))
    self.fc1 = nn.Linear(128, 32)
    self.fc2 = nn.Linear(32, 10)
    self.dropout_rate = 0.5

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.pool1(x)
    x = x.view(x.size(0), -1)
    x = self.fc1(x)
    x = self.fc2(x)
    return x
net = Net().to(device)

In [37]:
net

Net(
  (layer1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (pool1): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc1): Linear(in_features=128, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=10, bias=True)
)

In [38]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
out = net(images.cuda())
print(out.shape)

torch.Size([64, 10])


In [40]:
def loss_kd(outputs, labels, teacher_outputs, temparature, alpha):
  KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/temparature, \
            dim=1),F.softmax(teacher_outputs/temparature,dim=1))*\
            (alpha * temparature * temparature)+\
            F.cross_entropy(outputs, labels) *(1.- alpha)
  return KD_loss
def get_outputs(model, dataloader):
   '''
   Used to get the output of the teacher network
   '''
   outputs = []
   for inputs, labels in dataloader:
      inputs_batch, labels_batch = inputs.cuda(), labels.cuda()
      output_batch = model(inputs_batch).data.cpu().numpy()
      outputs.append(output_batch)
   return outputs

In [42]:
def train_kd(model,teacher_out, optimizer, loss_kd, dataloader, temparature, alpha):
   model.train()
   running_loss = 0.0
   running_corrects = 0
   for i,(images, labels) in enumerate(dataloader):
      inputs = images.to(device)
      labels = labels.to(device)
      optimizer.zero_grad()
      outputs = model(inputs)
      outputs_teacher = torch.from_numpy(teacher_out[i]).to(device)
      loss = loss_kd(outputs,labels,outputs_teacher,temparature, 
                     alpha)
      _, preds = torch.max(outputs, 1)
      loss.backward()
      optimizer.step()
      running_loss += loss.item() * inputs.size(0)
      running_corrects += torch.sum(preds == labels.data)
   
   epoch_loss = running_loss / len(trainset)
   epoch_acc = running_corrects.double() / len(trainset)
   print("Train Loss: {:.4f} Acc: {:.4f}".format(epoch_loss,epoch_acc))
def eval_kd(model,teacher_out, optimizer, loss_kd, dataloader, temparature, alpha):
   model.eval()
   running_loss = 0.0
   running_corrects = 0
   for i,(images, labels) in enumerate(dataloader):
      inputs = images.to(device)
      labels = labels.to(device)
      outputs = model(inputs)
      outputs_teacher = torch.from_numpy(teacher_out[i]).cuda()
      loss = loss_kd(outputs,labels,outputs_teacher,temparature, 
                     alpha)
      _, preds = torch.max(outputs, 1)
      running_loss += loss.item() * inputs.size(0)
      running_corrects += torch.sum(preds == labels.data)
   epoch_loss = running_loss / len(valset)
   epoch_acc = running_corrects.double() / len(valset)
   print("Val Loss: {:.4f} Acc: {:.4f}".format(epoch_loss,
          epoch_acc))
   return epoch_acc
def train_and_evaluate_kd(model, teacher_model, optimizer, loss_kd, trainloader, valloader, temparature, alpha, num_epochs=25):
   teacher_model.eval()
   best_model_wts = copy.deepcopy(model.state_dict())
   outputs_teacher_train = get_outputs(teacher_model, trainloader)
   outputs_teacher_val = get_outputs(teacher_model, valloader)
   print("Teacher’s outputs are computed now starting the training process-")
   best_acc = 0.0
   for epoch in range(num_epochs):
      print("Epoch {}/{}".format(epoch, num_epochs-1))
      print("-" * 10)
      
      # Training the student with the soft labes as the outputs 
      #from the teacher and using the loss_kd function
      
      train_kd(model, outputs_teacher_train, 
               optim.Adam(net.parameters()),loss_kd,trainloader, 
               temparature, alpha)
     
      # Evaluating the student network
      epoch_acc_val = eval_kd(model, outputs_teacher_val, 
                          optim.Adam(net.parameters()), loss_kd, 
                          valloader, temparature, alpha)
      if epoch_acc_val > best_acc:
         best_acc = epoch_acc_val
         best_model_wts = copy.deepcopy(model.state_dict())
         print("Best val Acc: {:4f}".format(best_acc))
         model.load_state_dict(best_model_wts)
   return model

In [None]:
stud=train_and_evaluate_kd(net,resnet_teacher,optim.Adam(net.parameters()),loss_kd,trainloader,valloader,1,0.5,20)

In [None]:
torch.save(stud.state_dict(),'/content/student_weight.pth')