In [None]:
import os
import time
from tqdm import tqdm
import numpy as np 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

ModuleNotFoundError: No module named 'torch'

### Get device. CUDA if available, CPU otherwise

In [None]:
def get_device(silent=False):
    """
    Returns the device to be used for PyTorch operations.
    If a GPU is available, it returns 'cuda', otherwise it returns 'cpu'.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not silent:
        print(f"Using device: {device}")
    
    return device

### Get CIFAR-10 data loaders

In [None]:
def get_cifar10_loaders(train_ratio=0.9, train_batch_size=128, test_batch_size=128, silent=False):
    """
    Returns the CIFAR-10 dataset loaders for training, validation and testing.
    The training set is shuffled, while the test set is not.

    reference: Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, 2009.
    """

    # define transform for CIFAR-10 dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.49139968, 0.48215827, 0.44653124],  # CIFAR-10 means
                             std  = [0.24703233, 0.24348505, 0.26158768])  # CIFAR-10 stds
    ])
  
    # load full CIFAR-10 train set
    full_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

    # calculate split sizes for train and validation sets
    train_size = int(train_ratio * len(full_trainset))
    val_size = len(full_trainset) - train_size

    # perform split
    train_subset, val_subset = random_split(full_trainset, [train_size, val_size])
        
    # create DataLoaders
    train_loader = DataLoader(train_subset, batch_size=train_batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=train_batch_size, shuffle=False)

    # CIFAR-10 test set and loader for accuracy evaluation
    test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_set, batch_size=test_batch_size, shuffle=False)

    if not silent:
        print(f"Full train set size: {len(full_trainset)}")
        print(f"Train ratio: {train_ratio}")
        print(f"Train samples: {len(train_subset)}")
        print(f"Validation samples: {len(val_subset)}")
        print(f"Test samples: {len(test_set)}") 
        print(f"Number of training batches: {len(train_loader)}")
        print(f"Number of validation batches: {len(val_loader)}")
        print(f"Number of test batches: {len(test_loader)}")

    return train_loader, val_loader, test_loader

### Models

In [None]:
def get_resnet50_for_cifar10(device=None):
    """
    Returns a modified ResNet-50 model for CIFAR-10 classification.
    """

    if device is None:
        device = get_device(silent=True)

    model = models.resnet50(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model.to(device)

### Train function

In [None]:
def train(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    epochs,
    scheduler=None,
    grad_clip=None,
    save_path="best_model.pt",
    early_stopping_patience=5,
    resume=True
):
    """
    Trains the model using the provided data loaders, optimizer, and loss function.
    Supports early stopping and model checkpointing.
    """

    model.to(device)

    start_epoch = 0
    best_val_loss = float("inf")
    epochs_without_improvement = 0

    # Optional resume
    if resume and os.path.exists(save_path):
        checkpoint = torch.load(save_path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        if "scheduler_state" in checkpoint and scheduler:
            scheduler.load_state_dict(checkpoint["scheduler_state"])
        best_val_loss = checkpoint.get("best_val_loss", best_val_loss)
        start_epoch = checkpoint.get("epoch", 0) + 1
        print(f"🔁 Resumed training from epoch {start_epoch}")

    for epoch in range(start_epoch, epochs):
        model.train()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        train_loop = tqdm(train_loader, desc=f"[Epoch {epoch+1}/{epochs}]", leave=False)
        for inputs, targets in train_loop:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()

            total_loss += loss.detach()
            preds = outputs.argmax(dim=1)
            total_correct += (preds == targets).sum().item()
            total_samples += targets.size(0)

        avg_train_loss = total_loss / len(train_loader)
        train_accuracy = total_correct / total_samples
        tqdm.write(f"Epoch {epoch+1} | Train Loss: {avg_train_loss.item():.4f} | Acc: {train_accuracy:.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_samples = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                val_loss += loss.detach()
                preds = outputs.argmax(dim=1)
                val_correct += (preds == targets).sum().item()
                val_samples += targets.size(0)

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = val_correct / val_samples
        tqdm.write(f"          | Val   Loss: {avg_val_loss.item():.4f} | Acc: {val_accuracy:.4f}")

        # Scheduler step
        if scheduler is not None:
            try:
                scheduler.step(avg_val_loss)  # for ReduceLROnPlateau
            except TypeError:
                scheduler.step()

        # Early stopping + checkpoint
        if avg_val_loss.item() < best_val_loss:
            best_val_loss = avg_val_loss.item()
            epochs_without_improvement = 0
            torch.save({
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict() if scheduler else None,
                "best_val_loss": best_val_loss,
                "epoch": epoch,
            }, save_path)
            tqdm.write(f"          | ✅ New best model saved to '{save_path}'")
        else:
            epochs_without_improvement += 1
            tqdm.write(f"          | No improvement for {epochs_without_improvement} epoch(s)")

        if epochs_without_improvement >= early_stopping_patience:
            tqdm.write(f"🛑 Early stopping triggered after {early_stopping_patience} epochs without improvement.")
            break

    print("Training complete.")      

### Evaluate function

In [None]:
def evaluate(model, data_loader, device):
    """
    Evaluates the model on the test set and returns the accuracy.
    """

    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Evaluating", unit="batch"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
                
    acc = correct / total
    return acc

### Measure Latency

In [None]:
class Timer:
    """
    A simple timer class to measure the time taken for operations.
    It uses CUDA events if a GPU is available, otherwise it uses time.time().
    """

    def __init__(self):
        self.use_cuda = self.torch.cuda.is_available()
        if self.use_cuda:
            self.starter = self.torch.cuda.Event(enable_timing=True)
            self.ender = self.torch.cuda.Event(enable_timing=True)

    def start(self):
        if self.use_cuda:
            self.starter.record()
        else:
            self.start_time = self.time.time()

    def stop(self):
        if self.use_cuda:
            self.ender.record()
            self.torch.cuda.synchronize()
            return self.starter.elapsed_time(self.ender)  # ms
        else:
            return (self.time.time() - self.start_time) * 1000  # ms

def estimate_latency(model, example_inputs, repetitions=50):
    """
    Estimates the latency of the model by running it on example inputs multiple times.
    Returns the mean and standard deviation of the latencies.
    """

    timer = Timer()
    timings = np.zeros((repetitions, 1))

    # warm-up
    for _ in range(5):
        _ = model(example_inputs)

    with torch.no_grad():

        # measure latency
        for rep in tqdm(range(repetitions), desc="Measuring latency"):
            timer.start()
            _ = model(example_inputs)
            elapsed = timer.stop()
            timings[rep] = elapsed

    return np.mean(timings), np.std(timings)

### Measure size of model

In [None]:

def get_size(model):
    """
    Returns the size of the model in MB.
    """

    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / 1e6
    os.remove("temp.p")
    return size
