# Imports

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler

import numpy as np
import matplotlib.pyplot as plt
import time
import os
from tqdm import tqdm

In [12]:
torch.manual_seed(42)
np.random.seed(42)

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


# Constants

In [None]:
IMAGE_SIZE = 32
NUM_CLASSES = 100
CIFAR10_MEAN = [0.5071, 0.4867, 0.4408]
CIFAR10_STD = [0.2675, 0.2565, 0.2761]

BATCH_SIZE = 128
TEACHER_LR = 0.001
RANDOM_INIT_LR = 0.001
STUDENT_DISTILL_LR = 0.001
WEIGHT_DECAY = 0.01

# Load Data

In [15]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                             download=True, transform=train_transform)

test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                            download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")
print(f"Number of classes: {len(train_dataset.classes)}")

Files already downloaded and verified
Files already downloaded and verified
Training dataset size: 50000
Testing dataset size: 10000
Number of classes: 100


# Models (Teacher and Student)

In [16]:
def create_teacher_model():
    model = models.resnet50(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 100)
    
    return model

def create_student_model():
    model = models.mobilenet_v2(weights=None)
    model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
    model.classifier = nn.Linear(model.last_channel, 100)

    return model

In [17]:
teacher_model = create_teacher_model().to(device)
student_model = create_student_model().to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Teacher model (ResNet-50) parameters: {count_parameters(teacher_model):,}")
print(f"Student model (ResNet-18) parameters: {count_parameters(student_model):,}")

Teacher model (ResNet-50) parameters: 23,705,252
Student model (ResNet-18) parameters: 2,351,972


# Training the Teacher Model

In [22]:
def train_model(model, train_loader, test_loader, epochs, device, save_path=None):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler()
    
    best_acc = 0.0
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': []
    }
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for inputs, targets in progress_bar:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            
            with autocast(device_type=device.type):
                logits = model(inputs)
                loss = criterion(logits, targets)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item() * inputs.size(0)
            _, predicted = logits.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

            progress_bar.set_postfix(
                {"loss": loss.item(), "acc": 100 * train_correct / train_total}
            )
        
        train_loss = train_loss / len(train_loader.dataset)
        train_acc = 100.0 * train_correct / train_total
        
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        progress_bar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} [Test]")
        with torch.no_grad():
            for inputs, targets in progress_bar:
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                
                logits = model(inputs)
                loss = criterion(logits, targets)
                
                test_loss += loss.item() * inputs.size(0)
                _, predicted = logits.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()

                progress_bar.set_postfix(
                    {"loss": loss.item(), "acc": 100 * test_correct / test_total}
                )
        
        test_loss = test_loss / len(test_loader.dataset)
        test_acc = 100.0 * test_correct / test_total

        scheduler.step()
        
        if test_acc > best_acc and save_path:
            best_acc = test_acc
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with accuracy: {best_acc:.2f}%")
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    
    return model, history

In [23]:
os.remove('/kaggle/working/teacher_model.pth')

In [None]:
teacher_path = './teacher_model.pth'
if os.path.exists(teacher_path):
    print("Loading pre-trained teacher model...")
    teacher_model.load_state_dict(torch.load(teacher_path, weights_only=True))
else:
    print("Training teacher model...")
    teacher_model, teacher_history = train_model(
        teacher_model, train_loader, test_loader, epochs=100, device=device, save_path=teacher_path
    )

Training teacher model...


Epoch 1/100 [Train]: 100%|██████████| 391/391 [01:17<00:00,  5.05it/s, loss=3.75, acc=7.12]
Epoch 1/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.48it/s, loss=3.84, acc=10.5]


Saved best model with accuracy: 10.54%
Epoch 1/100 - Train Loss: 4.1612, Train Acc: 7.12%, Test Loss: 3.8179, Test Acc: 10.54%


Epoch 2/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=3.65, acc=13.3]
Epoch 2/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.54it/s, loss=3.1, acc=18.4] 


Saved best model with accuracy: 18.39%
Epoch 2/100 - Train Loss: 3.6806, Train Acc: 13.29%, Test Loss: 3.4289, Test Acc: 18.39%


Epoch 3/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=3.12, acc=19]  
Epoch 3/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.57it/s, loss=2.82, acc=23.6]


Saved best model with accuracy: 23.65%
Epoch 3/100 - Train Loss: 3.3643, Train Acc: 18.95%, Test Loss: 3.1248, Test Acc: 23.65%


