In [1]:
import os
import sys
project_dir = os.path.dirname(os.getcwd())
sys.path.append(project_dir)

In [2]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed = torch.Generator().manual_seed(42)
print(device)

cuda


# Data Ingestion

In [3]:
from data.mnist import get_mnist_pipeline

train_loader, val_loader, test_loader = get_mnist_pipeline(batch_size=32)
for i, data in enumerate(train_loader, 0):
    inputs, labels = data
    if i == 0:
        print("Shape of training batch inputs:", inputs.shape)
        print("Shape of training batch labels:", labels.shape)
        break
for i, data in enumerate(val_loader, 0):
    inputs, labels = data
    if i == 0:
        print("Shape of validation batch inputs:", inputs.shape)
        print("Shape of validation batch labels:", labels.shape)
        break
for i, data in enumerate(test_loader, 0):
    inputs, labels = data
    if i == 0:
        print("Shape of testing batch inputs:", inputs.shape)
        print("Shape of testing batch labels:", labels.shape)
        break
print("MNIST dataset loaded into PyTorch DataLoaders.")

Shape of training batch inputs: torch.Size([32, 1, 28, 28])
Shape of training batch labels: torch.Size([32])
Shape of validation batch inputs: torch.Size([32, 1, 28, 28])
Shape of validation batch labels: torch.Size([32])
Shape of testing batch inputs: torch.Size([32, 1, 28, 28])
Shape of testing batch labels: torch.Size([32])
MNIST dataset loaded into PyTorch DataLoaders.


# Model

In [None]:
import torch
import torch.optim as optim
import models.baseline as mlp
from utils.summary import get_model_stats

teacher = mlp.mnist1200().to(device)
student = mlp.mnist400().to(device)
smaller = mlp.mnist400().to(device)

sample = torch.randn(1, 1, 28, 28).to(device)
with torch.no_grad():
    print("Teacher model stats:")
    pred = teacher(sample)
    for name, param in get_model_stats(teacher, sample.shape).items():
        print(f"{name}: {param}")
    print("Student model stats:")
    pred = smaller(sample)
    for name, param in get_model_stats(student, sample.shape).items():
        print(f"{name}: {param}")

Teacher model stats:
flops: 2392800
params: 2395210
Student model stats:
flops: 477600
params: 478410


# Training Loop

In [None]:
from dataclasses import dataclass

# container for things to pass to train_val and eval

@dataclass
class ScriptArgs:
    model: str
    criterion: str
    optimizer: str   # lr included in optimizer
    scheduler: str
    device: str
    epochs: int
    lr: float
    metrics: dict
    path: str
    
    def __post_init__(self):
        # instantiate these
        self.criterion = None
        self.optimizer = getattr(optim, self.optimizer)  # create the optimizer
        self.scheduler = None

In [None]:
import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np


def train_model(train_loader: DataLoader, model: nn.Module, criterion: nn.Module, optimizer: nn.Module,
                scheduler: nn.Module=None, device: str='cpu') -> list:
    """
    Train the model for one epoch.

    Args:
        model: The PyTorch model to train.
        train_loader: DataLoader for the training data.
        optimizer: Optimizer for updating model parameters.
        criterion: Loss function.
        device: Device to run the training on ('cpu' or 'cuda').

    Returns:
        list: Collection of train losses.
    """
    model.train()
    epoch_losses = []
    for inputs, targets in tqdm.tqdm(train_loader, desc='training...', file=sys.stdout):
        inputs = inputs.to(device)
        targets = targets.to(device)
        preds = model(inputs)
        loss = criterion(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        if scheduler and isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(loss.item())
        elif scheduler:
            scheduler.step()
    return epoch_losses


def evaluate_model(val_loader: DataLoader, model: nn.Module, criterion: nn.Module, device: str='cpu') -> list:
    """
    Evaluate the model on validation data.

    Args:
        model: The PyTorch model to evaluate.
        val_loader: DataLoader for the validation data.
        criterion: Loss function.
        device: Device to run the evaluation on ('cpu' or 'cuda').

    Returns:
        list: Collection of metrics.
    """
    model.eval()
    epoch_metrics = []
    with torch.no_grad():
        for inputs, targets in tqdm.tqdm(val_loader, desc='evaluating...', file=sys.stdout):
            inputs = inputs.to(device)
            targets = targets.to(device)
            preds = model(inputs)
            loss = criterion(preds, targets)
            epoch_metrics.append(loss.item())
    return epoch_metrics

def train_val(model, train_loader, val_loader, criterion, optimizer, scheduler, device, path, metrics={}):
    """
    Trains and Validates Model
    """
    try:
        best_val_loss = torch.load(path)['val_loss']
    except FileNotFoundError:
        best_val_loss = float('inf')
    patience = 10
    counter = 0
    epochs = 200
    
    metrics.update({'train_loss': [], 'val_loss': []})

    for epoch in range(epochs):
        # dump config + add device
        train_loss = train_model(train_loader, model, criterion, optimizer, scheduler, device)
        val_loss = evaluate_model(val_loader, model, criterion, device)
        # pop metrics from config
        metrics['train_loss'].append(np.mean(train_loss))
        metrics['val_loss'].append(np.mean(val_loss))
        if metrics['val_loss'][-1] < best_val_loss:
            best_val_loss = metrics['val_loss'][-1]
            counter = 0
            print(f"Epoch {epoch+1}: New best val loss: {best_val_loss:.4f}, saving model...")
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'val_loss': best_val_loss
            }
            torch.save(state, path)
        else:
            counter += 1
        if counter >= patience:
            print(f"Epoch {epoch+1}: Early stop triggered.")
            break

