### Dataset Sufficiency Analysis for Classification Tutorial

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
import typing

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2
from torch.utils.data import DataLoader, Dataset, Subset

from daml.datasets import DamlDataset
from daml.metrics.sufficiency import Sufficiency

np.random.seed(0)
torch.manual_seed(0)
torch.set_float32_matmul_precision('high')
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
datasets.MNIST('./data', train=True, download=True)
datasets.MNIST('./data', train=False, download=True)

##### Load data and define functions

Load the MNIST data and create the training and test datasets.


In [None]:
def to_ds(data: Dataset, len: int) -> DamlDataset:
    loader = DataLoader(Subset(data, range(len)), len)
    return DamlDataset(next(iter(loader))[0].numpy(), next(iter(loader))[1].numpy())

to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
train_ds = to_ds(datasets.MNIST('./data', train=True, download=True, transform=to_tensor), 2000)
test_ds = to_ds(datasets.MNIST('./data', train=False, download=True, transform=to_tensor), 500)

In [None]:
fig = plt.figure()

for lbl in range(10):
    i = np.argwhere(train_ds.labels == lbl)[0]
    img = np.reshape(train_ds.images[i], (28, 28))
    ax = fig.add_subplot(2, 5, lbl+1)
    ax.imshow(img, cmap='gray_r')

Next, we define the network architecture we will be using.

In [None]:
# Define our network architecture
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(6400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Compile the model and typecast for proper type hinting
model = typing.cast(nn.Module, torch.compile(Net().to(device)))

Finally, we define our custom training and evaluation functions.

In [None]:
def custom_train(model: nn.Module, dataloader: DataLoader):
    # Defined only for this testing scenario
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    epochs = 10

    for epoch in range(epochs):
        for batch in dataloader:
            X = torch.Tensor(batch[0]).to(device)
            y = torch.Tensor(batch[1]).to(device)
            # Zero out gradients
            optimizer.zero_grad()
            # Forward Propagation
            outputs = model(X)
            # Back prop
            loss = criterion(outputs, y)
            loss.backward()
            # Update optimizer
            optimizer.step()


def custom_eval(model: nn.Module, dataloader: DataLoader) -> float:
    metric = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)
    result = 0

    # Set model layers into evaluation mode
    model.eval()
    # Tell PyTorch to not track gradients, greatly speeds up processing
    with torch.no_grad():
        for batch in dataloader:
            X = torch.Tensor(batch[0]).to(device)
            y = torch.Tensor(batch[1]).to(device)
            preds = model(X)
            metric.update(preds, y)
        result = metric.compute()
    return result

##### Initialize metric

Attach the custom training and evaluation functions to the Sufficiency metric.

In [None]:
# Instantiate sufficiency metric
suff = Sufficiency()

# Set training and eval functions defined above
suff.set_training_func(custom_train)
suff.set_eval_func(custom_eval)

##### Define training parameters

Define the number of models to train in parallel (stability), as well as the number of steps along the learning curve to evaluate.

In [None]:
length = len(train_ds)
model_count = 10
num_steps = 10

# Create data indices for training
suff.setup(length, model_count, num_steps)

##### Evaluate Sufficiency

Now we can run the metric to train the models and produce the learning curve.


In [None]:
# Train & test model
output = suff.run(model, train_ds, test_ds, batch_size=16)

suff.plot(output)

##### Results

Using this learning curve, we can project performance under much larger datasets (with the same model).