In [3]:
import jax

from toylib.nn.layers import Linear

### Set up data

In [4]:
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0000 00:00:1702699753.943133   28402 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [5]:
import torch
import torchvision


normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/MNIST/raw/train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 27026452.80it/s]


Extracting MNIST/MNIST/raw/train-images-idx3-ubyte.gz to MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 5000854.30it/s]

Extracting MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 8666183.06it/s]


Extracting MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 3365817.80it/s]

Extracting MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/MNIST/raw






In [6]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x1x28x28
print(dummy_y.shape)  # 64
print(dummy_y)

(64, 1, 28, 28)
(64,)
[6 2 9 0 5 0 3 2 6 5 3 2 4 2 1 1 1 5 2 2 7 4 2 8 7 1 5 3 7 8 6 8 7 9 8 9 9
 8 2 6 2 1 7 9 7 4 1 5 1 2 6 9 8 1 5 4 8 2 8 5 1 0 2 1]


### Linear baseline

In [22]:
from jaxtyping import Array, PRNGKeyArray
from typing import Optional, Callable

from toylib.nn import linear
from toylib.nn import module

class MLP(module.Module):
    
    # The input, hidden, output dimensions of the MLP
    # e.g.: [128, 256, 10] represents an MLP of input dimension 128,
    # hidden dimension of 256 and output dimension of 10
    dims: list[int]
    # activation function to apply
    activation: Optional[Callable] = None
    use_bias: bool

    def __init__(self, dims: list[int], activation: Optional[Callable] = None, use_bias: bool = True, *, key: PRNGKeyArray) -> None:
        self.dims = list(dims)
        self.activation = activation
        self.layers = {}
        self.use_bias = use_bias

        assert len(dims) > 1, "Need at least input and output dimension in `dims`"

        for ix in range(len(dims)-1):
            self.__setattr__(f'layer_{ix}', linear.Linear(dims[ix], dims[ix+1], use_bias=use_bias, key=key))
    
    def __call__(self, x: Array) -> Array:
        pass

In [25]:
model_mlp = MLP([128, 128, 1], key=key)

In [26]:
model_mlp.tree_flatten()

([],
 {'aux': {'dims': [10, 1],
   'activation': None,
   'layers': {},
   'use_bias': True,
   'layer_0': <toylib.nn.linear.Linear at 0x7f76b6332e60>},
  'dynamic_keys': []})

In [29]:
import numpy as np
xs = np.random.random((16, 128))

In [30]:
xs.shape

(16, 128)

In [None]:
jax.vmap(model)()

### Convolution!