Epoch 4/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=3.44, acc=24.3]
Epoch 4/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=4.99, acc=30.1]


Saved best model with accuracy: 30.05%
Epoch 4/100 - Train Loss: 3.0725, Train Acc: 24.31%, Test Loss: 3.0867, Test Acc: 30.05%


Epoch 5/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=3.04, acc=28.7]
Epoch 5/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.63it/s, loss=2, acc=34.7]   


Saved best model with accuracy: 34.66%
Epoch 5/100 - Train Loss: 2.8337, Train Acc: 28.70%, Test Loss: 2.5290, Test Acc: 34.66%


Epoch 6/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=2.84, acc=32.9]
Epoch 6/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=2.06, acc=38.5]


Saved best model with accuracy: 38.45%
Epoch 6/100 - Train Loss: 2.6290, Train Acc: 32.90%, Test Loss: 2.3870, Test Acc: 38.45%


Epoch 7/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=2.4, acc=36.2] 
Epoch 7/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=3.06, acc=40.9]


Saved best model with accuracy: 40.89%
Epoch 7/100 - Train Loss: 2.4801, Train Acc: 36.16%, Test Loss: 2.7862, Test Acc: 40.89%


Epoch 8/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=2.21, acc=38.9]
Epoch 8/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=1.5, acc=44.3] 


Saved best model with accuracy: 44.34%
Epoch 8/100 - Train Loss: 2.3511, Train Acc: 38.92%, Test Loss: 2.2202, Test Acc: 44.34%


Epoch 9/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=2.21, acc=41.5]
Epoch 9/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.56it/s, loss=1.79, acc=46.6]


Saved best model with accuracy: 46.63%
Epoch 9/100 - Train Loss: 2.2210, Train Acc: 41.50%, Test Loss: 2.0841, Test Acc: 46.63%


Epoch 10/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=2, acc=43.4]   
Epoch 10/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=2.91, acc=49.8]


Saved best model with accuracy: 49.84%
Epoch 10/100 - Train Loss: 2.1360, Train Acc: 43.43%, Test Loss: 1.9119, Test Acc: 49.84%


Epoch 11/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=2.06, acc=46]  
Epoch 11/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.55it/s, loss=1.4, acc=49.1] 


Epoch 11/100 - Train Loss: 2.0331, Train Acc: 45.97%, Test Loss: 1.9206, Test Acc: 49.11%


Epoch 12/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=2.13, acc=47.4]
Epoch 12/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.57it/s, loss=1.23, acc=51.9]


Saved best model with accuracy: 51.86%
Epoch 12/100 - Train Loss: 1.9605, Train Acc: 47.45%, Test Loss: 1.7481, Test Acc: 51.86%


Epoch 13/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.88, acc=49.6]
Epoch 13/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=1.45, acc=51.7]


Epoch 13/100 - Train Loss: 1.8800, Train Acc: 49.65%, Test Loss: 1.8228, Test Acc: 51.71%


Epoch 14/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.91, acc=51.3]
Epoch 14/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.55it/s, loss=1.32, acc=54.7]


Saved best model with accuracy: 54.73%
Epoch 14/100 - Train Loss: 1.8100, Train Acc: 51.32%, Test Loss: 1.8661, Test Acc: 54.73%


Epoch 15/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=2.07, acc=52.5]
Epoch 15/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=1.11, acc=56.3]


Saved best model with accuracy: 56.27%
Epoch 15/100 - Train Loss: 1.7468, Train Acc: 52.54%, Test Loss: 1.6553, Test Acc: 56.27%


Epoch 16/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.59, acc=54]  
Epoch 16/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.57it/s, loss=1.06, acc=56.7]


Saved best model with accuracy: 56.74%
Epoch 16/100 - Train Loss: 1.6960, Train Acc: 53.99%, Test Loss: 1.5695, Test Acc: 56.74%


Epoch 17/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.82, acc=54.5]
Epoch 17/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=0.828, acc=57] 


Saved best model with accuracy: 56.95%
Epoch 17/100 - Train Loss: 1.6791, Train Acc: 54.48%, Test Loss: 1.5680, Test Acc: 56.95%


Epoch 18/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.58, acc=55.6]
Epoch 18/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=1.23, acc=60.5]


Saved best model with accuracy: 60.51%
Epoch 18/100 - Train Loss: 1.6256, Train Acc: 55.63%, Test Loss: 1.4186, Test Acc: 60.51%


