In [1]:
import torch
from torch import nn
from src.utils import load_model, save_model
from src.data_loader import get_cifar10_loader
from src.train import train_model_kd, loss_fn_kd
from src.model import ResNet, BasicBlock
from src.evaluate import evaluate
from src.train import KDParams


In [2]:
# Parameters
device = torch.device("mps")
teacher_model_path = "models/resnet110_baseline_30_mps.pth"
student_model_path = "models/pruned_45-30_resnet110_mps.pth"

batch_size = 128
learning_rate = 0.001
num_epochs = 10
kd_alpha = 0.7
kd_temperature = 4.0


In [3]:
# Load pretrained ResNet model
teacher_model = load_model(teacher_model_path, device=device)
student_model = load_model(student_model_path, device=device)

# Define optimizer and criterion for training
optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Load data
train_loader = get_cifar10_loader('train', batch_size=batch_size)
val_loader = get_cifar10_loader('val', batch_size=batch_size)

kd_params = KDParams(alpha=kd_alpha, temperature=kd_temperature)


In [4]:
train_model_kd(
    student_model=student_model,
    teacher_model=teacher_model,
    train_loader=train_loader,
    optimizer=optimizer,
    device=device,
    kd_params=kd_params,
    num_epochs=num_epochs,
    loss_fn_kd=loss_fn_kd,
    use_amp=False
)


                                                                                       

In [5]:
evaluate(student_model, val_loader, device)

Validation Accuracy: 86.62%, Avg Loss: 0.4247, Time: 3.59s


(86.62, 0.4247211040496826)

In [6]:
save_model(student_model, "models/pruned_45-30_kd_10_resnet110_mps.pth")

In [7]:
student_model = load_model(student_model_path, device=device)

In [8]:
evaluate(student_model, val_loader, device)

Validation Accuracy: 84.10%, Avg Loss: 0.5485, Time: 2.81s


(84.1, 0.5484769086837769)