<a href="https://colab.research.google.com/github/Alafiade/Knowledge-distillation-using-Cosine-runLoss-Minimization/blob/main/Knowledge_Distillation_using_Cosine_Loss_Minimization_run.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

IMPORTING DEPENDENCIES

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets





In [None]:
import torch
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
print (f'Using {device}device')

Using cudadevice


DATA LOADING AND PREPROCESSING

In [None]:
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.299,0.224,0.225])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

100%|██████████| 170M/170M [00:05<00:00, 33.8MB/s]


### SETTING DATA LOADERS

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

DEFINING MODEL

In [None]:
class DeepNN(nn.Module):
  def __init__(self, num_classes=10):
    super(DeepNN, self). __init__()
    self.features = nn.Sequential(
        nn.Conv2d(3,128,kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(128,64,kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64,32,kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
        nn.Linear(2048,512),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(512,num_classes)
    )
  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x,1)
    x = self.classifier(x)
    return x


#Lightweight neural network class to be used as student:
class LightNN(nn.Module):
  def __init__(self,num_classes=10):
    super(LightNN, self). __init__()
    self.features = nn.Sequential(
        nn.Conv2d(3,16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(16,16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        )
    self.classifier = nn.Sequential(
        nn.Linear(1024,256),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(256,num_classes)


    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x,1)
    x = self.classifier(x)
    return x


In [None]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:

            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

CROSS ENTROPY RUNS

In [None]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)

Epoch 1/10, Loss: 1.3288499415683015
Epoch 2/10, Loss: 0.878422740475296
Epoch 3/10, Loss: 0.6890814704510867
Epoch 4/10, Loss: 0.547064641552508
Epoch 5/10, Loss: 0.42462955426681986
Epoch 6/10, Loss: 0.31261353587250573
Epoch 7/10, Loss: 0.22779948868410058
Epoch 8/10, Loss: 0.16838190793190771
Epoch 9/10, Loss: 0.1418863251862471
Epoch 10/10, Loss: 0.12091316542852565
Test Accuracy: 74.17%


In [None]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

Comparing the first layer of the initial lightweight model and the new lightweight model

In [None]:
print('Norm of 1st layer of nn_light:', torch.norm(nn_light.features[0].weight).item()) # Corrected to nn_light
print('Norm of 1st layer of new_nn_light:', torch.norm(new_nn_light.features[0].weight).item()) # Corrected to new_nn_light

Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296


In [None]:
total_params_deep = '{:,}'.format(sum(p.numel() for p in nn_deep.parameters()))
total_params_light = '{:,}'.format(sum(p.numel()for p in nn_light.parameters()))
print(f'Deep NN parameters: {total_params_deep}')
print(f' LightNN parameters: {total_params_light}')

Deep NN parameters: 1,150,058
 LightNN parameters: 267,738


In [None]:
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light,test_loader, device)

Epoch 1/10, Loss: 1.4754226991282704
Epoch 2/10, Loss: 1.1681805099062907
Epoch 3/10, Loss: 1.039349402003276
Epoch 4/10, Loss: 0.9384047581106806
Epoch 5/10, Loss: 0.8638731817455243
Epoch 6/10, Loss: 0.7978585077368695
Epoch 7/10, Loss: 0.7358091481963692
Epoch 8/10, Loss: 0.6794487219637312
Epoch 9/10, Loss: 0.626218048203022
Epoch 10/10, Loss: 0.5802044873042485
Test Accuracy: 70.54%


In [None]:
print(f'Test accuracy of DeepNN: {test_accuracy_deep:.2f}%')
print(f'Test accuracy of LightNN: {test_accuracy_light_ce:.2f}%')

Test accuracy of DeepNN: 74.17%
Test accuracy of LightNN: 70.54%


SAVING MODEL

In [None]:
torch.save(nn_deep.state_dict(), 'deep_model.pth')
torch.save(nn_light.state_dict(), 'light_model.pth')

MODIFIED STUDENT AND TEACHER ARCHITECTURE

In [None]:
class ModifiedDeepNNCosine(nn.Module):
  def __init__(self, num_classes=10):
    super(ModifiedDeepNNCosine, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3,128, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(128,64, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64,64, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(64,32, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,  stride=2),


    )
    self.classifier = nn.Sequential(
        nn.Linear(2048,512),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(512, num_classes)
    )

  def forward(self,x):
    x = self.features(x)
    flattened_conv_output = torch.flatten(x,1)
    x = self.classifier(flattened_conv_output)
    flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output,2)
    return x, flattened_conv_output_after_pooling


class ModifiedLightNNCosine(nn.Module):
  def __init__(self, num_classes=10):
    super(ModifiedLightNNCosine, self). __init__()
    self.features = nn.Sequential(
        nn.Conv2d(3,16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(16,16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
        nn.Linear(1024,256),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(256, num_classes)
    )
  def forward(self,x):
    x = self.features(x)
    flattened_conv_output = torch.flatten(x,1)
    x = self.classifier(flattened_conv_output)
    return x, flattened_conv_output

In [None]:
modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
original_state_dict = nn_deep.state_dict()
modified_state_dict = modified_nn_deep.state_dict()

for name, param in original_state_dict.items():
  if name in modified_state_dict and param.shape == modified_state_dict[name].shape:
    modified_state_dict[name].copy_(param)

modified_nn_deep.load_state_dict(modified_state_dict, strict=False)

<All keys matched successfully>

In [None]:
# Checking the norm of the first layer for both modified and previous network
print('Norm of 1st layer for deep_nn:', torch.norm(nn_deep.features[0].weight).item())
print('Norm of 1st layer for modified_deep_nn:', torch.norm(modified_nn_deep.features[0].weight).item())


Norm of 1st layer for deep_nn: 7.891228199005127
Norm of 1st layer for modified_deep_nn: 7.891228199005127


In [None]:
# Initializing the Modified Light weight network with the same seed as our previous lightweight instances
torch.manual_seed(42)
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
print('Norm of 1st layer:', torch.norm(modified_nn_light.features[0].weight).item())

Norm of 1st layer: 2.327361822128296


PREPARING MODEL FOR DISTILLATION

In [None]:
sample_input = torch.randn(128,3,32,32).to(device)
logits, hidden_representation = modified_nn_light(sample_input)

print('Student logits shape:', logits.shape)
print('Student hidden representation shape:', hidden_representation.shape)

logits, hidden_representation = modified_nn_deep(sample_input)

print('Teacher logits shape:', logits.shape)
print('Teacher hidden representation shape:', hidden_representation.shape)

Student logits shape: torch.Size([128, 10])
Student hidden representation shape: torch.Size([128, 1024])
Teacher logits shape: torch.Size([128, 10])
Teacher hidden representation shape: torch.Size([128, 1024])


IMPLEMENTING KNOWLEDGE DISTILLATION TRAININGG PROCESS USING COSINE EMBEDDING LOSS

In [None]:
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight,device):
  ce_loss = nn.CrossEntropyLoss()
  cosine_loss = nn.CosineEmbeddingLoss()
  optimizer = optim.Adam(student.parameters(), lr=learning_rate)

  # Ensure both teacher and student models are on the correct device
  teacher.to(device)
  student.to(device)
  teacher.eval()
  student.train()

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
      # Ensure inputs and labels are on the correct device
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()
      # Forward pass with the teacher model and keep only hidden representation
      with torch.no_grad():
        _, teacher_hidden_representation = teacher(inputs)

      student_logits, student_hidden_representation = student(inputs)

      hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))

      label_loss = ce_loss(student_logits, labels)

      loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss

      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(f'Epoch {epoch+1}/{epochs}, Loss:{running_loss/len(train_loader)}')

EVALUATION ON THE TEST DATASET

In [None]:
def test_multiple_outputs(model, test_loader, device):
  model.to(device)
  model.eval()

  correct = 0
  total = 0

  with torch.no_grad():
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(device), labels.to(device)
      # The model returns two values, so we unpack them into outputs and _
      outputs, _ = model(inputs)  # Changed this line
      _, predicted = torch.max(outputs.data,1)

      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  print(f'Test Accuracy: {accuracy:.2f}%')
  return accuracy

In [None]:
train_cosine_loss(teacher= modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device  )
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light,test_loader,device)

Epoch 1/10, Loss:0.6063390386378978
Epoch 2/10, Loss:0.5788718739433971
Epoch 3/10, Loss:0.557668267156157
Epoch 4/10, Loss:0.532042542915515
Epoch 5/10, Loss:0.5103208540040819
Epoch 6/10, Loss:0.49039658492483446
Epoch 7/10, Loss:0.4699215755590697
Epoch 8/10, Loss:0.4498132763768706
Epoch 9/10, Loss:0.4319571669754165
Epoch 10/10, Loss:0.416853809410044
Test Accuracy: 71.15%
