# Setup

In [1]:
!pip install torchinfo
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader, random_split
from torchinfo import summary
from tqdm import tqdm
import sys
import numpy as np
import math
import matplotlib.pyplot as plt



In [2]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

# Download dataset

In [3]:
validation_split = 0.2
batch_size = 128

# data augmentation and normalization
transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# download dataset
train_and_val_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=True,
    transform=transform_train,
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=False,
    transform=transform_test,
    download=True
)

# split train and validation dataset
train_size = int((1 - validation_split) * len(train_and_val_dataset))
val_size = len(train_and_val_dataset) - train_size
train_dataset, val_dataset = random_split(train_and_val_dataset, [train_size, val_size])

# create dataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

test_num = len(test_dataset)
test_steps = len(test_loader)

100%|██████████| 170M/170M [00:01<00:00, 95.7MB/s]


# Create teacher and student models


In [4]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

In [5]:
class ResNet(nn.Module):

    def __init__(self, block, blocks_num, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        ######################################################################################################
        # 1. Finish the forward pass and return the output layer as well as hidden features.  #
        # 2. The output layer and hidden features will be used later for distilling.       #
        # 3. You can refer to the ResNet structure illustration to finish it.           #
        ####################################################################################################

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        feature1 = self.layer1(x)
        feature2 = self.layer2(feature1)
        feature3 = self.layer3(feature2)
        feature4 = self.layer4(feature3)

        x = self.avgpool(feature4)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x, [feature1, feature2, feature3, feature4]

def resnet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def resnet34(num_classes=10):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)

## Teacher model

In [6]:
Teacher = resnet34(num_classes=10)  # commment out this line if loading trained teacher model
Teacher = torch.load('/kaggle/input/pre-trained/pytorch/kd/1/Teacher.pt', weights_only=False)  # loading trained teacher model
Teacher = Teacher.to(device)

In [7]:
summary(Teacher)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv2d: 3-1                  36,864
│    │    └─BatchNorm2d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv2d: 3-4                  36,864
│    │    └─BatchNorm2d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv2d: 3-6                  36,864
│    │    └─BatchNorm2d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv2d: 3-9                  36,864
│    │    └─BatchNorm2d: 3-10            128
│    └─BasicBlock: 2-3                   --
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            12

## Student model

In [8]:
Student = resnet18(num_classes=10)  # commment out this line if loading trained student model
Student = torch.load('/kaggle/input/pre-trained/pytorch/kd/1/Student.pt', weights_only=False)  # loading trained student model
Student = Student.to(device)

In [9]:
summary(Student)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv2d: 3-1                  36,864
│    │    └─BatchNorm2d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv2d: 3-4                  36,864
│    │    └─BatchNorm2d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv2d: 3-6                  36,864
│    │    └─BatchNorm2d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv2d: 3-9                  36,864
│    │    └─BatchNorm2d: 3-10            128
├─Sequential: 1-6                        --
│    └─BasicBlock: 2-3                   --
│    │    └─Conv2d: 3-11                 73,728

# Define testing function

In [10]:
def test(model, test_loader ,device, type=None):
    criterion = nn.CrossEntropyLoss()
    acc = 0.0
    test_loss = 0.0

    if type == None:
        model.eval()
    elif type == 'distiller':
        model.eval()
        model.teacher.eval()
        model.student.eval()
    else:
       raise ValueError(f'Error: only support response-based and feature-based distillation')

    with torch.no_grad():
        test_bar = tqdm(test_loader, file=sys.stdout)
        for test_data in test_bar:
            test_images, test_labels = test_data
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            if type == None:
                outputs, features = model(test_images)
                loss = criterion(outputs, test_labels)
            elif type == 'distiller':
                outputs, loss = model(test_images, test_labels)
            else:
                raise ValueError(f'Error: only support response-based and feature-based distillation')

            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, test_labels.to(device)).sum().item()
            test_loss += loss.item()
            test_bar.desc = "test"

    test_accurate = acc / test_num
    print('test_loss: %.3f  test_accuracy: %.3f' %(test_loss / test_steps, test_accurate * 100))
    return test_loss / test_steps, test_accurate * 100.

# Define distillation fuction

In [11]:
#####################################################################
# Finish the loss function for response-based distillation. #
#####################################################################
def loss_re(student_logits, teacher_logits, labels, T, alpha):
    soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_logits/T, dim=1), F.softmax(teacher_logits/T, dim=1)) * (T*T)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    
    loss = alpha * soft_loss + (1 - alpha) * hard_loss
    
    return loss

