In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

Loading CIFAR-10
================

CIFAR-10 is a popular image dataset with ten classes. Our objective is
to predict one of the following classes for each input image.

![Example of CIFAR-10
images](https://pytorch.org/tutorials//../_static/img/cifar10.png){.align-center}

The input images are RGB, so they have 3 channels and are 32x32 pixels.
Basically, each image is described by 3 x 32 x 32 = 3072 numbers ranging
from 0 to 255. A common practice in neural networks is to normalize the
input, which is done for multiple reasons, including avoiding saturation
in commonly used activation functions and increasing numerical
stability. Our normalization process consists of subtracting the mean
and dividing by the standard deviation along each channel. The tensors
\"mean=\[0.485, 0.456, 0.406\]\" and \"std=\[0.229, 0.224, 0.225\]\"
were already computed, and they represent the mean and standard
deviation of each channel in the predefined subset of CIFAR-10 intended
to be the training set. Notice how we use these values for the test set
as well, without recomputing the mean and standard deviation from
scratch. This is because the network was trained on features produced by
subtracting and dividing the numbers above, and we want to maintain
consistency. Furthermore, in real life, we would not be able to compute
the mean and standard deviation of the test set since, under our
assumptions, this data would not be accessible at that point.

As a closing point, we often refer to this held-out set as the
validation set, and we use a separate set, called the test set, after
optimizing a model\'s performance on the validation set. This is done to
avoid selecting a model based on the greedy and biased optimization of a
single metric.


In [34]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

Files already downloaded and verified
Files already downloaded and verified


In [35]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Defining model classes and utility functions
============================================

Next, we need to define our model classes. Several user-defined
parameters need to be set here. We use two different architectures,
keeping the number of filters fixed across our experiments to ensure
fair comparisons. Both architectures are Convolutional Neural Networks
(CNNs) with a different number of convolutional layers that serve as
feature extractors, followed by a classifier with 10 classes. The number
of filters and neurons is smaller for the students

In [36]:
# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

We employ 2 functions to help us produce and evaluate the results on our
original classification task. One function is called `train` and takes
the following arguments:

-   `model`: A model instance to train (update its weights) via this
    function.
-   `train_loader`: We defined our `train_loader` above, and its job is
    to feed the data into the model.
-   `epochs`: How many times we loop over the dataset.
-   `learning_rate`: The learning rate determines how large our steps
    towards convergence should be. Too large or too small steps can be
    detrimental.
-   `device`: Determines the device to run the workload on. Can be
    either CPU or GPU depending on availability.

Our test function is similar, but it will be invoked with `test_loader`
to load images from the test set.

![Train both networks with Cross-Entropy. The student will be used as a
baseline:](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/ce_only.png){.align-center

Cross-entropy runs
==================

For reproducibility, we need to set the torch manual seed. We train
networks using different methods, so to compare them fairly, it makes
sense to initialize the networks with the same weights. Start by
training the teacher network using cross-entropy

In [37]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [41]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)

Epoch 1/10, Loss: 1.3290103455943525
Epoch 2/10, Loss: 0.8618725254712507
Epoch 3/10, Loss: 0.6830734239362389
Epoch 4/10, Loss: 0.5315783902659745
Epoch 5/10, Loss: 0.4122244017127225
Epoch 6/10, Loss: 0.3005805501852499
Epoch 7/10, Loss: 0.21594502018464495
Epoch 8/10, Loss: 0.1664198536873626
Epoch 9/10, Loss: 0.1377350527107182
Epoch 10/10, Loss: 0.11376107812327954
Test Accuracy: 75.72%


In [42]:
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

To ensure we have created a copy of the first network, we inspect the
norm of its first layer. If it matches, then we are safe to conclude
that the networks are indeed the same

In [43]:
# Print the norm of the first layer of the initial lightweight model
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# Print the norm of the first layer of the new lightweight model
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())

Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296


## print the total numbers of parameters of the model

In [44]:
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")

DeepNN parameters: 1,186,986
LightNN parameters: 267,738


## Train and test lightweight network with cross entropy loss

In [45]:
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)

Epoch 1/10, Loss: 1.4703492529861761
Epoch 2/10, Loss: 1.1586465205987702
Epoch 3/10, Loss: 1.0233277681538515
Epoch 4/10, Loss: 0.9199803726142629
Epoch 5/10, Loss: 0.84628275547491
Epoch 6/10, Loss: 0.7804365962972422
Epoch 7/10, Loss: 0.7133219540881379
Epoch 8/10, Loss: 0.6576626432673706
Epoch 9/10, Loss: 0.6038828446432147
Epoch 10/10, Loss: 0.5523658263713808
Test Accuracy: 69.91%


As we can see, based on test accuracy, we can now compare the deeper network that is to be used as a teacher with the lightweight network that is our supposed student. So far, our student has not intervened with the teacher, therefore this performance is achieved by the student itself. The metrics so far can be seen with the following lines

In [46]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

Teacher accuracy: 75.72%
Student accuracy: 69.91%


Knowledge distillation run
==========================

Now let\'s try to improve the test accuracy of the student network by
incorporating the teacher. Knowledge distillation is a straightforward
technique to achieve this, based on the fact that both networks output a
probability distribution over our classes. Therefore, the two networks
share the same number of output neurons. The method works by
incorporating an additional loss into the traditional cross entropy
loss, which is based on the softmax output of the teacher network. The
assumption is that the output activations of a properly trained teacher
network carry additional information that can be leveraged by a student
network during training. The original work suggests that utilizing
ratios of smaller probabilities in the soft targets can help achieve the
underlying objective of deep neural networks, which is to create a
similarity structure over the data where similar objects are mapped
closer together. For example, in CIFAR-10, a truck could be mistaken for
an automobile or airplane, if its wheels are present, but it is less
likely to be mistaken for a dog. Therefore, it makes sense to assume
that valuable information resides not only in the top prediction of a
properly trained model but in the entire output distribution. However,
cross entropy alone does not sufficiently exploit this information as
the activations for non-predicted classes tend to be so small that
propagated gradients do not meaningfully change the weights to construct
this desirable vector space.

As we continue defining our first helper function that introduces a
teacher-student dynamic, we need to include a few extra parameters:

-   `T`: Temperature controls the smoothness of the output
    distributions. Larger `T` leads to smoother distributions, thus
    smaller probabilities get a larger boost.
-   `soft_target_loss_weight`: A weight assigned to the extra objective
    we\'re about to include.
-   `ce_loss_weight`: A weight assigned to cross-entropy. Tuning these
    weights pushes the network towards optimizing for either objective.

![Distillation loss is calculated from the logits of the networks. It
only returns gradients to the
student:](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/distillation_output_loss.png){.align-center}
