#Knowledge Distillation 
We will impliment [TCN](https://www.bmvc2021-virtualconference.com/conference/papers/paper_0831.html) paper. It is a varient of knowledge distillation which uses dense feature vactors instead of logits to transfer knowledge from teacher to student.  

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


# Training base teacher network
This section is not graded

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
import os
from torch.autograd import Variable
import tqdm

batch_size = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=100, shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

teacher = VGG('VGG16')
teacher = teacher.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher.parameters(), lr=0.0001)

def train(epoch):
    print('\nEpoch: %d' % (epoch+1))
    teacher.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = Variable(inputs, requires_grad=False)
        targets = Variable(targets)
        teacher.zero_grad()
        outputs = teacher(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if(batch_idx % 200 == 0):
          print("Accuracy : ",100.*correct/total," Loss : ", train_loss/(batch_idx+1))
def test(epoch):
    teacher.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = teacher(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if(batch_idx % 20 == 0):
              print("Accuracy : ",100.*correct/total," Loss : ", test_loss/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+10):
    train(epoch)
    print("Validation: ")
    test(epoch)


Epoch: 1
Accuracy :  14.0  Loss :  2.3733975887298584
Accuracy :  41.19900497512438  Loss :  1.5865661386233658
Accuracy :  47.80548628428928  Loss :  1.41774071199341
Validation: 
Accuracy :  67.0  Loss :  0.9107257127761841
Accuracy :  65.33333333333333  Loss :  0.9950999220212301
Accuracy :  64.17073170731707  Loss :  1.0056527911162958
Accuracy :  64.47540983606558  Loss :  0.9981735944747925
Accuracy :  64.24691358024691  Loss :  1.005907082999194

Epoch: 2
Accuracy :  64.0  Loss :  0.9784829616546631
Accuracy :  63.975124378109456  Loss :  1.007334947289519
Accuracy :  65.65336658354114  Loss :  0.9614940865378725
Validation: 
Accuracy :  75.0  Loss :  0.7087078094482422
Accuracy :  73.0952380952381  Loss :  0.7782325758820489
Accuracy :  72.07317073170732  Loss :  0.8009776901908037
Accuracy :  72.45901639344262  Loss :  0.7891016011355353
Accuracy :  72.55555555555556  Loss :  0.7947008230803926

Epoch: 3
Accuracy :  64.0  Loss :  0.8603017330169678
Accuracy :  72.074626865671

In [None]:
model_save_name = 'teacher.pt'
path = F"/content/gdrive/MyDrive/{model_save_name}" 
#torch.save(teacher.state_dict(), path)
teacher.load_state_dict(torch.load(path))

<All keys matched successfully>

# Creating Dense Feature Dataset
1.1 In this cell we remove the head of teacher network(i.e: last fullyconnected layer) and add a flatten layer at the end.

In [None]:
from torchsummary import summary
summary(teacher, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,584
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 16, 16]               0
        MaxPool2d-14            [-1, 12

In [None]:
teacher_WOH = nn.Sequential(*list(teacher.children())[:-1],nn.Flatten())

The summery of the new teacher without head :

In [None]:
from torchsummary import summary

In [None]:
#from torchsummary import summary
summary(teacher_WOH, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,584
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 16, 16]               0
        MaxPool2d-14            [-1, 12

1.2 In this cell you have to create dense feature labels dataset(i.e: the outputs of teacher network without head). For that you have to do forward pass on whole dataset and append the outputs in a variable. 

In [None]:
# teacher_WOH.eval()
# DenseTrain = None
# DenseTest = None
# with torch.no_grad():
#     for batch_idx, (inputs, targets) in enumerate(trainloader):
#         inputs, targets = inputs.to(device), targets.to(device)
#         outputs = teacher_WOH(inputs)
#         if(DenseTrain == None):
#             DenseTrain = outputs
#         else:
#             DenseTrain = torch.cat((DenseTrain,outputs))
# with torch.no_grad():
#     for batch_idx, (inputs, targets) in enumerate(testloader):
#         inputs, targets = inputs.to(device), targets.to(device)
#         outputs = teacher_WOH(inputs)
#         if(DenseTest == None):
#             DenseTest = outputs
#         else:
#             DenseTest = torch.cat((DenseTest,outputs))

#Creating ad-hoc student network
we create an ad-hoc student network 

In [None]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    'VGGS': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M', 256, 256, 'M',512,'M'],

}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 512)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

s1 = VGG('VGGS')
s1 = s1.to(device)
summary(s1, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          18,496
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
           Conv2d-11           [-1, 64, 16, 16]          36,928
      BatchNorm2d-12           [-1, 64, 16, 16]             128
             ReLU-13           [-1, 64, 16, 16]               0
        MaxPool2d-14             [-1, 6

# Training Student
We will train the student network using Dense Features that we created.
Dataset datagen will provide data in batches so we need to extract the corresponding batch of targets from our Dense feature variable from 1.2, for this we use the following formula:

batch_index * batch_size --> (batch_index * batch_size) + batch_size

In [None]:
# optimizer = optim.Adam(s1.parameters(), lr=0.0001)
# criterion = nn.MSELoss()

# def train(epoch):
#     print('\nEpoch: %d' % (epoch+1))
#     s1.train()
#     train_loss = 0
#     for batch_idx, (inputs, targets) in enumerate(trainloader):
#         inputs, targets = inputs.to(device), targets.to(device)
#         targets = DenseTrain[batch_idx*batch_size:(batch_idx*batch_size)+batch_size]
        
#         inputs = Variable(inputs, requires_grad=False)
#         targets = Variable(targets)
        
#         s1.zero_grad()
#         outputs = s1(inputs)
#         loss = criterion(outputs, targets)
#         loss.backward()
#         optimizer.step()

#         train_loss += loss.item()
#         if(batch_idx % 10 == 0):
#           print("Loss : ", train_loss/(batch_idx+1))
# def test(epoch):
#     s1.eval()
#     test_loss = 0
#     with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(testloader):
#             inputs, targets = inputs.to(device), targets.to(device)
#             targets = DenseTest[batch_idx*batch_size:(batch_idx*batch_size)+batch_size]
#             outputs = s1(inputs)
#             loss = criterion(outputs, targets)

#             test_loss += loss.item()
#             if(batch_idx % 20 == 0):
#               print(" Loss : ", test_loss/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+60):
    train(epoch)
    print("Validation: ")
    test(epoch)


Epoch: 1
Loss :  3.6982531547546387
Loss :  2.6693457798524336
Loss :  2.040993488970257
Loss :  1.6936085147242392
Loss :  1.467016842307114
Loss :  1.3230298921173693
Loss :  1.2173858148152712
Loss :  1.1410718870834566
Loss :  1.0773069829116633
Loss :  1.0279522839483324
Loss :  0.9826088010674656
Loss :  0.9441749147466711
Loss :  0.913011798188706
Loss :  0.8847250028420951
Loss :  0.8595928154515882
Loss :  0.8373164838513002
Loss :  0.8162185294287545
Loss :  0.7974991399293755
Loss :  0.7824804616567179
Loss :  0.7662618968499268
Loss :  0.751274244405737
Loss :  0.7378353629067046
Loss :  0.7251876513073348
Loss :  0.7137788181955164
Loss :  0.7028869803515707
Loss :  0.6934796735584974
Loss :  0.683782813078599
Loss :  0.6741902095805234
Loss :  0.6649867340763268
Loss :  0.65718315843864
Loss :  0.6493550655849748
Loss :  0.642017930650251
Loss :  0.634771172595544
Loss :  0.6277699426401778
Loss :  0.6213034062791081
Loss :  0.6148867536643972
Loss :  0.6086351433123908


#Finetuning student on crossentropy
Now the student is trained. In this cell we need to replace the classifier(i.e: Fully Connected layer) of student network from one with output shape of dense feature to one with shape of classes. e.g: nn.Linear(256,512) to nn.Linear(256,10). After this we need to freez Conv layers in the network and finetune the network using orignal dataset. 

In [None]:
s1.classifier = nn.Linear(512, 10)
for m in s1.modules():
    if isinstance(m, nn.Conv2d):
        m.weight.requires_grad = False
        if m.bias is not None:
            m.bias.requires_grad = False
s1 = s1.to(device)
summary(s1, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          18,496
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
           Conv2d-11           [-1, 64, 16, 16]          36,928
      BatchNorm2d-12           [-1, 64, 16, 16]             128
             ReLU-13           [-1, 64, 16, 16]               0
        MaxPool2d-14             [-1, 6

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(s1.parameters(), lr=0.0001)

def train(epoch):
    print('\nEpoch: %d' % (epoch+1))
    s1.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = Variable(inputs, requires_grad=False)
        targets = Variable(targets)
        s1.zero_grad()
        outputs = s1(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if(batch_idx % 200 == 0):
          print("Accuracy : ",100.*correct/total," Loss : ", train_loss/(batch_idx+1))
def test(epoch):
    s1.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = s1(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if(batch_idx % 20 == 0):
              print("Accuracy : ",100.*correct/total," Loss : ", test_loss/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+10):
    train(epoch)
    print("Validation: ")
    test(epoch)


Epoch: 1
Accuracy :  13.0  Loss :  2.3406362533569336
Accuracy :  69.44278606965175  Loss :  1.4717139343717205
Accuracy :  77.3640897755611  Loss :  1.1176168737268806
Validation: 
Accuracy :  87.0  Loss :  0.5915505886077881
Accuracy :  84.14285714285714  Loss :  0.5783692726067134
Accuracy :  83.82926829268293  Loss :  0.5829475471159307
Accuracy :  84.09836065573771  Loss :  0.5820076895541832
Accuracy :  84.11111111111111  Loss :  0.5806706735381374

Epoch: 2
Accuracy :  82.0  Loss :  0.5544735193252563
Accuracy :  85.71641791044776  Loss :  0.5207162867138042
Accuracy :  86.07481296758105  Loss :  0.49043705888519856
Validation: 
Accuracy :  87.0  Loss :  0.4678022861480713
Accuracy :  85.0  Loss :  0.4648697092419579
Accuracy :  84.60975609756098  Loss :  0.4711946045480123
Accuracy :  84.75409836065573  Loss :  0.46949492419352296
Accuracy :  84.85185185185185  Loss :  0.4675458275977476

Epoch: 3
Accuracy :  87.0  Loss :  0.40074536204338074
Accuracy :  86.33830845771145  Los

In [None]:
# model_save_name = '1student.pt'
# path = F"/content/gdrive/MyDrive/{model_save_name}" 
# torch.save(s1.state_dict(), path)

In [None]:
model_save_name = '1student.pt'
path = F"/content/gdrive/MyDrive/{model_save_name}" 
s1.load_state_dict(torch.load(path))

<All keys matched successfully>

#Extract denseFeatures from s1

In [None]:
S1_AS_TA_WOH = nn.Sequential(*list(s1.children())[:-1],nn.Flatten())
summary(s1, (3, 32, 32))
summary(S1_AS_TA_WOH, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          18,496
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
           Conv2d-11           [-1, 64, 16, 16]          36,928
      BatchNorm2d-12           [-1, 64, 16, 16]             128
             ReLU-13           [-1, 64, 16, 16]               0
        MaxPool2d-14             [-1, 6

In [None]:
S1_AS_TA_WOH.eval()
S1DenseTrain = None
s1DenseTest = None
with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = S1_AS_TA_WOH(inputs)
        if(S1DenseTrain == None):
            S1DenseTrain = outputs
        else:
            S1DenseTrain = torch.cat((S1DenseTrain,outputs))
with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = S1_AS_TA_WOH(inputs)
        if(s1DenseTest == None):
            s1DenseTest = outputs
        else:
            s1DenseTest = torch.cat((s1DenseTest,outputs))

#Creating 2 more students 

In [None]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    'VGGS': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M', 256, 256, 'M','M'],
    

}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(256, 256)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

s01 = VGG('VGGS')
s01 = s01.to(device)
summary(s01, (3, 32, 32))
s2 = VGG('VGGS')
s2 = s2.to(device)
summary(s2, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          18,496
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
           Conv2d-11           [-1, 64, 16, 16]          36,928
      BatchNorm2d-12           [-1, 64, 16, 16]             128
             ReLU-13           [-1, 64, 16, 16]               0
        MaxPool2d-14             [-1, 6

#Training Multi Students
1.6 In this step you will train two students instead of one. In the training loop you will pass the input from both students and then backwark both the losses. 

In [None]:
optimizer1 = optim.Adam(s01.parameters(), lr=0.0001)
optimizer2 = optim.Adam(s2.parameters(), lr=0.0001)
criterion = nn.MSELoss()

def train(epoch):
    print('\nEpoch: %d' % (epoch+1))
    s01.train()
    train_loss1 = 0
    train_loss2 = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        #code here
        targets = S1DenseTrain[batch_idx*batch_size:(batch_idx*batch_size)+batch_size]
        targets.to(device)
        s01.zero_grad()
        s2.zero_grad()
        output1 = s01(inputs)
        output2 = s2(inputs)
        loss1 = criterion(output1, targets[:,:256])
        loss2 = criterion(output2, targets[:,256:])
        loss1.backward()
        loss2.backward()
        optimizer1.step()
        optimizer2.step()

        train_loss1 += loss1.item()
        train_loss2 += loss2.item()
        if(batch_idx % 10 == 0):
          print("Loss S01: ", train_loss1/(batch_idx+1))
          print("Loss S2: ", train_loss2/(batch_idx+1))
def test(epoch):
    s01.eval()
    test_loss1 = 0
    test_loss2 = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            #code here
            targets = s1DenseTest[batch_idx*batch_size:(batch_idx*batch_size)+batch_size]
            targets.to(device)
            output1 = s01(inputs)
            output2 = s2(inputs)
            loss1 = criterion(output1, targets[:,:256])
            loss2 = criterion(output2, targets[:,256:])
            test_loss1 += loss1.item()
            test_loss2 += loss2.item()
            if(batch_idx % 20 == 0):
              print(" Loss S01: ", test_loss1/(batch_idx+1))
              print(" Loss S2: ", test_loss2/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+60):
    train(epoch)
    print("Validation: ")
    test(epoch)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Loss S2:  0.04375377510404645
Loss S01:  0.038112056333208875
Loss S2:  0.04377395882671633
Loss S01:  0.03808712256583979
Loss S2:  0.043727828505213465
Loss S01:  0.038085202637916246
Loss S2:  0.04373429564425464
Loss S01:  0.03805583996281127
Loss S2:  0.04373177210335721
Loss S01:  0.0380508425017241
Loss S2:  0.04374178717167414
Loss S01:  0.03805911801201776
Loss S2:  0.043720615580836175
Loss S01:  0.038063195307512544
Loss S2:  0.04373016571645429
Loss S01:  0.03802304793824489
Loss S2:  0.04369718822709412
Validation: 
 Loss S01:  0.03389983996748924
 Loss S2:  0.04940960928797722
 Loss S01:  0.0355720402938979
 Loss S2:  0.05561993909733636
 Loss S01:  0.035594048750836674
 Loss S2:  0.05614946591781407
 Loss S01:  0.03548704863327448
 Loss S2:  0.05598935301675171
 Loss S01:  0.03547275038780989
 Loss S2:  0.05580454098957556

Epoch: 17
Loss S01:  0.04076644405722618
Loss S2:  0.0450172983109951
Loss S01:  0.0

#Create ensamble model

1.8 In this step you will create a new network class that takes s1, and s2 as perimeters. This class should initiate a new network that ensembles both s1 and s2, and have a classifier for cross-entropy. In the forward method pass the input x from both s1 and s2 and then concatenate there outputs along axis 1. Then pass this concatinated output through classifier of appropriate shape. 

In [None]:
class Net(nn.Module):
    def __init__(self, s1,s2):
        super(Net, self).__init__()
        self.s1 = s1
        self.s2 = s2
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out1 = self.s1(x)
        out2 = self.s2(x)
        out = torch.cat((out1,out2),1)
        out = self.classifier(out)
        return out
net = Net(s01,s2)
net.to(device)
summary(net, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          18,496
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
           Conv2d-11           [-1, 64, 16, 16]          36,928
      BatchNorm2d-12           [-1, 64, 16, 16]             128
             ReLU-13           [-1, 64, 16, 16]               0
        MaxPool2d-14             [-1, 6

#Train The Ensambled network
1.9 In this step you will freez all the conv layers in the ensambled network and then finetune it on orignal dataset. 

In [None]:
for m in net.modules():
    if isinstance(m, nn.Conv2d):
        m.weight.requires_grad = False
        if m.bias is not None:
            m.bias.requires_grad = False
net = net.to(device)
summary(net, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          18,496
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
           Conv2d-11           [-1, 64, 16, 16]          36,928
      BatchNorm2d-12           [-1, 64, 16, 16]             128
             ReLU-13           [-1, 64, 16, 16]               0
        MaxPool2d-14             [-1, 6

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

def train(epoch):
    print('\nEpoch: %d' % (epoch+1))
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = Variable(inputs, requires_grad=False)
        targets = Variable(targets)
        net.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if(batch_idx % 200 == 0):
          print("Accuracy : ",100.*correct/total," Loss : ", train_loss/(batch_idx+1))
def test(epoch):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if(batch_idx % 20 == 0):
              print("Accuracy : ",100.*correct/total," Loss : ", test_loss/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+10):
    train(epoch)
    print("Validation: ")
    test(epoch)


Epoch: 1
Accuracy :  4.0  Loss :  2.5721871852874756
Accuracy :  75.68656716417911  Loss :  1.0726231840712515
Accuracy :  80.09476309226933  Loss :  0.7963963118277286
Validation: 
Accuracy :  86.0  Loss :  0.4607113301753998
Accuracy :  84.85714285714286  Loss :  0.465903381506602
Accuracy :  84.09756097560975  Loss :  0.4800547274147592
Accuracy :  83.8688524590164  Loss :  0.47959929997803735
Accuracy :  83.93827160493827  Loss :  0.4771355927726369

Epoch: 2
Accuracy :  85.0  Loss :  0.4412436783313751
Accuracy :  85.17910447761194  Loss :  0.4412178424608648
Accuracy :  85.31920199501246  Loss :  0.43435492126870334
Validation: 
Accuracy :  85.0  Loss :  0.4136185050010681
Accuracy :  85.04761904761905  Loss :  0.4412853866815567
Accuracy :  84.36585365853658  Loss :  0.4556760987857493
Accuracy :  84.36065573770492  Loss :  0.4555130300463223
Accuracy :  84.41975308641975  Loss :  0.4533260306458414

Epoch: 3
Accuracy :  85.0  Loss :  0.3970857560634613
Accuracy :  85.686567164

In [None]:
model_save_name = 'ss1student.pt'
path = F"/content/gdrive/MyDrive/{model_save_name}" 
torch.save(s01.state_dict(), path)
# #s01.load_state_dict(torch.load(path))
model_save_name = 'ss2student.pt'
path = F"/content/gdrive/MyDrive/{model_save_name}" 
torch.save(s2.state_dict(), path)
# #s2.load_state_dict(torch.load(path))



---


**Create** **4** **More Students.**



In [None]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    'VGGS1': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M', 256, 256, 'M','M'],
    'VGGS33': [32, 32, 'M', 64, 64, 'M', 128, 'M','M','M'],
    'VGGS2':[32,'M', 32, 'M', 64, 64, 'M', 128,128,'M',128,128, 'M'],
    

}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(128, 128)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

s11 = VGG('VGGS2')
s11 = s11.to(device)
summary(s11, (3, 32, 32))
s22 = VGG('VGGS2')
s22 = s22.to(device)
summary(s22, (3,32,32))
s33 = VGG('VGGS2')
s33 = s33.to(device)
summary(s33, (3, 32, 32))
s44 = VGG('VGGS2')
s44 = s44.to(device)
summary(s44, (3, 32, 32))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
         MaxPool2d-4           [-1, 32, 16, 16]               0
            Conv2d-5           [-1, 32, 16, 16]           9,248
       BatchNorm2d-6           [-1, 32, 16, 16]              64
              ReLU-7           [-1, 32, 16, 16]               0
         MaxPool2d-8             [-1, 32, 8, 8]               0
            Conv2d-9             [-1, 64, 8, 8]          18,496
      BatchNorm2d-10             [-1, 64, 8, 8]             128
             ReLU-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,928
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [None]:
# TA_WOH = nn.Sequential(*list(s01.children())[:-1],nn.Flatten())
# #TA_WOH = nn.Sequential(*list(s1.children())[:],nn.Flatten())
# summary(s01, (3, 32, 32))
# summary(TA_WOH, (3, 32, 32))

In [None]:
# TA2_WOH = nn.Sequential(*list(s2.children())[:-1],nn.Flatten())
# summary(s2, (3, 32, 32))
# summary(TA2_WOH, (3, 32, 32))

In [None]:
# # TA_WOH.eval()
# # TA2_WOH.eval()
# # s1.eval()
# # s2.eval()
# TADenseTrain = None
# TADenseTest = None

# with torch.no_grad():
#     for batch_idx, (inputs, targets) in enumerate(trainloader):
#         inputs, targets = inputs.to(device), targets.to(device)
#         outputs1 = TA_WOH(inputs)
#         outputs2 = TA2_WOH(inputs)
#         if(TADenseTrain == None):
#             TADenseTrain = torch.cat((outputs1,outputs2),1) 
#         else:
#             totalOUTPUT = torch.cat((outputs1,outputs2),1)         
#             TADenseTrain = torch.cat((TADenseTrain,totalOUTPUT))
           
# with torch.no_grad():
#     for batch_idx, (inputs, targets) in enumerate(testloader):
#         inputs, targets = inputs.to(device), targets.to(device)
#         outputs1 = TA_WOH(inputs)
#         outputs2 = TA2_WOH(inputs)
#         if(TADenseTest == None):
#             TADenseTest = torch.cat((outputs1,outputs2),1)
#         else:
#             totalOUTPUT = torch.cat((outputs1,outputs2),1)         
#             TADenseTest = torch.cat((TADenseTest,totalOUTPUT))
           

In [None]:
# print(TADenseTrain.shape)  

In [None]:
optimizer1 = optim.Adam(s11.parameters(), lr=0.0001)
optimizer2 = optim.Adam(s22.parameters(), lr=0.0001)
optimizer3 = optim.Adam(s33.parameters(), lr=0.0001)
optimizer4 = optim.Adam(s44.parameters(), lr=0.0001)
criterion = nn.MSELoss()

def train4(epoch):
    print('\nEpoch: %d' % (epoch+1))
    s11.train()
    train_loss1 = 0
    train_loss2 = 0
    train_loss3 = 0
    train_loss4= 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        #code here
        targets = S1DenseTrain[batch_idx*batch_size:(batch_idx*batch_size)+batch_size]
        targets.to(device)
        s11.zero_grad()
        s22.zero_grad()
        s33.zero_grad()
        s44.zero_grad()
        output1 = s11(inputs)
        output2 = s22(inputs)
        output3 = s33(inputs)
        output4 = s44(inputs)
        
        loss1 = criterion(output1, targets[:,:128])
        loss2 = criterion(output2, targets[:,128:256])
        loss3 = criterion(output3, targets[:,256:384])
        loss4 = criterion(output4, targets[:,384:512])
        loss1.backward()
        loss2.backward()
        loss3.backward()
        loss4.backward()
        optimizer1.step()
        optimizer2.step()
        optimizer3.step()
        optimizer4.step()
      

        train_loss1 += loss1.item()
        train_loss2 += loss2.item()
        train_loss3 += loss3.item()
        train_loss4 += loss4.item()
        if(batch_idx % 10 == 0):
          print("Loss S11: ", train_loss1/(batch_idx+1))
          print("Loss S22: ", train_loss2/(batch_idx+1))
          print("Loss S33: ", train_loss3/(batch_idx+1))
          print("Loss S44: ", train_loss4/(batch_idx+1))
def test4(epoch):
    s11.eval()
    
    test_loss1 = 0
    test_loss2 = 0
    test_loss3 = 0
    test_loss4= 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            #code here
            targets = s1DenseTest[batch_idx*batch_size:(batch_idx*batch_size)+batch_size]
            targets.to(device)
            output1 = s11(inputs)
            output2 = s22(inputs)
            output3 = s33(inputs)
            output4 = s44(inputs)
            loss1 = criterion(output1, targets[:,:128])
            loss2 = criterion(output2, targets[:,128:256])
            loss3 = criterion(output3, targets[:,256:384])
            loss4 = criterion(output4, targets[:,384:512])
            test_loss1 += loss1.item()
            test_loss2 += loss2.item()
            test_loss3 += loss3.item()
            test_loss4 += loss4.item()
            if(batch_idx % 20 == 0):
              print(" Loss S11: ", test_loss1/(batch_idx+1))
              print(" Loss S22: ", test_loss2/(batch_idx+1))
              print(" Loss S33: ", test_loss3/(batch_idx+1))
              print(" Loss S44: ", test_loss4/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+60):
    train4(epoch)
    print("Validation: ")
    test4(epoch)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Loss S44:  0.041593300783557524
Loss S11:  0.03358945997310019
Loss S22:  0.0361513559100134
Loss S33:  0.03852142209065295
Loss S44:  0.04159874789347159
Loss S11:  0.03355841392094095
Loss S22:  0.03615825169148582
Loss S33:  0.03850166783231983
Loss S44:  0.041638054516380286
Loss S11:  0.03358996336148981
Loss S22:  0.03622488099043845
Loss S33:  0.03852689967564474
Loss S44:  0.041659492660652504
Loss S11:  0.03360727827391054
Loss S22:  0.03627605151990031
Loss S33:  0.0385355099377639
Loss S44:  0.041699970517827575
Loss S11:  0.03362002729894877
Loss S22:  0.03627966703986362
Loss S33:  0.038523078594263904
Loss S44:  0.04168877383265799
Loss S11:  0.033613092183183146
Loss S22:  0.036250844070652745
Loss S33:  0.038469527840533986
Loss S44:  0.04164751606046993
Loss S11:  0.033590202295561165
Loss S22:  0.03625302219621622
Loss S33:  0.038446644345528184
Loss S44:  0.041655483491855654
Loss S11:  0.03355972064406

In [None]:
class Net2(nn.Module):
    def __init__(self, s11,s22,s33,s44):
        super(Net2, self).__init__()
        self.s11 = s11
        self.s22 = s22
        self.s33 = s33
        self.s44 = s44
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out1 = self.s11(x)
        out2 = self.s22(x)
        out3 = self.s33(x)
        out4 = self.s44(x)

        out = torch.cat((out1,out2,out3,out4),1)
        out = self.classifier(out)
        return out
net4students = Net2(s11,s22,s33,s44)
net4students.to(device)
summary(net4students, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
         MaxPool2d-4           [-1, 32, 16, 16]               0
            Conv2d-5           [-1, 32, 16, 16]           9,248
       BatchNorm2d-6           [-1, 32, 16, 16]              64
              ReLU-7           [-1, 32, 16, 16]               0
         MaxPool2d-8             [-1, 32, 8, 8]               0
            Conv2d-9             [-1, 64, 8, 8]          18,496
      BatchNorm2d-10             [-1, 64, 8, 8]             128
             ReLU-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,928
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [None]:
for m in net4students.modules():
    if isinstance(m, nn.Conv2d):
        m.weight.requires_grad = False
        if m.bias is not None:
            m.bias.requires_grad = False
net4students = net4students.to(device)
summary(net4students, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
         MaxPool2d-4           [-1, 32, 16, 16]               0
            Conv2d-5           [-1, 32, 16, 16]           9,248
       BatchNorm2d-6           [-1, 32, 16, 16]              64
              ReLU-7           [-1, 32, 16, 16]               0
         MaxPool2d-8             [-1, 32, 8, 8]               0
            Conv2d-9             [-1, 64, 8, 8]          18,496
      BatchNorm2d-10             [-1, 64, 8, 8]             128
             ReLU-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,928
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net4students.parameters(), lr=0.0001)

def train41(epoch):
    print('\nEpoch: %d' % (epoch+1))
    net4students.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = Variable(inputs, requires_grad=False)
        targets = Variable(targets)
        net4students.zero_grad()
        outputs = net4students(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if(batch_idx % 200 == 0):
          print("Accuracy : ",100.*correct/total," Loss : ", train_loss/(batch_idx+1))
def test42(epoch):
    net4students.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net4students(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if(batch_idx % 20 == 0):
              print("Accuracy : ",100.*correct/total," Loss : ", test_loss/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+10):
    train41(epoch)
    print("Validation: ")
    test42(epoch)


Epoch: 1
Accuracy :  26.0  Loss :  2.5862815380096436
Accuracy :  73.69651741293532  Loss :  1.1676214880018092
Accuracy :  78.35910224438902  Loss :  0.8690274864658156
Validation: 
Accuracy :  87.0  Loss :  0.474447101354599
Accuracy :  83.19047619047619  Loss :  0.5040500305947804
Accuracy :  82.70731707317073  Loss :  0.5196995618866711
Accuracy :  82.72131147540983  Loss :  0.5172153574521424
Accuracy :  82.81481481481481  Loss :  0.5134969408865329

Epoch: 2
Accuracy :  84.0  Loss :  0.5228055715560913
Accuracy :  83.56218905472637  Loss :  0.4840512572236322
Accuracy :  83.84788029925187  Loss :  0.47575937690877557
Validation: 
Accuracy :  88.0  Loss :  0.4355241060256958
Accuracy :  84.19047619047619  Loss :  0.47791553678966703
Accuracy :  83.46341463414635  Loss :  0.49652906089294246
Accuracy :  83.29508196721312  Loss :  0.4926418574129949
Accuracy :  83.39506172839506  Loss :  0.4885882275339998

Epoch: 3
Accuracy :  88.0  Loss :  0.4041728675365448
Accuracy :  84.199004

**Multiple Students PART 2-- 8 students now**


In [None]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    'VGGS1': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M', 256, 256, 'M','M'],
    'VGGS33': [32,32,'M', 32,32,'M',32,64,'M', 64,64,'M',64,64,64,'M'],
    'VGGS2': [32,'M', 32, 'M', 64, 64, 'M', 128,128,'M', 256,'M'],

}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(64, 64)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        #out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

ss11 = VGG('VGGS33')
ss11 = ss11.to(device)
summary(ss11, (3, 32, 32))
ss22 = VGG('VGGS33')
ss22 = ss22.to(device)
summary(ss22, (3,32,32))
ss33 = VGG('VGGS33')
ss33 = ss33.to(device)
summary(ss33, (3, 32, 32))
ss44 = VGG('VGGS33')
ss44 = ss44.to(device)
summary(ss44, (3, 32, 32))
ss55 = VGG('VGGS33')
ss55 = ss55.to(device)
summary(ss55, (3, 32, 32))
ss66 = VGG('VGGS33')
ss66 = ss66.to(device)
summary(ss66, (3,32,32))
ss77 = VGG('VGGS33')
ss77 = ss77.to(device)
summary(ss77, (3, 32, 32))
ss88 = VGG('VGGS33')
ss88 = ss88.to(device)
summary(ss88, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 32, 16, 16]           9,248
       BatchNorm2d-9           [-1, 32, 16, 16]              64
             ReLU-10           [-1, 32, 16, 16]               0
           Conv2d-11           [-1, 32, 16, 16]           9,248
      BatchNorm2d-12           [-1, 32, 16, 16]              64
             ReLU-13           [-1, 32, 16, 16]               0
        MaxPool2d-14             [-1, 3

In [None]:
# TAA1_WOH = nn.Sequential(*list(s11.children())[:-1],nn.Flatten())
# summary(s11, (3, 32, 32))
# summary(TAA1_WOH, (3, 32, 32))

# TAA2_WOH = nn.Sequential(*list(s22.children())[:-1],nn.Flatten())
# summary(s22, (3, 32, 32))
# summary(TAA2_WOH, (3, 32, 32))

# TAA3_WOH = nn.Sequential(*list(s33.children())[:-1],nn.Flatten())
# summary(s33, (3, 32, 32))
# summary(TAA3_WOH, (3, 32, 32))

# TAA4_WOH = nn.Sequential(*list(s44.children())[:-1],nn.Flatten())
# summary(s44, (3, 32, 32))
# summary(TAA4_WOH, (3, 32, 32))

In [None]:
# TAA1_WOH.eval()
# TAA2_WOH.eval()
# TAA3_WOH.eval()
# TAA4_WOH.eval()
# TA2DenseTrain = None
# TA2DenseTest = None

# with torch.no_grad():
#     for batch_idx, (inputs, targets) in enumerate(trainloader):
#         inputs, targets = inputs.to(device), targets.to(device)
#         outputs1 = TAA1_WOH(inputs)
#         outputs2 = TAA2_WOH(inputs)
#         outputs3 = TAA3_WOH(inputs)
#         outputs4 = TAA4_WOH(inputs)
#         if(TA2DenseTrain == None):
#             TA2DenseTrain = torch.cat((outputs1,outputs2,outputs3,outputs4),1) 
#         else:
#             totalOUTPUT = torch.cat((outputs1,outputs2,outputs3,outputs4),1)         
#             TA2DenseTrain = torch.cat((TA2DenseTrain,totalOUTPUT))
           
# with torch.no_grad():
#     for batch_idx, (inputs, targets) in enumerate(testloader):
#         inputs, targets = inputs.to(device), targets.to(device)
#         outputs1 = TAA1_WOH(inputs)
#         outputs2 = TAA2_WOH(inputs)
#         outputs3 = TAA3_WOH(inputs)
#         outputs4 = TAA4_WOH(inputs)
#         if(TA2DenseTest == None):
#             TA2DenseTest = torch.cat((outputs1,outputs2,outputs3,outputs4),1) 
#         else:
#             totalOUTPUT = torch.cat((outputs1,outputs2,outputs3,outputs4),1)      
#             TA2DenseTest = torch.cat((TA2DenseTest,totalOUTPUT))
           

In [None]:
# print(TA2DenseTrain.shape)

In [None]:
optimizer1 = optim.Adam(ss11.parameters(), lr=0.0001)
optimizer2 = optim.Adam(ss22.parameters(), lr=0.0001)
optimizer3 = optim.Adam(ss33.parameters(), lr=0.0001)
optimizer4 = optim.Adam(ss44.parameters(), lr=0.0001)
optimizer5 = optim.Adam(ss55.parameters(), lr=0.0001)
optimizer6 = optim.Adam(ss66.parameters(), lr=0.0001)
optimizer7 = optim.Adam(ss77.parameters(), lr=0.0001)
optimizer8 = optim.Adam(ss88.parameters(), lr=0.0001)
criterion = nn.MSELoss()

def train8(epoch):
    print('\nEpoch: %d' % (epoch+1))
    s11.train()
    train_loss1 = 0
    train_loss2 = 0
    train_loss3 = 0
    train_loss4= 0
    train_loss5 = 0
    train_loss6 = 0
    train_loss7 = 0
    train_loss8= 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        #code here
        targets = S1DenseTrain[batch_idx*batch_size:(batch_idx*batch_size)+batch_size]
        targets.to(device)
        ss11.zero_grad()
        ss22.zero_grad()
        ss33.zero_grad()
        ss44.zero_grad()
        ss55.zero_grad()
        ss66.zero_grad()
        ss77.zero_grad()
        ss88.zero_grad()
        output1 = ss11(inputs)
        output2 = ss22(inputs)
        output3 = ss33(inputs)
        output4 = ss44(inputs)
        output5 = ss55(inputs)
        output6 = ss66(inputs)
        output7 = ss77(inputs)
        output8 = ss88(inputs)
        
        loss1 = criterion(output1, targets[:,:64])
        loss2 = criterion(output2, targets[:,64:128])
        loss3 = criterion(output3, targets[:,128:192])
        loss4 = criterion(output4, targets[:,192:256])
        loss5 = criterion(output5, targets[:,256:320])
        loss6 = criterion(output6, targets[:,320:384])
        loss7 = criterion(output7, targets[:,384:448])
        loss8 = criterion(output8, targets[:,448:512])

        loss1.backward()
        loss2.backward()
        loss3.backward()
        loss4.backward()
        loss5.backward()
        loss6.backward()
        loss7.backward()
        loss8.backward()
        optimizer1.step()
        optimizer2.step()
        optimizer3.step()
        optimizer4.step()
        optimizer5.step()
        optimizer6.step()
        optimizer7.step()
        optimizer8.step()
      

        train_loss1 += loss1.item()
        train_loss2 += loss2.item()
        train_loss3 += loss3.item()
        train_loss4 += loss4.item()
        train_loss5 += loss5.item()
        train_loss6 += loss6.item()
        train_loss7 += loss7.item()
        train_loss8 += loss8.item()
        if(batch_idx % 10 == 0):
          print("Loss SS11: ", train_loss1/(batch_idx+1))
          print("Loss SS22: ", train_loss2/(batch_idx+1))
          print("Loss SS33: ", train_loss3/(batch_idx+1))
          print("Loss SS44: ", train_loss4/(batch_idx+1))
          print("Loss SS55: ", train_loss5/(batch_idx+1))
          print("Loss SS66: ", train_loss6/(batch_idx+1))
          print("Loss SS77: ", train_loss7/(batch_idx+1))
          print("Loss SS88: ", train_loss8/(batch_idx+1))
def test8(epoch):
    ss11.eval()
    
    test_loss1 = 0
    test_loss2 = 0
    test_loss3 = 0
    test_loss4= 0
    test_loss5 = 0
    test_loss6 = 0
    test_loss7 = 0
    test_loss8= 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            #code here
            targets =   s1DenseTest[batch_idx*batch_size:(batch_idx*batch_size)+batch_size]
            targets.to(device)
            output1 = ss11(inputs)
            output2 = ss22(inputs)
            output3 = ss33(inputs)
            output4 = ss44(inputs)
            output5 = ss55(inputs)
            output6 = ss66(inputs)
            output7 = ss77(inputs)
            output8 = ss88(inputs)
            loss1 = criterion(output1, targets[:,:64])
            loss2 = criterion(output2, targets[:,64:128])
            loss3 = criterion(output3, targets[:,128:192])
            loss4 = criterion(output4, targets[:,192:256])
            loss5 = criterion(output5, targets[:,256:320])
            loss6 = criterion(output6, targets[:,320:384])
            loss7 = criterion(output7, targets[:,384:448])
            loss8 = criterion(output8, targets[:,448:512])
            test_loss1 += loss1.item()
            test_loss2 += loss2.item()
            test_loss3 += loss3.item()
            test_loss4 += loss4.item()

            test_loss5 += loss5.item()
            test_loss6 += loss6.item()
            test_loss7 += loss7.item()
            test_loss8 += loss8.item()
            if(batch_idx % 20 == 0):
              print(" Loss SS11: ", test_loss1/(batch_idx+1))
              print(" Loss SS22: ", test_loss2/(batch_idx+1))
              print(" Loss SS33: ", test_loss3/(batch_idx+1))
              print(" Loss SS44: ", test_loss4/(batch_idx+1))
              print(" Loss SS55: ", test_loss5/(batch_idx+1))
              print(" Loss SS66: ", test_loss6/(batch_idx+1))
              print(" Loss SS77: ", test_loss7/(batch_idx+1))
              print(" Loss SS88: ", test_loss8/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+60):
    train8(epoch)
    print("Validation: ")
    test8(epoch)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Loss SS33:  0.030099085844157603
Loss SS44:  0.03940099960817096
Loss SS55:  0.041487391828499795
Loss SS66:  0.03372475795943261
Loss SS77:  0.04204665821836428
Loss SS88:  0.037911063174495614
Loss SS11:  0.03513001892576342
Loss SS22:  0.033769242251826044
Loss SS33:  0.03011040861934349
Loss SS44:  0.03943429851992766
Loss SS55:  0.0414892865479587
Loss SS66:  0.033750974361847465
Loss SS77:  0.04206380065056749
Loss SS88:  0.037917338334897216
Loss SS11:  0.03512453000512146
Loss SS22:  0.03378360563494863
Loss SS33:  0.030101229991863532
Loss SS44:  0.03944364233608663
Loss SS55:  0.04150627100264649
Loss SS66:  0.03372454839937588
Loss SS77:  0.04205544401694388
Loss SS88:  0.03789606344616471
Loss SS11:  0.03513397019790074
Loss SS22:  0.033797359417000176
Loss SS33:  0.03011147310236146
Loss SS44:  0.039445145759120975
Loss SS55:  0.041515296789642754
Loss SS66:  0.033730129748936785
Loss SS77:  0.042049729608158

In [None]:
class Net4(nn.Module):
    def __init__(self, ss11,ss22,ss33,ss44,ss55,ss66,ss77,ss88):
        super(Net4, self).__init__()
        self.ss11 = ss11
        self.ss22 = ss22
        self.ss33 = ss33
        self.ss44 = ss44
        self.ss55 = ss55
        self.ss66 = ss66
        self.ss77 = ss77
        self.ss88 = ss88
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out1 = self.ss11(x)
        out2 = self.ss22(x)
        out3 = self.ss33(x)
        out4 = self.ss44(x)
        out5 = self.ss55(x)
        out6 = self.ss66(x)
        out7 = self.ss77(x)
        out8 = self.ss88(x)

        out = torch.cat((out1,out2,out3,out4,out5,out6,out7,out8),1)
        out = self.classifier(out)
        return out
net8students = Net4( ss11,ss22,ss33,ss44,ss55,ss66,ss77,ss88)
net8students.to(device)
summary(net8students, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 32, 16, 16]           9,248
       BatchNorm2d-9           [-1, 32, 16, 16]              64
             ReLU-10           [-1, 32, 16, 16]               0
           Conv2d-11           [-1, 32, 16, 16]           9,248
      BatchNorm2d-12           [-1, 32, 16, 16]              64
             ReLU-13           [-1, 32, 16, 16]               0
        MaxPool2d-14             [-1, 3

In [None]:
for m in net8students.modules():
    if isinstance(m, nn.Conv2d):
        m.weight.requires_grad = False
        if m.bias is not None:
            m.bias.requires_grad = False
net8students = net8students.to(device)
summary(net8students, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
              ReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           9,248
       BatchNorm2d-5           [-1, 32, 32, 32]              64
              ReLU-6           [-1, 32, 32, 32]               0
         MaxPool2d-7           [-1, 32, 16, 16]               0
            Conv2d-8           [-1, 32, 16, 16]           9,248
       BatchNorm2d-9           [-1, 32, 16, 16]              64
             ReLU-10           [-1, 32, 16, 16]               0
           Conv2d-11           [-1, 32, 16, 16]           9,248
      BatchNorm2d-12           [-1, 32, 16, 16]              64
             ReLU-13           [-1, 32, 16, 16]               0
        MaxPool2d-14             [-1, 3

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net8students.parameters(), lr=0.0001)

def train81(epoch):
    print('\nEpoch: %d' % (epoch+1))
    net8students.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = Variable(inputs, requires_grad=False)
        targets = Variable(targets)
        net8students.zero_grad()
        outputs = net8students(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if(batch_idx % 200 == 0):
          print("Accuracy : ",100.*correct/total," Loss : ", train_loss/(batch_idx+1))
def test82(epoch):
    net8students.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net8students(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if(batch_idx % 20 == 0):
              print("Accuracy : ",100.*correct/total," Loss : ", test_loss/(batch_idx+1))

In [None]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+10):
    train81(epoch)
    print("Validation: ")
    test82(epoch)


Epoch: 1
Accuracy :  10.0  Loss :  2.6411001682281494
Accuracy :  64.66666666666667  Loss :  1.5341597134201088
Accuracy :  73.356608478803  Loss :  1.1754024572800519
Validation: 
Accuracy :  87.0  Loss :  0.6132931709289551
Accuracy :  82.9047619047619  Loss :  0.6218314185028985
Accuracy :  82.34146341463415  Loss :  0.6272539547303828
Accuracy :  82.45901639344262  Loss :  0.6230301412402607
Accuracy :  82.46913580246914  Loss :  0.6209356645007192

Epoch: 2
Accuracy :  81.0  Loss :  0.6154868602752686
Accuracy :  82.96019900497512  Loss :  0.5849587328694946
Accuracy :  83.24438902743142  Loss :  0.5587492741254202
Validation: 
Accuracy :  87.0  Loss :  0.4913302958011627
Accuracy :  83.28571428571429  Loss :  0.5138541474228814
Accuracy :  82.92682926829268  Loss :  0.5211354341448807
Accuracy :  83.0  Loss :  0.515239635940458
Accuracy :  82.96296296296296  Loss :  0.5131170065314682

Epoch: 3
Accuracy :  83.0  Loss :  0.45704376697540283
Accuracy :  83.71641791044776  Loss :  