Epoch 19/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.66, acc=57.5]
Epoch 19/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.54it/s, loss=1.62, acc=58.3]


Epoch 19/100 - Train Loss: 1.5680, Train Acc: 57.46%, Test Loss: 1.4970, Test Acc: 58.31%


Epoch 20/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.53, acc=58.5]
Epoch 20/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=0.99, acc=60.4]


Epoch 20/100 - Train Loss: 1.5154, Train Acc: 58.47%, Test Loss: 1.4795, Test Acc: 60.35%


Epoch 21/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.49, acc=59.7]
Epoch 21/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.56it/s, loss=0.731, acc=61.2]


Saved best model with accuracy: 61.19%
Epoch 21/100 - Train Loss: 1.4713, Train Acc: 59.69%, Test Loss: 1.4356, Test Acc: 61.19%


Epoch 22/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.26, acc=61]  
Epoch 22/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=1.05, acc=60.4]


Epoch 22/100 - Train Loss: 1.4125, Train Acc: 60.98%, Test Loss: 1.5251, Test Acc: 60.41%


Epoch 23/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.52, acc=62.1]
Epoch 23/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.58it/s, loss=0.749, acc=64] 


Saved best model with accuracy: 63.96%
Epoch 23/100 - Train Loss: 1.3717, Train Acc: 62.09%, Test Loss: 1.3297, Test Acc: 63.96%


Epoch 24/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.54, acc=63]  
Epoch 24/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=0.799, acc=62.9]


Epoch 24/100 - Train Loss: 1.3406, Train Acc: 62.97%, Test Loss: 1.3861, Test Acc: 62.91%


Epoch 25/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.17, acc=63.6]
Epoch 25/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=1.04, acc=63.7] 


Epoch 25/100 - Train Loss: 1.3119, Train Acc: 63.56%, Test Loss: 1.3326, Test Acc: 63.68%


Epoch 26/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.2, acc=64.3]  
Epoch 26/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.57it/s, loss=0.954, acc=64.1]


Saved best model with accuracy: 64.09%
Epoch 26/100 - Train Loss: 1.2915, Train Acc: 64.25%, Test Loss: 1.3110, Test Acc: 64.09%


Epoch 27/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.95, acc=65.6] 
Epoch 27/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.58it/s, loss=1.23, acc=64.9] 


Saved best model with accuracy: 64.92%
Epoch 27/100 - Train Loss: 1.2476, Train Acc: 65.59%, Test Loss: 1.3029, Test Acc: 64.92%


Epoch 28/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.945, acc=65.7]
Epoch 28/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=0.987, acc=65.9]


Saved best model with accuracy: 65.90%
Epoch 28/100 - Train Loss: 1.2237, Train Acc: 65.72%, Test Loss: 1.2530, Test Acc: 65.90%


Epoch 29/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=1.4, acc=66.9]  
Epoch 29/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.61it/s, loss=0.91, acc=62.5]


Epoch 29/100 - Train Loss: 1.1829, Train Acc: 66.92%, Test Loss: 1.4578, Test Acc: 62.48%


Epoch 30/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.985, acc=67.3]
Epoch 30/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.63it/s, loss=0.828, acc=66.7]


Saved best model with accuracy: 66.68%
Epoch 30/100 - Train Loss: 1.1730, Train Acc: 67.32%, Test Loss: 1.2463, Test Acc: 66.68%


Epoch 31/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=1.08, acc=67.8] 
Epoch 31/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=1.15, acc=65.7] 


Epoch 31/100 - Train Loss: 1.1594, Train Acc: 67.82%, Test Loss: 1.2984, Test Acc: 65.65%


Epoch 32/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.03, acc=68.9] 
Epoch 32/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=0.887, acc=68.2]


Saved best model with accuracy: 68.23%
Epoch 32/100 - Train Loss: 1.1174, Train Acc: 68.88%, Test Loss: 1.2455, Test Acc: 68.23%


Epoch 33/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.23, acc=70.2] 
Epoch 33/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.57it/s, loss=1.07, acc=65.8]


Epoch 33/100 - Train Loss: 1.0695, Train Acc: 70.16%, Test Loss: 1.3562, Test Acc: 65.82%


Epoch 34/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.29, acc=71]   
Epoch 34/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=0.886, acc=67.5]


Epoch 34/100 - Train Loss: 1.0328, Train Acc: 71.04%, Test Loss: 1.3115, Test Acc: 67.45%


Epoch 35/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.21, acc=71.8] 
Epoch 35/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.724, acc=64] 


