In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import random_split, DataLoader
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay
from student_models import DeepNN, LightNN, MiniEfficientNet, MiniResNet, MiniShuffleNet, MiniSqueezeNet
import torchvision.models as models

# Check if GPU is available, and if not, use the CPU
print("Check current device: ")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available(): # Should return True 
    print(f"Using GPU: {torch.cuda.get_device_name(0)}") # Should show your GPU name
else:
    print("Using CPU")

Check current device: 
Using GPU: NVIDIA GeForce RTX 4060


Loading CIFAR-10
================


In [5]:
# Transformations for data preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Loading the CIFAR-10 dataset: (train set will later be split into train and val)
full_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split trainset into train and validation datasets (80% train, 20% val)
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size
trainset, valset = random_split(full_trainset, [train_size, val_size])

# DataLoaders for train, validation, and test datasets
train_loader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(valset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

# Check if dataset loads correctly
print(f"Number of training samples: {len(trainset)}")
print(f"Number of validation samples: {len(valset)}")
print(f"Number of testing samples: {len(testset)}")

Files already downloaded and verified
Files already downloaded and verified
Number of training samples: 40000
Number of validation samples: 10000
Number of testing samples: 10000


Defining model classes and utility functions
============================================

In [6]:
def train(model, train_loader, val_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_losses = []
    val_losses = []

    model.train()

    for epoch in range(epochs):

        # Training Step
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {avg_train_loss}")

        # Validation Step
        val_loss = 0.0
        with torch.no_grad():  # Disable gradient computation for validation
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() 

        avg_val_loss = val_loss / len(val_loader)  # Average validation loss
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss:.4f}")
    return train_losses, val_losses 

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
	
            # Collect predictions and true labels
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics using sklearn
    cm = confusion_matrix(all_labels, all_predictions)
    report = classification_report(all_labels, all_predictions, output_dict=True, zero_division=0)

    return cm, report

Cross-entropy runs
==================


In [7]:
torch.manual_seed(42)
nn_deep = models.mobilenet_v3_large(weights=None).to(device)
train(nn_deep, train_loader, val_loader, epochs=10, learning_rate=0.001, device=device)
test_deep = test(nn_deep, test_loader, device)
test_accuracy_deep = test_deep[1]["accuracy"] * 100
print(f"Teacher Accuracy: {test_accuracy_deep:.2f}%")

# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = models.mobilenet_v3_small(weights=None).to(device)

Epoch 1/10, Training Loss: 1.519029443502426
Epoch 1/10, Validation Loss: 1.1415
Epoch 2/10, Training Loss: 0.9902840671777725
Epoch 2/10, Validation Loss: 0.8332
Epoch 3/10, Training Loss: 0.7527625097990036
Epoch 3/10, Validation Loss: 0.6874
Epoch 4/10, Training Loss: 0.6156341633796691
Epoch 4/10, Validation Loss: 0.6446
Epoch 5/10, Training Loss: 0.5264834771513939
Epoch 5/10, Validation Loss: 0.5632
Epoch 6/10, Training Loss: 0.4482562301039696
Epoch 6/10, Validation Loss: 0.5331
Epoch 7/10, Training Loss: 0.39021080291867255
Epoch 7/10, Validation Loss: 0.5237
Epoch 8/10, Training Loss: 0.34441135109364984
Epoch 8/10, Validation Loss: 0.5171
Epoch 9/10, Training Loss: 0.2993848577469587
Epoch 9/10, Validation Loss: 0.4838
Epoch 10/10, Training Loss: 0.2628646334066987
Epoch 10/10, Validation Loss: 0.5234
Teacher Accuracy: 84.83%


In [8]:
torch.manual_seed(42)
new_nn_light = models.mobilenet_v3_small(weights=None).to(device)

print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0][0].weight).item())
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0][0].weight).item())

total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

Norm of 1st layer of nn_light: 2.354264974594116
Norm of 1st layer of new_nn_light: 2.354264974594116
DeepNN parameters: 5,483,032
LightNN parameters: 2,542,856


In [9]:
train(nn_light, train_loader, val_loader, epochs=10, learning_rate=0.001, device=device)
test_light_ce = test(nn_light, test_loader, device)
test_accuracy_light_ce = test_light_ce[1]["accuracy"] * 100
print(f"Student Accuracy: {test_accuracy_light_ce:.2f}%")

