## Step 1: Let's define some model building blocks

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np

### MLP

In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, activation="sigmoid"):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

        # Activation
        if activation == "sigmoid":
            self.activation = nn.Sigmoid()
        elif activation == "relu":
            self.activation = nn.ReLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

    def forward(self, x):
        x = self.flatten(x)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
class PureConvCNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10, activation="sigmoid"):
        super().__init__()

        # --- Activation selection ---
        if activation == "sigmoid":
            self.act = nn.Sigmoid()
        elif activation == "relu":
            self.act = nn.ReLU()
        elif activation == "tanh":
            self.act = nn.Tanh()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        # --- Network architecture ---
        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.classifier = nn.Conv2d(32, num_classes, kernel_size=7, stride=1, padding=0)
        self.global_avg = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.pool1(x)
        x = self.act(self.conv2(x))
        x = self.pool2(x)
        x = self.classifier(x)
        x = self.global_avg(x)
        x = torch.flatten(x, 1)  # (batch, num_classes)
        return x

## Training & Plotting Architecture

This needs to do the following:
- create an api/function so that if i pass it a model and some data it will invoke the train function
- return the model to me so that i can then go on and test the function out 
- return data related to the loss and accuracy of the modle after each epoch of training

something like 

trainModel(model, data, hyperparmeters, etc...) -> [trained]_model, accuracy_per_iteration, loss_per_iteration


In [None]:
def _make_optimizer(model, name="sgd", lr=0.05):
    name = name.lower()
    if name == "sgd":
        return optim.SGD(model.parameters(), lr=lr)
    if name == "adam":
        return optim.Adam(model.parameters(), lr=lr)
    raise ValueError(f"Unknown optimizer: {name}")


@torch.no_grad()
def evaluate(model, loader, device="cpu", criterion=None):
    model.eval()
    total, correct, total_loss = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        if criterion is not None:
            total_loss += criterion(logits, y).item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)
    avg_loss = (total_loss / total) if criterion is not None else None
    acc = correct / total
    return avg_loss, acc


def train_model(
    model,
    train_loader,
    test_loader,
    epochs=15,
    optimizer_name="sgd",
    lr=0.05,
    device=None,
    log_every=100,
):
    """
    Train on `train_loader`, evaluate on `test_loader` each epoch.
    Returns (trained_model, history) where:
      history = {
        'train_loss': [...], 'train_acc': [...],
        'test_loss':  [...], 'test_acc':  [...]
      }
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Loss + optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = _make_optimizer(model, optimizer_name, lr)

    history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}

    for epoch in range(1, epochs + 1):
        model.train()
        total, correct, total_loss = 0, 0, 0.0

        for step, (x, y) in enumerate(train_loader, 1):
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            correct += (logits.argmax(1) == y).sum().item()
            total += x.size(0)

            if log_every and (step % log_every == 0):
                print(f"epoch {epoch} step {step}: loss={loss.item():.4f}")

        # per-epoch metrics
        train_loss = total_loss / total
        train_acc = correct / total
        test_loss, test_acc = evaluate(
            model, test_loader, device=device, criterion=criterion
        )

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["test_loss"].append(test_loss)
        history["test_acc"].append(test_acc)

        print(
            f"[{epoch}/{epochs}] "
            f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
            f"test_loss={test_loss:.4f} test_acc={test_acc:.4f}"
        )

    return model, history

In [None]:
def test_model(model, test_loader, device=None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    _, acc = evaluate(model, test_loader, device=device, criterion=None)
    return acc

In [None]:
def plot_history(history, title="Training History"):
    """
    Plots train/test loss and accuracy curves from a history dictionary:
    history = {
        'train_loss': [...],
        'train_acc':  [...],
        'test_loss':  [...],
        'test_acc':   [...]
    }
    """
    epochs = np.arange(1, len(history["train_loss"]) + 1)

    # ---- Plot Loss ----
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, history["train_loss"], "o-", label="Train Loss")
    plt.plot(epochs, history["test_loss"], "s-", label="Test Loss")
    plt.title(f"{title} - Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.xticks(epochs)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.show()

    # ---- Plot Accuracy ----
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, history["train_acc"], "o-", label="Train Accuracy")
    plt.plot(epochs, history["test_acc"], "s-", label="Test Accuracy")
    plt.title(f"{title} - Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.xticks(epochs)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.show()

## Preparing the Data (MNIST and CIFAR-10)

In [None]:
mnist_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

mnist_train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=mnist_transform
)
mnist_test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=mnist_transform
)

mnist_train_loader = torch.utils.data.DataLoader(
    mnist_train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True
)
mnist_test_loader = torch.utils.data.DataLoader(
    mnist_test_dataset, batch_size=1000, shuffle=False, num_workers=2, pin_memory=True
)

print(
    "MNIST data loaded:",
    f"\n  Training samples: {len(mnist_train_dataset)}",
    f"\n  Test samples: {len(mnist_test_dataset)}",
)

In [None]:
cifar10_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)
        ),
    ]
)

cifar10_train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=cifar10_transform
)
cifar10_test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=cifar10_transform
)

cifar10_train_loader = torch.utils.data.DataLoader(
    cifar10_train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True
)
cifar10_test_loader = torch.utils.data.DataLoader(
    cifar10_test_dataset, batch_size=1000, shuffle=False, num_workers=2, pin_memory=True
)

print(
    "CIFAR-10 data loaded:",
    f"\n  Training samples: {len(cifar10_train_dataset)}",
    f"\n  Test samples: {len(cifar10_test_dataset)}",
)

## Testing Models with Data

### Scenerio 1: Base Parameters

In [None]:
# Optimizer: SGD; Learning rate: 0.05; BatchSize: 64; Epochs: 15; Loss function: CrossEntropyLoss
config = {"learning_rate": 0.05, "epochs": 15, "optimizer": "sgd"}