In [1]:
import torch
from torch import nn
from src.utils import load_model, save_model
from src.data_loader import get_cifar10_loader
from src.train import train_model_kd, loss_fn_kd
from src.model import ResNet, BasicBlock
from src.evaluate import evaluate


In [2]:
# Parameters
device = torch.device("mps")
teacher_model_path = "models/resnet110_baseline_120_mps.pth"
student_model_path = "models/pruned_45-30_resnet110_mps.pth"

batch_size = 128
learning_rate = 0.001
num_epochs = 10
kd_alpha = 0.7
kd_temperature = 4.0


In [3]:
# Load pretrained ResNet model
teacher_model = load_model(teacher_model_path, device=device)
student_model = load_model(student_model_path, device=device)

# Define optimizer and criterion for training
optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Load data
train_loader = get_cifar10_loader('train', batch_size=batch_size)
val_loader = get_cifar10_loader('val', batch_size=batch_size)


In [None]:
train_model_kd(
    student_model=student_model,
    teacher_model=teacher_model,
    train_loader=train_loader,
    optimizer=optimizer,
    device=device,
    num_epochs=num_epochs,
    temperature=kd_temperature,
    alpha=kd_alpha,
)


Epoch 2/10:   2%|▏         | 8/391 [00:03<03:06,  2.05it/s, acc=94.63%, loss=1.1864]  

In [5]:
evaluate(student_model, val_loader, device)

Validation Accuracy: 88.47%, Avg Loss: 0.6132, Time: 3.57s


(88.47, 0.6131786368846893, 3.5690488815307617)

In [6]:
save_model(student_model, "pruned_45-30_kd_10_resnet110_mps.pth")