In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [4]:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [5]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [6]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=50, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)

Epoch 1/50, Loss: 1.3355058013935528
Epoch 2/50, Loss: 0.8668691409213464
Epoch 3/50, Loss: 0.6752048631763214
Epoch 4/50, Loss: 0.5310805269214504
Epoch 5/50, Loss: 0.41089592161385907
Epoch 6/50, Loss: 0.3109386853321129
Epoch 7/50, Loss: 0.21642804507861663
Epoch 8/50, Loss: 0.16224463847096618
Epoch 9/50, Loss: 0.14398209176138235
Epoch 10/50, Loss: 0.11642187057763262
Epoch 11/50, Loss: 0.11232386205030982
Epoch 12/50, Loss: 0.09794358372250024
Epoch 13/50, Loss: 0.09115883650834603
Epoch 14/50, Loss: 0.08433126543393678
Epoch 15/50, Loss: 0.07208394244208437
Epoch 16/50, Loss: 0.07749299952269667
Epoch 17/50, Loss: 0.07536998473565139
Epoch 18/50, Loss: 0.06654192930649576
Epoch 19/50, Loss: 0.0761739694622948
Epoch 20/50, Loss: 0.05641717683461964
Epoch 21/50, Loss: 0.06006078817583906
Epoch 22/50, Loss: 0.07211977076630972
Epoch 23/50, Loss: 0.05984745851343931
Epoch 24/50, Loss: 0.055542813671771864
Epoch 25/50, Loss: 0.0609149327799392
Epoch 26/50, Loss: 0.054206395792105545


In [7]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)


In [8]:
# Print the norm of the first layer of the initial lightweight model
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# Print the norm of the first layer of the new lightweight model
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())

Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296


In [9]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 1,186,986
LightNN parameters: 267,738


In [10]:
train(nn_light, train_loader, epochs=50, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

Epoch 1/50, Loss: 1.4696982763612363
Epoch 2/50, Loss: 1.1597190871263099
Epoch 3/50, Loss: 1.0274034961105307
Epoch 4/50, Loss: 0.9194910018645284
Epoch 5/50, Loss: 0.8405520173594775
Epoch 6/50, Loss: 0.7739942942738838
Epoch 7/50, Loss: 0.7083837921966982
Epoch 8/50, Loss: 0.6514772647024726
Epoch 9/50, Loss: 0.5975743385074693
Epoch 10/50, Loss: 0.5475177697818298
Epoch 11/50, Loss: 0.5032763705991418
Epoch 12/50, Loss: 0.4561621770072166
Epoch 13/50, Loss: 0.4133868062358988
Epoch 14/50, Loss: 0.3758860077242107
Epoch 15/50, Loss: 0.34238458312380954
Epoch 16/50, Loss: 0.3135577602993192
Epoch 17/50, Loss: 0.2820614961635731
Epoch 18/50, Loss: 0.26250885327911133
Epoch 19/50, Loss: 0.2440284442947344
Epoch 20/50, Loss: 0.22438596617763915
Epoch 21/50, Loss: 0.2049499749946777
Epoch 22/50, Loss: 0.19051729348461952
Epoch 23/50, Loss: 0.18244457847017156
Epoch 24/50, Loss: 0.1723389409653976
Epoch 25/50, Loss: 0.16546708353988046
Epoch 26/50, Loss: 0.1550621938846453
Epoch 27/50, Lo

In [11]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

Teacher accuracy: 74.77%
Student accuracy: 67.79%


In [12]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)
            # print(student_logits.mean(), teacher_logits.mean())
            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=50, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")


Epoch 1/50, Loss: nan
Epoch 2/50, Loss: nan
Epoch 3/50, Loss: nan
Epoch 4/50, Loss: nan
Epoch 5/50, Loss: nan
Epoch 6/50, Loss: nan
Epoch 7/50, Loss: nan
Epoch 8/50, Loss: nan
Epoch 9/50, Loss: nan
Epoch 10/50, Loss: nan
Epoch 11/50, Loss: nan
Epoch 12/50, Loss: nan
Epoch 13/50, Loss: nan
Epoch 14/50, Loss: nan
Epoch 15/50, Loss: nan
Epoch 16/50, Loss: nan
Epoch 17/50, Loss: nan
Epoch 18/50, Loss: nan
Epoch 19/50, Loss: nan
Epoch 20/50, Loss: nan
Epoch 21/50, Loss: nan
Epoch 22/50, Loss: nan
Epoch 23/50, Loss: nan
Epoch 24/50, Loss: nan
Epoch 25/50, Loss: nan
Epoch 26/50, Loss: nan
Epoch 27/50, Loss: nan
Epoch 28/50, Loss: nan
Epoch 29/50, Loss: nan
Epoch 30/50, Loss: nan
Epoch 31/50, Loss: nan
Epoch 32/50, Loss: nan
Epoch 33/50, Loss: nan
Epoch 34/50, Loss: nan
Epoch 35/50, Loss: nan
Epoch 36/50, Loss: nan
Epoch 37/50, Loss: nan
Epoch 38/50, Loss: nan
Epoch 39/50, Loss: nan
Epoch 40/50, Loss: nan
Epoch 41/50, Loss: nan
Epoch 42/50, Loss: nan
Epoch 43/50, Loss: nan
Epoch 44/50, Loss: n