# Knowledge distillation using VGG-16 to train VGG-11 with CIFAR-10 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch_directml

# Check if GPU is available, and if not, use the CPU
device = torch_directml.device(torch_directml.default_device())
print(device)

privateuseone:0


In [2]:
transforms_cifar = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='../data', train=False, download=True, transform=transforms_cifar)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

## Define the VGG models

### VGG-16

In [4]:
class VGG16(nn.Module):

    def __init__(self, num_classes=10):
        super(VGG16, self).__init__()

        # Convolution layer: 3 input channels (rgb), 64 output channels, 3x3 kernel
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64), # normalize the convolution output
            nn.ReLU())  # activation function essentially throws out values < 0

        # Convolution layer: 64 input, 64 output, 3x3 kernel
        # Max pooling into a 2x2 kernel
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        # Convolution layer: 64 input, 128 output, 3x3 kernel
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())

        # Convolution layer: 128 input, 128 output, 3x3 kernel
        # Max pooling into a 2x2 kernel
        self.layer4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        # Convolution layer: 128 input, 256 output, 3x3 kernel
        self.layer5 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())

        # Convolution layer: 256 input, 256 output, 3x3 kernel
        self.layer6 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())

        # Convolution layer: 256 input, 256 output, 3x3 kernel
        # Max pooling into a 2x2 kernel
        self.layer7 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        # Convolution layer: 256 input, 512 output, 3x3 kernel
        self.layer8 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())

        # Convolution layer: 512 input, 512 output, 3x3 kernel
        self.layer9 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())

        # Convolution layer: 512 input, 512 output, 3x3 kernel
        # Max pooling into a 2x2 kernel
        self.layer10 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        # Convolution layer: 512 input, 512 output, 3x3 kernel
        self.layer11 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())

        # Convolution layer: 512 input, 512 output, 3x3 kernel
        self.layer12 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())

        # Convolution layer: 512 input, 512 output, 3x3 kernel
        # Max pooling into a 2x2 kernel
        self.layer13 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(7 * 7 * 512, 4096),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())

        # This is the layer that performs the classification
          # it takes the 4096 input channels from fc1 and outputs probabilities of each class in CIFAR
          # the outputs (num_classes) depend on if we classify super or fine classes in CIFAR-100 (10 or 100 classes)
        self.fc2= nn.Sequential(
            nn.Linear(4096, num_classes))

    def forward(self, x):
      out = self.layer1(x)
      out = self.layer2(out)
      out = self.layer3(out)
      out = self.layer4(out)
      out = self.layer5(out)
      out = self.layer6(out)
      out = self.layer7(out)
      out = self.layer8(out)
      out = self.layer9(out)
      out = self.layer10(out)
      out = self.layer11(out)
      out = self.layer12(out)
      out = self.layer13(out)
      out = out.reshape(out.size(0), -1) # Not fully sure what this is doing or if it is true to the original VGG
      out = self.fc(out)
      out = self.fc1(out)
      out = self.fc2(out)
      return out

### VGG-11

In [5]:
class VGG11(nn.Module):

    def __init__(self, num_classes=10):
        super(VGG11, self).__init__()
        
        # Convolution layers #
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())

        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        self.layer5 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())

        self.layer6 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))

        self.layer7 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())

        self.layer8 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2),
            nn.AvgPool2d(kernel_size = 1, stride = 2))  # half the output size instead of halfing the image input

        # Fully connected layers #
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4 * 4 * 512, 4096),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())

        # Classification Layer #
        self.fc2= nn.Sequential(
            nn.Linear(4096, num_classes))

    def forward(self, x):
      out = self.layer1(x)
      out = self.layer2(out)
      out = self.layer3(out)
      out = self.layer4(out)
      out = self.layer5(out)
      out = self.layer6(out)
      out = self.layer7(out)
      out = self.layer8(out)
      out = out.reshape(out.size(0), -1)
      out = self.fc(out)
      out = self.fc1(out)
      out = self.fc2(out) # classification
      return out

