#Knowledge distillation notebook


In [None]:
! wget -c 'path_dataset'
! unzip dataset.zip


In [None]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm_notebook
import numpy as np
import time

#start_time = time.time() #timing of ex.
#print("--- %s seconds ---" % (time.time() - start_time))

transform = transforms.Compose([
         transforms.Resize((150,150)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
])


train_data = datasets.ImageFolder('seg_train/seg_train', transform=transform)
test_data = datasets.ImageFolder('seg_test/seg_test',transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=512, shuffle=True)

dataiter = iter(test_loader)
images, labels = dataiter.next()
images.shape


In [None]:
import matplotlib.pyplot as plt

def show_imgs(imgs, labels):
    f, axes= plt.subplots(1, 10, figsize=(30,30))
    for i, axis in enumerate(axes):
      axes[i].imshow(np.squeeze(np.transpose(imgs[i+100].numpy(), (1, 2, 0))), cmap='gray')
      axes[i].set_title(labels[i+100].numpy())
    plt.show()

#show_imgs(images, labels)

#Knowledge from logits
####The student is a simple network with 2 Conv2d Layers and 2 Linear Layers 
####The teacher is an Alexnet network with Linear layer



In [None]:
class Teacher(nn.Module):
  
    def __init__(self):
        super().__init__()
      
        self.conv1 = nn.Conv2d(3,5,kernel_size=3)
        self.maxpool1 = nn.MaxPool2d((2,2))
        self.bn1 = nn.BatchNorm2d(5)
        self.conv2 = nn.Conv2d(5,4,kernel_size=3)
        self.maxpool2 = nn.MaxPool2d((2,2))
        self.bn2 = nn.BatchNorm2d(4)
        self.conv3 = nn.Conv2d(4,3,kernel_size=3)
        

        self.fc1 = nn.Linear(3468,256)
        self.bn3 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256,6)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.maxpool1(x)
        x = self.bn1(x)
        x = F.relu(self.conv2(x))
        x = self.maxpool2(x)
        x = self.bn2(x)
        x = F.relu(self.conv3(x))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.bn3(x)
        x = F.softmax(self.fc2(x))
        return x

from torchvision import models
AlexNet = models.alexnet(pretrained=True)
AlexNet.classifier = nn.Sequential(*list(AlexNet.classifier.children()))[:-1] 

class Sofistic_Teacher(nn.Module):

    def __init__(self):
        super().__init__()
        self.net = AlexNet
        self.fc = nn.Linear(4096,6)

    def forward(self, x):
        x = self.net(x)
        x = F.softmax(self.fc(x))      
        return x

class Teacher_Assistant(nn.Module):

  def __init__(self):
    pass

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3,1, kernel_size = 3)
        self.maxpool = nn.MaxPool2d((2,2))
        self.fc1 = nn.Linear(5476, 128)
        self.fc2 = nn.Linear(128, 6)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.maxpool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


#KD-LOSS

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, alpha, T ):

    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss


def accuracy(outputs, labels):

    outputs = np.argmax(outputs, axis=1)
    return np.sum(outputs==labels)/float(labels.size)

#Train

In [None]:
def train(net, n_epoch=5, teachernet = None , alpha = None, T = None):

  acc_train = []
  acc_test = []
  
  net.train()
  if teachernet:
      teachernet.eval()
  loss_hard = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
  scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             step_size=60, gamma=0.6)
  for epoch in tqdm_notebook(range(n_epoch)):

      running_loss = 0.0
      train_dataiter = iter(train_loader)
      for i, batch in enumerate(tqdm_notebook(train_dataiter)):
          X_batch, y_batch = batch
          optimizer.zero_grad()
          y_pred = net(X_batch)
          if teachernet:
              teacher_outputs = teachernet(X_batch)
              loss = loss_fn_kd(y_pred, y_batch, teacher_outputs, alpha, T)
          else:
              loss = loss_hard(y_pred, y_batch)
          loss.backward()
          optimizer.step()
          scheduler.step()
          running_loss += loss.item()
          
          if i % 25 == 0:
              test_dataiter = iter(test_loader)
              images, labels = test_dataiter.next()
              acc_test.append(accuracy(net.forward(images).detach().numpy(),labels.numpy()))
              acc_train.append(accuracy(y_pred.detach().numpy(),y_batch.numpy()))
              
              print('loss: %.3f,' % (running_loss/25), 'acc: %.3f,' % (accuracy(y_pred.detach().numpy(),y_batch.numpy())))
              running_loss = 0.0
  print('Обучение закончено')
  return net, acc_train, acc_test

