# Knowledge Distillation
- The concept of **knowledge distillation** is to utilize class probabilities of a higher-capacity model (teacher) as soft targets of a smaller model (student)
- The implement processes can be divided into several stages:
  1. Finish the `ResNet()` classes
  2. Train the teacher model (ResNet50) and the student model (ResNet18) from scratch, i.e. **without KD**
  3. Define the `Distiller()` class and `loss_re()`, `loss_fe()` functions
  4. Train the student model **with KD** from the teacher model in two different ways, response-based and feature based distillation
  5. Comparison of student models w/ & w/o KD

## Setup

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset, random_split
from torchinfo import summary
from tqdm import tqdm
import sys
import numpy as np
import math
import matplotlib.pyplot as plt
import os
from PIL import Image

In [2]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

## Download dataset

In [3]:
validation_split = 0.1
batch_size = 32

# data augmentation and normalization
transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# download dataset
train_and_val_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=True,
    transform=transform_train,
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=False,
    transform=transform_test,
    download=True
)

# split train and validation dataset
train_size = int((1 - validation_split) * len(train_and_val_dataset))
val_size = len(train_and_val_dataset) - train_size
train_dataset, val_dataset = random_split(train_and_val_dataset, [train_size, val_size])

# create dataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

test_num = len(test_dataset)
test_steps = len(test_loader)

## Create teacher and student models
### Define BottleNeck for ResNet50

In [4]:
class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BottleNeck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

### Define Resifual Block

In [5]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

### Define ResNet Model

In [6]:
class ResNet(nn.Module):

    def __init__(self, block, blocks_num, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 1. Finish the forward pass and return the output layer as well as hidden features.
        # 2. The output layer and hidden features will be used later for distilling.
        # 3. You can refer to the ResNet structure illustration to finish it.
        ## stem
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        ## extract features from each stage
        feature1 = self.layer1(x)
        feature2 = self.layer2(feature1)
        feature3 = self.layer3(feature2)
        feature4 = self.layer4(feature3)

        ## classification head
        out = self.avgpool(feature4)
        out = torch.flatten(out, 1)
        out = self.fc(out)

        ## return logits + hidden features for distillation
        return out, [feature1, feature2, feature3, feature4]

### Define ResNet50 and Resnet18

In [7]:
def resnet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def resnet50(num_classes=10):
    return ResNet(BottleNeck, [3, 4, 6, 3], num_classes=num_classes)

## Teacher Model (ResNet50)

In [8]:
Teacher = resnet50(num_classes=10)  # commment out this line if loading trained teacher model
# Teacher = torch.load('Teacher.pt', weights_only=False)  # loading trained teacher model
Teacher = Teacher.to(device)

In [9]:
summary(Teacher)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BottleNeck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─BottleNeck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               

## Student Model (ResNet18)

In [10]:
Student = resnet18(num_classes=10)  # commment out this line if loading trained student model
# Student = torch.load('Student.pt', weights_only=False)  # loading trained student model
Student = Student.to(device)

In [11]:
summary(Student)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv2d: 3-1                  36,864
│    │    └─BatchNorm2d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv2d: 3-4                  36,864
│    │    └─BatchNorm2d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv2d: 3-6                  36,864
│    │    └─BatchNorm2d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv2d: 3-9                  36,864
│    │    └─BatchNorm2d: 3-10            128
├─Sequential: 1-6                        --
│    └─BasicBlock: 2-3                   --
│    │    └─Conv2d: 3-11                 73,728

## Define training function

In [17]:
def train_from_scratch(model, train_loader, val_loader, epochs, learning_rate, device, model_name):
    criterion = nn.CrossEntropyLoss()
    params = [p for p in model.parameters() if p.requires_grad]
    # optimizer = torch.optim.Adam(params, lr=learning_rate)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.

        model.train()
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            logits, hidden = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            pred = logits.data.max(1, keepdim=True)[1]
            correct += np.sum(np.squeeze(pred.eq(labels.data.view_as(pred))).cpu().numpy())
            total += images.size(0)
            train_acc =  correct/total
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        model.eval()
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                outputs, hidden_outputs = model(val_images)
                loss = criterion(outputs, val_labels)
                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.data.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        scheduler.step()

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    torch.save(model, f'{model_name}.pt')
    print(f'{model_name}.pt is saved')

    print('Finished Training')

## Define testing function

In [13]:
def test(model, test_loader ,device, type=None):
    criterion = nn.CrossEntropyLoss()
    acc = 0.0
    test_loss = 0.0

    if type == None:
        model.eval()
    elif type == 'distiller':
        model.eval()
        model.teacher.eval()
        model.student.eval()
    else:
       raise ValueError(f'Error: only support response-based and feature-based distillation')

    with torch.no_grad():
        test_bar = tqdm(test_loader, file=sys.stdout)
        for test_data in test_bar:
            test_images, test_labels = test_data
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            if type == None:
                outputs, features = model(test_images)
                loss = criterion(outputs, test_labels)
            elif type == 'distiller':
                outputs, loss = model(test_images, test_labels)
            else:
                raise ValueError(f'Error: only support response-based and feature-based distillation')

            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, test_labels.to(device)).sum().item()
            test_loss += loss.item()
            test_bar.desc = "test"

    test_accurate = acc / test_num
    print('test_loss: %.3f  test_accuracy: %.3f' %(test_loss / test_steps, test_accurate * 100))
    return test_loss / test_steps, test_accurate * 100.

