# Training a basic ANN on MNIST

In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping


from torch.utils.tensorboard import SummaryWriter

In [2]:
# Hyperparameters


BATCH_SIZE = 64

LEARNING_RATE = 0.005

STEPS = 300

PRINT_EVERY = 30

SEED = 5678

key = jax.random.PRNGKey(SEED)

## Loading dataset

In [3]:
normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

In [4]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))

dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x1x28x28
print(dummy_y.shape)  # 64
print(dummy_y)

(64, 1, 28, 28)
(64,)
[1 7 5 0 8 1 4 2 7 9 8 3 9 2 3 4 5 5 7 9 2 1 0 6 9 5 4 4 2 9 3 9 9 0 9 3 7
 4 7 2 7 8 6 1 6 2 9 4 5 3 9 1 5 1 8 5 1 4 8 0 9 4 0 6]


In [5]:
class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x


key, subkey = jax.random.split(key, 2)
model = CNN(subkey)

In [9]:
print(type(dummy_x[0, 0, 0, 0]))

<class 'numpy.float32'>


In [44]:
print(model)

CNN(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[3,1,4,4],
      bias=f32[3,1,1],
      in_channels=1,
      out_channels=3,
      kernel_size=(4, 4),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      groups=1,
      use_bias=True
    ),
    MaxPool2d(
      init=-inf,
      operation=<function max>,
      num_spatial_dims=2,
      kernel_size=(2, 2),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      use_ceil=False
    ),
    <wrapped function relu>,
    <wrapped function ravel>,
    Linear(
      weight=f32[512,1728],
      bias=f32[512],
      in_features=1728,
      out_features=512,
      use_bias=True
    ),
    <wrapped function sigmoid>,
    Linear(
      weight=f32[64,512],
      bias=f32[64],
      in_features=512,
      out_features=64,
      use_bias=True
    ),
    <wrapped function relu>,
    Linear(
      weight=f32[10,64],
      bias=f32[10],
      in_features=64,
      out_features=10,
      use_bias=True
  

In [45]:
def loss(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x)
print(output.shape)  # batch of predictions

()
(64, 10)


In [46]:
# Getting the parameters

value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
print(value)


2.304172


# Evaluation

In [47]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)


In [48]:
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)


In [49]:
evaluate(model, testloader)


(Array(2.307958, dtype=float32), Array(0.10081609, dtype=float32))

# Training

In [60]:
def train(
    model: CNN,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> CNN:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: CNN,
        opt_state: PyTree,
        x: Float[Array, "batch 1 28 28"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)


        print("updates = ", updates)
        print("opt_state =", opt_state)
        
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader


    writer = SummaryWriter()

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss, test_accuracy = evaluate(model, testloader)
            train_loss, train_accuracy = evaluate(model, trainloader)

            writer.add_scalar("train loss", float(train_loss), step)
            writer.add_scalar("test loss", float(test_loss), step)

            writer.add_scalar("train accuracy", float(train_accuracy), step)
            writer.add_scalar("test accuracy", float(test_accuracy), step)

            print(
                f"train_loss={train_loss.item()}, train_accuracy={train_accuracy.item()} "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
            )
    return model

In [61]:
optim = optax.adamw(LEARNING_RATE)

# optax.adam(LEARNING_RATE).update()

model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)

updates =  CNN(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[3,1,4,4],
      bias=f32[3,1,1],
      in_channels=1,
      out_channels=3,
      kernel_size=(4, 4),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      groups=1,
      use_bias=True
    ),
    MaxPool2d(
      init=None,
      operation=None,
      num_spatial_dims=2,
      kernel_size=(2, 2),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      use_ceil=False
    ),
    None,
    None,
    Linear(
      weight=f32[512,1728],
      bias=f32[512],
      in_features=1728,
      out_features=512,
      use_bias=True
    ),
    None,
    Linear(
      weight=f32[64,512],
      bias=f32[64],
      in_features=512,
      out_features=64,
      use_bias=True
    ),
    None,
    Linear(
      weight=f32[10,64],
      bias=f32[10],
      in_features=64,
      out_features=10,
      use_bias=True
    ),
    None
  ]
)
opt_state = (ScaleByAdamState(count=Traced<ShapedArray(int3

In [81]:
optax.adam({"e":0})

TypeError: unsupported operand type(s) for *: 'int' and 'dict'

In [73]:
print(model.layers[0].weight)

[[[[ 3.3234400e-01  4.3591189e-01  1.6111216e-01  3.9150640e-01]
   [ 1.0801359e-01 -2.9369362e-03 -8.8275865e-02  3.7083393e-01]
   [-1.6904424e-01 -2.5167051e-01 -1.6658500e-01 -2.6474583e-01]
   [-1.8858138e-01  8.7269619e-02 -2.4614206e-01  1.7511271e-01]]]


 [[[ 3.8009986e-02 -2.0937578e-01  1.5259872e-01  8.7809712e-02]
   [-1.5437129e-01 -2.5056186e-01 -1.7233023e-01 -2.4398933e-01]
   [ 2.2501110e-01  2.4705268e-01 -1.5951101e-01 -2.5225306e-01]
   [-2.2187206e-04  3.9700784e-02  2.0728935e-01  2.4530593e-01]]]


 [[[-9.0941712e-02  1.7508501e-01  1.8408975e-01  1.9494964e-01]
   [-3.4508991e-01  6.3339532e-03  3.6781996e-01 -1.1202988e-01]
   [-3.3668917e-01 -1.1023605e-02  3.3684352e-01 -4.7630627e-02]
   [-1.3542780e-01  1.8202873e-01  3.3244020e-01  7.8686327e-02]]]]


Archive

- 0.38 loss - 3.5 min

BATCH_SIZE = 64

LEARNING_RATE = 3e-4

STEPS = 300

PRINT_EVERY = 30

SEED = 5678


- fucked

BATCH_SIZE = 64

LEARNING_RATE = 3e-2

STEPS = 300

PRINT_EVERY = 30

SEED = 5678


- 0.38 loss

BATCH_SIZE = 64

LEARNING_RATE = 2*3e-4

STEPS = 300

PRINT_EVERY = 30

SEED = 5678


- 0.15 loss

BATCH_SIZE = 64

LEARNING_RATE = 0.005

STEPS = 300

PRINT_EVERY = 30

SEED = 5678