In [None]:
start_time = time.time() #timing of ex.

net = Net()
net, acc_train_net, acc_test_net = train(net,3)

print("--- %s seconds ---" % (time.time() - start_time))


test_dataiter = iter(test_loader)
images, labels = test_dataiter.next()
print(accuracy(net.forward(images).detach().numpy(),labels.numpy()))

plt.plot(range(len(acc_train_net)),acc_train_net, label = 'acc_train_net')
plt.plot(range(len(acc_test_net)),acc_test_net, label = 'acc_test_net')
plt.legend()
plt.show()


#GRID SEARCH

In [None]:
alpha = np.array([0,0.2,0.4,0.6,0.9,1])
temperature = np.array([1,5,10,50,200])
nets = np.zeros([len(alpha),len(temperature)])
for i in range(len(alpha)):
  for j in range(len(temperature)):
    print(i,j)
    net3 = Net()
    net3, acc_train_net3, acc_test_net3 = train(net3,3, teacher, alpha[i], temperature[j])
      
    test_dataiter = iter(test_loader)
    images, labels = test_dataiter.next()
    nets[i,j] = accuracy(net3.forward(images).detach().numpy(),labels.numpy())
nets

In [None]:
sof_teacher = Sofistic_Teacher()
sof_teacher = train(sof_teacher, 5)

test_dataiter = iter(test_loader)
images, labels = test_dataiter.next()
print(accuracy(sof_teacher.forward(images).detach().numpy(),labels.numpy()))

In [None]:
nets2 = nets.transpose()
print(nets2.shape)
colours=['r','g','b','k', 'y']
plt.figure()  
for i in range(5):
    plt.plot(alpha,nets2[i],colours[i], label = '{}'.format(temperature[i]))
plt.legend()
plt.rcParams["figure.figsize"] = (5,5)
plt.show()
nets2

In [None]:
'''#In the picture we can see that max value of accuracy are 0.4-0.7        
net = Net()
net = train(net,3)

test_dataiter = iter(test_loader)
images, labels = test_dataiter.next()
print(accuracy(net.forward(images).detach().numpy(),labels.numpy()))
'''
net = Net()
net = train(net,3,teacher, 0.5, 2)

test_dataiter = iter(test_loader)
images, labels = test_dataiter.next()
print(accuracy(net.forward(images).detach().numpy(),labels.numpy()))

In [None]:
#In the picture we can see that max value of accuracy are 0.4-0.7        
net = Net()
net = train(net,3)

test_dataiter = iter(test_loader)
images, labels = test_dataiter.next()
print(accuracy(net.forward(images).detach().numpy(),labels.numpy()))''

'''import torch.nn as nn

modules = []
modules.append(nn.Linear(500, 256))
modules.append(nn.Linear(10, 10))

sequential = nn.Sequential(*modules)'''

In [None]:
net2 = Net()
net2 = train(net2,3,net, 0.5, 2)

test_dataiter = iter(test_loader)
images, labels = test_dataiter.next()
print(accuracy(net.forward(images).detach().numpy(),labels.numpy()))

In [None]:
plt.plot(range(len(acc_train)),acc_train, label = 'acc_train')
plt.plot(range(len(acc_test)),acc_test, label = 'acc_test')
plt.legend()
plt.show()


In [None]:
net3 = Teacher()
net3 = train(net3,3,net, 0.5, 2)

test_dataiter = iter(test_loader)
images, labels = test_dataiter.next()
print(accuracy(net.forward(images).detach().numpy(),labels.numpy()))