## Train Teacher and Student model from scratch

In [14]:
best_teacher_path = './Teacher.pt'
if os.path.exists(best_teacher_path):
    Teacher = torch.load(best_teacher_path, weights_only=False)
    print(f"Loaded model weights from {best_teacher_path}")
else:
    print("No saved model found, starting from scratch.")


Loaded model weights from ./Teacher.pt


In [20]:
# Decide the epochs and learning rate
train_from_scratch(Teacher, train_loader, val_loader, epochs=10 , learning_rate=0.00001 , device=device, model_name="Teacher")

train epoch[1/10]: 100%|██████████| 1407/1407 [00:26<00:00, 52.46it/s]
valid epoch[1/10]: 100%|██████████| 157/157 [00:01<00:00, 129.41it/s]
	Training Loss: 0.072227 	Validation Loss: 0.454244
	Train Accuracy: 97.482d% (43867/45000)	Valdation Accuracy: 88.940d% (4447/5000) 
train epoch[2/10]: 100%|██████████| 1407/1407 [00:25<00:00, 55.84it/s]
valid epoch[2/10]: 100%|██████████| 157/157 [00:01<00:00, 130.54it/s]
	Training Loss: 0.072908 	Validation Loss: 0.413532
	Train Accuracy: 97.449d% (43852/45000)	Valdation Accuracy: 89.420d% (4471/5000) 
train epoch[3/10]: 100%|██████████| 1407/1407 [00:26<00:00, 52.41it/s]
valid epoch[3/10]: 100%|██████████| 157/157 [00:01<00:00, 132.15it/s]
	Training Loss: 0.071984 	Validation Loss: 0.417559
	Train Accuracy: 97.502d% (43876/45000)	Valdation Accuracy: 89.380d% (4469/5000) 
train epoch[4/10]: 100%|██████████| 1407/1407 [00:26<00:00, 52.51it/s]
valid epoch[4/10]: 100%|██████████| 157/157 [00:01<00:00, 129.23it/s]
	Training Loss: 0.070826 	Validati

In [91]:
T_loss, T_accuracy = test(Teacher, test_loader, device=device)

test: 100%|██████████| 313/313 [00:02<00:00, 113.42it/s]
test_loss: 0.384  test_accuracy: 90.340


In [22]:
best_student_path = './Student.pt'
if os.path.exists(best_student_path):
    Student = torch.load(best_student_path, weights_only=False)
    print(f"Loaded model weights from {best_student_path}")
else:
    print("No saved model found, starting from scratch.")

No saved model found, starting from scratch.


In [24]:
# Decide the epochs and learning rate
train_from_scratch(Student, train_loader, val_loader, epochs= 10, learning_rate= 0.001, device=device, model_name="Student")

train epoch[1/10]: 100%|██████████| 1407/1407 [00:14<00:00, 97.71it/s] 
valid epoch[1/10]: 100%|██████████| 157/157 [00:00<00:00, 205.16it/s]
	Training Loss: 1.373968 	Validation Loss: 1.254976
	Train Accuracy: 50.191d% (22586/45000)	Valdation Accuracy: 54.100d% (2705/5000) 
train epoch[2/10]: 100%|██████████| 1407/1407 [00:11<00:00, 120.24it/s]
valid epoch[2/10]: 100%|██████████| 157/157 [00:00<00:00, 208.07it/s]
	Training Loss: 1.169915 	Validation Loss: 1.101051
	Train Accuracy: 58.033d% (26115/45000)	Valdation Accuracy: 60.240d% (3012/5000) 
train epoch[3/10]: 100%|██████████| 1407/1407 [00:11<00:00, 120.29it/s]
valid epoch[3/10]: 100%|██████████| 157/157 [00:00<00:00, 206.57it/s]
	Training Loss: 1.039268 	Validation Loss: 0.974650
	Train Accuracy: 63.044d% (28370/45000)	Valdation Accuracy: 65.840d% (3292/5000) 
