In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np

from nnx import autograd
from nnx.autograd.activations import ReLU, Softmax
from nnx.autograd.initialisation import xavier_uniform
from nnx.autograd.layers import Conv2D
from nnx.autograd.tensor import Tensor

seed = 3

autograd.rng = np.random.default_rng(seed=seed)

In [None]:
from typing import Callable

from nnx.autograd.layers import Layer


class Linear(Layer):
    """Implements an linear transformation."""

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        *,
        initialiser: Callable,
        bias: bool = True,
    ) -> None:
        """C'tor of the Linear layer.

        Args:
            in_dim: count of input neurons.
            out_dim: count of output neurons.
            initialiser: callable to initialise layers.
            bias: whether we want to use the bias term.

        """
        self._in_dim = in_dim
        self._out_dim = out_dim

        weights, bias_ = initialiser(
            in_dim,
            out_dim,
            size=(out_dim, in_dim),
        )

        self._weights = Tensor(weights, requires_grad=True)
        self._bias = Tensor(bias_, requires_grad=True) if bias else None

    def forward(self, inputs: Tensor) -> Tensor:
        """Compute the transformation given the inputs.

        Args:
            inputs: Tensor which needs to be transformed.

        Returns:
            Transformed Tensor.

        """
        outputs: Tensor = inputs @ self._weights.T

        outputs.prev = {inputs, self._weights}
        if self._bias is not None:
            outputs.prev.add(self._bias)
            outputs += self._bias

        def _backward() -> None:
            if inputs.requires_grad:
                dx = outputs.grad @ inputs.data.T
                inputs.grad = dx if inputs.grad is None else inputs.grad + dx

            if self._weights.requires_grad:
                dw = self._weights.data.T @ outputs.grad
                self._weights.grad = (
                    dw if self._weights.grad is None else self._weights.grad + dw
                )

            if self._bias is not None and self._bias.requires_grad:
                db = (
                    np.sum(outputs.grad, axis=0, keepdims=True)
                    if outputs.grad.ndim > self._bias.data.ndim
                    else outputs.grad
                )

                self._bias.grad = (
                    db if self._bias.grad is None else self._bias.grad + db
                )

        outputs.register_backward(_backward)

        return outputs

In [None]:
class SimpleNetwork:
    """Represents a small wrapper for multiple layers."""

    def __init__(self, num_classes: int) -> None:
        """C'tor of SimpleNetwork."""
        self._layers = [
            Conv2D((3, 3), 3, 32, initialiser=xavier_uniform),
            ReLU(),
            Conv2D((3, 3), 32, 64, initialiser=xavier_uniform),
            ReLU(),
            Conv2D((3, 3), 64, 128, initialiser=xavier_uniform),
            ReLU(),
            lambda x: Tensor(x.data.reshape(1, -1)),
            Linear(86528, num_classes, initialiser=xavier_uniform),
            Softmax(),
        ]

    def __call__(self, inputs: Tensor) -> Tensor:
        """Compute a forward pass.

        Returns:
            The result after computed forward pass.

        """
        output = inputs
        for layer in self._layers:
            output = layer(output)

        return output


In [None]:
network = SimpleNetwork(10)

In [None]:
outputs = network(inputs=inputs)
mock_label = Tensor(np.ones_like(outputs.data))


In [None]:
outputs.backward(np.ones_like(outputs.data))