# Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler

import numpy as np
import matplotlib.pyplot as plt
import time
import os
from tqdm.notebook import tqdm

In [None]:
torch.manual_seed(42)
np.random.seed(42)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load Data

In [None]:
mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                             download=True, transform=train_transform)

test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                            download=True, transform=test_transform)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")
print(f"Number of classes: {len(train_dataset.classes)}")

# Models (Teacher and Student)

In [None]:
def create_teacher_model():
    model = models.resnet50(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 100)
    
    return model

def create_student_model():
    model = models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 100)

    return model

In [None]:
teacher_model = create_teacher_model().to(device)
student_model = create_student_model().to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Teacher model (ResNet-50) parameters: {count_parameters(teacher_model):,}")
print(f"Student model (ResNet-18) parameters: {count_parameters(student_model):,}")
print(f"Compression ratio: {count_parameters(teacher_model) / count_parameters(student_model):.2f}x")

# Training the Teacher Model

In [None]:
def train_model(model, train_loader, test_loader, epochs, save_path=None):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler()
    
    best_acc = 0.0
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': []
    }
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for inputs, targets in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

            progress_bar.set_postfix(
                {"loss": loss.item(), "acc": 100 * train_correct / train_total}
            )
        
        train_loss = train_loss / len(train_loader.dataset)
        train_acc = 100.0 * train_correct / train_total
        
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} [Valid]"):
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()

                progress_bar.set_postfix(
                    {"loss": loss.item(), "acc": 100 * test_correct / test_total}
                )
        
        test_loss = test_loss / len(test_loader.dataset)
        test_acc = 100.0 * test_correct / test_total

        scheduler.step()
        
        if test_acc > best_acc and save_path:
            best_acc = test_acc
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with accuracy: {best_acc:.2f}%")
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    
    return model, history

In [None]:
teacher_path = './teacher_model.pth'
if os.path.exists(teacher_path):
    print("Loading pre-trained teacher model...")
    teacher_model.load_state_dict(torch.load(teacher_path))
else:
    print("Training teacher model...")
    teacher_model, teacher_history = train_model(
        teacher_model, train_loader, test_loader, epochs=100, save_path=teacher_path, is_teacher=True
    )

# Knowledge Distillation

In [None]:
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=4.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, student_outputs, teacher_outputs, targets):
        hard_loss = self.criterion(student_outputs, targets)
        soft_student = F.log_softmax(student_outputs / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_outputs / self.temperature, dim=1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
        return (1 - self.alpha) * hard_loss + self.alpha * soft_loss

# Training Students (From Scratch vs Distillation)

In [None]:
student_model_baseline = create_student_model().to(device)

print("Training student model from scratch without distillation...")
student_baseline_path = './student_baseline_model.pth'
student_model_baseline, student_baseline_history = train_model(
    student_model_baseline, train_loader, test_loader, epochs=100,
    save_path=student_baseline_path, is_teacher=False
)

In [None]:
def train_student_with_distillation(student_model, teacher_model, train_loader, test_loader, epochs, alpha=0.5, temperature=4.0, save_path=None):
    distillation_criterion = DistillationLoss(alpha=alpha, temperature=temperature)
    standard_criterion = nn.CrossEntropyLoss()
    
    optimizer = optim.AdamW(student_model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler()
    
    best_acc = 0.0
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': []
    }
    
    for epoch in range(epochs):
        student_model.train()
        teacher_model.eval()
        
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            with autocast():
                student_outputs = student_model(inputs)
                with torch.no_grad():
                    teacher_outputs = teacher_model(inputs)
                
                loss = distillation_criterion(student_outputs, teacher_outputs, targets)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item() * inputs.size(0)
            _, predicted = student_outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
        
        train_loss = train_loss / len(train_loader.dataset)
        train_acc = 100.0 * train_correct / train_total

        student_model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} - Testing"):
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = student_model(inputs)
                loss = standard_criterion(outputs, targets)
                
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()
        
        test_loss = test_loss / len(test_loader.dataset)
        test_acc = 100.0 * test_correct / test_total
        
        scheduler.step()
        
        if test_acc > best_acc and save_path:
            best_acc = test_acc
            torch.save(student_model.state_dict(), save_path)
            print(f"Saved best model with accuracy: {best_acc:.2f}%")
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    
    return student_model, history

print("Training student model with knowledge distillation...")
student_distill_path = './student_distill_model.pth'
student_model_distill, student_distill_history = train_student_with_distillation(
    student_model, teacher_model, train_loader, test_loader, 
    epochs=100, alpha=0.5, temperature=4.0,
    save_path=student_distill_path
)

## 7. Evaluation and Comparison

