In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import tqdm

# Define temperature and alpha for distillation
TEMPERATURE = 3.0
ALPHA = 0.9

weights101 = models.ResNet152_Weights.DEFAULT
transforms = weights101.transforms()

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Teacher and Student Models
teacher_model = models.resnet152(
    weights=weights101
).to(device)

print("======TEACHER======")
print(teacher_model)

student_model = models.resnet50(weights=None).to(device)

print("======STUDENT======")
print(student_model)

# Freeze teacher model parameters
for param in teacher_model.parameters():
    param.requires_grad = False

# Define loss functions
criterion = nn.CrossEntropyLoss()
distillation_loss = nn.KLDivLoss(reduction='batchmean')

# Define optimizer for the student model
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# Transform and load your dataset
train_dataset = datasets.ImageFolder(
    '/kaggle/input/imagenetmini-1000/imagenet-mini/train', 
    transform=transforms
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

def accuracy(output, target, topk=(1,)):
    """Computes the top-k accuracy for the specified values of k."""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# Knowledge Distillation Function
def distillation_loss_fn(student_logits, teacher_logits, labels):
    # Regular cross-entropy loss
    loss = criterion(student_logits, labels)
    
    # Softened outputs for distillation
    distillation_loss_value = distillation_loss(
        torch.nn.functional.log_softmax(student_logits / TEMPERATURE, dim=1),
        torch.nn.functional.softmax(teacher_logits / TEMPERATURE, dim=1)
    ) * (TEMPERATURE ** 2)
    
    # Weighted sum of both losses
    return ALPHA * distillation_loss_value + (1 - ALPHA) * loss

# Training Loop
def train_student(teacher_model, student_model, dataloader, optimizer, epochs=10):
    student_model.train()
    teacher_model.eval()
    
    best_loss = float('inf')
    best_acc = 0.0
    
    history = {
        'loss': [],
        'acc1': [],
        'acc5': []
    }
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        epoch_acc1 = 0.0
        epoch_acc5 = 0.0
        pbar = tqdm.tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}')
        
        for i, (images, labels) in enumerate(pbar):
            inputs, labels = images.to(device), labels.to(device)
            
            # Get teacher predictions (detach to avoid computing gradients on teacher)
            with torch.no_grad():
                teacher_logits = teacher_model(inputs)
            
            # Get student predictions
            student_logits = student_model(inputs)
            
            # Compute loss
            loss = distillation_loss_fn(student_logits, teacher_logits, labels)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            acc1, acc5 = accuracy(student_logits, labels, (1, 5))
            acc1 = acc1.detach().cpu()
            acc5 = acc5.detach().cpu()
            
            # Update running metrics
            batch_size = images.size(0)
            epoch_loss += loss.item() * batch_size
            epoch_acc1 += acc1 * batch_size
            epoch_acc5 += acc5 * batch_size

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc1': f'{acc1.item():.2f}%',
                'acc5': f'{acc5.item():.2f}%'
            })
            
        # Calculate epoch metrics
        num_samples = len(train_loader)
        epoch_loss /= num_samples
        epoch_acc1 /= num_samples
        epoch_acc5 /= num_samples
        
        # Update history
        history['loss'].append(epoch_loss)
        history['acc1'].append(epoch_acc1)
        history['acc5'].append(epoch_acc5)

        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{max_epoch}:')
        print(f'Loss: {epoch_loss:.4f}')
        print(f'Accuracy@1: {epoch_acc1.item():.2f}%')
        print(f'Accuracy@5: {epoch_acc5.item():.2f}%')
        
        # Save best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
                'accuracy': epoch_acc1
            }, 'best_model.pth')
        
    return student_model

# Train the student model
trained_student_model = train_student(teacher_model, student_model, train_loader, optimizer, epochs=100)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

Epoch 1/100:   2%|▏         | 13/543 [00:16<10:13,  1.16s/it, loss=1.1794, acc1=0.00%, acc5=1.56%]

In [None]:
from torch.utils.data import DataLoader
from torchvision import models
import torch

# Function to evaluate model accuracy
def evaluate_model(model, dataloader, device):
    model.eval()
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            
            # Top-1 accuracy
            _, predicted = outputs.max(1)
            correct_top1 += (predicted == labels).sum().item()
            
            # Top-5 accuracy
            _, top5_pred = outputs.topk(5, dim=1)
            correct_top5 += (top5_pred == labels.view(-1, 1)).sum().item()
            
            total += labels.size(0)
    
    top1_accuracy = 100 * correct_top1 / total
    top5_accuracy = 100 * correct_top5 / total
    return top1_accuracy, top5_accuracy

# DataLoader for the validation set (adjust the path accordingly)
val_dataset = datasets.ImageFolder('/kaggle/input/imagenetmini-1000/imagenet-mini/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=40, shuffle=False)

top1_teacher, top5_teacher = evaluate_model(teacher_model, val_loader, device)
print(f"Teacher Model - Top-1 Accuracy: {top1_teacher:.2f}%, Top-5 Accuracy: {top5_teacher:.2f}%")

# Train student_model here without distillation, then evaluate
student_model_wokd = models.resnet50(pretrained=True).to(device)
top1_student_baseline, top5_student_baseline = evaluate_model(student_model_wokd, val_loader, device)
print(f"Student Model Baseline - Top-1 Accuracy: {top1_student_baseline:.2f}%, Top-5 Accuracy: {top5_student_baseline:.2f}%")

# Train the student model with knowledge distillation and evaluate
# (Use the previously provided code for distillation training)
top1_student_distilled, top5_student_distilled = evaluate_model(trained_student_model, val_loader, device)
print(f"Student Model with Distillation - Top-1 Accuracy: {top1_student_distilled:.2f}%, Top-5 Accuracy: {top5_student_distilled:.2f}%")
