# Student Teacher Networks

Training thin deep networks following the student-teacher learning paradigm has received intensive attention because of its excellent performance. In such a paradigm, there is a huge neural network known as the teacher network which is expert at performing a certain task. There is also a much smaller student network which learns to perform the same task using some form of guidance from the teacher. 

The student can be small in terms of 1) Depth 2) Number of parameters.

The guidance is provided by the teacher network based on hints in some form or the other. In this notebook we will see one such setup where the guidance is provided by the outputs of the teacher network.

Here are the imports.

In [4]:
import numpy as np
import torch 
import torchvision
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

### Hyperparameters

In [5]:
num_epochs = 5
batch_size = 100
learning_rate = 0.001

### Downloading MNIST data

In [6]:
train_dataset = dsets.MNIST(root='../../data/lab6',
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='../../data/lab6',
                           train=False, 
                           transform=transforms.ToTensor())


Files already downloaded


### Dataloader

In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

### Defining the Teacher Network

A comparitively bigger and deeper network as compared to the student network defined later.

In [8]:
class Teacher(nn.Module):
    def __init__(self):
        super(Teacher, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU())
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc1 = nn.Linear(7*7*32, 300)
        self.fc2 = nn.Linear(300, 10)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.fc2(out)
        return out
    

### Defining the student network

A comparitively smaller and shallower network than the teacher.

In [9]:
class Student(nn.Module):
    def __init__(self):
        super(Student, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc1 = nn.Linear(14*14*16, 10)
        
    def forward(self, x):
        out = self.layer1(x)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        return out
    

<b>The below function is called to reinitialize the weights of the network and define the required loss criterion and the optimizer.</b> 

In [10]:
def reset_model(is_teacher = True):
    if is_teacher == True:
        net = Teacher()
    else:
        net = Student()
    net = net.cuda()


    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    return net,criterion,optimizer

### Training the teacher network

The first step is to train the teacher network to become an expert. We move ahead with regular training procedure using the cross entropy loss and the Adam optimizer.

In [11]:
teacher, criterion, optimizer = reset_model()

In [12]:
# Train the Model

def training(net, reset = True):
    if reset == True:
        net, criterion, optimizer = reset_model()
    else:
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    
    net.train()
    for epoch in range(num_epochs):
        total_loss = 0
        accuracy = []
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda()
            labels = labels.cuda()
            temp_labels = labels
            images = Variable(images)
            labels = Variable(labels)

            # Forward + Backward + Optimize
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.data[0]
            _, predicted = torch.max(outputs.data, 1)
            correct = (predicted == temp_labels).sum()
            accuracy.append(correct/float(batch_size))

        print('Epoch: %d, Loss: %.4f, Accuracy: %.4f' %(epoch+1,total_loss, (sum(accuracy)/float(len(accuracy)))))
    
    return net

### Testing the teacher network

In [13]:
# Test the Model
def testing(net):
    net.eval() 
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.cuda()
        labels = labels.cuda()
        images = Variable(images)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

    print('Test Accuracy of the network on the 10000 test images: %.2f %%' % (100.0 * correct / total))

In [18]:
reset = True
teacher = training(teacher, reset)
testing(teacher)

Epoch: 1, Loss: 77.1952, Accuracy: 0.9619
Epoch: 2, Loss: 30.7803, Accuracy: 0.9840
Epoch: 3, Loss: 22.5943, Accuracy: 0.9882
Epoch: 4, Loss: 20.9438, Accuracy: 0.9887
Epoch: 5, Loss: 16.7684, Accuracy: 0.9909
Test Accuracy of the network on the 10000 test images: 99.03 %


## Parameters for Student Network

Here, we define a few more parameters of the student network. In the student network, we will train with the soft targets as well the hard targets. The soft targets will be calculated by the following equation:

$$
f(z_{i}) = \frac{\exp(z_{i})}{\sum_{j}\exp(z_{j})}
$$

This results in softening out the outputs of the teacher and this can be used as hints for the student network.
<img src='images/stud_teach.png', style="width: 350px;">

The loss doesn't need to get backpropagated accross the teacher network and therefore we make the corresponding modification.

Also, for training witht he soft labels, we use mean square error loss since using a Cross Entropy loss for soft labels makes no sense.

In [15]:
temperature = 1.5
for p in teacher.parameters():
    p.requires_grad= False

student, criterion, optimizer = reset_model(is_teacher = False)
alpha = 0.6

mse_criterion = nn.MSELoss()
softmax = nn.Softmax()

print student

Student (
  (layer1): Sequential (
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  )
  (fc1): Linear (3136 -> 10)
)


### Training and testing the student network

In [16]:
#Train the Model

for epoch in range(num_epochs):
    total_loss = 0
    accuracy = []
    for i, (images, labels) in enumerate(train_loader):
        images = images.cuda()
        labels = labels.cuda()
        temp_labels = labels
        images = Variable(images)
        labels = Variable(labels)
        
        # Forward + Backward + Optimize
        optimizer.zero_grad()
        
        student_outputs = student(images)
        
        hard_outputs = teacher(images)
        soft_outputs = hard_outputs/ temperature
        soft_outputs = softmax(soft_outputs)
        
        hard_loss = criterion(student_outputs, labels)
        soft_loss = mse_criterion(student_outputs, soft_outputs)
        loss = alpha*hard_loss + (1-alpha)*soft_loss
        loss.backward()
        optimizer.step()
        
        total_loss += loss.data[0]
        _, predicted = torch.max(student_outputs.data, 1)
        correct = (predicted == temp_labels).sum()
        accuracy.append(correct/float(batch_size))
    
    print('Epoch: %d, Loss: %.4f, Accuracy: %.4f' %(epoch+1,total_loss, (sum(accuracy)/float(len(accuracy)))))

Epoch: 1, Loss: 354.3688, Accuracy: 0.9353
Epoch: 2, Loss: 313.1935, Accuracy: 0.9691
Epoch: 3, Loss: 305.5597, Accuracy: 0.9732
Epoch: 4, Loss: 300.7783, Accuracy: 0.9759
Epoch: 5, Loss: 297.6083, Accuracy: 0.9768


In [17]:
testing(student)

Test Accuracy of the network on the 10000 test images: 97.69 %


### Excercise

Try out the small student network on the CIFAR dataset. (Easy enough to load with the data loader!)

### References

1. https://arxiv.org/abs/1412.6550
2. https://www.cs.toronto.edu/~hinton/absps/distillation.pdf