In [None]:
def evaluate_model(model, test_loader, model_name):
    model.eval()
    correct = 0
    total = 0
    
    class_correct = list(0. for i in range(100))
    class_total = list(0. for i in range(100))
    
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc=f"Evaluating {model_name}"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            c = (predicted == targets).squeeze()
            for i in range(targets.size(0)):
                label = targets[i].item()
                class_correct[label] += c[i].item()
                class_total[label] += 1
    
    overall_acc = 100.0 * correct / total
    print(f"{model_name} - Test Accuracy: {overall_acc:.2f}%")
    
    class_accuracies = []
    for i in range(100):
        if class_total[i] > 0:
            class_acc = 100.0 * class_correct[i] / class_total[i]
            class_accuracies.append(class_acc)
    
    avg_class_acc = sum(class_accuracies) / len(class_accuracies)
    print(f"{model_name} - Average Class Accuracy: {avg_class_acc:.2f}%")
    
    return overall_acc, avg_class_acc

teacher_model.load_state_dict(torch.load(teacher_path))
student_model_baseline.load_state_dict(torch.load(student_baseline_path))
student_model_distill.load_state_dict(torch.load(student_distill_path))

teacher_acc, teacher_class_acc = evaluate_model(teacher_model, test_loader, "Teacher (ResNet-50)")
baseline_acc, baseline_class_acc = evaluate_model(student_model_baseline, test_loader, "Student Baseline (ResNet-18)")
distill_acc, distill_class_acc = evaluate_model(student_model_distill, test_loader, "Student with Distillation (ResNet-18)")

## 8. Visualization and Analysis

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(student_baseline_history['test_acc'], label='Student Baseline')
plt.plot(student_distill_history['test_acc'], label='Student with Distillation')
plt.axhline(y=teacher_acc, color='r', linestyle='--', label='Teacher')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(student_baseline_history['test_loss'], label='Student Baseline')
plt.plot(student_distill_history['test_loss'], label='Student with Distillation')
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

models = ['Teacher (ResNet-50)', 'Student Baseline (ResNet-18)', 'Student with Distillation (ResNet-18)']
accuracies = [teacher_acc, baseline_acc, distill_acc]
model_sizes = [count_parameters(teacher_model), count_parameters(student_model_baseline), count_parameters(student_model_distill)]

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.bar(models, accuracies, color=['blue', 'orange', 'green'])
plt.title('Model Accuracy Comparison')
plt.ylabel('Accuracy (%)')
plt.xticks(rotation=15, ha='right')
plt.ylim(0, 100)
for i, v in enumerate(accuracies):
    plt.text(i, v + 1, f"{v:.2f}%", ha='center')

plt.subplot(1, 2, 2)
sizes_in_millions = [s / 1000000 for s in model_sizes]
plt.bar(models, sizes_in_millions, color=['blue', 'orange', 'green'])
plt.title('Model Size Comparison')
plt.ylabel('Parameters (millions)')
plt.xticks(rotation=15, ha='right')
for i, v in enumerate(sizes_in_millions):
    plt.text(i, v + 0.1, f"{v:.2f}M", ha='center')

plt.tight_layout()
plt.show()

## 9. Inference Speed Comparison

In [None]:
def measure_inference_time(model, input_size=(128, 3, 32, 32), iterations=100):
    model.eval()
    x = torch.randn(input_size).to(device)
    
    # Warm-up
    with torch.no_grad():
        for _ in range(10):
            _ = model(x)
    
    # Measure
    torch.cuda.synchronize()
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(iterations):
            _ = model(x)
    
    torch.cuda.synchronize()
    end_time = time.time()
    
    elapsed_time = end_time - start_time
    return elapsed_time / iterations * 1000  # Convert to ms per batch

teacher_time = measure_inference_time(teacher_model)
student_baseline_time = measure_inference_time(student_model_baseline)
student_distill_time = measure_inference_time(student_model_distill)

print(f"Teacher (ResNet-50) inference time: {teacher_time:.2f} ms/batch")
print(f"Student Baseline (ResNet-18) inference time: {student_baseline_time:.2f} ms/batch")
print(f"Student with Distillation (ResNet-18) inference time: {student_distill_time:.2f} ms/batch")
print(f"Speed-up: {teacher_time / student_distill_time:.2f}x")

plt.figure(figsize=(8, 6))
inference_times = [teacher_time, student_baseline_time, student_distill_time]
plt.bar(models, inference_times, color=['blue', 'orange', 'green'])
plt.title('Inference Time Comparison')
plt.ylabel('Time per batch (ms)')
plt.xticks(rotation=15, ha='right')
for i, v in enumerate(inference_times):
    plt.text(i, v + 0.2, f"{v:.2f} ms", ha='center')
plt.tight_layout()
plt.show()