Epoch 35/100 - Train Loss: 1.0112, Train Acc: 71.77%, Test Loss: 1.6566, Test Acc: 64.01%


Epoch 36/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=1.01, acc=71.7] 
Epoch 36/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.55it/s, loss=1.04, acc=67.5] 


Epoch 36/100 - Train Loss: 1.0081, Train Acc: 71.66%, Test Loss: 1.3390, Test Acc: 67.49%


Epoch 37/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.912, acc=73]  
Epoch 37/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.63it/s, loss=0.768, acc=68.4]


Saved best model with accuracy: 68.40%
Epoch 37/100 - Train Loss: 0.9629, Train Acc: 72.96%, Test Loss: 1.2497, Test Acc: 68.40%


Epoch 38/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=1.29, acc=73.2] 
Epoch 38/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.58it/s, loss=1.2, acc=69]    


Saved best model with accuracy: 68.97%
Epoch 38/100 - Train Loss: 0.9543, Train Acc: 73.15%, Test Loss: 1.2628, Test Acc: 68.97%


Epoch 39/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=1.01, acc=74]   
Epoch 39/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.57it/s, loss=1.09, acc=66.7] 


Epoch 39/100 - Train Loss: 0.9219, Train Acc: 73.99%, Test Loss: 1.3709, Test Acc: 66.67%


Epoch 40/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.88, acc=75]   
Epoch 40/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.65it/s, loss=0.97, acc=69.3] 


Saved best model with accuracy: 69.35%
Epoch 40/100 - Train Loss: 0.8915, Train Acc: 74.96%, Test Loss: 1.2684, Test Acc: 69.35%


Epoch 41/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.796, acc=75.3]
Epoch 41/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.58it/s, loss=0.946, acc=68.3]


Epoch 41/100 - Train Loss: 0.8816, Train Acc: 75.25%, Test Loss: 1.3104, Test Acc: 68.34%


Epoch 42/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.909, acc=75.9]
Epoch 42/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.814, acc=69]  


Epoch 42/100 - Train Loss: 0.8610, Train Acc: 75.88%, Test Loss: 1.2832, Test Acc: 68.96%


Epoch 43/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.879, acc=76.5]
Epoch 43/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.757, acc=68.3]


Epoch 43/100 - Train Loss: 0.8342, Train Acc: 76.52%, Test Loss: 1.5722, Test Acc: 68.32%


Epoch 44/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.541, acc=77.2]
Epoch 44/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=0.972, acc=68.5]


Epoch 44/100 - Train Loss: 0.8131, Train Acc: 77.19%, Test Loss: 1.5642, Test Acc: 68.47%


Epoch 45/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.848, acc=77.7]
Epoch 45/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.63it/s, loss=0.763, acc=69]  


Epoch 45/100 - Train Loss: 0.7902, Train Acc: 77.75%, Test Loss: 1.3752, Test Acc: 68.98%


Epoch 46/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.89, acc=78.2] 
Epoch 46/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.61it/s, loss=0.887, acc=69.9]


Saved best model with accuracy: 69.86%
Epoch 46/100 - Train Loss: 0.7751, Train Acc: 78.15%, Test Loss: 1.4325, Test Acc: 69.86%


Epoch 47/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=1.07, acc=78.8] 
Epoch 47/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.862, acc=69.7]


Epoch 47/100 - Train Loss: 0.7543, Train Acc: 78.76%, Test Loss: 1.3315, Test Acc: 69.70%


Epoch 48/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.485, acc=79.3]
Epoch 48/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.869, acc=68.3]


Epoch 48/100 - Train Loss: 0.7383, Train Acc: 79.33%, Test Loss: 1.5348, Test Acc: 68.33%


Epoch 49/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.81, acc=79.9] 
Epoch 49/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.65it/s, loss=0.938, acc=68.8]


Epoch 49/100 - Train Loss: 0.7190, Train Acc: 79.89%, Test Loss: 1.4600, Test Acc: 68.85%


Epoch 50/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.625, acc=80.4]
Epoch 50/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=1.04, acc=69.3] 


Epoch 50/100 - Train Loss: 0.7035, Train Acc: 80.41%, Test Loss: 1.4512, Test Acc: 69.33%


Epoch 51/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.716, acc=81]  
Epoch 51/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=1.03, acc=69]  


Epoch 51/100 - Train Loss: 0.6837, Train Acc: 80.95%, Test Loss: 1.5433, Test Acc: 68.98%


