In [1]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Subset, Dataset
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import copy
import os
from torchinfo import summary

In [3]:
def set_seed(seed=0):
    """Sets the seed for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

os.makedirs('models_resnet_teacher', exist_ok=True)
os.makedirs('results_resnet_teacher', exist_ok=True)

Using device: cuda


In [5]:
MNIST_IMG_SIZE = 28
MNIST_INPUT_DIM = MNIST_IMG_SIZE * MNIST_IMG_SIZE
MNIST_OUTPUT_DIM = 10
MNIST_MEAN = (0.1307,)
MNIST_STD = (0.3081,)

STUDENT_HIDDEN = 400
STUDENT_DROPOUT_INPUT = 0.0
STUDENT_DROPOUT_HIDDEN = 0.0

BATCH_SIZE = 128

LR_RESNET = 0.001
LR_MLP = 0.01
WEIGHT_DECAY_RESNET = 1e-4
WEIGHT_DECAY_MLP = 1e-5
MOMENTUM_MLP = 0.9
OPTIMIZER_MLP = 'SGD'
EPOCHS_TEACHER_RESNET = 30
EPOCHS_STUDENT_MLP = 30
PRINT_FREQ = 100

DEFAULT_T = 4
ALPHA = 0.5

## 3. Data Loading and Preprocessing


In [6]:


transform_basic = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(MNIST_MEAN, MNIST_STD)
])

train_val_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_basic)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_basic)
num_total = len(train_val_dataset)
num_train = int(0.95 * num_total)
num_val = num_total - num_train
train_indices, val_indices = torch.utils.data.random_split(range(num_total), [num_train, num_val])

train_dataset = Subset(train_val_dataset, train_indices.indices)
val_dataset = Subset(train_val_dataset, val_indices.indices)

print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

Training set size: 57000
Validation set size: 3000
Test set size: 10000


In [7]:

class ResNet18_MNIST(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.resnet = models.resnet18(weights=None)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = nn.Identity()
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        elif x.dim() == 2:
             x = x.view(x.size(0), 1, MNIST_IMG_SIZE, MNIST_IMG_SIZE)

        return self.resnet(x)

class StudentNet(nn.Module):
    def __init__(self, input_dim=MNIST_INPUT_DIM, hidden_units=STUDENT_HIDDEN, output_dim=MNIST_OUTPUT_DIM,
                 dropout_input=STUDENT_DROPOUT_INPUT, dropout_hidden=STUDENT_DROPOUT_HIDDEN):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_units)
        self.fc2 = nn.Linear(hidden_units, output_dim)
        self.dropout_input_rate = dropout_input
        self.dropout_hidden_rate = dropout_hidden

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.dropout(x, p=self.dropout_input_rate, training=self.training)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.dropout_hidden_rate, training=self.training)
        x = self.fc2(x)
        return x

def evaluate_accuracy(model: nn.Module, dataloader: DataLoader, device: torch.device):
    """Evaluates model accuracy on the provided dataloader."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def distillation_loss_fn(student_logits, teacher_logits, hard_labels, temp, alpha):
    """Computes the distillation loss."""
    if alpha < 0 or alpha > 1:
        raise ValueError("alpha must be between 0 and 1")
    soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_logits/temp, dim=1),
                                                     F.softmax(teacher_logits/temp, dim=1).detach()) * (temp * temp)
    hard_loss = nn.CrossEntropyLoss()(student_logits, hard_labels)
    total_loss = alpha * soft_loss + (1 - alpha) * hard_loss
    return total_loss


In [8]:

teacher_model_path = 'models_resnet_teacher/resnet18_teacher_final.pth'
results_teacher_path = 'results_resnet_teacher/resnet18_teacher_results.npy'
teacher_model = ResNet18_MNIST().to(device)

print("--- ResNet-18 Teacher Model Summary ---")

summary(teacher_model, input_size=(BATCH_SIZE, 1, MNIST_IMG_SIZE, MNIST_IMG_SIZE))

if os.path.exists(teacher_model_path):
    print(f"\nLoading pre-trained ResNet-18 teacher model from {teacher_model_path}")
    teacher_model.load_state_dict(torch.load(teacher_model_path, map_location=device))
    teacher_train_results = np.load(results_teacher_path, allow_pickle=True).item()
