In [1]:
import os
import sys
project_dir = os.path.dirname(os.getcwd())
sys.path.append(project_dir)

# Data Ingestion

In [2]:
from data.mnist import get_mnist_pipeline

train_loader, val_loader, test_loader = get_mnist_pipeline(batch_size=32)
for i, data in enumerate(train_loader, 0):
    inputs, labels = data
    if i == 0:
        print("Shape of training batch inputs:", inputs.shape)
        print("Shape of training batch labels:", labels.shape)
        break
for i, data in enumerate(val_loader, 0):
    inputs, labels = data
    if i == 0:
        print("Shape of testing batch inputs:", inputs.shape)
        print("Shape of testing batch labels:", labels.shape)
        break
print("MNIST dataset loaded into PyTorch DataLoaders.")

Shape of training batch inputs: torch.Size([32, 1, 28, 28])
Shape of training batch labels: torch.Size([32])
Shape of testing batch inputs: torch.Size([32, 1, 28, 28])
Shape of testing batch labels: torch.Size([32])
MNIST dataset loaded into PyTorch DataLoaders.


# Model

In [3]:
import torch
import torch.optim as optim
from models.baseline import BasicMLP
from utils.summary import get_model_flops

teacher = BasicMLP(hidden_size=1200)
student = BasicMLP(hidden_size=400)

sample = torch.randn(1, 1, 28, 28)

with torch.no_grad():
    pred = teacher(sample)
    print("Teacher model output shape:", pred.shape)
    # print("Teacher model FLOPs:", get_model_flops(teacher, sample.shape[1:]))

Teacher model output shape: torch.Size([1, 10])


# Train Teacher

In [4]:
import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np


