In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections.abc import Generator

import numpy as np
import torchvision
from tqdm.auto import tqdm

from nnx import autograd
from nnx.autograd.activations import ReLU, Softmax
from nnx.autograd.initialisation import xavier_uniform
from nnx.autograd.layers import Dropout, Linear, Reshape, Sequential
from nnx.autograd.loss import Tensor, cross_entropy_loss
from nnx.autograd.optimizer import AdamW

seed = 3

autograd.rng = np.random.default_rng(seed=seed)

  from .autonotebook import tqdm as notebook_tqdm


# Data loading
----
Since the focus is on the main framework we just use some dataset from torchvision. The decision was to use a dataset with small resolution to being able to quickly train a model.

In [3]:
training_data = torchvision.datasets.FashionMNIST(
    ".",
    train=True,
)  # for initial download set download=True
validation_data = torchvision.datasets.FashionMNIST(".", train=False)


def create_mnist_batch_loader(dataset, batch_size=64) -> Generator:
    """Create a batch loader for Fashion MNIST dataset.

    Args:
        dataset: The PyTorch dataset (e.g., FashionMNIST)
        batch_size: Number of samples per batch

    Yields:
        Generator that yields batches of (images, one-hot labels).

    """
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    num_batches = int(dataset_size // batch_size)

    for batch_idx in range(num_batches):
        # Get indices for this batch
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices = indices[start_idx:end_idx]

        # Initialize batch arrays
        actual_batch_size = len(batch_indices)
        batch_images = np.zeros((actual_batch_size, 28, 28, 1), dtype=np.float32)
        batch_labels = np.zeros((actual_batch_size, 10), dtype=np.float32)

        for i, idx in enumerate(batch_indices):
            image, label = dataset[idx]

            # Convert PIL image to numpy array, normalize to [0,1]
            img_array = np.array(image, dtype=np.float32)[:, :, None] / 255.0

            batch_images[i] = img_array
            batch_labels[i, label] = 1.0  # One-hot encoding

        images_tensor = Tensor(batch_images)
        labels_tensor = Tensor(batch_labels)

        yield images_tensor, labels_tensor

# Model definition
--------
Definition of the network and optimizer. While one could easily go into tuning parameters to get an optimal performance, the main goal is to show that we can actually
train a model using the framework and combination with some optimizer such as AdamW. Using `Conv2D` was also tried for some epochs but its current implementation does not make it feasible when it comes to time to be trained until the end. An MLP is however sufficient for Demo purposes.

In [None]:
num_classes = 10
batch_size = 512
dim = 784
epochs = 25

network = Sequential(
    Reshape((batch_size, dim)),
    Linear(dim, dim * 2, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(dim * 2, dim * 4, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(dim * 4, dim * 4, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(dim * 4, dim * 2, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(dim * 2, 784, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(784, num_classes, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    Softmax(),
)

optimizer = AdamW(network.parameters, lr=1e-3, weight_decay=1e-4)

# Plotting
------
We add some plots related to the training of the network such as the gradients and evalutation by looking at the validation accuracy over time and looking at some predictions. Since this has nothing to do with the 

In [5]:
from plots import plot_gradient_flow, plot_model_predictions, plot_validation_accuracy


def extract_gradient_stats(model: Sequential, gradient_history: list) -> None:
    """Extract gradients for visualisation purposes."""
    grad_stats = {}
    for i, layer in enumerate(model.layers):
        if hasattr(layer, "parameters"):
            for j, param in enumerate(layer.parameters):
                if param.grad is not None:
                    layer_name = f"Layer {i}-Param {j}"
                    grad_stats[layer_name] = np.mean(np.abs(param.grad))
    gradient_history.append(grad_stats)


# Training

In [None]:
losses = []
accuracies = []
gradient_history = []

for epoch_idx in range(epochs):
    loss_per_epoch = []
    batch_pbar = tqdm(
        create_mnist_batch_loader(training_data, batch_size=batch_size),
        desc="Starting training...",
    )

    for idx, [images, targets] in enumerate(batch_pbar):
        output = network(images)
        loss = cross_entropy_loss(output, targets)

        batch_pbar.set_description(f"Epoch {epoch_idx + 1}/{epochs} | Batch {idx}")
        batch_pbar.set_postfix({"loss": f"{loss.data:.4f}"})
        loss_per_epoch.append(loss.data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        extract_gradient_stats(network, gradient_history)

    losses.append(loss_per_epoch)

    # Valiation logic
    val_pbar = tqdm(
        create_mnist_batch_loader(validation_data, batch_size=batch_size),
        desc="Starting validation pass...",
    )

    tmp = []
    for idx, [images, targets] in enumerate(val_pbar):
        output = network(images)

        predicted_classes = np.argmax(output.data, axis=1)
        target_classes = np.argmax(targets.data, axis=1)
        result = predicted_classes == target_classes

        tmp.extend(result.tolist())
        batch_accuracy = result.mean()

        val_pbar.set_description(f"Iteration {idx}")
        batch_pbar.set_postfix({"accuracy": f"{batch_accuracy}"})
    accuracies.append(tmp)

Epoch 1/25 | Batch 46: : 46it [01:49,  2.48s/it, loss=1.2499]

In [None]:
fig_gradients = plot_gradient_flow(gradient_history)
fig_gradients.savefig("gradient_flow.png", dpi=300, bbox_inches="tight")

fig = plot_validation_accuracy(
    [sum(acc) / len(acc) for acc in accuracies],
    [sum(loss) / len(loss) for loss in losses],
)
fig.savefig("accuracy_plot.png", dpi=300, bbox_inches="tight")

val_images, val_labels = next(
    iter(create_mnist_batch_loader(validation_data, batch_size=batch_size)),
)
predictions = network(val_images)
class_names = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

fig = plot_model_predictions(
    val_images.data,
    val_labels.data,
    predictions.data,
    class_names,
    num_samples=48,
)
fig.savefig("prediction_showcase.png", dpi=300, bbox_inches="tight")