In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torchvision
from tqdm import tqdm

from nnx import autograd
from nnx.autograd.activations import ReLU, Softmax
from nnx.autograd.initialisation import xavier_uniform
from nnx.autograd.layers import Conv2D
from nnx.autograd.tensor import Tensor

seed = 3

autograd.rng = np.random.default_rng(seed=seed)

In [None]:
from typing import Callable

from nnx.autograd.layers import Layer


class Reshape(Layer):
    """Layer to reshape tensor while preserving gradient flow."""

    def __init__(self, target_shape):
        """Initialize with target shape.

        If an element of target_shape is -1, it will be inferred from the input.

        Args:
            target_shape: Shape to reshape to. Can include -1 for inference.
        """
        super().__init__()
        self.target_shape = target_shape

    def forward(self, inputs: Tensor) -> Tensor:
        """Forward pass to reshape tensor.

        Args:
            inputs: Input tensor to reshape

        Returns:
            Reshaped tensor with gradient connections preserved
        """
        # Store original shape for backward pass
        original_shape = inputs.data.shape

        # Perform reshape
        reshaped_data = inputs.data.reshape(self.target_shape)
        outputs = Tensor(reshaped_data, requires_grad=inputs.requires_grad)

        if inputs.requires_grad:
            # Connect the computational graph
            outputs.prev = {inputs}

            def _backward():
                if outputs.grad is not None:
                    # Reshape gradient back to original shape
                    grad = outputs.grad.reshape(original_shape)
                    inputs.grad = grad if inputs.grad is None else inputs.grad + grad

            outputs.register_backward(_backward)

        return outputs


class Linear(Layer):
    """Implements an linear transformation."""

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        *,
        initialiser: Callable,
        bias: bool = True,
    ) -> None:
        """C'tor of the Linear layer.

        Args:
            in_dim: count of input neurons.
            out_dim: count of output neurons.
            initialiser: callable to initialise layers.
            bias: whether we want to use the bias term.

        """
        super().__init__()
        self._in_dim = in_dim
        self._out_dim = out_dim

        weights, bias_ = initialiser(
            in_dim,
            out_dim,
            size=(out_dim, in_dim),
        )

        self._weights = Tensor(weights, requires_grad=True)
        self._parameters.append(self._weights)

        self._bias = None
        if bias:
            self._bias = Tensor(bias_, requires_grad=True)
            self._parameters.append(self._bias)

    def forward(self, inputs: Tensor) -> Tensor:
        """Compute the transformation given the inputs.

        Args:
            inputs: Tensor which needs to be transformed.

        Returns:
            Transformed Tensor.

        """
        outputs: Tensor = inputs @ self._weights.T

        if self._bias is not None:
            outputs += self._bias

        return outputs


class SGD:
    def __init__(
        self,
        parameters: list[Tensor],
        lr: float = 0.01,
        clip_value: float = 1.0,
    ) -> None:
        """Initialize SGD optimizer with gradient clipping."""
        self.parameters = parameters
        self.lr = lr
        self.clip_value = clip_value

    def step(self) -> None:
        for param in self.parameters:
            if param.grad is not None:
                # Handle broadcasting
                grad = param.grad
                if param.grad.shape != param._data.shape:
                    # Sum across extra dimensions (typically batch dimension)
                    axes = tuple(range(len(param.grad.shape) - len(param._data.shape)))
                    grad = np.sum(param.grad, axis=axes)

                    # Handle case where dimensions are aligned but sizes differ
                    if grad.shape != param._data.shape:
                        # Try reshaping
                        grad = np.reshape(grad, param._data.shape)

                # Gradient clipping
                clipped_grad = np.clip(grad, -self.clip_value, self.clip_value)

                # Update parameters
                param._data -= self.lr * clipped_grad

        def zero_grad(self) -> None:
            """Reset gradients to None."""
            for param in self.parameters:
                param.grad = None


def cross_entropy_loss(predictions: Tensor, targets: Tensor) -> Tensor:
    """Cross entropy loss for classification.

    Args:
        predictions: Model predictions (after softmax)
        targets: One-hot encoded target labels

    Returns:
        Loss value as a Tensor with gradient connections preserved
    """
    epsilon = 1e-10
    log_probs = np.log(predictions.data + epsilon)
    loss_val = -np.sum(targets.data * log_probs) / targets.data.shape[0]
    loss = Tensor(loss_val, requires_grad=predictions.requires_grad)

    if predictions.requires_grad:
        # Connect the computational graph
        loss.prev = {predictions}

        def _backward():
            batch_size = predictions.data.shape[0]
            # Gradient of cross-entropy w.r.t softmax output
            grad = (predictions.data - targets.data) / batch_size
            predictions.grad = (
                grad if predictions.grad is None else predictions.grad + grad
            )

        loss.register_backward(_backward)

    return loss