# Train Teacher

In [None]:
teacher = teacher.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher.parameters(), lr=0.01, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

torch.manual_seed(42)
metrics = {"train_loss": [], "val_loss": []}
best_val_loss = float('inf')
patience = 10
counter = 0
epochs = 200

for epoch in range(epochs):
    train_loss = train_model(train_loader, teacher, criterion, optimizer, scheduler, device)
    val_loss = evaluate_model(val_loader, teacher, criterion, device)
    metrics['train_loss'].append(np.mean(train_loss))
    metrics['val_loss'].append(np.mean(val_loss))
    if metrics['val_loss'][-1] < best_val_loss:
        best_val_loss = metrics['val_loss'][-1]
        counter = 0
        print(f"Epoch {epoch+1}: New best val loss: {best_val_loss:.4f}, saving model...")
        state = {
            'epoch': epoch,
            'state_dict': teacher.state_dict(),
            'optimizer': optimizer.state_dict(),
            'val_loss': best_val_loss
        }
        torch.save(state, '../models/weights/teacher.pth')
    else:
        counter += 1
    if counter >= patience:
        print(f"Epoch {epoch+1}: Early stop triggered.")
        break

cuda
training...: 100%|██████████| 1500/1500 [00:13<00:00, 112.85it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 160.89it/s]
Epoch 1: New best val loss: 0.6518, saving model...
training...: 100%|██████████| 1500/1500 [00:11<00:00, 129.90it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 167.68it/s]
Epoch 2: New best val loss: 0.6467, saving model...
training...: 100%|██████████| 1500/1500 [00:13<00:00, 110.51it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 166.37it/s]
training...: 100%|██████████| 1500/1500 [00:11<00:00, 126.31it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 164.37it/s]
training...: 100%|██████████| 1500/1500 [00:11<00:00, 127.61it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 166.43it/s]
training...: 100%|██████████| 1500/1500 [00:13<00:00, 111.70it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 164.40it/s]
training...: 100%|██████████| 1500/1500 [00:11<00:00, 129.30it/s]
evaluating...: 100%|██████████| 3

