In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections.abc import Generator

import numpy as np
import torchvision
from tqdm.auto import tqdm

from nnx import autograd
from nnx.autograd.activations import ReLU, Softmax
from nnx.autograd.initialisation import xavier_uniform
from nnx.autograd.layers import Linear, Reshape, Sequential, Dropout
from nnx.autograd.loss import Tensor, cross_entropy_loss
from nnx.autograd.optimizer import SGD

seed = 3

autograd.rng = np.random.default_rng(seed=seed)

In [None]:
num_classes = 10
batch_size = 512
dim = 784
epochs = 10

network = Sequential(
    Reshape((batch_size, dim)),
    Linear(dim, dim * 2, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(dim * 2, dim * 4, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(dim * 4, dim * 2, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(dim * 2, 784, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    ReLU(),
    Linear(784, num_classes, initialiser=xavier_uniform, bias=True),
    Dropout(0.3),
    Softmax(),
)

In [None]:
training_data = torchvision.datasets.FashionMNIST(
    ".",
    train=True,
)  # for initial download set download=True
validation_data = torchvision.datasets.FashionMNIST(".", train=False)


def create_mnist_batch_loader(dataset, batch_size=64) -> Generator:
    """Create a batch loader for Fashion MNIST dataset.

    Args:
        dataset: The PyTorch dataset (e.g., FashionMNIST)
        batch_size: Number of samples per batch

    Yields:
        Generator that yields batches of (images, one-hot labels).

    """
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    num_batches = int(dataset_size // batch_size)

    for batch_idx in range(num_batches):
        # Get indices for this batch
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices = indices[start_idx:end_idx]

        # Initialize batch arrays
        actual_batch_size = len(batch_indices)
        batch_images = np.zeros((actual_batch_size, 28, 28, 1), dtype=np.float32)
        batch_labels = np.zeros((actual_batch_size, 10), dtype=np.float32)

        for i, idx in enumerate(batch_indices):
            image, label = dataset[idx]

            # Convert PIL image to numpy array, normalize to [0,1]
            img_array = np.array(image, dtype=np.float32)[:, :, None] / 255.0

            batch_images[i] = img_array
            batch_labels[i, label] = 1.0  # One-hot encoding

        images_tensor = Tensor(batch_images, requires_grad=True)
        labels_tensor = Tensor(batch_labels, requires_grad=False)

        yield images_tensor, labels_tensor

In [None]:
optimizer = SGD(network.parameters, lr=0.5, clip_value=1.0)
losses = []
accuracies = []

for epoch_idx in range(epochs):
    batch_pbar = tqdm(
        create_mnist_batch_loader(training_data, batch_size=batch_size),
        desc="Starting training...",
    )

    for idx, [images, targets] in enumerate(batch_pbar):
        output = network(images)
        loss = cross_entropy_loss(output, targets)

        batch_pbar.set_description(f"Epoch {epoch_idx + 1}/{epochs} | Batch {idx}")
        batch_pbar.set_postfix({"loss": f"{loss.data:.4f}"})
        losses.append(loss.data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    val_pbar = tqdm(
        create_mnist_batch_loader(validation_data, batch_size=batch_size),
        desc="Starting validation pass...",
    )

    tmp = []
    for idx, [images, targets] in enumerate(val_pbar):
        output = network(images)

        oidx = np.argmax(output.data[0])
        tidx = np.argmax(targets.data[0])

        result = oidx == tidx
        tmp.append(result)

        val_pbar.set_description(f"Iteration {idx}")
        batch_pbar.set_postfix({"accuracy": f"{result}"})
    accuracies.append(tmp)

In [None]:
import matplotlib.pyplot as plt

plt.plot(accuracies)