<a href="https://colab.research.google.com/github/Alafiade/Implementing-Knowledge-Distillation/blob/main/Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

IMPORTING DEPENDENCIES

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets





In [None]:
import torch
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
print (f'Using {device}device')

Using cudadevice


DATA LOADING AND PREPROCESSING

In [None]:
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.299,0.224,0.225])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

100%|██████████| 170M/170M [00:06<00:00, 26.3MB/s]


### SETTING DATA LOADERS

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

DEFINING MODEL

In [None]:
class DeepNN(nn.Module):
  def __init__(self, num_classes=10):
    super(DeepNN, self). __init__()
    self.features = nn.Sequential(
        nn.Conv2d(3,128,kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(128,64,kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64,32,kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
        nn.Linear(2048,512),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(512,num_classes)
    )
  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x,1)
    x = self.classifier(x)
    return x


#Lightweight neural network class to be used as student:
class LightNN(nn.Module):
  def __init__(self,num_classes=10):
    super(LightNN, self). __init__()
    self.features = nn.Sequential(
        nn.Conv2d(3,16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(16,16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        )
    self.classifier = nn.Sequential(
        nn.Linear(1024,256),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(256,num_classes)


    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x,1)
    x = self.classifier(x)
    return x


In [None]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:

            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

CROSS ENTROPY RUNS

In [None]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)

Epoch 1/10, Loss: 1.3274331938885058
Epoch 2/10, Loss: 0.8768705486336632
Epoch 3/10, Loss: 0.6902951818445454
Epoch 4/10, Loss: 0.5474821895437167
Epoch 5/10, Loss: 0.4261758156749598
Epoch 6/10, Loss: 0.31982939071057703
Epoch 7/10, Loss: 0.2332105693381156
Epoch 8/10, Loss: 0.17445363737928593
Epoch 9/10, Loss: 0.13516285310468407
Epoch 10/10, Loss: 0.11944542772820234
Test Accuracy: 73.88%


In [None]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

Comparing the first layer of the initial lightweight model and the new lightweight model

In [None]:
print('Norm of 1st layer of nn_light:', torch.norm(nn_light.features[0].weight).item()) # Corrected to nn_light
print('Norm of 1st layer of new_nn_light:', torch.norm(new_nn_light.features[0].weight).item()) # Corrected to new_nn_light

Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296


In [None]:
total_params_deep = '{:,}'.format(sum(p.numel() for p in nn_deep.parameters()))
total_params_light = '{:,}'.format(sum(p.numel()for p in nn_light.parameters()))
print(f'Deep NN parameters: {total_params_deep}')
print(f' LightNN parameters: {total_params_light}')

Deep NN parameters: 1,150,058
 LightNN parameters: 267,738


In [None]:
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light,test_loader, device)

Epoch 1/10, Loss: 1.4760891415578934
Epoch 2/10, Loss: 1.1697084418952923
Epoch 3/10, Loss: 1.0418787882151201
Epoch 4/10, Loss: 0.9423012375221838
Epoch 5/10, Loss: 0.8657975026103847
Epoch 6/10, Loss: 0.8009030367712231
Epoch 7/10, Loss: 0.7372408323275769
Epoch 8/10, Loss: 0.6811386679139588
Epoch 9/10, Loss: 0.6305688669919358
Epoch 10/10, Loss: 0.5824526600977954
Test Accuracy: 70.66%


In [None]:
print(f'Test accuracy of DeepNN: {test_accuracy_deep:.2f}%')
print(f'Test accuracy of LightNN: {test_accuracy_light_ce:.2f}%')

Test accuracy of DeepNN: 73.88%
Test accuracy of LightNN: 70.66%


APPLYING KNOWLEDGE DISTILLATIION

In [None]:
def train_knowledge_distillation( teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
  # Initializing the cross-entropy loss for hard labels
  ce_loss = nn.CrossEntropyLoss()
  # Setting up Adam optimizer for the student model
  optimizer = optim.Adam(student.parameters(), lr= learning_rate)

  teacher.eval() # Setting teacher model to evaluation
  student.train()
  # Main training loop
  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        # Extract teacher logits without gradient computation
        with torch.no_grad():
          teacher_logits = teacher(inputs)
        # Generate student logits
        student_logits = student(inputs)
        # Apply temperature scaling to soften logits
        soft_targets = nn.functional.softmax(teacher_logits/ T,dim=-1)
        soft_prob = nn.functional.log_softmax(student_logits/ T, dim=-1)

        soft_targets_loss = torch.sum(soft_targets * (soft_targets.log()-soft_prob)) / soft_prob.size()[0]* (T**2)

        label_loss = ce_loss(student_logits, labels)

        loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight *label_loss

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch+1}/{epochs},Loss: {running_loss/ len(train_loader)}')
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=4, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")


Epoch 1/10,Loss: 1.6343824311595438
Epoch 2/10,Loss: 1.5003332262453826
Epoch 3/10,Loss: 1.4221786013649553
Epoch 4/10,Loss: 1.3413194050569364
Epoch 5/10,Loss: 1.2725227819684217
Epoch 6/10,Loss: 1.216282490116861
Epoch 7/10,Loss: 1.1557223115430768
Epoch 8/10,Loss: 1.1037090134132854
Epoch 9/10,Loss: 1.0659202149761913
Epoch 10/10,Loss: 1.022444518630767
Test Accuracy: 71.86%
Teacher accuracy: 73.88%
Student accuracy without teacher: 70.66%
Student accuracy with CE + KD: 71.86%