# Train Smaller

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(smaller.parameters(), lr=0.01, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

torch.manual_seed(42)
metrics = {"train_loss": [], "val_loss": []}
best_val_loss = float('inf')
patience = 10
counter = 0
epochs = 200

for epoch in range(epochs):
    train_loss = train_model(train_loader, smaller, criterion, optimizer, scheduler, device)
    val_loss = evaluate_model(val_loader, smaller, criterion, device)
    metrics['train_loss'].append(np.mean(train_loss))
    metrics['val_loss'].append(np.mean(val_loss))
    if metrics['val_loss'][-1] < best_val_loss:
        best_val_loss = metrics['val_loss'][-1]
        counter = 0
        print(f"Epoch {epoch+1}: New best val loss: {best_val_loss:.4f}, saving model...")
        state = {
            'epoch': epoch,
            'state_dict': smaller.state_dict(),
            'optimizer': optimizer.state_dict(),
            'val_loss': best_val_loss
        }
        torch.save(state, '../models/weights/smaller.pth')
    else:
        counter += 1
    if counter >= patience:
        print(f"Epoch {epoch+1}: Early stop triggered.")
        break

# Distill Student

In [7]:
from utils.losses import DistillationLoss


def distill_model(train_loader: DataLoader, student: nn.Module, teacher: nn.Module, 
                  criterion: nn.Module, optimizer: nn.Module,
                  scheduler: nn.Module=None, device: str='cpu') -> list:
    """
    Train the student model for one epoch using knowledge distillation.

    Args:
        train_loader: DataLoader for the training data.
        student_model: The PyTorch student model to train.
        teacher_model: The PyTorch teacher model (should be in eval mode).
        optimizer: Optimizer for updating student model parameters.
        criterion: Distillation loss function (e.g., DistillationLoss).
        device: Device to run the training on ('cpu' or 'cuda').

    Returns:
        list: Collection of train losses.
    """
    student.train()
    teacher.eval()  # Ensure teacher model is in evaluation mode
    epoch_losses = []
    for inputs, targets in tqdm.tqdm(train_loader, desc='distilling...', file=sys.stdout):
        inputs = inputs.to(device)
        targets = targets.to(device)
        # preds
        student_preds = student(inputs)
        with torch.no_grad():
            teacher_preds = teacher(inputs)
        # distillation loss
        loss = criterion(student_preds, teacher_preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        if scheduler and isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(loss.item())
        elif scheduler:
            scheduler.step()
    return epoch_losses

In [None]:
teacher.load_state_dict(torch.load('../models/weights/teacher.pth', map_location=device)['state_dict'])
student = student.to(device)
criterion = DistillationLoss(T=20)
crossentropy = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher.parameters(), lr=0.001)  # remove regularization
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

metrics = {"train_loss": [], "val_loss": []}
best_val_loss = float('inf')
patience = 10
counter = 0
epochs = 200

for epoch in range(epochs):
    train_loss = distill_model(train_loader, student, teacher, criterion, optimizer, scheduler, device)
    val_loss = evaluate_model(val_loader, student, crossentropy, device)
    metrics['train_loss'].append(np.mean(train_loss))
    metrics['val_loss'].append(np.mean(val_loss))
    if metrics['val_loss'][-1] < best_val_loss:
        best_val_loss = metrics['val_loss'][-1]
        counter = 0
        print(f"Epoch {epoch+1}: New best val loss: {best_val_loss:.4f}, saving model...")
        state = {
            'epoch': epoch,
            'state_dict': student.state_dict(),
            'optimizer': optimizer.state_dict(),
            'val_loss': best_val_loss
        }
        torch.save(state, '../models/weights/student.pth')
    else:
        counter += 1
    if counter >= patience:
        print(f"Epoch {epoch+1}: Early stop triggered.")
        break

  teacher.load_state_dict(torch.load('../models/pretrained/teacher.pth', map_location=device))


distilling...: 100%|██████████| 1500/1500 [00:11<00:00, 134.87it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 162.68it/s]
Epoch 1: New best val loss: 2.3134, saving model...
distilling...: 100%|██████████| 1500/1500 [00:13<00:00, 112.94it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 158.44it/s]
Epoch 2: New best val loss: 2.3129, saving model...
distilling...: 100%|██████████| 1500/1500 [00:11<00:00, 128.90it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 156.66it/s]
distilling...: 100%|██████████| 1500/1500 [00:14<00:00, 105.95it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 153.57it/s]
distilling...: 100%|██████████| 1500/1500 [00:12<00:00, 124.79it/s]
evaluating...: 100%|██████████| 375/375 [00:02<00:00, 165.77it/s]
distilling...: 100%|██████████| 1500/1500 [00:12<00:00, 119.59it/s]
evaluating...: 100%|██████████| 375/375 [00:04<00:00, 88.98it/s] 
distilling...: 100%|██████████| 1500/1500 [00:12<00:00, 117.42it/s]
evaluating...: 100%|████

# Compare: Teacher, Smaller, Student

In [None]:
def eval(model, test_loader, criterion, device, metrics={}):
    # abstract to model rather than teacher/small    
    metrics.update({'test_loss': []})

    for epoch in range(epochs):
        # dump config + add device
        test_loss = evaluate_model(test_loader, model, criterion, device)
        # pop metrics from config
        metrics['test_loss'].append(np.mean(test_loss))
        
def plugins():
    """Some visualization bs"""
    pass

In [None]:
teacher_state = torch.load('../models/weights/teacher.pth', map_location=device)
student_state = torch.load('../models/weights/student.pth', map_location=device)
smaller_state = torch.load('../models/weights/smaller.pth', map_location=device)
teacher.load_state_dict(teacher_state['state_dict'])
student.load_state_dict(student_state['state_dict'])
smaller.load_state_dict(smaller_state['state_dict'])

# look at the best val_losses

# work on test sets

# Transfer Set Training

In [None]:
def get_transfer_set(dropped_classes):
    """
    """
    pass

In [None]:
teacher = mlp.mnist1200().to(device)
student = mlp.mnist400().to(device)