In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections.abc import Generator

import numpy as np
import torchvision
from tqdm import tqdm

from nnx import autograd
from nnx.autograd.activations import ReLU, Softmax
from nnx.autograd.initialisation import xavier_uniform
from nnx.autograd.layers import Conv2D, Linear, Reshape, Sequential
from nnx.autograd.loss import Tensor, cross_entropy_loss
from nnx.autograd.optimizer import SGD

seed = 3

autograd.rng = np.random.default_rng(seed=seed)

In [None]:
num_classes = 10
network = Sequential(
    Reshape((-1, 784)),
    Linear(784, 784 * 2, initialiser=xavier_uniform, bias=True),
    ReLU(),
    Linear(784 * 2, 784 * 4, initialiser=xavier_uniform, bias=True),
    ReLU(),
    Linear(784 * 4, 784 * 2, initialiser=xavier_uniform, bias=True),
    ReLU(),
    Linear(784 * 2, 784, initialiser=xavier_uniform, bias=True),
    ReLU(),
    Linear(784, num_classes, initialiser=xavier_uniform, bias=True),
    Softmax(),
)

In [None]:
training_data = torchvision.datasets.FashionMNIST(
    ".",
    train=True,
)  # for initial download set download=True
validation_data = torchvision.datasets.FashionMNIST(".", train=False)


def create_mnist_batch_loader(dataset, batch_size=64) -> Generator:
    """Create a batch loader for Fashion MNIST dataset.

    Args:
        dataset: The PyTorch dataset (e.g., FashionMNIST)
        batch_size: Number of samples per batch

    Yields:
        Generator that yields batches of (images, one-hot labels).

    """
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    num_batches = (dataset_size + batch_size - 1) // batch_size  # Ceiling division

    for batch_idx in range(num_batches):
        # Get indices for this batch
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices = indices[start_idx:end_idx]

        # Initialize batch arrays
        actual_batch_size = len(batch_indices)
        batch_images = np.zeros((actual_batch_size, 28, 28, 3), dtype=np.float32)
        batch_labels = np.zeros((actual_batch_size, 10), dtype=np.float32)

        for i, idx in enumerate(batch_indices):
            image, label = dataset[idx]

            # Convert PIL image to numpy array, normalize to [0,1]
            img_array = np.array(image, dtype=np.float32)[:, :, None] / 255.0

            batch_images[i] = img_array
            batch_labels[i, label] = 1.0  # One-hot encoding

        images_tensor = Tensor(batch_images, requires_grad=True)
        labels_tensor = Tensor(batch_labels, requires_grad=False)

        yield images_tensor, labels_tensor

In [None]:
batch_size = 1

optimizer = SGD(network.parameters, lr=0.01, clip_value=1.0)
train_loader = create_mnist_batch_loader(training_data, batch_size=batch_size)
tqdm_iter = tqdm(train_loader, desc="")

In [None]:
for idx, [images, targets] in enumerate(tqdm_iter):
    break

while True:
    output = network(images)

    loss = cross_entropy_loss(output, targets)
    msg = f"Epoch {idx} | Current loss: {loss.data} "
    tqdm_iter.desc = msg
    print(msg)

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()