Epoch 52/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=1.04, acc=81.2] 
Epoch 52/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.61it/s, loss=0.907, acc=69.5]


Epoch 52/100 - Train Loss: 0.6742, Train Acc: 81.19%, Test Loss: 1.6513, Test Acc: 69.48%


Epoch 53/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.641, acc=81.7]
Epoch 53/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.56it/s, loss=0.954, acc=69.6]


Epoch 53/100 - Train Loss: 0.6562, Train Acc: 81.69%, Test Loss: 1.4419, Test Acc: 69.59%


Epoch 54/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.56, acc=82.2] 
Epoch 54/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.66it/s, loss=0.809, acc=69.7]


Epoch 54/100 - Train Loss: 0.6439, Train Acc: 82.16%, Test Loss: 1.5054, Test Acc: 69.68%


Epoch 55/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.837, acc=82.5]
Epoch 55/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=0.687, acc=70.3]


Saved best model with accuracy: 70.28%
Epoch 55/100 - Train Loss: 0.6286, Train Acc: 82.51%, Test Loss: 1.4427, Test Acc: 70.28%


Epoch 56/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.599, acc=83.1]
Epoch 56/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.63it/s, loss=1.2, acc=70.5]  


Saved best model with accuracy: 70.50%
Epoch 56/100 - Train Loss: 0.6106, Train Acc: 83.11%, Test Loss: 1.4962, Test Acc: 70.50%


Epoch 57/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.635, acc=83.3]
Epoch 57/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.66it/s, loss=1.12, acc=69.9]


Epoch 57/100 - Train Loss: 0.6038, Train Acc: 83.33%, Test Loss: 1.4669, Test Acc: 69.86%


Epoch 58/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.733, acc=83.7]
Epoch 58/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=0.762, acc=69.6]


Epoch 58/100 - Train Loss: 0.5891, Train Acc: 83.71%, Test Loss: 1.4799, Test Acc: 69.60%


Epoch 59/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.658, acc=83.6]
Epoch 59/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.818, acc=70.6]


Saved best model with accuracy: 70.64%
Epoch 59/100 - Train Loss: 0.5923, Train Acc: 83.57%, Test Loss: 1.4614, Test Acc: 70.64%


Epoch 60/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.663, acc=84.2]
Epoch 60/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.63it/s, loss=0.739, acc=70.5]


Epoch 60/100 - Train Loss: 0.5659, Train Acc: 84.23%, Test Loss: 1.5418, Test Acc: 70.52%


Epoch 61/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.583, acc=84.5]
Epoch 61/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.63it/s, loss=0.741, acc=70.5]


Epoch 61/100 - Train Loss: 0.5598, Train Acc: 84.52%, Test Loss: 1.4965, Test Acc: 70.54%


Epoch 62/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.409, acc=85.2]
Epoch 62/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.65it/s, loss=1.16, acc=70.6] 


Epoch 62/100 - Train Loss: 0.5389, Train Acc: 85.19%, Test Loss: 1.4943, Test Acc: 70.60%


Epoch 63/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.582, acc=85.1]
Epoch 63/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=1.14, acc=71.1] 


Saved best model with accuracy: 71.07%
Epoch 63/100 - Train Loss: 0.5378, Train Acc: 85.15%, Test Loss: 1.4855, Test Acc: 71.07%


Epoch 64/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.5, acc=85.7]  
Epoch 64/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.811, acc=71.2]


Saved best model with accuracy: 71.19%
Epoch 64/100 - Train Loss: 0.5216, Train Acc: 85.72%, Test Loss: 1.4522, Test Acc: 71.19%


Epoch 65/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.275, acc=85.9]
Epoch 65/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.61it/s, loss=1.14, acc=71.1] 


Epoch 65/100 - Train Loss: 0.5169, Train Acc: 85.89%, Test Loss: 1.4914, Test Acc: 71.10%


Epoch 66/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.625, acc=85.7]
Epoch 66/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.65it/s, loss=0.981, acc=70.6]


Epoch 66/100 - Train Loss: 0.5195, Train Acc: 85.73%, Test Loss: 1.5946, Test Acc: 70.61%


Epoch 67/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.341, acc=86.1]
Epoch 67/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=1.06, acc=70.2] 


Epoch 67/100 - Train Loss: 0.5060, Train Acc: 86.09%, Test Loss: 1.5462, Test Acc: 70.22%