else:
    print("\n--- Training ResNet-18 Teacher Model ---")
    teacher_model.to(device)
    optimizer = optim.Adam(teacher_model.parameters(), lr=LR_RESNET, weight_decay=WEIGHT_DECAY_RESNET)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_TEACHER_RESNET)

    criterion = nn.CrossEntropyLoss()
    teacher_train_losses = []
    teacher_train_accs = []
    teacher_val_accs = []
    best_val_acc = 0.0

    for epoch in range(EPOCHS_TEACHER_RESNET):
        teacher_model.train()
        batch_losses = []
        batch_accs = []
        progress_bar = tqdm(train_loader, desc=f"ResNet Epoch {epoch+1}/{EPOCHS_TEACHER_RESNET}", leave=False)

        for i, (inputs, labels) in enumerate(progress_bar):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = teacher_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            accuracy = (predicted == labels).float().mean().item()
            batch_losses.append(loss.item())
            batch_accs.append(accuracy)

            if (i + 1) % PRINT_FREQ == 0:
                 progress_bar.set_postfix(loss=f"{np.mean(batch_losses[-PRINT_FREQ:]):.3f}", acc=f"{np.mean(batch_accs[-PRINT_FREQ:]):.3f}")

        epoch_train_loss = np.mean(batch_losses)
        epoch_train_acc = np.mean(batch_accs)
        teacher_train_losses.append(epoch_train_loss)
        teacher_train_accs.append(epoch_train_acc)

        epoch_val_acc = evaluate_accuracy(teacher_model, val_loader, device)
        teacher_val_accs.append(epoch_val_acc)
        print(f"Epoch {epoch+1}/{EPOCHS_TEACHER_RESNET} - Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}, Val Acc: {epoch_val_acc:.2f}%")
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            print(f"  Saving best model with Val Acc: {best_val_acc:.2f}%")
            torch.save(teacher_model.state_dict(), teacher_model_path)

        scheduler.step()

    print("--- Finished ResNet-18 Teacher Training ---")
    if os.path.exists(teacher_model_path):
         print(f"Loading best teacher model weights from {teacher_model_path}")
         teacher_model.load_state_dict(torch.load(teacher_model_path, map_location=device))
    teacher_train_results = {'train_loss': teacher_train_losses, 'train_acc': teacher_train_accs, 'val_acc': teacher_val_accs}
    np.save(results_teacher_path, teacher_train_results)
teacher_test_accuracy = evaluate_accuracy(teacher_model, test_loader, device)
print(f"\nFinal ResNet-18 Teacher Model Test Accuracy: {teacher_test_accuracy:.2f}%")

--- ResNet-18 Teacher Model Summary ---

Loading pre-trained ResNet-18 teacher model from models_resnet_teacher/resnet18_teacher_final.pth

Final ResNet-18 Teacher Model Test Accuracy: 99.38%


## 7. Train MLP Student Baseline (Against Hard Labels)

In [10]:
student_baseline_model_path = 'models_resnet_teacher/mlp_student_baseline_final.pth'
results_student_baseline_path = 'results_resnet_teacher/mlp_student_baseline_results.npy'
student_baseline = StudentNet().to(device) # Fresh student model

print("\n--- MLP Student Model Summary ---")
summary(student_baseline, input_size=(BATCH_SIZE, MNIST_INPUT_DIM))

if os.path.exists(student_baseline_model_path):
    print(f"\nLoading pre-trained MLP student baseline model from {student_baseline_model_path}")
    student_baseline.load_state_dict(torch.load(student_baseline_model_path, map_location=device))
    student_baseline_train_results = np.load(results_student_baseline_path, allow_pickle=True).item()
