In [None]:
import os
import random
from typing import Dict, cast

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2
from torch.utils.data import DataLoader, Dataset, Subset
import nannyml as nml
from IPython.display import display
import loss_estimation

np.random.seed(0)
np.set_printoptions(formatter={"float": lambda x: f"{x:0.4f}"})
torch.manual_seed(0)
torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"
torch._dynamo.config.suppress_errors = True

random.seed(0)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

torch._dynamo.disable()

In [None]:
# Download the mnist dataset and preview the images
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
train_ds = datasets.MNIST("./data", train=True, download=True, transform=to_tensor)
test_ds = datasets.MNIST("./data", train=False, download=True, transform=to_tensor)

class_names = list(range(10))

In [None]:
class Corrupt(v2.Transform):
    def _transform(self, inpt, params):
        return self.contrast(inpt)

    def contrast(self, sample):
        x = sample
        c = 0.3

        # x = np.array(x) / 255.0
        x = x.float() / 255.0
        rands = torch.normal(x, std=c)
        x = torch.clip(rands, 0, 1)

        return x


c_to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True), Corrupt()])
c_contrast = v2.Compose([Corrupt()])

c_test_ds = datasets.MNIST("./data", train=False, download=True, transform=c_to_tensor)

In [None]:
# Take a subset of 2000 training images and 500 test images
train_ds = Subset(train_ds, range(2000))
test_ds = Subset(test_ds, range(500))
c_test_ds = Subset(c_test_ds, range(500))

In [None]:
# Define our network architecture
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(6400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.softmax = torch.nn.Softmax()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.softmax(x)
        return x


# Compile the model
model = torch.compile(Net().to(device))

# Type cast the model back to Net as torch.compile returns a Unknown
# Nothing internally changes from the cast; we are simply signaling the type
model = cast(Net, model)

In [None]:
def custom_train(model: nn.Module, dataset: Dataset):
    # Defined only for this testing scenario
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    epochs = 10

    # Define the dataloader for training
    dataloader = DataLoader(dataset, batch_size=16)

    for epoch in range(epochs):
        for batch in dataloader:
            # Load data/images to device
            X = torch.Tensor(batch[0]).to(device)
            # Load targets/labels to device
            y = torch.Tensor(batch[1]).to(device)
            # Zero out gradients
            optimizer.zero_grad()
            # Forward propagation
            outputs = model(X)
            # Compute loss
            loss = criterion(outputs, y)
            # Back prop
            loss.backward()
            # Update weights/parameters
            optimizer.step()

def reset_parameters(model: nn.Module):
    """
    Re-initializes each layer in the model using
    the layer's defined weight_init function
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # Check if the current module has reset_parameters
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()  # type: ignore

    # Applies fn recursively to every submodule see:
    # https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    return model.apply(fn=weight_reset)

In [None]:
# Reset the network weights to "create" an untrained model
model = reset_parameters(model)
# Run the model with each substep of data
# train on subset of train data
train_kwargs = {}
eval_kwargs = {}
custom_train(
    model,
    train_ds,
    **train_kwargs,
)

In [None]:
estimator = loss_estimation.LossEstimator("CBPE", "classification_multiclass")
results = estimator.evaluate(model, test_ds, c_test_ds, class_names)

In [None]:
print(results)