Epoch 68/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.604, acc=86.4]
Epoch 68/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=1.1, acc=71.2]  


Epoch 68/100 - Train Loss: 0.4913, Train Acc: 86.35%, Test Loss: 1.5109, Test Acc: 71.15%


Epoch 69/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.497, acc=86.6]
Epoch 69/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.60it/s, loss=1.2, acc=70.8]  


Epoch 69/100 - Train Loss: 0.4851, Train Acc: 86.60%, Test Loss: 1.5665, Test Acc: 70.80%


Epoch 70/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.51, acc=86.7] 
Epoch 70/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=1.19, acc=70.8] 


Epoch 70/100 - Train Loss: 0.4835, Train Acc: 86.75%, Test Loss: 1.6291, Test Acc: 70.78%


Epoch 71/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.423, acc=87.1]
Epoch 71/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.55it/s, loss=0.95, acc=71.3] 


Saved best model with accuracy: 71.26%
Epoch 71/100 - Train Loss: 0.4736, Train Acc: 87.09%, Test Loss: 1.5267, Test Acc: 71.26%


Epoch 72/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.49, acc=87.5] 
Epoch 72/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.57it/s, loss=0.948, acc=71.4]


Saved best model with accuracy: 71.39%
Epoch 72/100 - Train Loss: 0.4541, Train Acc: 87.45%, Test Loss: 1.4744, Test Acc: 71.39%


Epoch 73/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.383, acc=87.5]
Epoch 73/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.55it/s, loss=1.18, acc=71.2] 


Epoch 73/100 - Train Loss: 0.4577, Train Acc: 87.51%, Test Loss: 1.5496, Test Acc: 71.25%


Epoch 74/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.484, acc=87.6]
Epoch 74/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.65it/s, loss=0.981, acc=71.2]


Epoch 74/100 - Train Loss: 0.4528, Train Acc: 87.59%, Test Loss: 1.4922, Test Acc: 71.24%


Epoch 75/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.479, acc=88.1]
Epoch 75/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=1.2, acc=71.6]  


Saved best model with accuracy: 71.56%
Epoch 75/100 - Train Loss: 0.4356, Train Acc: 88.07%, Test Loss: 1.6414, Test Acc: 71.56%


Epoch 76/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.411, acc=87.8]
Epoch 76/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.65it/s, loss=1.03, acc=71.2]


Epoch 76/100 - Train Loss: 0.4432, Train Acc: 87.78%, Test Loss: 1.5933, Test Acc: 71.25%


Epoch 77/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.352, acc=88]  
Epoch 77/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=0.733, acc=71.3]


Epoch 77/100 - Train Loss: 0.4391, Train Acc: 87.99%, Test Loss: 1.6608, Test Acc: 71.30%


Epoch 78/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.582, acc=88.1]
Epoch 78/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.57it/s, loss=0.797, acc=71.5]


Epoch 78/100 - Train Loss: 0.4353, Train Acc: 88.14%, Test Loss: 1.5312, Test Acc: 71.50%


Epoch 79/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.744, acc=88.5]
Epoch 79/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.59it/s, loss=0.875, acc=71.3]


Epoch 79/100 - Train Loss: 0.4220, Train Acc: 88.51%, Test Loss: 1.5095, Test Acc: 71.35%


Epoch 80/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.391, acc=88.7]
Epoch 80/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.53it/s, loss=0.952, acc=71.5]


Epoch 80/100 - Train Loss: 0.4204, Train Acc: 88.72%, Test Loss: 1.5269, Test Acc: 71.47%


Epoch 81/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.277, acc=88.8]
Epoch 81/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=0.95, acc=71.3]


Epoch 81/100 - Train Loss: 0.4126, Train Acc: 88.82%, Test Loss: 1.5439, Test Acc: 71.26%


Epoch 82/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.514, acc=89.1]
Epoch 82/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.924, acc=71.8]


Saved best model with accuracy: 71.75%
Epoch 82/100 - Train Loss: 0.4003, Train Acc: 89.08%, Test Loss: 1.5402, Test Acc: 71.75%


Epoch 83/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.258, acc=89.2]
Epoch 83/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=1.01, acc=71.3] 


Epoch 83/100 - Train Loss: 0.3958, Train Acc: 89.18%, Test Loss: 1.5395, Test Acc: 71.32%


Epoch 84/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.297, acc=89.3]
Epoch 84/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.56it/s, loss=1.01, acc=72]   