In [None]:
class Sequential(Layer):
    """Sequentially applies a list of layers."""

    def __init__(self, *layers: Layer) -> None:
        """Initialize with a sequence of layers."""
        super().__init__()
        self.layers = layers

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass through all layers in sequence."""
        for layer in self.layers:
            x = layer(x)
        return x

    @property
    def parameters(self) -> list[Tensor]:
        """Get all parameters from all layers."""
        params = []
        for layer in self.layers:
            if hasattr(layer, "parameters"):
                params.extend(layer.parameters)
            else:
                print(f"Skpping layer {layer}")
        return params

In [None]:
num_classes = 10
network = Sequential(
    Reshape((-1, 784)),
    Linear(784, 784 * 2, initialiser=xavier_uniform, bias=True),
    ReLU(),
    Linear(784 * 2, 784 * 4, initialiser=xavier_uniform, bias=True),
    ReLU(),
    Linear(784 * 4, 784 * 2, initialiser=xavier_uniform, bias=True),
    ReLU(),
    Linear(784 * 2, 784, initialiser=xavier_uniform, bias=True),
    ReLU(),
    Linear(784, num_classes, initialiser=xavier_uniform, bias=True),
    Softmax(),
)

In [None]:
training_data = torchvision.datasets.FashionMNIST(
    "./FashionMNIST",
    train=True,
)  # for initial download set download=True
validation_data = torchvision.datasets.FashionMNIST("./FashionMNIST", train=False)


def create_mnist_batch_loader(dataset, batch_size=64):
    """Create a batch loader for Fashion MNIST dataset.

    Args:
        dataset: The PyTorch dataset (e.g., FashionMNIST)
        batch_size: Number of samples per batch
        shuffle: Whether to shuffle the dataset

    Returns:
        Generator that yields batches of (images, one-hot labels)

    """

    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    num_batches = (dataset_size + batch_size - 1) // batch_size  # Ceiling division

    for batch_idx in range(num_batches):
        # Get indices for this batch
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices = indices[start_idx:end_idx]

        # Initialize batch arrays
        actual_batch_size = len(batch_indices)
        batch_images = np.zeros((actual_batch_size, 28, 28, 3), dtype=np.float32)
        batch_labels = np.zeros((actual_batch_size, 10), dtype=np.float32)

        for i, idx in enumerate(batch_indices):
            image, label = dataset[idx]

            # Convert PIL image to numpy array, normalize to [0,1]
            img_array = np.array(image, dtype=np.float32)[:, :, None] / 255.0

            batch_images[i] = img_array
            batch_labels[i, label] = 1.0  # One-hot encoding

        images_tensor = Tensor(batch_images, requires_grad=True)
        labels_tensor = Tensor(batch_labels, requires_grad=False)

        yield images_tensor, labels_tensor

In [None]:
optimizer = SGD(network.parameters, lr=0.01, clip_value=1.0)
num_samples = len(training_data)
batch_size = 1

train_loader = create_mnist_batch_loader(training_data, batch_size=batch_size)

tqdm_iter = tqdm(train_loader, desc="")

In [None]:
# Get a single example
single_img, single_label = training_data[0]
img_array = np.array(single_img, dtype=np.float32)[:, :, None] / 255.0
img_tensor = Tensor(img_array[None, :, :, :], requires_grad=True)  # Add batch dimension

# One-hot encode
label_onehot = np.zeros((1, 10), dtype=np.float32)
label_onehot[0, single_label] = 1.0
label_tensor = Tensor(label_onehot, requires_grad=False)

# Train loop with higher learning rate
optimizer = SGD(network.parameters, lr=100, clip_value=1.0)
for i in range(1000):  # More iterations
    output = network(img_tensor)
    loss = cross_entropy_loss(output, label_tensor)
    
    if i % 10 == 0:
        print(f"Iteration {i}, Loss: {loss.data}")
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:


for idx, [images, targets] in enumerate(tqdm_iter):
    break

while True:
    output = network(images)

    loss = cross_entropy_loss(output, targets)
    msg = f"Epoch {idx} | Current loss: {loss.data} "
    tqdm_iter.desc = msg
    print(msg)

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()