def train_model(train_loader: DataLoader, model: nn.Module, criterion: nn.Module, optimizer: nn.Module,
                scheduler: nn.Module=None, device: str='cpu') -> list:
    """
    Train the model for one epoch.

    Args:
        model: The PyTorch model to train.
        train_loader: DataLoader for the training data.
        optimizer: Optimizer for updating model parameters.
        criterion: Loss function.
        device: Device to run the training on ('cpu' or 'cuda').

    Returns:
        list: Collection of train losses.
    """
    model.train()
    epoch_losses = []
    for inputs, targets in tqdm.tqdm(train_loader, desc='training...', file=sys.stdout):
        inputs = inputs.to(device)
        targets = targets.to(device)
        preds = model(inputs)
        loss = criterion(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        if scheduler and isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(loss.item())
        elif scheduler:
            scheduler.step()
    return epoch_losses


def evaluate_model(val_loader: DataLoader, model: nn.Module, criterion: nn.Module, device: str='cpu') -> list:
    """
    Evaluate the model on validation data.

    Args:
        model: The PyTorch model to evaluate.
        val_loader: DataLoader for the validation data.
        criterion: Loss function.
        device: Device to run the evaluation on ('cpu' or 'cuda').

    Returns:
        list: Collection of metrics.
    """
    model.eval()
    epoch_metrics = []
    with torch.no_grad():
        for inputs, targets in tqdm.tqdm(val_loader, desc='evaluating...', file=sys.stdout):
            inputs = inputs.to(device)
            targets = targets.to(device)
            preds = model(inputs)
            loss = criterion(preds, targets)
            epoch_metrics.append(loss.item())
    return epoch_metrics

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed = torch.Generator().manual_seed(42)
print(device)

teacher = teacher.to(device)
criterion = nn.CrossEntropyLoss()   # temperatione?
optimizer = optim.Adam(teacher.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

torch.manual_seed(42)
metrics = {"train_loss": [], "val_loss": []}
best_val_loss = float('inf')
patience = 10
counter = 0
epochs = 200

for epoch in range(epochs):
    train_loss = train_model(train_loader, teacher, criterion, optimizer, scheduler, device)
    val_loss = evaluate_model(val_loader, teacher, criterion, device)
    metrics['train_loss'].append(np.mean(train_loss))
    metrics['val_loss'].append(np.mean(val_loss))
    if metrics['val_loss'][-1] < best_val_loss:
        best_val_loss = metrics['val_loss'][-1]
        counter = 0
        print(f"Epoch {epoch+1}: New best val loss: {best_val_loss:.4f}, saving model...")
        torch.save(teacher.state_dict(), '../models/pretrained/teacher.pth')
    else:
        counter += 1
    if counter >= patience:
        print(f"Epoch {epoch+1}: Early stop triggered.")
        break

cuda
training...: 100%|██████████| 1500/1500 [00:13<00:00, 112.85it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 160.89it/s]
Epoch 1: New best val loss: 0.6518, saving model...
training...: 100%|██████████| 1500/1500 [00:11<00:00, 129.90it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 167.68it/s]
Epoch 2: New best val loss: 0.6467, saving model...
training...: 100%|██████████| 1500/1500 [00:13<00:00, 110.51it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 166.37it/s]
training...: 100%|██████████| 1500/1500 [00:11<00:00, 126.31it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 164.37it/s]
training...: 100%|██████████| 1500/1500 [00:11<00:00, 127.61it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 166.43it/s]
training...: 100%|██████████| 1500/1500 [00:13<00:00, 111.70it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 164.40it/s]
training...: 100%|██████████| 1500/1500 [00:11<00:00, 129.30it/s]
evaluating...: 100%|██████████| 3

# Train Student

In [None]:
from utils.losses import DistillationLoss


def distill_model(train_loader: DataLoader, student: nn.Module, teacher: nn.Module, 
                  criterion: nn.Module, optimizer: nn.Module,
                  scheduler: nn.Module=None, device: str='cpu') -> list:
    """
    Train the student model for one epoch using knowledge distillation.

    Args:
        train_loader: DataLoader for the training data.
        student_model: The PyTorch student model to train.
        teacher_model: The PyTorch teacher model (should be in eval mode).
        optimizer: Optimizer for updating student model parameters.
        criterion: Distillation loss function (e.g., DistillationLoss).
        device: Device to run the training on ('cpu' or 'cuda').

    Returns:
        list: Collection of train losses.
    """
    student.train()
    teacher.eval()  # Ensure teacher model is in evaluation mode
    epoch_losses = []
    for inputs, targets in tqdm.tqdm(train_loader, desc='distilling...', file=sys.stdout):
        inputs = inputs.to(device)
        targets = targets.to(device)
        # preds
        student_preds = student(inputs)
        with torch.no_grad():
            teacher_preds = teacher(inputs)
        # distillation loss
        loss = criterion(student_preds, teacher_preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        if scheduler and isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(loss.item())
        elif scheduler:
            scheduler.step()
    return epoch_losses

In [None]:
teacher.load_state_dict(torch.load('../models/pretrained/teacher.pth', map_location=device))
student = student.to(device)
criterion = DistillationLoss(T=20)
optimizer = optim.Adam(teacher.parameters(), lr=0.001)  # remove regularization
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

metrics = {"train_loss": [], "val_loss": []}
best_val_loss = float('inf')
patience = 5
counter = 0
epochs = 200

for epoch in range(epochs):
    train_loss = train_model(train_loader, teacher, criterion, optimizer, scheduler, device)
    val_loss = evaluate_model(val_loader, student, criterion, device)
    metrics['train_loss'].append(np.mean(train_loss))
    metrics['val_loss'].append(np.mean(val_loss))
    if metrics['val_loss'][-1] < best_val_loss:
        best_val_loss = metrics['val_loss'][-1]
        counter = 0
        print(f"Epoch {epoch+1}: New best val loss: {best_val_loss:.4f}, saving model...")
        torch.save(teacher.state_dict(), '../models/pretrained/student.pth')
    else:
        counter += 1
    if counter >= patience:
        print(f"Epoch {epoch+1}: Early stop triggered.")
        break