Saved best model with accuracy: 72.03%
Epoch 84/100 - Train Loss: 0.3930, Train Acc: 89.28%, Test Loss: 1.5023, Test Acc: 72.03%


Epoch 85/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.342, acc=89.1]
Epoch 85/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=0.833, acc=71.8]


Epoch 85/100 - Train Loss: 0.3995, Train Acc: 89.09%, Test Loss: 1.5818, Test Acc: 71.81%


Epoch 86/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, loss=0.391, acc=89.3]
Epoch 86/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.64it/s, loss=0.752, acc=72.1]


Saved best model with accuracy: 72.09%
Epoch 86/100 - Train Loss: 0.3902, Train Acc: 89.25%, Test Loss: 1.4990, Test Acc: 72.09%


Epoch 87/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.475, acc=89.5]
Epoch 87/100 [Test]: 100%|██████████| 79/79 [00:05<00:00, 14.62it/s, loss=0.851, acc=72.2]


Saved best model with accuracy: 72.19%
Epoch 87/100 - Train Loss: 0.3863, Train Acc: 89.55%, Test Loss: 1.4954, Test Acc: 72.19%


Epoch 88/100 [Train]: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, loss=0.759, acc=89.5]
Epoch 88/100 [Test]:   6%|▋         | 5/79 [00:00<00:05, 12.36it/s, loss=1.59, acc=73.1]

# Knowledge Distillation

In [None]:
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=4.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, student_outputs, teacher_outputs, targets):
        hard_loss = self.criterion(student_outputs, targets)
        soft_student = F.log_softmax(student_outputs / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_outputs / self.temperature, dim=1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
        return (1 - self.alpha) * hard_loss + self.alpha * soft_loss

# Training Students (From Scratch vs Distillation)

In [None]:
student_model_baseline = create_student_model().to(device)

print("Training student model from scratch without distillation...")
student_baseline_path = './student_baseline_model.pth'
student_model_baseline, student_baseline_history = train_model(
    student_model_baseline, train_loader, test_loader, epochs=100,
    save_path=student_baseline_path, is_teacher=False
)

In [None]:
def train_student_with_distillation(student_model, teacher_model, train_loader, test_loader, epochs, alpha=0.5, temperature=4.0, save_path=None):
    distillation_criterion = DistillationLoss(alpha=alpha, temperature=temperature)
    standard_criterion = nn.CrossEntropyLoss()
    
    optimizer = optim.AdamW(student_model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler()
    
    best_acc = 0.0
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': []
    }
    
    for epoch in range(epochs):
        student_model.train()
        teacher_model.eval()
        
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for inputs, targets in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            with autocast():
                student_outputs = student_model(inputs)
                with torch.no_grad():
                    teacher_outputs = teacher_model(inputs)
                
                loss = distillation_criterion(student_outputs, teacher_outputs, targets)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item() * inputs.size(0)
            _, predicted = student_outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

            progress_bar.set_postfix(
                {"loss": loss.item(), "acc": 100 * train_correct / train_total}
            )
        
        train_loss = train_loss / len(train_loader.dataset)
        train_acc = 100.0 * train_correct / train_total

        student_model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        progress_bar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} [Valid]")
        with torch.no_grad():
            for inputs, targets in progress_bar:
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = student_model(inputs)
                loss = standard_criterion(outputs, targets)
                
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()

                progress_bar.set_postfix(
                    {"loss": loss.item(), "acc": 100 * test_correct / test_total}
                )
        
        test_loss = test_loss / len(test_loader.dataset)
        test_acc = 100.0 * test_correct / test_total
        
        scheduler.step()
        
        if test_acc > best_acc and save_path:
            best_acc = test_acc
            torch.save(student_model.state_dict(), save_path)
            print(f"Saved best model with accuracy: {best_acc:.2f}%")
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    
    return student_model, history

print("Training student model with knowledge distillation...")
student_distill_path = './student_distill_model.pth'
student_model_distill, student_distill_history = train_student_with_distillation(
    student_model, teacher_model, train_loader, test_loader, 
    epochs=100, alpha=0.5, temperature=4.0,
    save_path=student_distill_path
)

# Evaluation and Comparison

