<a href="https://colab.research.google.com/github/Alafiade/Implementing-Knowlegde-Distillation-on-CNN-models/blob/main/Implementation_of_KD_on_CNN_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Importing Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets



Data Preprocessing for CIFAR10


In [None]:
transforms_cifar = transforms.Compose([
    transforms.RandomCrop(32,padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2,contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.486,0.456,0.406], std=[0.229,0.224,0.225]),

])

train_dataset = datasets.CIFAR10(root='./data', train =True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:06<00:00, 26.9MB/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64,shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
device = torch.device('cuda:0' if torch.cuda.is_available()else 'cpu')

Importing Teacher and Student model

In [None]:
import torchvision.models as models
#Importing Resnet50 as the Teacher model
teacher_model = models.resnet50(weights=True)
teacher_model.fc = nn.Linear(teacher_model.fc.in_features,10)
device = torch.device('cuda:0' if torch.cuda.is_available()else 'cpu')
teacher_model = teacher_model.to(device)
print(teacher_model)
#Importing MobilenetV2 as the Student model
student_model = models.mobilenet_v2(weights=True)
student_model.classifier[1] = nn.Linear(student_model.classifier[1].in_features,10)
student_model.to(device)
print(student_model)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 177MB/s]
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

100%|██████████| 13.6M/13.6M [00:00<00:00, 202MB/s]

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=




Model Training

In [None]:
def train(model, train_loader, epochs, learning_rate, weight_decay,momentum, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate,weight_decay=weight_decay, momentum=momentum)
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}')

def test(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

Evaluating the performance of the teacher model on the dataset

In [None]:
 teacher_model = models.resnet50(num_classes=10).to (device)
train(teacher_model,train_loader, epochs=40, learning_rate=0.001,device=device,weight_decay=0.01,momentum=0.9)
test_accuracy =  test(teacher_model,test_loader,device)

student_model = models.mobilenet_v2(weights=True)
student_model.classifier[1] = nn.Linear(student_model.classifier[1].in_features,10)
student_model.to(device)

Epoch 1/40, Loss: 2.1750342083709016
Epoch 2/40, Loss: 1.8768999825048325
Epoch 3/40, Loss: 1.7331892748927826
Epoch 4/40, Loss: 1.6295619625264726
Epoch 5/40, Loss: 1.5488635874770182
Epoch 6/40, Loss: 1.4814987734455587
Epoch 7/40, Loss: 1.4205073764561997
Epoch 8/40, Loss: 1.3565544781020231
Epoch 9/40, Loss: 1.3085187279507327
Epoch 10/40, Loss: 1.2568118306986815
Epoch 11/40, Loss: 1.2033406598946017
Epoch 12/40, Loss: 1.1651286495006299
Epoch 13/40, Loss: 1.128161735623084
Epoch 14/40, Loss: 1.0926517094187724
Epoch 15/40, Loss: 1.0477832875898123
Epoch 16/40, Loss: 1.0019713789605729
Epoch 17/40, Loss: 0.976882233491639
Epoch 18/40, Loss: 0.9403345416421476
Epoch 19/40, Loss: 0.9152753287569031
Epoch 20/40, Loss: 0.8896232517174137
Epoch 21/40, Loss: 0.8763343661718661
Epoch 22/40, Loss: 0.8565900106259319
Epoch 23/40, Loss: 0.8404212261709716
Epoch 24/40, Loss: 0.8227220418508095
Epoch 25/40, Loss: 0.8079972909691998
Epoch 26/40, Loss: 0.8106145480328508
Epoch 27/40, Loss: 0.78

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

Creating a new student model and initializing it

In [None]:
torch.manual_seed(42)
new_student_model = models.mobilenet_v2(weights=True)
new_student_model.classifier[1] = nn.Linear(new_student_model.classifier[1].in_features,10)


Printing the First layer of the Student model and new stdent model

In [None]:
print('Norm of the 1st layer of student_model',torch.norm(student_model.features[0][0].weight).item())
print('Norm of the 1st layer of new_student_model', torch.norm(new_student_model.features[0][0].weight).item())

Norm of the 1st layer of student_model 6.710669994354248
Norm of the 1st layer of new_student_model 6.710669040679932


Checking Teacher and Student's parameters

In [None]:
total_params_teacher = '{:,}'.format(sum(p.numel()for p in teacher_model.parameters()))
print(f'Teacher_model parameters:{total_params_teacher}')
total_params_student = '{:,}'.format(sum(p.numel()for p in student_model.parameters()))
print(f'Student_model parameters:{total_params_student}')

Teacher_model parameters:23,528,522
Student_model parameters:2,236,682


Evaluating the Student's model performance

In [None]:
train(student_model, train_loader, epochs=40, learning_rate = 0.001,device=device,momentum=0.9, weight_decay=0.01)
test_accuracy_student_ce = test(student_model,test_loader,device)

Epoch 1/40, Loss: 1.3361339163597283
Epoch 2/40, Loss: 0.9860750785111771
Epoch 3/40, Loss: 0.878702763706217
Epoch 4/40, Loss: 0.8131215067775658
Epoch 5/40, Loss: 0.7574517800832343
Epoch 6/40, Loss: 0.7314097163317453
Epoch 7/40, Loss: 0.7065221932156921
Epoch 8/40, Loss: 0.6929305765177588
Epoch 9/40, Loss: 0.6801436473341549
Epoch 10/40, Loss: 0.6672952230400442
Epoch 11/40, Loss: 0.6630167188241963
Epoch 12/40, Loss: 0.668178392150213
Epoch 13/40, Loss: 0.6678739740415607
Epoch 14/40, Loss: 0.6763368287812108
Epoch 15/40, Loss: 0.671323751282814
Epoch 16/40, Loss: 0.6771041510431358
Epoch 17/40, Loss: 0.6838685505073089
Epoch 18/40, Loss: 0.6940080307023909
Epoch 19/40, Loss: 0.6991380653189271
Epoch 20/40, Loss: 0.7064757975166106
Epoch 21/40, Loss: 0.7032052334540945
Epoch 22/40, Loss: 0.7173227689531453
Epoch 23/40, Loss: 0.7266162367504271
Epoch 24/40, Loss: 0.7303748961604769
Epoch 25/40, Loss: 0.7339914959409962
Epoch 26/40, Loss: 0.7403880465975807
Epoch 27/40, Loss: 0.747

Comparing the accuracy of the Teacher  and Student model

In [None]:
print(f'Teacher accuracy: {test_accuracy:.2f}%')
print(f'Student accuracy: {test_accuracy_student_ce:.2f}%')

Teacher accuracy: 74.19%
Student accuracy: 71.19%


Implementing Knowledge distillation

In [None]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.SGD(student.parameters(), lr=learning_rate)

    # Ensure models are on the correct device
    teacher.to(device)
    student.to(device)

    teacher.eval()  # Set teacher model to evaluation mode
    student.train()  # Set student model to training mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # Move inputs and labels to the same device as models
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.no_grad():
                teacher_logits = teacher(inputs)  # Get teacher logits
            student_logits = student(inputs)  # Get student logits
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=1)  # Softened teacher targets
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)  # Log of student probabilities

            # Soft target loss
            soft_target_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size(0) * (T ** 2)

            # Cross-entropy loss for the hard labels
            label_loss = ce_loss(student_logits, labels)

            # Total loss: Weighted sum of soft target loss and label loss
            loss = soft_target_loss_weight * soft_target_loss + ce_loss_weight * label_loss

            # Backpropagation
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}')

