<a href="https://colab.research.google.com/github/YSW2/CV-KnowledgeDistillation/blob/master/TAKD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=1e-4)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [3]:
def data_loader(num_classes=10):
  # Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
  transforms_cifar = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])

  if num_classes == 10:
    # Loading the CIFAR-10 dataset:
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

  elif num_classes == 100:
    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms_cifar)
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transforms_cifar)

  return train_dataset, test_dataset

In [4]:
class ConvNetMaker(nn.Module):
    """
    Creates a convolutional neural network (CNN) based on a given specification of layers.
    """
    def __init__(self, layers):
        """
        Initializes the CNN model with the specified layers.

        :param layers: A list of strings, each representing a layer specification,
                       such as ["Conv64", "MaxPool", "Conv128", "MaxPool", "FC100", "FC10"].
                       "Conv64" means a convolutional layer with 64 filters,
                       "MaxPool" means a max pooling layer,
                       "FC100" means a fully connected layer with 100 neurons.
        """
        super(ConvNetMaker, self).__init__()
        self.conv_layers = []
        self.fc_layers = []
        h, w, d = 32, 32, 3
        previous_layer_filter_count = 3
        previous_layer_size = h * w * d
        num_fc_layers_remained = len([1 for l in layers if l.startswith("FC")])
        for layer in layers:
            if layer.startswith("Conv"):
                filter_count = int(layer[4:])
                self.conv_layers += [
                    nn.Conv2d(
                        previous_layer_filter_count,
                        filter_count,
                        kernel_size=3,
                        padding=1,
                    ),
                    nn.BatchNorm2d(filter_count),
                    nn.ReLU(inplace=True),
                ]
                previous_layer_filter_count = filter_count
                d = filter_count
                previous_layer_size = h * w * d
            elif layer.startswith("MaxPool"):
                self.conv_layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                h, w = int(h / 2.0), int(w / 2.0)
                previous_layer_size = h * w * d
            elif layer.startswith("FC"):
                num_fc_layers_remained -= 1
                current_layer_size = int(layer[2:])
                if num_fc_layers_remained == 0:
                    self.fc_layers += [
                        nn.Linear(previous_layer_size, current_layer_size)
                    ]
                else:
                    self.fc_layers += [
                        nn.Linear(previous_layer_size, current_layer_size),
                        nn.ReLU(inplace=True),
                    ]
                previous_layer_size = current_layer_size

        conv_layers = self.conv_layers
        fc_layers = self.fc_layers
        self.conv_layers = nn.Sequential(*conv_layers)
        self.fc_layers = nn.Sequential(*fc_layers)


    def forward(self, x):
        """
        Defines the forward pass of the model.

        :param x: Input tensor of shape (batch_size, 3, 32, 32)
        :return: Output tensor
        """

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x


In [5]:
plane_cifar10_book = {
	'2': ['Conv16', 'MaxPool', 'Conv16', 'MaxPool', 'FC10'],
	'4': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'FC10'],
	'6': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'FC10'],
	'8': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool',
		  'Conv128', 'Conv128','MaxPool', 'FC64', 'FC10'],
	'10': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool',
		   'Conv256', 'Conv256', 'Conv256', 'Conv256' , 'MaxPool', 'FC128' ,'FC10'],
}


plane_cifar100_book = {
	'2': ['Conv32', 'MaxPool', 'Conv32', 'MaxPool', 'FC100'],
	'4': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'FC100'],
	'6': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool','Conv128', 'Conv128' ,'FC100'],
	'8': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool',
		  'Conv256', 'Conv256','MaxPool', 'FC64', 'FC100'],
	'10': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool',
		   'Conv256', 'Conv256', 'Conv256', 'Conv256' , 'MaxPool', 'FC512', 'FC100'],
}

In [6]:
#Dataloaders
train_dataset, test_dataset = data_loader(100)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:03<00:00, 42756795.34it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [8]:
torch.manual_seed(42)
layer = plane_cifar100_book.get('10')

teacher_model = ConvNetMaker(layer).to(device)
train(teacher_model, train_loader, epochs=160, learning_rate=0.1, device=device)
test_accuracy_deep = test(teacher_model, test_loader, device)

