In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torchvision
from tqdm import tqdm

from nnx import autograd
from nnx.autograd.activations import ReLU, Softmax
from nnx.autograd.initialisation import xavier_uniform
from nnx.autograd.layers import Conv2D
from nnx.autograd.tensor import Tensor

seed = 3

autograd.rng = np.random.default_rng(seed=seed)

In [None]:
from typing import Callable

from nnx.autograd.layers import Layer


class Linear(Layer):
    """Implements an linear transformation."""

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        *,
        initialiser: Callable,
        bias: bool = True,
    ) -> None:
        """C'tor of the Linear layer.

        Args:
            in_dim: count of input neurons.
            out_dim: count of output neurons.
            initialiser: callable to initialise layers.
            bias: whether we want to use the bias term.

        """
        super().__init__()
        self._in_dim = in_dim
        self._out_dim = out_dim

        weights, bias_ = initialiser(
            in_dim,
            out_dim,
            size=(out_dim, in_dim),
        )

        self._weights = Tensor(weights, requires_grad=True)
        self._parameters.append(self._weights)

        self._bias = None
        if bias:
            self._bias = Tensor(bias_, requires_grad=True)
            self._parameters.append(self._bias)

    def forward(self, inputs: Tensor) -> Tensor:
        """Compute the transformation given the inputs.

        Args:
            inputs: Tensor which needs to be transformed.

        Returns:
            Transformed Tensor.

        """
        outputs: Tensor = inputs @ self._weights.T
        
        if self._bias is not None:
            outputs += self._bias

        return outputs


class SGD:
    def __init__(self, parameters: list[Tensor], lr: float = 0.01) -> None:
        """Initialize SGD optimizer.

        Args:
            parameters: List of parameters to optimize
            lr: Learning rate
            
        """
        self.parameters = parameters
        self.lr = lr

    def step(self) -> None:
        """Update parameters using gradients."""
        for param in self.parameters:
            print("PARAM", param, param.grad)
            if param.grad is not None:
                param._data -= self.lr * param.grad

    def zero_grad(self) -> None:
        """Reset gradients to None."""
        for param in self.parameters:
            param.grad = None


def cross_entropy_loss(predictions: Tensor, targets: Tensor) -> Tensor:
    """Cross entropy loss for classification."""
    epsilon = 1e-10
    log_probs = np.log(predictions.data + epsilon)
    loss_val = -np.sum(targets.data * log_probs) / targets.data.shape[0]
    loss = Tensor(loss_val, requires_grad=predictions.requires_grad)

    if predictions.requires_grad:

        def _backward():
            batch_size = predictions.data.shape[0]
            grad = -(targets.data / (predictions.data + epsilon)) / batch_size
            predictions.grad = (
                grad if predictions.grad is None else predictions.grad + grad
            )

        loss.register_backward(_backward)

    return loss

In [None]:
class Sequential(Layer):
    """Sequentially applies a list of layers."""

    def __init__(self, *layers: Layer) -> None:
        """Initialize with a sequence of layers."""
        super().__init__()
        self.layers = layers

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass through all layers in sequence."""
        for layer in self.layers:
            x = layer(x)
        return x

    @property
    def parameters(self) -> list[Tensor]:
        """Get all parameters from all layers."""
        params = []
        for layer in self.layers:
            if hasattr(layer, "parameters"):
                params.extend(layer.parameters)
            else:
                print(f"Skpping layer {layer}")
        return params

In [None]:
num_classes = 10
network = Sequential(
    Conv2D((3, 3), 3, 32, initialiser=xavier_uniform),
    ReLU(),
    Conv2D((3, 3), 32, 64, initialiser=xavier_uniform),
    ReLU(),
    Conv2D((3, 3), 64, 128, initialiser=xavier_uniform),
    ReLU(),
    lambda x: Tensor(x.data.reshape(1, -1), requires_grad=x.requires_grad),
    Linear(61952, num_classes, initialiser=xavier_uniform),
    Softmax(),
)

In [None]:
training_data = torchvision.datasets.FashionMNIST(
    "./FashionMNIST", train=True,
)  # for initial download set download=True
validation_data = torchvision.datasets.FashionMNIST("./FashionMNIST", train=False)

In [None]:
optimizer = SGD(network.parameters, lr=0.1)

tqdm_iter = tqdm(training_data, desc="")

for image, label in tqdm_iter:
    image_ = Tensor(np.array(image)[None, :, :, None], requires_grad=True)
    output = network(image_)

    targets = np.zeros(shape=(1, 10))
    targets[0, label] = 1
    targets = Tensor(targets, requires_grad=True)

    loss = cross_entropy_loss(output, targets)
    tqdm_iter.desc = f"Current loss: {loss.data}"

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()
    

In [None]:
output.prev

for m in network.layers:
    print([p.grad for p in m.parameters])

In [None]:
output.requires_grad

In [None]:
output.data.shape

In [None]:
image