# Move models to the same device BEFORE training
teacher_model.to(device)
new_student_model.to(device)

# Training
train_knowledge_distillation(teacher=teacher_model,
                             student=new_student_model,
                             train_loader=train_loader,
                             epochs=20,
                             learning_rate=0.001,
                             T=4,
                             soft_target_loss_weight=0.25,
                             ce_loss_weight=0.75,
                             device=device)

# Testing
test_accuracy_student_ce_and_kd = test(new_student_model, test_loader, device)


Epoch 1/20, Loss: 0.7445
Epoch 2/20, Loss: 0.7393
Epoch 3/20, Loss: 0.7311
Epoch 4/20, Loss: 0.7299
Epoch 5/20, Loss: 0.7237
Epoch 6/20, Loss: 0.7165
Epoch 7/20, Loss: 0.7166
Epoch 8/20, Loss: 0.7226
Epoch 9/20, Loss: 0.7051
Epoch 10/20, Loss: 0.7052
Epoch 11/20, Loss: 0.6973
Epoch 12/20, Loss: 0.6948
Epoch 13/20, Loss: 0.6894
Epoch 14/20, Loss: 0.6878
Epoch 15/20, Loss: 0.6778
Epoch 16/20, Loss: 0.6799
Epoch 17/20, Loss: 0.6740
Epoch 18/20, Loss: 0.6749
Epoch 19/20, Loss: 0.6681
Epoch 20/20, Loss: 0.6652
Test Accuracy: 75.78%


Comparing the Teacher/student accuracy with and without Knowledge distillation

In [None]:
print(f'Teacher accuracy: {test_accuracy:.2f}%')
print(f'Student accuracy without the teacher: {test_accuracy_student_ce:.2f}%')
print(f'Student accuracy with ce_and_kd : {test_accuracy_student_ce_and_kd :.2f}%')