In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)


In [None]:
teacher_model = torchvision.models.resnet18(pretrained=True)
teacher_model.fc = nn.Linear(teacher_model.fc.in_features, 10)  # Modify the last layer to output 10 classes as CIFAR 10 has 10 classes
teacher_model.eval()




ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


In [None]:
student_model = LightNN()
teacher_model.to(device)
student_model.to(device)



LightNN(
  (features): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [None]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy


In [None]:
torch.manual_seed(12)
#test accuracy of teacher model
nn_light = LightNN(num_classes=10).to(device)
train(teacher_model,train_loader,10,0.001,device)
      # Instantiate the lightweight network:
torch.manual_seed(12)
nn_light = LightNN(num_classes=10).to(device)

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1/10, Loss: 0.8913398543587121
Epoch 2/10, Loss: 0.5771907635052186
Epoch 3/10, Loss: 0.44804791869867183
Epoch 4/10, Loss: 0.3595827208531787
Epoch 5/10, Loss: 0.27550305530924324
Epoch 6/10, Loss: 0.21650786953204124
Epoch 7/10, Loss: 0.17429714370757113
Epoch 8/10, Loss: 0.14862033497075292
Epoch 9/10, Loss: 0.12260043871639024
Epoch 10/10, Loss: 0.10782287292696936


In [None]:
test_teacher_acc = test(teacher_model,test_loader,device) # After training the teacher model i.e ResNet18

Test Accuracy: 80.30%


In [None]:
train(student_model, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(student_model, test_loader, device)

Epoch 1/10, Loss: 1.4839133606542407
Epoch 2/10, Loss: 1.1422722562194785
Epoch 3/10, Loss: 1.0060766395705436
Epoch 4/10, Loss: 0.9065643340120535
Epoch 5/10, Loss: 0.8275762406151618
Epoch 6/10, Loss: 0.7629977048510481
Epoch 7/10, Loss: 0.6959948579368689
Epoch 8/10, Loss: 0.6385537817350129
Epoch 9/10, Loss: 0.5796255488377398
Epoch 10/10, Loss: 0.5325781712141793
Test Accuracy: 70.94%


In [None]:
new_nn_light = LightNN(num_classes=10).to(device)

def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size(0) * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")


In [None]:

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=teacher_model, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)


Epoch 1/10, Loss: 2.255328537862929
Epoch 2/10, Loss: 1.7522254224933322
Epoch 3/10, Loss: 1.547888114324311
Epoch 4/10, Loss: 1.3963010968149776
Epoch 5/10, Loss: 1.2733512610730613
Epoch 6/10, Loss: 1.172116348810513
Epoch 7/10, Loss: 1.0820440149977995
Epoch 8/10, Loss: 1.005013661463852
Epoch 9/10, Loss: 0.931414150063644
Epoch 10/10, Loss: 0.8675993246495571
Test Accuracy: 71.01%


In [None]:

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test(teacher_model,test_loader,device):.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Test Accuracy: 80.30%
Teacher accuracy: 80.30%
Student accuracy without teacher: 70.94%
Student accuracy with CE + KD: 71.01%


In [None]:
num_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
print("Number of trainable parameters:", num_params)


Number of trainable parameters: 267738


### Choosing a less complex student model

In [None]:
class SimpleNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(8, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(8 * 8 * 8, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


student_model2 = SimpleNN()
student_model2.to(device)

SimpleNN(
  (features): Sequential(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=512, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [None]:
num_params = sum(p.numel() for p in student_model2.parameters() if p.requires_grad)
print("Number of trainable parameters:", num_params)


Number of trainable parameters: 34290


### Have reduced the training parameters 8-fold


In [None]:
train(student_model2, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce2 = test(student_model, test_loader, device) # have tested the earlier student model here, have tested the newer trained model 2 blocks ahead

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1/10, Loss: 1.6995985050640448
Epoch 2/10, Loss: 1.4124027164390935
Epoch 3/10, Loss: 1.3158904317089968
Epoch 4/10, Loss: 1.2542868084309962
Epoch 5/10, Loss: 1.2044663305782601
Epoch 6/10, Loss: 1.1705441418511178
Epoch 7/10, Loss: 1.1365708762117663
Epoch 8/10, Loss: 1.1128534750865244
Epoch 9/10, Loss: 1.0901222395165193
Epoch 10/10, Loss: 1.0715871628592997
Test Accuracy: 70.94%


In [None]:
new_nn_light2 = SimpleNN(num_classes=10).to(device)
# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=teacher_model, student=new_nn_light2, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd2 = test(new_nn_light2, test_loader, device)


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1/10, Loss: 2.7138322710686023
Epoch 2/10, Loss: 2.219245237462661
Epoch 3/10, Loss: 2.0599898176120064
Epoch 4/10, Loss: 1.949792299429169
Epoch 5/10, Loss: 1.8652809367460363
Epoch 6/10, Loss: 1.8011495612771309
Epoch 7/10, Loss: 1.7504935206659615
Epoch 8/10, Loss: 1.713055707609562
Epoch 9/10, Loss: 1.6756213196098346
Epoch 10/10, Loss: 1.6412708597719823
Test Accuracy: 62.50%


In [None]:
print(f"Teacher accuracy: {test(teacher_model,test_loader,device):.2f}%")
print(f"Student accuracy without teacher: {test(student_model2, test_loader, device):.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd2:.2f}%")

Test Accuracy: 80.30%
Teacher accuracy: 80.30%
Test Accuracy: 62.46%
Student accuracy without teacher: 62.46%
Student accuracy with CE + KD: 62.50%


### Testing with the fractions of CE loss and distillation loss on the earlier student architecture. Originally, the split is 0.75 CE and 0.25 DE

In [None]:
new_student_ = LightNN(num_classes=10).to(device)
newer_student_ = LightNN(num_classes=10).to(device)

In [None]:
train_knowledge_distillation(teacher=teacher_model, student=new_student_, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.5, ce_loss_weight=0.5, device=device)
test_accuracy_light_ce_and_kd = test(new_student_, test_loader, device)


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1/10, Loss: 3.064442022987034
Epoch 2/10, Loss: 2.3497703560173053
Epoch 3/10, Loss: 2.0468436199076034
Epoch 4/10, Loss: 1.8678961894701205
Epoch 5/10, Loss: 1.706568258192838
Epoch 6/10, Loss: 1.5891584795149392
Epoch 7/10, Loss: 1.4777253796072567
Epoch 8/10, Loss: 1.379095415325116
Epoch 9/10, Loss: 1.2905288721289476
Epoch 10/10, Loss: 1.2057416649425732
Test Accuracy: 70.67%


In [None]:
print(f"Teacher accuracy: {test(teacher_model,test_loader,device):.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE - 0.5 + KD - 0.5: {test_accuracy_light_ce_and_kd:.2f}%")

Test Accuracy: 80.30%
Teacher accuracy: 80.30%
Student accuracy without teacher: 70.94%
Student accuracy with CE - 0.5 + KD - 0.5: 70.67%


In [None]:
train_knowledge_distillation(teacher=teacher_model, student=newer_student_, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.75, ce_loss_weight=0.25, device=device)
test_accuracy_light_ce_and_kd_new = test(newer_student_, test_loader, device)


Epoch 1/10, Loss: 3.7721775011028473
Epoch 2/10, Loss: 2.8648899132028567
Epoch 3/10, Loss: 2.491447638977519
Epoch 4/10, Loss: 2.2353563915433177
Epoch 5/10, Loss: 2.038106931749817
Epoch 6/10, Loss: 1.8922760700020949
Epoch 7/10, Loss: 1.7563159075539436
Epoch 8/10, Loss: 1.6474022938467352
Epoch 9/10, Loss: 1.5314176097855239
Epoch 10/10, Loss: 1.448274991701326
Test Accuracy: 71.00%


In [None]:
print(f"Teacher accuracy: {test(teacher_model,test_loader,device):.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE - 0.25 + KD - 0.75: {test_accuracy_light_ce_and_kd_new:.2f}%")

Test Accuracy: 80.30%
Teacher accuracy: 80.30%
Student accuracy without teacher: 70.94%
Student accuracy with CE - 0.25 + KD - 0.75: 71.00%


### Effect of the parameter T. Earlier it was kept = 2

In [None]:
new__student = LightNN(num_classes=10).to(device)
newer__student = LightNN(num_classes=10).to(device)

In [None]:
train_knowledge_distillation(teacher=teacher_model, student=new__student, train_loader=train_loader, epochs=10, learning_rate=0.001, T=1.5, soft_target_loss_weight=0.5, ce_loss_weight=0.5, device=device)
test_accuracy_light_ce_and_kd_ = test(new__student, test_loader, device)


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1/10, Loss: 2.166744774869641
Epoch 2/10, Loss: 1.6616742272511162
Epoch 3/10, Loss: 1.460659090820176
Epoch 4/10, Loss: 1.3118485248912022
Epoch 5/10, Loss: 1.1865622020133622
Epoch 6/10, Loss: 1.0955751151075144
Epoch 7/10, Loss: 1.0138910421934884
Epoch 8/10, Loss: 0.9373732555247939
Epoch 9/10, Loss: 0.8705272830050924
Epoch 10/10, Loss: 0.8093018786376699
Test Accuracy: 70.65%


In [None]:
print(f"Teacher accuracy: {test(teacher_model,test_loader,device):.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD + T = 1.5: {test_accuracy_light_ce_and_kd_:.2f}%")

Test Accuracy: 80.30%
Teacher accuracy: 80.30%
Student accuracy without teacher: 70.94%
Student accuracy with CE + KD + T = 1.5: 70.65%


In [None]:
train_knowledge_distillation(teacher=teacher_model, student=newer__student, train_loader=train_loader, epochs=10, learning_rate=0.001, T=1.01, soft_target_loss_weight=0.5, ce_loss_weight=0.5, device=device)
test_accuracy_light_ce_and_kd__ = test(newer__student, test_loader, device)


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1/10, Loss: nan
Epoch 2/10, Loss: nan
Epoch 3/10, Loss: nan
Epoch 4/10, Loss: nan
Epoch 5/10, Loss: nan
Epoch 6/10, Loss: nan
Epoch 7/10, Loss: nan
Epoch 8/10, Loss: nan
Epoch 9/10, Loss: nan
Epoch 10/10, Loss: nan
Test Accuracy: 70.44%


In [None]:
print(f"Teacher accuracy: {test(teacher_model,test_loader,device):.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD + T = 1.01: {test_accuracy_light_ce_and_kd__:.2f}%")

Test Accuracy: 80.30%
Teacher accuracy: 80.30%
Student accuracy without teacher: 70.94%
Student accuracy with CE + KD + T = 1.01: 70.44%