In [12]:
class Distiller(nn.Module):
    def __init__(self, teacher, student, type, T, alpha):
        super(Distiller, self).__init__()

        ########################################
        # 1. Finish the __init__ method. #
        ########################################

        self.teacher = teacher
        self.student = student

        self.T = T
        self.alpha = alpha

        self.type = type

        self.teacher.train()
        for param in self.teacher.parameters():
            param.requires_grad = False

    def forward(self, x, target):

        #####################################
        # 2. Finish the forward pass. #
        #####################################
        teacher_logits, teacher_feature = self.teacher(x)
        student_logits, student_feature = self.student(x)

        if self.type == 'response':
            loss_distill = loss_re(student_logits, teacher_logits, target, T=self.T, alpha=self.alpha)
        elif self.type == 'feature':
            loss_distill = loss_fe(student_logits, student_feature, teacher_feature, target)

        return student_logits, loss_distill

In [13]:
def train_distillation(distiller, student, train_loader, val_loader, epochs, learning_rate, device):
    ce_loss = nn.CrossEntropyLoss()
    ###########################
    # define the optimizer #
    ###########################
    optimizer = torch.optim.Adam(distiller.student.parameters(), lr=learning_rate)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        distiller.train()
        distiller.teacher.train()
        distiller.student.train()

        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc  = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs, loss = distiller(images, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            pred = outputs.data.max(1, keepdim=True)[1]
            result = pred.eq(labels.data.view_as(pred))
            result = np.squeeze(result.cpu().numpy())
            correct += np.sum(result)
            total += images.size(0)
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        distiller.eval()
        distiller.teacher.eval()
        distiller.student.eval()

        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:

                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                outputs, loss = distiller(val_images, val_labels)

                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    print('Finished Distilling')

# Response-based distillation
## Find Temparature and alpha

In [14]:
T_list = [1, 2, 5, 10]
alpha_list = [0.1, 0.3, 0.5, 0.7, 0.9]
results = []

for T in T_list:
    for alpha in alpha_list:
        print(f"\n==== Training with T={T}, alpha={alpha} ====")        
        Student_re = resnet18(num_classes=10)
        Student_re = Student_re.to(device)
        distiller_re = Distiller(Teacher, Student_re, type='response', T=T, alpha=alpha)
        train_distillation(distiller_re, Student_re, train_loader, val_loader, epochs=10, learning_rate=0.001, device=device)
        
        reS_loss, reS_accuracy = test(distiller_re, test_loader, type='distiller', device=device)
        
        results.append((T, alpha, reS_accuracy, reS_loss))

best = max(results, key=lambda x: x[2])
print(f'\nBest Group：T={best[0]}, alpha={best[1]}, accuracy={best[2]:.2f}, loss={best[3]:.2f}')


==== Training with T=1, alpha=0.1 ====
train epoch[1/10]: 100%|██████████| 313/313 [01:07<00:00,  4.61it/s]
valid epoch[1/10]: 100%|██████████| 79/79 [00:10<00:00,  7.58it/s]
	Training Loss: 1.431178 	Validation Loss: 1.199081
	Train Accuracy: 47.523d% (19009/40000)	Valdation Accuracy: 57.230d% (5723/10000) 
train epoch[2/10]: 100%|██████████| 313/313 [01:06<00:00,  4.69it/s]
valid epoch[2/10]: 100%|██████████| 79/79 [00:10<00:00,  7.25it/s]
	Training Loss: 0.975055 	Validation Loss: 1.077327
	Train Accuracy: 65.040d% (26016/40000)	Valdation Accuracy: 62.270d% (6227/10000) 
train epoch[3/10]: 100%|██████████| 313/313 [01:07<00:00,  4.67it/s]
valid epoch[3/10]: 100%|██████████| 79/79 [00:10<00:00,  7.32it/s]
	Training Loss: 0.780014 	Validation Loss: 0.897258
	Train Accuracy: 72.243d% (28897/40000)	Valdation Accuracy: 68.530d% (6853/10000) 
train epoch[4/10]: 100%|██████████| 313/313 [01:07<00:00,  4.65it/s]
valid epoch[4/10]: 100%|██████████| 79/79 [00:10<00:00,  7.34it/s]
	Training L

In [15]:
# print(f"\n==== Training with T= , alpha= ====")        
# Student_re = resnet18(num_classes=10)
# Student_re = Student_re.to(device)
# distiller_re = Distiller(Teacher, Student_re, type='response', T= , alpha= )
# train_distillation(distiller_re, Student_re, train_loader, val_loader, epochs=1, learning_rate=0.001, device=device)

# reS_loss, reS_accuracy = test(distiller_re, test_loader, type='distiller', device=device)

# results.append((T, alpha, reS_accuracy, reS_loss))

# Comparison

In [16]:
print(f'Teacher from scratch: loss-0.482, accuracy-91.420') 
print(f'Student from scratch: loss-0.441, accuracy-85.850')
print(f'Response-based student: loss-{best[3]}, accuracy-{best[2]}')

Teacher from scratch: loss-0.482, accuracy-91.420
Student from scratch: loss-0.441, accuracy-85.850
Response-based student: loss-5.424638434301449, accuracy-87.82