In [None]:
def evaluate_model(model, test_loader, model_name):
    model.eval()
    correct = 0
    total = 0
    
    class_correct = list(0. for i in range(100))
    class_total = list(0. for i in range(100))
    
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc=f"Evaluating {model_name}"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            c = (predicted == targets).squeeze()
            for i in range(targets.size(0)):
                label = targets[i].item()
                class_correct[label] += c[i].item()
                class_total[label] += 1
    
    overall_acc = 100.0 * correct / total
    print(f"{model_name} - Test Accuracy: {overall_acc:.2f}%")
    
    class_accuracies = []
    for i in range(100):
        if class_total[i] > 0:
            class_acc = 100.0 * class_correct[i] / class_total[i]
            class_accuracies.append(class_acc)
    
    avg_class_acc = sum(class_accuracies) / len(class_accuracies)
    print(f"{model_name} - Average Class Accuracy: {avg_class_acc:.2f}%")
    
    return overall_acc, avg_class_acc

teacher_model.load_state_dict(torch.load(teacher_path))
student_model_baseline.load_state_dict(torch.load(student_baseline_path))
student_model_distill.load_state_dict(torch.load(student_distill_path))

teacher_acc, teacher_class_acc = evaluate_model(teacher_model, test_loader, "Teacher (ResNet-50)")
baseline_acc, baseline_class_acc = evaluate_model(student_model_baseline, test_loader, "Student Baseline (ResNet-18)")
distill_acc, distill_class_acc = evaluate_model(student_model_distill, test_loader, "Student with Distillation (ResNet-18)")

# Visualization and Analysis

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(student_baseline_history['test_acc'], label='Student Baseline')
plt.plot(student_distill_history['test_acc'], label='Student with Distillation')
plt.axhline(y=teacher_acc, color='r', linestyle='--', label='Teacher')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(student_baseline_history['test_loss'], label='Student Baseline')
plt.plot(student_distill_history['test_loss'], label='Student with Distillation')
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

models = ['Teacher (ResNet-50)', 'Student Baseline (ResNet-18)', 'Student with Distillation (ResNet-18)']
accuracies = [teacher_acc, baseline_acc, distill_acc]
model_sizes = [count_parameters(teacher_model), count_parameters(student_model_baseline), count_parameters(student_model_distill)]

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.bar(models, accuracies, color=['blue', 'orange', 'green'])
plt.title('Model Accuracy Comparison')
plt.ylabel('Accuracy (%)')
plt.xticks(rotation=15, ha='right')
plt.ylim(0, 100)
for i, v in enumerate(accuracies):
    plt.text(i, v + 1, f"{v:.2f}%", ha='center')

plt.subplot(1, 2, 2)
sizes_in_millions = [s / 1000000 for s in model_sizes]
plt.bar(models, sizes_in_millions, color=['blue', 'orange', 'green'])
plt.title('Model Size Comparison')
plt.ylabel('Parameters (millions)')
plt.xticks(rotation=15, ha='right')
for i, v in enumerate(sizes_in_millions):
    plt.text(i, v + 0.1, f"{v:.2f}M", ha='center')

plt.tight_layout()
plt.show()

# Inference Speed Comparison

In [None]:
def measure_inference_time(model, input_size=(128, 3, 32, 32), iterations=100):
    model.eval()
    x = torch.randn(input_size).to(device)
    
    # Warm-up
    with torch.no_grad():
        for _ in range(10):
            _ = model(x)
    
    # Measure
    torch.cuda.synchronize()
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(iterations):
            _ = model(x)
    
    torch.cuda.synchronize()
    end_time = time.time()
    
    elapsed_time = end_time - start_time
    return elapsed_time / iterations * 1000  # Convert to ms per batch

teacher_time = measure_inference_time(teacher_model)
student_baseline_time = measure_inference_time(student_model_baseline)
student_distill_time = measure_inference_time(student_model_distill)

print(f"Teacher (ResNet-50) inference time: {teacher_time:.2f} ms/batch")
print(f"Student Baseline (ResNet-18) inference time: {student_baseline_time:.2f} ms/batch")
print(f"Student with Distillation (ResNet-18) inference time: {student_distill_time:.2f} ms/batch")
print(f"Speed-up: {teacher_time / student_distill_time:.2f}x")

plt.figure(figsize=(8, 6))
inference_times = [teacher_time, student_baseline_time, student_distill_time]
plt.bar(models, inference_times, color=['blue', 'orange', 'green'])
plt.title('Inference Time Comparison')
plt.ylabel('Time per batch (ms)')
plt.xticks(rotation=15, ha='right')
for i, v in enumerate(inference_times):
    plt.text(i, v + 0.2, f"{v:.2f} ms", ha='center')
plt.tight_layout()
plt.show()