Epoch 1/10, Training Loss: 1.464234339094162
Epoch 1/10, Validation Loss: 1.1129
Epoch 2/10, Training Loss: 0.9595084422588348
Epoch 2/10, Validation Loss: 0.8683
Epoch 3/10, Training Loss: 0.755334764122963
Epoch 3/10, Validation Loss: 0.7111
Epoch 4/10, Training Loss: 0.6297902264595032
Epoch 4/10, Validation Loss: 0.6671
Epoch 5/10, Training Loss: 0.5427592440009117
Epoch 5/10, Validation Loss: 0.5837
Epoch 6/10, Training Loss: 0.47154156504273415
Epoch 6/10, Validation Loss: 0.5935
Epoch 7/10, Training Loss: 0.4155354429960251
Epoch 7/10, Validation Loss: 0.5643
Epoch 8/10, Training Loss: 0.36720344785749914
Epoch 8/10, Validation Loss: 0.5373
Epoch 9/10, Training Loss: 0.32618121957480906
Epoch 9/10, Validation Loss: 0.5513
Epoch 10/10, Training Loss: 0.2928263768732548
Epoch 10/10, Validation Loss: 0.5135
Student Accuracy: 84.30%


In [10]:
torch.save(nn_deep.state_dict(), "MobileNetV3_large_CE.pth")
torch.save(nn_light.state_dict(), "MobileNetV3_small_CE.pth")

print("Model saved as teacher (MobileNetV3 Large) and student (MobileNetV3 small) with CE")

Model saved as teacher (MobileNetV3 Large) and student (MobileNetV3 small) with CE


Knowledge distillation run
==========================

In [11]:
def train_knowledge_distillation(teacher, student, train_loader, val_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    train_losses = []
    val_losses = []
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {avg_train_loss}")

        # Validation Step
        val_loss = 0.0
        with torch.no_grad():  # Disable gradient computation for validation
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss = ce_loss(outputs, labels)
                val_loss += loss.item() 

        avg_val_loss = val_loss / len(val_loader)  # Average validation loss
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {avg_val_loss:.4f}")
    return train_losses, val_losses 

In [12]:
temperatures = [2,3,4,5]
pairs = [(0.2, 0.8), (0.25, 0.75), (0.3, 0.7), (0.4, 0.6), (0.5, 0.5)]
name = 0
for t in temperatures:
    for x in pairs:
        torch.manual_seed(42)
        new_nn_light = models.mobilenet_v3_small(weights=None).to(device)
        # Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
        train_light_ce_and_kd = train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, val_loader=val_loader,
            epochs=10, learning_rate=0.001, T=t, soft_target_loss_weight=x[0], ce_loss_weight=x[1], device=device)
        test_light_ce_and_kd = test(new_nn_light, test_loader, device)
        test_accuracy_light_ce_and_kd = test_light_ce_and_kd[1]["accuracy"] * 100
        precision_light_ce_and_kd = test_light_ce_and_kd[1]["weighted avg"]["precision"]
        recall_light_ce_and_kd = test_light_ce_and_kd[1]["weighted avg"]["recall"]
        f1_light_ce_and_kd = test_light_ce_and_kd[1]["weighted avg"]["f1-score"]
        print("-----------------------------------------")
        print(f"Student accuracy with CE + KD and T={t} weights={x}:")
        print(f"Accuracy: {test_accuracy_light_ce_and_kd:.2f}%")
        print(f"Precision: {precision_light_ce_and_kd:.2f}")
        print(f"Recall: {recall_light_ce_and_kd:.2f}")
        print(f"F1 Score: {f1_light_ce_and_kd:.2f}")
        torch.save(new_nn_light.state_dict(), f"KD_model_{t}_{name}.pth")
        print(f"Model saved as KD_model_{t}_{name}.pth")
        name += 1
        print()

Epoch 1/10, Training Loss: 1.9868482129096985
Epoch 1/10, Validation Loss: 1.0530
Epoch 2/10, Training Loss: 1.17822644135952
Epoch 2/10, Validation Loss: 0.8189
Epoch 3/10, Training Loss: 0.890794629907608
Epoch 3/10, Validation Loss: 0.7048
Epoch 4/10, Training Loss: 0.7407869048595428
Epoch 4/10, Validation Loss: 0.6368
Epoch 5/10, Training Loss: 0.6390546699762344
Epoch 5/10, Validation Loss: 0.5820
Epoch 6/10, Training Loss: 0.562871788930893
Epoch 6/10, Validation Loss: 0.5645
Epoch 7/10, Training Loss: 0.5030176959872246
Epoch 7/10, Validation Loss: 0.5531
Epoch 8/10, Training Loss: 0.4519233579158783
Epoch 8/10, Validation Loss: 0.5128
Epoch 9/10, Training Loss: 0.405798263835907
Epoch 9/10, Validation Loss: 0.5055
Epoch 10/10, Training Loss: 0.366661584186554
Epoch 10/10, Validation Loss: 0.5014
-----------------------------------------
Student accuracy with CE + KD and T=2 weights=(0.2, 0.8):
Accuracy: 85.23%
Precision: 0.85
Recall: 0.85
F1 Score: 0.85
Model saved as KD_model