else:
    print("\n--- Training MLP Student Model (Baseline) ---")
    student_baseline.to(device)
    optimizer = optim.SGD(student_baseline.parameters(), lr=LR_MLP, momentum=MOMENTUM_MLP, weight_decay=WEIGHT_DECAY_MLP)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

    criterion = nn.CrossEntropyLoss()
    student_baseline_train_losses = []
    student_baseline_train_accs = []
    student_baseline_val_accs = []
    best_val_acc = 0.0

    for epoch in range(EPOCHS_STUDENT_MLP):
        student_baseline.train()
        batch_losses = []
        batch_accs = []
        progress_bar = tqdm(train_loader, desc=f"MLP BL Epoch {epoch+1}/{EPOCHS_STUDENT_MLP}", leave=False)

        for i, (inputs, labels) in enumerate(progress_bar):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = student_baseline(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            accuracy = (predicted == labels).float().mean().item()
            batch_losses.append(loss.item())
            batch_accs.append(accuracy)

            if (i + 1) % PRINT_FREQ == 0:
                 progress_bar.set_postfix(loss=f"{np.mean(batch_losses[-PRINT_FREQ:]):.3f}", acc=f"{np.mean(batch_accs[-PRINT_FREQ:]):.3f}")

        epoch_train_loss = np.mean(batch_losses)
        epoch_train_acc = np.mean(batch_accs)
        student_baseline_train_losses.append(epoch_train_loss)
        student_baseline_train_accs.append(epoch_train_acc)

        epoch_val_acc = evaluate_accuracy(student_baseline, val_loader, device)
        student_baseline_val_accs.append(epoch_val_acc)
        print(f"Epoch {epoch+1}/{EPOCHS_STUDENT_MLP} - Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}, Val Acc: {epoch_val_acc:.2f}%")
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            print(f"  Saving best model with Val Acc: {best_val_acc:.2f}%")
            torch.save(student_baseline.state_dict(), student_baseline_model_path)

        scheduler.step()

    print("--- Finished MLP Student Baseline Training ---")

    if os.path.exists(student_baseline_model_path):
        print(f"Loading best baseline student model weights from {student_baseline_model_path}")
        student_baseline.load_state_dict(torch.load(student_baseline_model_path, map_location=device))
    student_baseline_train_results = {'train_loss': student_baseline_train_losses, 'train_acc': student_baseline_train_accs, 'val_acc': student_baseline_val_accs}
    np.save(results_student_baseline_path, student_baseline_train_results)
student_baseline_test_accuracy = evaluate_accuracy(student_baseline, test_loader, device)
print(f"\nFinal MLP Student Model (Baseline) Test Accuracy: {student_baseline_test_accuracy:.2f}%")


--- MLP Student Model Summary ---

--- Training MLP Student Model (Baseline) ---


MLP BL Epoch 1/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 1/30 - Train Loss: 0.3486, Train Acc: 0.8997, Val Acc: 94.17%
  Saving best model with Val Acc: 94.17%


MLP BL Epoch 2/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 2/30 - Train Loss: 0.1543, Train Acc: 0.9557, Val Acc: 96.03%
  Saving best model with Val Acc: 96.03%


MLP BL Epoch 3/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 3/30 - Train Loss: 0.1099, Train Acc: 0.9695, Val Acc: 96.83%
  Saving best model with Val Acc: 96.83%


MLP BL Epoch 4/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 4/30 - Train Loss: 0.0855, Train Acc: 0.9759, Val Acc: 96.73%


MLP BL Epoch 5/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 5/30 - Train Loss: 0.0701, Train Acc: 0.9804, Val Acc: 97.27%
  Saving best model with Val Acc: 97.27%


MLP BL Epoch 6/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 6/30 - Train Loss: 0.0590, Train Acc: 0.9839, Val Acc: 97.53%
  Saving best model with Val Acc: 97.53%


MLP BL Epoch 7/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 7/30 - Train Loss: 0.0503, Train Acc: 0.9871, Val Acc: 97.93%
  Saving best model with Val Acc: 97.93%


MLP BL Epoch 8/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 8/30 - Train Loss: 0.0446, Train Acc: 0.9885, Val Acc: 97.73%


MLP BL Epoch 9/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 9/30 - Train Loss: 0.0388, Train Acc: 0.9902, Val Acc: 97.87%


MLP BL Epoch 10/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 10/30 - Train Loss: 0.0346, Train Acc: 0.9916, Val Acc: 98.10%
  Saving best model with Val Acc: 98.10%


MLP BL Epoch 11/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 11/30 - Train Loss: 0.0307, Train Acc: 0.9929, Val Acc: 97.97%


MLP BL Epoch 12/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 12/30 - Train Loss: 0.0283, Train Acc: 0.9940, Val Acc: 97.90%


MLP BL Epoch 13/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 13/30 - Train Loss: 0.0257, Train Acc: 0.9946, Val Acc: 98.10%


MLP BL Epoch 14/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 14/30 - Train Loss: 0.0235, Train Acc: 0.9954, Val Acc: 98.27%
  Saving best model with Val Acc: 98.27%


MLP BL Epoch 15/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 15/30 - Train Loss: 0.0215, Train Acc: 0.9963, Val Acc: 98.30%
  Saving best model with Val Acc: 98.30%


MLP BL Epoch 16/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 16/30 - Train Loss: 0.0201, Train Acc: 0.9964, Val Acc: 98.23%


MLP BL Epoch 17/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 17/30 - Train Loss: 0.0188, Train Acc: 0.9970, Val Acc: 98.27%


MLP BL Epoch 18/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 18/30 - Train Loss: 0.0176, Train Acc: 0.9974, Val Acc: 98.33%
  Saving best model with Val Acc: 98.33%


MLP BL Epoch 19/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 19/30 - Train Loss: 0.0166, Train Acc: 0.9977, Val Acc: 98.30%


MLP BL Epoch 20/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 20/30 - Train Loss: 0.0158, Train Acc: 0.9982, Val Acc: 98.30%


MLP BL Epoch 21/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 21/30 - Train Loss: 0.0151, Train Acc: 0.9981, Val Acc: 98.33%


MLP BL Epoch 22/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 22/30 - Train Loss: 0.0142, Train Acc: 0.9985, Val Acc: 98.27%


MLP BL Epoch 23/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 23/30 - Train Loss: 0.0135, Train Acc: 0.9986, Val Acc: 98.30%


MLP BL Epoch 24/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 24/30 - Train Loss: 0.0130, Train Acc: 0.9987, Val Acc: 98.30%


MLP BL Epoch 25/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 25/30 - Train Loss: 0.0125, Train Acc: 0.9988, Val Acc: 98.30%


MLP BL Epoch 26/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 26/30 - Train Loss: 0.0121, Train Acc: 0.9988, Val Acc: 98.40%
  Saving best model with Val Acc: 98.40%


MLP BL Epoch 27/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 27/30 - Train Loss: 0.0117, Train Acc: 0.9989, Val Acc: 98.27%


MLP BL Epoch 28/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 28/30 - Train Loss: 0.0114, Train Acc: 0.9989, Val Acc: 98.37%


MLP BL Epoch 29/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 29/30 - Train Loss: 0.0109, Train Acc: 0.9991, Val Acc: 98.30%


MLP BL Epoch 30/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 30/30 - Train Loss: 0.0107, Train Acc: 0.9991, Val Acc: 98.30%
--- Finished MLP Student Baseline Training ---
Loading best baseline student model weights from models_resnet_teacher/mlp_student_baseline_final.pth

Final MLP Student Model (Baseline) Test Accuracy: 98.14%


## 8. Knowledge Distillation (ResNet Teacher -> MLP Student)

In [11]:
print(f"\n--- Starting Distillation (ResNet Teacher -> MLP Student, T={DEFAULT_T}, alpha={ALPHA}) ---")
student_distilled = StudentNet().to(device)
teacher_model_kd = ResNet18_MNIST().to(device)
teacher_model_kd.load_state_dict(torch.load(teacher_model_path, map_location=device))
teacher_model_kd.eval()
optimizer = optim.SGD(student_distilled.parameters(), lr=LR_MLP, momentum=MOMENTUM_MLP, weight_decay=WEIGHT_DECAY_MLP)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

student_distilled_train_losses = []
student_distilled_train_accs = []
student_distilled_val_accs = []
best_val_acc = 0.0
distilled_model_path = f'models_resnet_teacher/mlp_student_distilled_T{DEFAULT_T}_alpha{ALPHA}_final.pth'
for epoch in range(EPOCHS_STUDENT_MLP):
    student_distilled.train()
    batch_losses = []
    batch_accs = []
    progress_bar = tqdm(train_loader, desc=f"Distill Epoch {epoch+1}/{EPOCHS_STUDENT_MLP}", leave=False)

    for i, (inputs, labels) in enumerate(progress_bar):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        with torch.no_grad():
            teacher_logits = teacher_model_kd(inputs)
        student_logits = student_distilled(inputs)

        loss = distillation_loss_fn(student_logits, teacher_logits, labels, DEFAULT_T, ALPHA)

        loss.backward()
        optimizer.step()

        _, predicted = torch.max(student_logits.data, 1)
        accuracy = (predicted == labels).float().mean().item()
        batch_losses.append(loss.item())
        batch_accs.append(accuracy)

        if (i + 1) % PRINT_FREQ == 0:
            progress_bar.set_postfix(loss=f"{np.mean(batch_losses[-PRINT_FREQ:]):.3f}", acc=f"{np.mean(batch_accs[-PRINT_FREQ:]):.3f}")

    epoch_train_loss = np.mean(batch_losses)
    epoch_train_acc = np.mean(batch_accs)
    student_distilled_train_losses.append(epoch_train_loss)
    student_distilled_train_accs.append(epoch_train_acc)

    epoch_val_acc = evaluate_accuracy(student_distilled, val_loader, device)
    student_distilled_val_accs.append(epoch_val_acc)
    print(f"Epoch {epoch+1}/{EPOCHS_STUDENT_MLP} - Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}, Val Acc: {epoch_val_acc:.2f}%")

    if epoch_val_acc > best_val_acc:
        best_val_acc = epoch_val_acc
        print(f"  Saving best model with Val Acc: {best_val_acc:.2f}%")
        torch.save(student_distilled.state_dict(), distilled_model_path)

    scheduler.step()

print("--- Finished Distillation Training ---")
if os.path.exists(distilled_model_path):
    print(f"Loading best distilled student model weights from {distilled_model_path}")
    student_distilled.load_state_dict(torch.load(distilled_model_path, map_location=device))

distilled_train_results = {'train_loss': student_distilled_train_losses, 'train_acc': student_distilled_train_accs, 'val_acc': student_distilled_val_accs}
np.save(f'results_resnet_teacher/mlp_student_distilled_T{DEFAULT_T}_alpha{ALPHA}_results.npy', distilled_train_results)


student_distilled_test_accuracy = evaluate_accuracy(student_distilled, test_loader, device)

print(f"\n--- Distillation Results (ResNet Teacher -> MLP Student, T={DEFAULT_T}, alpha={ALPHA}) ---")
print(f"ResNet-18 Teacher Test Accuracy:     {teacher_test_accuracy:.2f}%")
print(f"MLP Student Baseline Test Accuracy:  {student_baseline_test_accuracy:.2f}%")
print(f"MLP Student Distilled Test Accuracy: {student_distilled_test_accuracy:.2f}%")
improvement = student_distilled_test_accuracy - student_baseline_test_accuracy
print(f"Improvement over Baseline:           {improvement:+.2f}%")


--- Starting Distillation (ResNet Teacher -> MLP Student, T=4, alpha=0.5) ---


Distill Epoch 1/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 1/30 - Train Loss: 1.5765, Train Acc: 0.9179, Val Acc: 96.27%
  Saving best model with Val Acc: 96.27%


Distill Epoch 2/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 2/30 - Train Loss: 0.6217, Train Acc: 0.9695, Val Acc: 97.27%
  Saving best model with Val Acc: 97.27%


Distill Epoch 3/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 3/30 - Train Loss: 0.4688, Train Acc: 0.9790, Val Acc: 97.90%
  Saving best model with Val Acc: 97.90%


Distill Epoch 4/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 4/30 - Train Loss: 0.3939, Train Acc: 0.9833, Val Acc: 97.90%


Distill Epoch 5/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 5/30 - Train Loss: 0.3440, Train Acc: 0.9865, Val Acc: 98.27%
  Saving best model with Val Acc: 98.27%


Distill Epoch 6/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 6/30 - Train Loss: 0.3087, Train Acc: 0.9885, Val Acc: 98.37%
  Saving best model with Val Acc: 98.37%


Distill Epoch 7/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 7/30 - Train Loss: 0.2835, Train Acc: 0.9900, Val Acc: 98.23%


Distill Epoch 8/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 8/30 - Train Loss: 0.2631, Train Acc: 0.9918, Val Acc: 98.50%
  Saving best model with Val Acc: 98.50%


Distill Epoch 9/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 9/30 - Train Loss: 0.2466, Train Acc: 0.9927, Val Acc: 98.53%
  Saving best model with Val Acc: 98.53%


Distill Epoch 10/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 10/30 - Train Loss: 0.2339, Train Acc: 0.9934, Val Acc: 98.57%
  Saving best model with Val Acc: 98.57%


Distill Epoch 11/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 11/30 - Train Loss: 0.2222, Train Acc: 0.9941, Val Acc: 98.40%


Distill Epoch 12/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 12/30 - Train Loss: 0.2129, Train Acc: 0.9949, Val Acc: 98.53%


Distill Epoch 13/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 13/30 - Train Loss: 0.2044, Train Acc: 0.9953, Val Acc: 98.50%


Distill Epoch 14/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 14/30 - Train Loss: 0.1969, Train Acc: 0.9957, Val Acc: 98.57%


Distill Epoch 15/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 15/30 - Train Loss: 0.1909, Train Acc: 0.9960, Val Acc: 98.43%


Distill Epoch 16/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 16/30 - Train Loss: 0.1857, Train Acc: 0.9964, Val Acc: 98.53%


Distill Epoch 17/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 17/30 - Train Loss: 0.1804, Train Acc: 0.9964, Val Acc: 98.50%


Distill Epoch 18/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 18/30 - Train Loss: 0.1758, Train Acc: 0.9967, Val Acc: 98.57%


Distill Epoch 19/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 19/30 - Train Loss: 0.1719, Train Acc: 0.9969, Val Acc: 98.50%


Distill Epoch 20/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 20/30 - Train Loss: 0.1683, Train Acc: 0.9970, Val Acc: 98.47%


Distill Epoch 21/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 21/30 - Train Loss: 0.1650, Train Acc: 0.9972, Val Acc: 98.47%


Distill Epoch 22/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 22/30 - Train Loss: 0.1623, Train Acc: 0.9973, Val Acc: 98.57%


Distill Epoch 23/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 23/30 - Train Loss: 0.1595, Train Acc: 0.9975, Val Acc: 98.53%


Distill Epoch 24/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 24/30 - Train Loss: 0.1570, Train Acc: 0.9974, Val Acc: 98.50%


Distill Epoch 25/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 25/30 - Train Loss: 0.1548, Train Acc: 0.9977, Val Acc: 98.47%


Distill Epoch 26/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 26/30 - Train Loss: 0.1529, Train Acc: 0.9976, Val Acc: 98.57%


Distill Epoch 27/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 27/30 - Train Loss: 0.1508, Train Acc: 0.9978, Val Acc: 98.50%


Distill Epoch 28/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 28/30 - Train Loss: 0.1490, Train Acc: 0.9979, Val Acc: 98.50%


Distill Epoch 29/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 29/30 - Train Loss: 0.1475, Train Acc: 0.9979, Val Acc: 98.50%


Distill Epoch 30/30:   0%|          | 0/446 [00:00<?, ?it/s]

Epoch 30/30 - Train Loss: 0.1459, Train Acc: 0.9981, Val Acc: 98.53%
--- Finished Distillation Training ---
Loading best distilled student model weights from models_resnet_teacher/mlp_student_distilled_T4_alpha0.5_final.pth

--- Distillation Results (ResNet Teacher -> MLP Student, T=4, alpha=0.5) ---
ResNet-18 Teacher Test Accuracy:     99.38%
MLP Student Baseline Test Accuracy:  98.14%
MLP Student Distilled Test Accuracy: 98.33%
Improvement over Baseline:           +0.19%
