In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections.abc import Generator

import numpy as np
import torchvision
from tqdm.auto import tqdm

from nnx import tinygrad
from nnx.tinygrad.activations import ReLU, Softmax
from nnx.tinygrad.initialisation import xavier_uniform
from nnx.tinygrad.layers import Dropout, Linear, Reshape, Sequential
from nnx.tinygrad.loss import Tensor, cross_entropy_loss
from nnx.tinygrad.optimizer import AdamW

seed = 3

tinygrad.rng = np.random.default_rng(seed=seed)

In [None]:
"""Layer normalization implementation."""

import numpy as np

from nnx.tinygrad.layers import Layer
from nnx.tinygrad.tensor import Tensor


class LayerNorm(Layer):
    """Implements Layer Normalization.

    Normalizes the input across the features dimension (last dimension)
    with learnable affine parameters (gamma and beta).
    """

    def __init__(self, normalized_shape: int, eps: float = 1e-5) -> None:
        """Initialize LayerNorm with shape parameters.

        Args:
            normalized_shape: Size of the feature dimension to normalize across
            eps: Small constant for numerical stability
        """
        super().__init__()
        self._normalized_shape = normalized_shape
        self._eps = eps

        # Initialize parameters (gamma = scale, beta = shift)
        self._gamma = Tensor(np.ones(normalized_shape), requires_grad=True)
        self._beta = Tensor(np.zeros(normalized_shape), requires_grad=True)

        # Register parameters so they're available for optimization
        self._parameters.append(self._gamma)
        self._parameters.append(self._beta)

    def forward(self, inputs: Tensor) -> Tensor:
        """Apply layer normalization to inputs.

        Args:
            inputs: Tensor with shape (..., normalized_shape)

        Returns:
            Normalized tensor of same shape with scaling and shifting applied

        Raises:
            ValueError: If the last dimension of inputs doesn't match normalized_shape
        """
        if inputs.data.shape[-1] != self._normalized_shape:
            msg = (
                f"Last dimension of inputs {inputs.data.shape[-1]} "
                f"doesn't match normalized_shape {self._normalized_shape}"
            )
            raise ValueError(msg)

        # Calculate mean and variance along the feature dimension (last dimension)
        # Keep dimensions for proper broadcasting
        mean = np.mean(inputs.data, axis=-1, keepdims=True)
        var = np.var(inputs.data, axis=-1, keepdims=True)

        # Normalize
        x_norm = (inputs.data - mean) / np.sqrt(var + self._eps)

        # Scale and shift (broadcast gamma and beta across batch dimensions)
        outputs_data = self._gamma.data * x_norm + self._beta.data

        # Create output tensor with gradient tracking if needed
        outputs = Tensor(
            outputs_data,
            requires_grad=(
                inputs.requires_grad
                or self._gamma.requires_grad
                or self._beta.requires_grad
            ),
        )

        if outputs.requires_grad:
            # Connect the computational graph
            outputs.prev = {inputs, self._gamma, self._beta}

            def _backward() -> None:
                if outputs.grad is not None:
                    # Cache values for reuse in gradients
                    inv_std = 1.0 / np.sqrt(var + self._eps)

                    # Gradient for gamma
                    if self._gamma.requires_grad:
                        gamma_grad = np.sum(
                            outputs.grad * x_norm,
                            axis=tuple(range(outputs.grad.ndim - 1)),
                        )
                        self._gamma.grad = (
                            gamma_grad
                            if self._gamma.grad is None
                            else self._gamma.grad + gamma_grad
                        )

                    # Gradient for beta
                    if self._beta.requires_grad:
                        beta_grad = np.sum(
                            outputs.grad, axis=tuple(range(outputs.grad.ndim - 1))
                        )
                        self._beta.grad = (
                            beta_grad
                            if self._beta.grad is None
                            else self._beta.grad + beta_grad
                        )

                    # Gradient for input if needed
                    if inputs.requires_grad:
                        # Formulas for the gradients of layer normalization
                        N = inputs.data.shape[-1]
                        dx_norm = outputs.grad * self._gamma.data

                        # Gradient with respect to input given normalized input
                        dvar = (
                            -0.5
                            * np.sum(
                                dx_norm * (inputs.data - mean), axis=-1, keepdims=True
                            )
                            * inv_std**3
                        )
                        dmean = -np.sum(
                            dx_norm * inv_std, axis=-1, keepdims=True
                        ) + dvar * (-2.0 / N) * np.sum(
                            inputs.data - mean, axis=-1, keepdims=True
                        )

                        # Final gradient for input
                        dx = (
                            dx_norm * inv_std
                            + dvar * (2.0 / N) * (inputs.data - mean)
                            + dmean * (1.0 / N)
                        )
                        inputs.grad = dx if inputs.grad is None else inputs.grad + dx

            outputs.register_backward(_backward)

        return outputs

