In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections.abc import Generator

import numpy as np
import torchvision
from tqdm.auto import tqdm

from nnx import tinygrad
from nnx.tinygrad.activations import ReLU, Softmax
from nnx.tinygrad.initialisation import xavier_uniform
from nnx.tinygrad.layers import Dropout, LayerNorm, Linear, Reshape, Sequential
from nnx.tinygrad.loss import Tensor, cross_entropy_loss
from nnx.tinygrad.optimizer import AdamW

seed = 3

tinygrad.rng = np.random.default_rng(seed=seed)

# Data loading
----
This notebook demonstrates the TinyGrad framework using the Fashion MNIST dataset. While the focus is on showing the neural network architecture implementation, the torchvision dataset loader is used for convenience. Fashion MNIST serves as a good example since the input resolution is small what eases training time for a CPU only implementation and still has 10 distinct non-trivial classes which need to be classified.

In [None]:
training_data = torchvision.datasets.FashionMNIST(
    ".",
    train=True,
)  # for initial download set download=True
validation_data = torchvision.datasets.FashionMNIST(".", train=False)


def create_mnist_batch_loader(dataset, batch_size=64) -> Generator:
    """Create a batch loader for Fashion MNIST dataset.

    Args:
        dataset: The PyTorch dataset (e.g., FashionMNIST)
        batch_size: Number of samples per batch

    Yields:
        Generator that yields batches of (images, one-hot labels).

    """
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    num_batches = int(dataset_size // batch_size)

    for batch_idx in range(num_batches):
        # Get indices for this batch
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices = indices[start_idx:end_idx]

        # Initialize batch arrays
        actual_batch_size = len(batch_indices)
        batch_images = np.zeros((actual_batch_size, 28, 28, 1), dtype=np.float32)
        batch_labels = np.zeros((actual_batch_size, 10), dtype=np.float32)

        for i, idx in enumerate(batch_indices):
            image, label = dataset[idx]

            # Convert PIL image to numpy array, normalize to [0,1]
            img_array = np.array(image, dtype=np.float32)[:, :, None] / 255.0

            batch_images[i] = img_array
            batch_labels[i, label] = 1.0  # One-hot encoding

        images_tensor = Tensor(batch_images)
        labels_tensor = Tensor(batch_labels)

        yield images_tensor, labels_tensor

# Model definition
--------
Below is the definition of the neural network architecture and optimizer. The model implementation demonstrates TinyGrad's PyTorch-inspired API design. A multi-layer perceptron (MLP) architecture is used rather than convolutional layers in the initial layers for feature extraction, as the current `Conv2D` implementation prioritizes educational clarity over performance optimization.

The network architecture includes several fully-connected layers with dropout regularization and ReLU activations, concluding with a softmax layer for classification. Batch size and epochs were chosen by heart and were not the product of a grid/random search for optimal parameters.

In [None]:
num_classes = 10
batch_size = 512
epochs = 10

network = Sequential(
    Reshape((batch_size, 784)),
    Linear(784, 784, initialiser=xavier_uniform, bias=True),
    LayerNorm(784),
    Dropout(0.1),
    ReLU(),
    Linear(784, 512, initialiser=xavier_uniform, bias=True),
    LayerNorm(512),
    Dropout(0.1),
    ReLU(),
    Linear(512, 256, initialiser=xavier_uniform, bias=True),
    LayerNorm(256),
    Dropout(0.1),
    ReLU(),
    Linear(256, 64, initialiser=xavier_uniform, bias=True),
    LayerNorm(64),
    Dropout(0.1),
    ReLU(),
    Linear(64, num_classes, initialiser=xavier_uniform, bias=True),
    LayerNorm(num_classes),
    Softmax(),
)

optimizer = AdamW(network.parameters, lr=1e-3, weight_decay=1e-2)


# Visualization
---------------
To monitor training progress multiple visualization functions were implemented. These help us track:
- Gradient flow through the network
- Validation accuracy over time
- Model predictions on sample data

While these visualization utilities are not core to the TinyGrad implementation, they provide valuable insights into model training dynamics and performance.


In [None]:
from plots import plot_gradient_flow, plot_model_predictions, plot_validation_accuracy


def extract_gradient_stats(model: Sequential, gradient_history: list) -> None:
    """Extract gradients for visualisation purposes."""
    grad_stats = {}
    for i, layer in enumerate(model.layers):
        if hasattr(layer, "parameters"):
            for j, param in enumerate(layer.parameters):
                if param.grad is not None:
                    layer_name = f"Layer {i}-Param {j}"
                    grad_stats[layer_name] = np.mean(np.abs(param.grad))
    gradient_history.append(grad_stats)


# Training
---------
In this section, we train our model on the Fashion MNIST dataset, tracking metrics throughout the process. The training loop demonstrates TinyGrad's end-to-end capabilities, from forward and backward passes to optimization and evaluation.


In [None]:
losses = []
accuracies = []
gradient_history = []

for epoch_idx in range(epochs):
    loss_per_epoch = []
    batch_pbar = tqdm(
        create_mnist_batch_loader(training_data, batch_size=batch_size),
        desc="Starting training...",
    )

    for idx, [images, targets] in enumerate(batch_pbar):
        output = network(images)
        loss = cross_entropy_loss(output, targets)

        batch_pbar.set_description(f"Epoch {epoch_idx + 1}/{epochs} | Batch {idx}")
        batch_pbar.set_postfix({"loss": f"{loss.data:.4f}"})
        loss_per_epoch.append(loss.data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        extract_gradient_stats(network, gradient_history)

    losses.append(loss_per_epoch)

    # Valiation logic
    val_pbar = tqdm(
        create_mnist_batch_loader(validation_data, batch_size=batch_size),
        desc="Starting validation pass...",
    )

    tmp = []
    for idx, [images, targets] in enumerate(val_pbar):
        output = network(images)

        predicted_classes = np.argmax(output.data, axis=1)
        target_classes = np.argmax(targets.data, axis=1)
        result = predicted_classes == target_classes

        tmp.extend(result.tolist())
        batch_accuracy = result.mean()

        val_pbar.set_description(f"Iteration {idx}")
        batch_pbar.set_postfix({"accuracy": f"{batch_accuracy}"})
    accuracies.append(tmp)

In [None]:
network.eval()

fig_gradients = plot_gradient_flow(gradient_history)
fig_gradients.savefig("images/gradient_flow.png", dpi=300, bbox_inches="tight")

fig = plot_validation_accuracy(
    [sum(acc) / len(acc) for acc in accuracies],
    [sum(loss) / len(loss) for loss in losses],
)
fig.savefig("images/accuracy_plot.png", dpi=300, bbox_inches="tight")

val_images, val_labels = next(
    iter(create_mnist_batch_loader(validation_data, batch_size=batch_size)),
)
predictions = network(val_images)
class_names = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

fig = plot_model_predictions(
    val_images.data,
    val_labels.data,
    predictions.data,
    class_names,
    num_samples=36,
)
fig.savefig("images/prediction_showcase.png", dpi=300, bbox_inches="tight")