train epoch[4/10]: 100%|██████████| 1407/1407 [00:11<00:00, 119.72it/s]
valid epoch[4/10]: 100%|██████████| 157/157 [00:00<00:00, 207.24it/s]
	Training Loss: 0.942961 	Vali

In [25]:
S_loss, S_accuracy = test(Student, test_loader, device=device)

test: 100%|██████████| 313/313 [00:01<00:00, 282.08it/s]
test_loss: 0.616  test_accuracy: 79.310


## Define distillation

### Define the loss functions

In [26]:
# Finish the loss function for response-based distillation.
def loss_re(student_logits, teacher_logits, targets):
    T = 4 # Set temperature parameter
    alpha = 0.5 # Set weighting parameter

    ## Implement loss calculation
    # ---- 1. Hard Loss ----
    # typical cross entropy loss for student
    hard_loss = F.cross_entropy(student_logits, targets)

    # ---- 2. Soft Loss (Distillation Loss) ----
    # Student soft predictions
    student_soft = F.log_softmax(student_logits / T, dim=1)  # log Q_S

    # Teacher soft predictions
    teacher_soft = F.softmax(teacher_logits / T, dim=1)      # Q_T

    # KL divergence: KL(Q_T || Q_S)
    # In PyTorch: KLDivLoss expects log-prob input + prob target
    soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')

    # ---- 3. Combine ----
    loss = (1 - alpha) * hard_loss + (alpha * (T ** 2)) * soft_loss
    
    return loss

In [138]:
# Finish the loss function for feature-based distillation.
def loss_fe(student_features, teacher_features, connectors, student_logits, labels):
    index = 1
    alpha = 0.1
    # Implement loss calculation whatever you prefer
    # 1. Detach teacher feature (teacher is fixed)
    t = teacher_features[index].detach()

    # 2. Align student feature channels to teacher channels
    #    connector: student C_s → teacher C_t
    s_aligned = connectors[index](student_features[index])

    # 3. Compute L2 loss
    loss_f = F.mse_loss(s_aligned, t)
    
    # 4. Compute CE loss if logits and labels are provided
    if student_logits is not None and labels is not None:
        ce_loss = nn.CrossEntropyLoss()
        loss_ce = ce_loss(student_logits, labels)
    else:
        loss_ce = 0.0

    # 5. Combine feature loss and CE loss
    loss = alpha * loss_f + (1 - alpha) * loss_ce

    return loss

### Define Distillation Framework

In [133]:
class Distiller(nn.Module):
    def __init__(self, teacher, student, type):
        super(Distiller, self).__init__()

        # 1. Finish the __init__ method.
        self.teacher = teacher.eval()       # teacher fixed
        self.student = student              # student trainable
        self.type = type                    # 'response' or 'feature'

        # ---------------------------
        #  Connector layers for FKD
        # ---------------------------
        if type == 'feature':
            self.connectors = nn.ModuleList([
                nn.Conv2d(64, 256, 1),
                nn.Conv2d(128, 512, 1),
                nn.Conv2d(256, 1024, 1),
                nn.Conv2d(512, 2048, 1),
            ]).to(device)
        else:
            self.connectors = None

    def forward(self, x, target):
        # 2. Finish the forward pass.
        with torch.no_grad():
            teacher_logits, teacher_features = self.teacher(x)

        student_logits, student_features = self.student(x)

        if self.type == 'response':
            loss_distill = loss_re(student_logits, teacher_logits, target)
        elif self.type == 'feature':
            loss_distill = loss_fe(student_features, teacher_features, self.connectors, student_logits, target)
        else:
            raise ValueError(f'Error: only support response-based and feature-based distillation')

        return student_logits, loss_distill

### Training function