# Data loading
----
This notebook demonstrates the TinyGrad framework using the Fashion MNIST dataset. While the focus is on showing the neural network architecture implementation, the torchvision dataset loader is used for convenience. Fashion MNIST serves as a good example since the input resolution is small what eases training time for a CPU only implementation and still has 10 distinct non-trivial classes which need to be classified.

In [None]:
training_data = torchvision.datasets.FashionMNIST(
    ".",
    train=True,
)  # for initial download set download=True
validation_data = torchvision.datasets.FashionMNIST(".", train=False)


def create_mnist_batch_loader(dataset, batch_size=64) -> Generator:
    """Create a batch loader for Fashion MNIST dataset.

    Args:
        dataset: The PyTorch dataset (e.g., FashionMNIST)
        batch_size: Number of samples per batch

    Yields:
        Generator that yields batches of (images, one-hot labels).

    """
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    num_batches = int(dataset_size // batch_size)

    for batch_idx in range(num_batches):
        # Get indices for this batch
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices = indices[start_idx:end_idx]

        # Initialize batch arrays
        actual_batch_size = len(batch_indices)
        batch_images = np.zeros((actual_batch_size, 28, 28, 1), dtype=np.float32)
        batch_labels = np.zeros((actual_batch_size, 10), dtype=np.float32)

        for i, idx in enumerate(batch_indices):
            image, label = dataset[idx]

            # Convert PIL image to numpy array, normalize to [0,1]
            img_array = np.array(image, dtype=np.float32)[:, :, None] / 255.0

            batch_images[i] = img_array
            batch_labels[i, label] = 1.0  # One-hot encoding

        images_tensor = Tensor(batch_images)
        labels_tensor = Tensor(batch_labels)

        yield images_tensor, labels_tensor

# Model definition
--------
Below is the definition of the neural network architecture and optimizer. The model implementation demonstrates TinyGrad's PyTorch-inspired API design. A multi-layer perceptron (MLP) architecture is used rather than convolutional layers in the initial layers for feature extraction, as the current `Conv2D` implementation prioritizes educational clarity over performance optimization.

The network architecture includes several fully-connected layers with dropout regularization and ReLU activations, concluding with a softmax layer for classification. Batch size and epochs were chosen by heart and were not the product of a grid/random search for optimal parameters.

In [None]:
num_classes = 10
batch_size = 256
dim = 784
epochs = 25

network = Sequential(
    Reshape((batch_size, dim)),
    Linear(dim, dim * 4, initialiser=xavier_uniform, bias=True),
    LayerNorm(dim * 4),
    Dropout(0.3),
    ReLU(),
    Linear(dim * 4, dim * 4, initialiser=xavier_uniform, bias=True),
    LayerNorm(dim * 4),
    Dropout(0.3),
    ReLU(),
    # Linear(dim * 4, dim * 4, initialiser=xavier_uniform, bias=True),
    # LayerNorm(dim * 4),
    # Dropout(0.3),
    # ReLU(),
    Linear(dim * 4, dim, initialiser=xavier_uniform, bias=True),
    LayerNorm(dim),
    Dropout(0.3),
    ReLU(),
    Linear(dim, num_classes, initialiser=xavier_uniform, bias=True),
    LayerNorm(num_classes),
    Dropout(0.3),
    Softmax(),
)

optimizer = AdamW(network.parameters, lr=1e-3, weight_decay=1e-6)

from nnx.tinygrad.optimizer import SGD

optimizer = SGD(network.parameters, lr=1e-1, momentum=0.99, weight_decay=1e-4)


# Visualization
---------------
To monitor training progress multiple visualization functions were implemented. These help us track:
- Gradient flow through the network
- Validation accuracy over time
- Model predictions on sample data

While these visualization utilities are not core to the TinyGrad implementation, they provide valuable insights into model training dynamics and performance.


In [None]:
from plots import plot_gradient_flow, plot_model_predictions, plot_validation_accuracy


def extract_gradient_stats(model: Sequential, gradient_history: list) -> None:
    """Extract gradients for visualisation purposes."""
    grad_stats = {}
    for i, layer in enumerate(model.layers):
        if hasattr(layer, "parameters"):
            for j, param in enumerate(layer.parameters):
                if param.grad is not None:
                    layer_name = f"Layer {i}-Param {j}"
                    grad_stats[layer_name] = np.mean(np.abs(param.grad))
    gradient_history.append(grad_stats)


# Training
---------
In this section, we train our model on the Fashion MNIST dataset, tracking metrics throughout the process. The training loop demonstrates TinyGrad's end-to-end capabilities, from forward and backward passes to optimization and evaluation.


In [None]:
losses = []
accuracies = []
gradient_history = []

for epoch_idx in range(epochs):
    loss_per_epoch = []
    batch_pbar = tqdm(
        create_mnist_batch_loader(training_data, batch_size=batch_size),
        desc="Starting training...",
    )

    for idx, [images, targets] in enumerate(batch_pbar):
        output = network(images)
        loss = cross_entropy_loss(output, targets)

        batch_pbar.set_description(f"Epoch {epoch_idx + 1}/{epochs} | Batch {idx}")
        batch_pbar.set_postfix({"loss": f"{loss.data:.4f}"})
        loss_per_epoch.append(loss.data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        extract_gradient_stats(network, gradient_history)

    losses.append(loss_per_epoch)

    # Valiation logic
    val_pbar = tqdm(
        create_mnist_batch_loader(validation_data, batch_size=batch_size),
        desc="Starting validation pass...",
    )

    tmp = []
    for idx, [images, targets] in enumerate(val_pbar):
        output = network(images)

        predicted_classes = np.argmax(output.data, axis=1)
        target_classes = np.argmax(targets.data, axis=1)
        result = predicted_classes == target_classes

        tmp.extend(result.tolist())
        batch_accuracy = result.mean()

        val_pbar.set_description(f"Iteration {idx}")
        batch_pbar.set_postfix({"accuracy": f"{batch_accuracy}"})
    accuracies.append(tmp)

In [None]:
fig_gradients = plot_gradient_flow(gradient_history)
fig_gradients.savefig("images/gradient_flow.png", dpi=300, bbox_inches="tight")

fig = plot_validation_accuracy(
    [sum(acc) / len(acc) for acc in accuracies[:5]],
    [sum(loss) / len(loss) for loss in losses[:5]],
)
fig.savefig("images/accuracy_plot.png", dpi=300, bbox_inches="tight")

val_images, val_labels = next(
    iter(create_mnist_batch_loader(validation_data, batch_size=batch_size)),
)
predictions = network(val_images)
class_names = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

fig = plot_model_predictions(
    val_images.data,
    val_labels.data,
    predictions.data,
    class_names,
    num_samples=48,
)
fig.savefig("images/prediction_showcase.png", dpi=300, bbox_inches="tight")