In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets

In [2]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
#random seeds
torch.manual_seed(42)
np.random.seed(42)

## Teacher Model (ResNet18)

In [None]:
# Load the pre-trained ResNet model
resnet = models.resnet18(pretrained=True)

In [4]:
# Freeze all layers except the last one
for param in resnet.parameters():
    param.requires_grad = False

In [5]:
# Replace the last fully connected layer with a new one
num_ftrs = resnet.fc.in_features
num_classes = 10  # Assuming 10 classes for example
resnet.fc = nn.Linear(num_ftrs, num_classes)

In [6]:
# Transfer the model to the GPU
resnet = resnet.to(device)

In [7]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet.fc.parameters(), lr=0.001, momentum=0.9)

In [12]:
# Load your custom dataset
train_dataset = datasets.CIFAR10(root='./data', transform=transforms.Compose([
                                            transforms.Resize(256),
                                            transforms.CenterCrop(224),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                 std=[0.229, 0.224, 0.225])]))

In [13]:
# Define DataLoader
batch_size = 32  # Adjust batch size according to your GPU memory
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [14]:
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # Transfer data to GPU
        optimizer.zero_grad()
        outputs = resnet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

print('Finished Training')

Epoch [1/10], Loss: 0.9527
Epoch [2/10], Loss: 0.7423
Epoch [3/10], Loss: 0.7051
Epoch [4/10], Loss: 0.6883
Epoch [5/10], Loss: 0.6724
Epoch [6/10], Loss: 0.6670
Epoch [7/10], Loss: 0.6596
Epoch [8/10], Loss: 0.6547
Epoch [9/10], Loss: 0.6530
Epoch [10/10], Loss: 0.6480
Finished Training


In [17]:
# Set the model to evaluation mode
resnet.eval()

# Load your test dataset
test_dataset = datasets.CIFAR10(root='./data', transform=transforms.Compose([
                                            transforms.Resize(256),
                                            transforms.CenterCrop(224),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                 std=[0.229, 0.224, 0.225])]))

# Define DataLoader for test data
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [19]:
# Evaluate the model on test data
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)  # Transfer data to GPU
        outputs =resnet(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print('Accuracy of the network on the test images: %d %%' % accuracy)

Accuracy of the network on the test images: 79 %


## Student Model

In [31]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

Files already downloaded and verified
Files already downloaded and verified


In [32]:
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)


In [33]:
# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [34]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [35]:
# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN().to(device)

In [36]:
# Training Student Model
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

Epoch 1/10, Loss: 1.4696010293253243
Epoch 2/10, Loss: 1.1599607729850827
Epoch 3/10, Loss: 1.031422875878756
Epoch 4/10, Loss: 0.9300340305813743
Epoch 5/10, Loss: 0.8561210892999264
Epoch 6/10, Loss: 0.7905343470670988
Epoch 7/10, Loss: 0.7257177899865543
Epoch 8/10, Loss: 0.6702893704861936
Epoch 9/10, Loss: 0.6185725430393463
Epoch 10/10, Loss: 0.5673962647805129
Test Accuracy: 70.54%


In [38]:
print(f"Teacher accuracy: 79%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

Teacher accuracy: 79%
Student accuracy: 70.54%


## Student Model with Knowledge Distillation

In [40]:
# New Light Model
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

In [41]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [42]:
# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=resnet, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: 79%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Epoch 1/10, Loss: 1.6131410138381412
Epoch 2/10, Loss: 1.3859690101555242
Epoch 3/10, Loss: 1.2988521235678203
Epoch 4/10, Loss: 1.2367138804682076
Epoch 5/10, Loss: 1.1921843900095166
Epoch 6/10, Loss: 1.1545858133174574
Epoch 7/10, Loss: 1.118700791045528
Epoch 8/10, Loss: 1.091860987188871
Epoch 9/10, Loss: 1.0659223573896892
Epoch 10/10, Loss: 1.038486744589208
Test Accuracy: 66.52%
Teacher accuracy: 79%
Student accuracy without teacher: 70.54%
Student accuracy with CE + KD: 66.52%


## Results on ResNet18

In [43]:
total_params_resnet = "{:,}".format(sum(p.numel() for p in resnet.parameters()))
print(f"Teacher parameters: {total_params_resnet}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

Teacher parameters: 11,181,642
LightNN parameters: 267,738


Teacher accuracy: 79% 

Student accuracy without teacher: 70.54%

Student accuracy with CE + KD: 66.52%