Epoch 1/160, Loss: 4.0023466302915605
Epoch 2/160, Loss: 3.3825098907246307
Epoch 3/160, Loss: 2.837200911148735
Epoch 4/160, Loss: 2.4425559885361614
Epoch 5/160, Loss: 2.1535565828728247
Epoch 6/160, Loss: 1.9373015598262973
Epoch 7/160, Loss: 1.752702307213298
Epoch 8/160, Loss: 1.5946563842046597
Epoch 9/160, Loss: 1.453705199844087
Epoch 10/160, Loss: 1.3163965981634682
Epoch 11/160, Loss: 1.2052330216178504
Epoch 12/160, Loss: 1.104205213544314
Epoch 13/160, Loss: 0.9951046691526233
Epoch 14/160, Loss: 0.9089377065143927
Epoch 15/160, Loss: 0.831626842546341
Epoch 16/160, Loss: 0.7648328394841051
Epoch 17/160, Loss: 0.6993429978637744
Epoch 18/160, Loss: 0.6465785038440733
Epoch 19/160, Loss: 0.580706291262756
Epoch 20/160, Loss: 0.554457656105461
Epoch 21/160, Loss: 0.5211712013730003
Epoch 22/160, Loss: 0.4704688175407517
Epoch 23/160, Loss: 0.45097348997202674
Epoch 24/160, Loss: 0.4142204585587582
Epoch 25/160, Loss: 0.3983548833890949
Epoch 26/160, Loss: 0.3866790081457714
E

In [9]:
layer = plane_cifar100_book.get('2')
torch.manual_seed(42)
student_model = ConvNetMaker(layer).to(device)

train(student_model, train_loader, epochs=160, learning_rate=0.1, device=device)
test_accuracy_light_ce = test(student_model, test_loader, device)

Epoch 1/160, Loss: 4.239775285696434
Epoch 2/160, Loss: 3.8019423356751347
Epoch 3/160, Loss: 3.651276835395247
Epoch 4/160, Loss: 3.5618112514086087
Epoch 5/160, Loss: 3.489645529281148
Epoch 6/160, Loss: 3.4355800554270632
Epoch 7/160, Loss: 3.389972392860276
Epoch 8/160, Loss: 3.3560808839090646
Epoch 9/160, Loss: 3.3215228322216923
Epoch 10/160, Loss: 3.2897250896219705
Epoch 11/160, Loss: 3.270614765489193
Epoch 12/160, Loss: 3.250949508393817
Epoch 13/160, Loss: 3.2386700344817414
Epoch 14/160, Loss: 3.220334542072033
Epoch 15/160, Loss: 3.2080346989204815
Epoch 16/160, Loss: 3.1948150753060265
Epoch 17/160, Loss: 3.1832464399849973
Epoch 18/160, Loss: 3.173683879930345
Epoch 19/160, Loss: 3.165260699094104
Epoch 20/160, Loss: 3.1542178944248676
Epoch 21/160, Loss: 3.1503302953432284
Epoch 22/160, Loss: 3.14686003060597
Epoch 23/160, Loss: 3.137268679221268
Epoch 24/160, Loss: 3.1337230327489127
Epoch 25/160, Loss: 3.12435141731711
Epoch 26/160, Loss: 3.1241686295365434
Epoch 27/

In [10]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.SGD(student.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=1e-4)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")


In [11]:
# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
kd_student_model = ConvNetMaker(layer).to(device)
train_knowledge_distillation(teacher=teacher_model, student=kd_student_model, train_loader=train_loader, epochs=160, learning_rate=0.1, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(kd_student_model, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Epoch 1/160, Loss: 7.300865008702973
Epoch 2/160, Loss: 6.638971815328769
Epoch 3/160, Loss: 6.403528075693818
Epoch 4/160, Loss: 6.252380539389217
Epoch 5/160, Loss: 6.163560833162664
Epoch 6/160, Loss: 6.0971230923977044
Epoch 7/160, Loss: 6.063027346530534
Epoch 8/160, Loss: 6.0121019051203035
Epoch 9/160, Loss: 5.930418932834245
Epoch 10/160, Loss: 5.754262536383041
Epoch 11/160, Loss: 5.649084627780768
Epoch 12/160, Loss: 5.589612998620933
Epoch 13/160, Loss: 5.548440801518042
Epoch 14/160, Loss: 5.51570523791301
Epoch 15/160, Loss: 5.484533177007495
Epoch 16/160, Loss: 5.473247542710561
Epoch 17/160, Loss: 5.453340045021623
Epoch 18/160, Loss: 5.435664272064443
Epoch 19/160, Loss: 5.42592137792836
Epoch 20/160, Loss: 5.414791857190145
Epoch 21/160, Loss: 5.400842664186912
Epoch 22/160, Loss: 5.3958924664255905
Epoch 23/160, Loss: 5.3903499930106165
Epoch 24/160, Loss: 5.374655116244655
Epoch 25/160, Loss: 5.376932133189248
Epoch 26/160, Loss: 5.366078758483653
Epoch 27/160, Loss: