<a href="https://colab.research.google.com/github/Original-2/Equinox_Examples/blob/main/simple_mnist_convnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Equinox version

In [1]:
!pip3 install equinox
!pip3 install optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting equinox
  Downloading equinox-0.5.3-py3-none-any.whl (63 kB)
[K     |████████████████████████████████| 63 kB 1.8 MB/s 
Installing collected packages: equinox
Successfully installed equinox-0.5.3
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting optax
  Downloading optax-0.1.2-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 33.8 MB/s 
Collecting chex>=0.0.4
  Downloading chex-0.1.3-py3-none-any.whl (72 kB)
[K     |████████████████████████████████| 72 kB 634 kB/s 
Installing collected packages: chex, optax
Successfully installed chex-0.1.3 optax-0.1.2


In [2]:
import jax.numpy as jnp
from tensorflow import keras
import equinox as eqx
import jax
import optax
from tqdm import tqdm, trange

In [3]:
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

x_train = jnp.expand_dims(x_train, -1)
x_test = jnp.expand_dims(x_test, -1)

x_train = jnp.resize(jnp.array(x_train), (len(x_train),1,28,28,))
x_test = jnp.resize(jnp.array(x_test), (len(x_test),1,28,28,))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [4]:
y_train = list(y_train)

for i in range(len(y_train)):
  temp = y_train[i]
  y_train[i] = [0 for i in range(10)]
  y_train[i][temp] = 1

y_test = list(y_test)

for i in range(len(y_test)):
  temp = y_test[i]
  y_test[i] = [0 for i in range(10)]
  y_test[i][temp] = 1

y_train = jnp.array(y_train)
y_test = jnp.array(y_test)

In [5]:
import equinox as eqx
import jax
import jax.example_libraries.stax as stax
import optax
import jax.numpy as jnp

class ConvNet(eqx.Module):
    conv: list
    dense: eqx.nn.Linear
    drop: eqx.nn.Dropout


    def __init__(
        self,
        *,
        key,
    ):
        key1, key2, key3 = jax.random.split(key, 3)

        self.conv = [eqx.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=2, key=key1),
                      eqx.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), padding=2, key=key2),
                      ]

        self.dense = eqx.nn.Linear(4096, 10, key=key3)
        self.drop = eqx.nn.Dropout()
    def __call__(self, x, key, inference=False):
        x = jax.nn.relu(self.conv[0](x))
        x = eqx.nn.AvgPool2D((2,2),2)(x)

        x = jax.nn.relu(self.conv[1](x))
        x = eqx.nn.AvgPool2D((2,2),2)(x)

        x = x.flatten()
        x = self.drop(x, key=key)
        x = self.dense(x)
        return jax.nn.softmax(x)

In [6]:
def main(
    batch_size=128,
    learning_rate=0.001,
    steps=60000//128,
    val_steps=10000//128,
    seed=5678,
    epochs=15
):
    key = jax.random.PRNGKey(seed)
    key, temp = jax.random.split(key, 2)
    model = ConvNet(key=temp)

    @eqx.filter_value_and_grad
    def compute_loss(model, x, y, key):
        predy = jax.vmap(model)(x, key=jnp.array(key))    
        predy = jnp.clip(predy, 1e-7, 1 - 1e-7)
        losses = jnp.sum(y * -jnp.log(predy), axis=-1, keepdims=False)
        return jnp.mean(losses)

    @eqx.filter_jit
    def make_step(model, x, y, opt_state, key):
        loss, grads = compute_loss(model, x, y, key)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    def compute_metrics(model, x, y, key):
        predy = jax.vmap(model)(x, jnp.array(key), jnp.array([True for i in range(batch_size)]))    
        predy = jnp.clip(predy, 1e-7, 1 - 1e-7)
        losses = jnp.sum(y * -jnp.log(predy), axis=-1, keepdims=False)

        accuracy = jnp.mean(jnp.argmax(predy, -1) == jnp.argmax(y, -1))
        return jnp.mean(losses), accuracy

    optim = optax.adam(learning_rate)
    opt_state = optim.init(model)

    for epoch in range(epochs):
        bar = trange(steps)
        for i in bar:
            bar.set_description(f"epoch {epoch}")
            start = i*batch_size
            stop = i*batch_size+batch_size

            key, *bkeys = jax.random.split(key, batch_size+1)
            loss, model, opt_state = make_step(model, x_train[start:stop], y_train[start:stop], opt_state, bkeys)
            loss = loss.item()
            bar.set_postfix(loss=loss)

        ### compute metrics ###

        metrics = {"loss":[],
                  "accuracy":[]}

        for i in range(val_steps):
          start = i*batch_size
          stop = i*batch_size+batch_size

          key, *bkeys = jax.random.split(key, batch_size+1)
          l = compute_metrics(model, x_test[start:stop], y_test[start:stop], bkeys)

          metrics["loss"].append(l[0])
          metrics["accuracy"].append(l[1])

        print("loss: " + str(jnp.mean(jnp.array(metrics["loss"]))) + ", accuracy: " + str(100*jnp.mean(jnp.array(metrics["accuracy"]))) + "%")

In [7]:
main()

epoch 0: 100%|██████████| 468/468 [00:26<00:00, 17.81it/s, loss=0.0538]


loss: 0.15498702, accuracy: 95.44271%


epoch 1: 100%|██████████| 468/468 [00:09<00:00, 47.58it/s, loss=0.0349]


loss: 0.09950211, accuracy: 96.905045%


epoch 2: 100%|██████████| 468/468 [00:09<00:00, 48.25it/s, loss=0.0543]


loss: 0.08405958, accuracy: 97.3758%


epoch 3: 100%|██████████| 468/468 [00:09<00:00, 47.93it/s, loss=0.0141]


loss: 0.07233128, accuracy: 97.79648%


epoch 4: 100%|██████████| 468/468 [00:09<00:00, 47.77it/s, loss=0.0262]


loss: 0.06449932, accuracy: 98.02684%


epoch 5: 100%|██████████| 468/468 [00:09<00:00, 47.57it/s, loss=0.0202]


loss: 0.05718883, accuracy: 98.20713%


epoch 6: 100%|██████████| 468/468 [00:10<00:00, 46.33it/s, loss=0.0247]


loss: 0.05789528, accuracy: 98.02684%


epoch 7: 100%|██████████| 468/468 [00:09<00:00, 47.22it/s, loss=0.0335]


loss: 0.052895613, accuracy: 98.227165%


epoch 8: 100%|██████████| 468/468 [00:09<00:00, 47.36it/s, loss=0.015]


loss: 0.05443692, accuracy: 98.31731%


epoch 9: 100%|██████████| 468/468 [00:09<00:00, 47.43it/s, loss=0.0278]


loss: 0.04910239, accuracy: 98.507614%


epoch 10: 100%|██████████| 468/468 [00:09<00:00, 47.55it/s, loss=0.0156]


loss: 0.047235426, accuracy: 98.4375%


epoch 11: 100%|██████████| 468/468 [00:09<00:00, 47.31it/s, loss=0.0119]


loss: 0.043852136, accuracy: 98.58774%


epoch 12: 100%|██████████| 468/468 [00:09<00:00, 47.04it/s, loss=0.0181]


loss: 0.043277003, accuracy: 98.547676%


epoch 13: 100%|██████████| 468/468 [00:09<00:00, 47.25it/s, loss=0.0088]


loss: 0.04302097, accuracy: 98.64784%


epoch 14: 100%|██████████| 468/468 [00:09<00:00, 47.27it/s, loss=0.019]


loss: 0.040528145, accuracy: 98.63782%