## Utility functions for training and testing

In [6]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Step {i+1}")

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

## Create the student and teacher

In [7]:
# Teacher model
teacher = VGG16().to(device)
teacher.load_state_dict(torch.load('./VGG-16_CIFAR-10_228X228.PT'))

# Student models
torch.manual_seed(42)
student = VGG11().to(device)
torch.manual_seed(42)
baseline = VGG11().to(device)

total_params_teacher = "{:,}".format(sum(p.numel() for p in teacher.parameters()))
total_params_student = "{:,}".format(sum(p.numel() for p in student.parameters()))

print(f'Teacher VGG-16 parameters: {total_params_teacher}\nStudent VGG-11 parameters: {total_params_student}')

Teacher VGG-16 parameters: 134,309,962
Student VGG-11 parameters: 59,606,794


In [8]:
# Norms should be exactly the same
print("Norm of 1st layer of baseline:", torch.norm(baseline.layer1[0].weight).item())
print("Norm of 1st layer of student:", torch.norm(student.layer1[0].weight).item())

Norm of 1st layer of baseline: 4.591419696807861
Norm of 1st layer of student: 4.591419696807861


### Train the baseline student network

In [9]:
train(baseline, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_baseline_ce = test(baseline, test_loader, device)

Epoch 1/10, Step 1
Epoch 1/10, Step 2
Epoch 1/10, Step 3
Epoch 1/10, Step 4
Epoch 1/10, Step 5
Epoch 1/10, Step 6
Epoch 1/10, Step 7
Epoch 1/10, Step 8
Epoch 1/10, Step 9
Epoch 1/10, Step 10
Epoch 1/10, Step 11
Epoch 1/10, Step 12
Epoch 1/10, Step 13
Epoch 1/10, Step 14
Epoch 1/10, Step 15
Epoch 1/10, Step 16
Epoch 1/10, Step 17
Epoch 1/10, Step 18
Epoch 1/10, Step 19
Epoch 1/10, Step 20
Epoch 1/10, Step 21
Epoch 1/10, Step 22
Epoch 1/10, Step 23
Epoch 1/10, Step 24
Epoch 1/10, Step 25
Epoch 1/10, Step 26
Epoch 1/10, Step 27
Epoch 1/10, Step 28
Epoch 1/10, Step 29
Epoch 1/10, Step 30
Epoch 1/10, Step 31
Epoch 1/10, Step 32
Epoch 1/10, Step 33
Epoch 1/10, Step 34
Epoch 1/10, Step 35
Epoch 1/10, Step 36
Epoch 1/10, Step 37
Epoch 1/10, Step 38
Epoch 1/10, Step 39
Epoch 1/10, Step 40
Epoch 1/10, Step 41
Epoch 1/10, Step 42
Epoch 1/10, Step 43
Epoch 1/10, Step 44
Epoch 1/10, Step 45
Epoch 1/10, Step 46
Epoch 1/10, Step 47
Epoch 1/10, Step 48
Epoch 1/10, Step 49
Epoch 1/10, Step 50
Epoch 1/1

In [10]:
test_accuracy_teacher = test(teacher, test_loader, device)

print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy: {test_accuracy_baseline_ce:.2f}%")

Test Accuracy: 82.61%
Teacher accuracy: 82.61%
Student accuracy: 81.31%


## Knowledge distillation training with Cross Entropy loss

In [11]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            # Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Step {i+1}")

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [12]:
# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=teacher, student=student, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_ce_and_kd = test(student, test_loader, device)

Epoch 1/10, Step 1
Epoch 1/10, Step 2
Epoch 1/10, Step 3
Epoch 1/10, Step 4
Epoch 1/10, Step 5
Epoch 1/10, Step 6
Epoch 1/10, Step 7
Epoch 1/10, Step 8
Epoch 1/10, Step 9
Epoch 1/10, Step 10
Epoch 1/10, Step 11
Epoch 1/10, Step 12
Epoch 1/10, Step 13
Epoch 1/10, Step 14
Epoch 1/10, Step 15
Epoch 1/10, Step 16
Epoch 1/10, Step 17
Epoch 1/10, Step 18
Epoch 1/10, Step 19
Epoch 1/10, Step 20
Epoch 1/10, Step 21
Epoch 1/10, Step 22
Epoch 1/10, Step 23
Epoch 1/10, Step 24
Epoch 1/10, Step 25
Epoch 1/10, Step 26
Epoch 1/10, Step 27
Epoch 1/10, Step 28
Epoch 1/10, Step 29
Epoch 1/10, Step 30
Epoch 1/10, Step 31
Epoch 1/10, Step 32
Epoch 1/10, Step 33
Epoch 1/10, Step 34
Epoch 1/10, Step 35
Epoch 1/10, Step 36
Epoch 1/10, Step 37
Epoch 1/10, Step 38
Epoch 1/10, Step 39
Epoch 1/10, Step 40
Epoch 1/10, Step 41
Epoch 1/10, Step 42
Epoch 1/10, Step 43
Epoch 1/10, Step 44
Epoch 1/10, Step 45
Epoch 1/10, Step 46
Epoch 1/10, Step 47
Epoch 1/10, Step 48
Epoch 1/10, Step 49
Epoch 1/10, Step 50
Epoch 1/1

In [13]:
# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_baseline_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_ce_and_kd:.2f}%")

Teacher accuracy: 82.61%
Student accuracy without teacher: 81.31%
Student accuracy with CE + KD: 80.44%


# RESULTS:

## Using teacher with 82.61% accuracy

### Students trained with 3 epochs
    ~baseline = 58.63%
    ~student_KD = 61.12%
    ~improvement = 2.51%

### Students trained with 10 epochs
    ~baseline = 81.31%
    ~student_KD = 80.44%
    ~improvement = -0.97%

In [14]:
# Continue training with a temperature of 4
train_knowledge_distillation(teacher=teacher, student=student, train_loader=train_loader, epochs=3, learning_rate=0.001, T=4, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_ce_and_kd = test(student, test_loader, device)
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_baseline_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_ce_and_kd:.2f}%")

Epoch 1/3, Step 1
Epoch 1/3, Step 2
Epoch 1/3, Step 3
Epoch 1/3, Step 4
Epoch 1/3, Step 5
Epoch 1/3, Step 6
Epoch 1/3, Step 7
Epoch 1/3, Step 8
Epoch 1/3, Step 9
Epoch 1/3, Step 10
Epoch 1/3, Step 11
Epoch 1/3, Step 12
Epoch 1/3, Step 13
Epoch 1/3, Step 14
Epoch 1/3, Step 15
Epoch 1/3, Step 16
Epoch 1/3, Step 17
Epoch 1/3, Step 18
Epoch 1/3, Step 19
Epoch 1/3, Step 20
Epoch 1/3, Step 21
Epoch 1/3, Step 22
Epoch 1/3, Step 23
Epoch 1/3, Step 24
Epoch 1/3, Step 25
Epoch 1/3, Step 26
Epoch 1/3, Step 27
Epoch 1/3, Step 28
Epoch 1/3, Step 29
Epoch 1/3, Step 30
Epoch 1/3, Step 31
Epoch 1/3, Step 32
Epoch 1/3, Step 33
Epoch 1/3, Step 34
Epoch 1/3, Step 35
Epoch 1/3, Step 36
Epoch 1/3, Step 37
Epoch 1/3, Step 38
Epoch 1/3, Step 39
Epoch 1/3, Step 40
Epoch 1/3, Step 41
Epoch 1/3, Step 42
Epoch 1/3, Step 43
Epoch 1/3, Step 44
Epoch 1/3, Step 45
Epoch 1/3, Step 46
Epoch 1/3, Step 47
Epoch 1/3, Step 48
Epoch 1/3, Step 49
Epoch 1/3, Step 50
Epoch 1/3, Step 51
Epoch 1/3, Step 52
Epoch 1/3, Step 53
Ep