## _Spoticore Stage 2 PyTorchified_

We are going to refactor the stage 2 code to (kinda) match PyTorch's public API for building neural networks similar to spoticore.


### Constants

In [2]:
from typing import Final

In [3]:
SAMPLE_SEED: Final[int] = 534150593

In [5]:
import torch
from math import sqrt
from collections.abc import Iterator

In [None]:
class Linear:
    def __init__(self, fan_in: int, fan_out: int, bias: bool = True) -> None:
        self.generator = torch.Generator().manual_seed(SAMPLE_SEED)
        # Kaiming Init with a gain of 1 for linear layer.
        self.weight = torch.randn((fan_in, fan_out), generator=self.generator) / sqrt(
            fan_in
        )
        self.bias = torch.zeros(fan_out) if bias else None

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        self.out = x @ self.weight
        if self.bias:
            self.out += self.bias
        return self.bias

    def parameters(self) -> Iterator[torch.Tensor]:
        yield self.weight
        if self.bias:
            yield self.bias

In [None]:
class BatchNorm1d:
    def __init__(
        self, num_features: int, eps: float = 1e-5, momentum: float = 0.1
    ) -> None:
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # bn gain
        self.gamma = torch.ones(num_features)
        # bn bias
        self.beta = torch.zeros(num_features)
        # bn running mean and variance
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        # forward pass: calculate activations
        if self.training:
            xmean = x.mean(0, keepdim=True)
            xvar = x.var(0, keepdim=True)
        else:
            xmean = self.running_mean
            xvar = self.running_var

        # normalize to unit variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)

        self.out = self.gamma * xhat + self.beta

        # update running mean and var
        if self.training:
            with torch.no_grad():
                self.running_mean = (
                    1 - self.momentum
                ) * self.running_mean + self.momentum * xmean
                self.running_var = (
                    1 - self.momentum
                ) * self.running_var + self.momentum * xvar

        return self.out

    def parameters(self) -> Iterator[torch.Tensor]:
        yield self.gamma
        yield self.beta

In [None]:
torch.zeros(20)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])