# Tutorial for Classification Dataset Sufficiency Analysis via Learning Curves


In [None]:
# Load everything we need

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics

from daml.datasets import DamlDataset
from daml.metrics.sufficiency import Sufficiency

np.random.seed(0)
torch.manual_seed(0)

## Load the MNIST data with load_dataset(), define network architecture, and define training and evaluation functions.


In [None]:
# Function that loads in MNIST data and creates a DAML dataset with it
def load_dataset():
    # Loads dataset
    path = "../../tests/datasets/mnist.npz"
    with np.load(path, allow_pickle=True) as fp:
        images, labels = fp["x_train"][:4000], fp["y_train"][:4000]
        test_images, test_labels = fp["x_test"][:500], fp["y_test"][:500]
    images = images.reshape((4000, 1, 28, 28))
    test_images = test_images.reshape((500, 1, 28, 28))
    train_ds = DamlDataset(images, labels)
    test_ds = DamlDataset(test_images, test_labels)
    return train_ds, test_ds


# Define our network architecture
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(6400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def custom_train(model: nn.Module, X: torch.Tensor, y: torch.Tensor):
    """
    Passes data once through the model with backpropagation

    Parameters
    ----------
    model : nn.Module
        The trained model that will be evaluated
    X : torch.Tensor
        The training data to be passed through the model
    y : torch.Tensor
        The training labels corresponding to the data
    """
    # Defined only for this testing scenario
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    epochs = 5

    for _ in range(epochs):
        # Zero out gradients
        optimizer.zero_grad()
        # Forward Propagation
        outputs = model(X)
        # Back prop
        loss = criterion(outputs, y)
        loss.backward()
        # Update optimizer
        optimizer.step()


def custom_eval(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
    """
    Evaluate a model on a single pass with a given metric

    Parameters
    ----------
    model : nn.Module
        The trained model that will be evaluated
    X : torch.Tensor
        The testing data to be passed through th model
    y : torch.Tensor
        The testing labels corresponding to the data

    Returns
    -------
    float
        The calculated performance of the model
    """
    metric = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    # Set model layers into evaluation mode
    model.eval()
    # Tell PyTorch to not track gradients, greatly speeds up processing
    with torch.no_grad():
        preds = model(X)
        result = metric(preds, y)
    return result

## Define daml sufficiency function, and attach custom training and evaluation functions.


In [None]:
train_ds, test_ds = load_dataset()
model = Net()
length = len(train_ds)

# Instantiate sufficiency metric
suff = Sufficiency()
# Set predefined training and eval functions
suff.set_training_func(custom_train)
suff.set_eval_func(custom_eval)

## Define number of models to train in parallel (stability), as well as the number of steps along the learning curve to evaluate. Train models to produce learning curve.


In [None]:
# Create data indices for training
m_count = 10
num_steps = 10
suff.setup(length, m_count, num_steps)
# Train & test model
output = suff.run(model, train_ds, test_ds)

In [None]:
suff.plot(output)

## Using this learning curve, we can project performance under much larger datasets (with the same model).