In [134]:
def train_distillation(distiller, student, train_loader, val_loader, epochs, learning_rate, device):
    ce_loss = nn.CrossEntropyLoss()
    # define the parameter the optimizer used
    optimizer = torch.optim.Adam(student.parameters(), lr=learning_rate)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        distiller.train()
        distiller.teacher.train()
        distiller.student.train()

        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc  = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs, loss = distiller(images, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            pred = outputs.data.max(1, keepdim=True)[1]
            result = pred.eq(labels.data.view_as(pred))
            result = np.squeeze(result.cpu().numpy())
            correct += np.sum(result)
            total += images.size(0)
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        distiller.eval()
        distiller.teacher.eval()
        distiller.student.eval()

        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:

                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                outputs, loss = distiller(val_images, val_labels)

                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    print('Finished Distilling')

## Response-based distillation

In [127]:
# Decide the epochs and learning rate
Student_re = resnet18(num_classes=10)
Student_re = Student_re.to(device)
distiller_re = Distiller(Teacher, Student_re, type='response')
train_distillation(distiller_re, Student_re, train_loader, val_loader, epochs= 10, learning_rate= 0.001, device=device)

train epoch[1/10]: 100%|██████████| 1407/1407 [00:30<00:00, 46.84it/s]
valid epoch[1/10]: 100%|██████████| 157/157 [00:02<00:00, 77.79it/s]
	Training Loss: 6.805971 	Validation Loss: 5.063555
	Train Accuracy: 48.358d% (21761/45000)	Valdation Accuracy: 58.020d% (2901/5000) 
train epoch[2/10]: 100%|██████████| 1407/1407 [00:29<00:00, 46.99it/s]
valid epoch[2/10]: 100%|██████████| 157/157 [00:02<00:00, 78.49it/s]
	Training Loss: 4.168979 	Validation Loss: 3.376985
	Train Accuracy: 65.836d% (29626/45000)	Valdation Accuracy: 69.220d% (3461/5000) 
train epoch[3/10]: 100%|██████████| 1407/1407 [00:30<00:00, 46.73it/s]
valid epoch[3/10]: 100%|██████████| 157/157 [00:02<00:00, 78.46it/s]
	Training Loss: 2.998611 	Validation Loss: 2.589981
	Train Accuracy: 73.842d% (33229/45000)	Valdation Accuracy: 73.660d% (3683/5000) 
train epoch[4/10]: 100%|██████████| 1407/1407 [00:30<00:00, 46.84it/s]
valid epoch[4/10]: 100%|██████████| 157/157 [00:02<00:00, 76.67it/s]
	Training Loss: 2.322870 	Validation L

In [128]:
reS_loss, reS_accuracy = test(distiller_re, test_loader, type='distiller', device=device)

test: 100%|██████████| 313/313 [00:02<00:00, 120.98it/s]
test_loss: 1.196  test_accuracy: 86.770


## Feature-based distillation

In [139]:
# Decide the epochs and learning rate
Student_fe = resnet18(num_classes=10)
Student_fe = Student_fe.to(device)
distiller_fe = Distiller(Teacher, Student_fe, type='feature')
train_distillation(distiller_fe, Student_fe, train_loader, val_loader, epochs= 10, learning_rate= 0.001, device=device)

train epoch[1/10]: 100%|██████████| 1407/1407 [00:30<00:00, 46.04it/s]
valid epoch[1/10]: 100%|██████████| 157/157 [00:02<00:00, 77.51it/s]
	Training Loss: 1.564217 	Validation Loss: 1.354475
	Train Accuracy: 48.038d% (21617/45000)	Valdation Accuracy: 56.620d% (2831/5000) 
train epoch[2/10]: 100%|██████████| 1407/1407 [00:29<00:00, 47.09it/s]
valid epoch[2/10]: 100%|██████████| 157/157 [00:01<00:00, 85.67it/s]
	Training Loss: 1.190503 	Validation Loss: 1.052575
	Train Accuracy: 63.927d% (28767/45000)	Valdation Accuracy: 69.060d% (3453/5000) 
train epoch[3/10]: 100%|██████████| 1407/1407 [00:30<00:00, 46.26it/s]
valid epoch[3/10]: 100%|██████████| 157/157 [00:01<00:00, 83.90it/s]
	Training Loss: 1.013960 	Validation Loss: 0.993026
	Train Accuracy: 70.696d% (31813/45000)	Valdation Accuracy: 72.380d% (3619/5000) 
train epoch[4/10]: 100%|██████████| 1407/1407 [00:29<00:00, 47.55it/s]
valid epoch[4/10]: 100%|██████████| 157/157 [00:02<00:00, 78.00it/s]
	Training Loss: 0.901546 	Validation L

In [146]:

ftS_loss, ftS_accuracy = test(distiller_fe, test_loader, type='distiller', device=device)

test: 100%|██████████| 313/313 [00:03<00:00, 99.93it/s] 
test_loss: 0.605  test_accuracy: 87.270


## Result and Comparison

In [145]:
print(f'Teacher from scratch: loss = {T_loss:.2f}, accuracy = {T_accuracy:.2f}')
print(f'Student from scratch: loss = {S_loss:.2f}, accuracy = {S_accuracy:.2f}')
print(f'Response-based student: loss = {reS_loss:.2f}, accuracy = {reS_accuracy:.2f}')
print(f'Featured-based student: loss = {ftS_loss:.2f}, accuracy = {ftS_accuracy:.2f}')

Teacher from scratch: loss = 0.38, accuracy = 90.34
Student from scratch: loss = 0.62, accuracy = 79.31
Response-based student: loss = 1.20, accuracy = 86.77
Featured-based student: loss = 0.61, accuracy = 87.27
