In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

Loading CIFAR-10
================

CIFAR-10 is a popular image dataset with ten classes. Our objective is
to predict one of the following classes for each input image.

![Example of CIFAR-10
images](https://pytorch.org/tutorials//../_static/img/cifar10.png){.align-center}

The input images are RGB, so they have 3 channels and are 32x32 pixels.
Basically, each image is described by 3 x 32 x 32 = 3072 numbers ranging
from 0 to 255. A common practice in neural networks is to normalize the
input, which is done for multiple reasons, including avoiding saturation
in commonly used activation functions and increasing numerical
stability. Our normalization process consists of subtracting the mean
and dividing by the standard deviation along each channel. The tensors
\"mean=\[0.485, 0.456, 0.406\]\" and \"std=\[0.229, 0.224, 0.225\]\"
were already computed, and they represent the mean and standard
deviation of each channel in the predefined subset of CIFAR-10 intended
to be the training set. Notice how we use these values for the test set
as well, without recomputing the mean and standard deviation from
scratch. This is because the network was trained on features produced by
subtracting and dividing the numbers above, and we want to maintain
consistency. Furthermore, in real life, we would not be able to compute
the mean and standard deviation of the test set since, under our
assumptions, this data would not be accessible at that point.

As a closing point, we often refer to this held-out set as the
validation set, and we use a separate set, called the test set, after
optimizing a model\'s performance on the validation set. This is done to
avoid selecting a model based on the greedy and biased optimization of a
single metric.


In [25]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

Files already downloaded and verified
Files already downloaded and verified


In [26]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Defining model classes and utility functions
============================================

Next, we need to define our model classes. Several user-defined
parameters need to be set here. We use two different architectures,
keeping the number of filters fixed across our experiments to ensure
fair comparisons. Both architectures are Convolutional Neural Networks
(CNNs) with a different number of convolutional layers that serve as
feature extractors, followed by a classifier with 10 classes. The number
of filters and neurons is smaller for the students

In [32]:
# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

We employ 2 functions to help us produce and evaluate the results on our
original classification task. One function is called `train` and takes
the following arguments:

-   `model`: A model instance to train (update its weights) via this
    function.
-   `train_loader`: We defined our `train_loader` above, and its job is
    to feed the data into the model.
-   `epochs`: How many times we loop over the dataset.
-   `learning_rate`: The learning rate determines how large our steps
    towards convergence should be. Too large or too small steps can be
    detrimental.
-   `device`: Determines the device to run the workload on. Can be
    either CPU or GPU depending on availability.

Our test function is similar, but it will be invoked with `test_loader`
to load images from the test set.

![Train both networks with Cross-Entropy. The student will be used as a
baseline:](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/ce_only.png){.align-center

Cross-entropy runs
==================

For reproducibility, we need to set the torch manual seed. We train
networks using different methods, so to compare them fairly, it makes
sense to initialize the networks with the same weights. Start by
training the teacher network using cross-entropy

In [28]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accurac

In [29]:
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=1, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)

Epoch 1/1, Loss: 1.3284041829731152
Test Accuracy: 65.00%


NameError: name 'accurac' is not defined

In [30]:
device

device(type='cuda')

In [31]:
print(torch.cuda.is_available())


True
