In [1]:
import datasets
import haiku as hk
import jax
import jax.nn as nn
import jax.numpy as jnp
import optax

In [2]:
seed = 23

jax.config.update('jax_platform_name', 'cpu')
key = jax.random.PRNGKey(seed)

2022-03-27 19:27:42.398851: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [69]:
mnist_train = datasets.load_dataset("mnist", split="train[:10%]")
mnist_test = datasets.load_dataset("mnist", split="test[:10%]")

def preprocess(batch):
     return {
        "image": jnp.array([jnp.array(img, dtype=jnp.float32) for img in batch["image"]]),
        "label": jnp.array(batch["label"])
    }

mnist_train.set_transform(preprocess)
mnist_test.set_transform(preprocess)

num_classes = mnist_train.features["label"].num_classes



In [70]:
@hk.without_apply_rng
@hk.transform
def model(batch):
    x = batch["image"]
    flat = hk.Flatten()
    l1 = hk.Linear(100)
    l2 = hk.Linear(100)
    l3 = hk.Linear(num_classes)

    x = flat(x)
    x = l1(x)
    x = l2(x)
    x = l3(x)
    return x


key, key1 = jax.random.split(key)
params = model.init(key1, mnist_train[[0]])

y = model.apply(params, mnist_train[[0]])
y.shape

(1, 10)

In [26]:
optimizer = optax.sgd(1e-4)

In [79]:
def cross_entropy(y_true, y_pred):
    return jnp.sum(-y_true * nn.log_softmax(y_pred), axis=-1)

@jax.value_and_grad
def loss(params: hk.Params, batch):
    y = nn.one_hot(batch["label"], num_classes)
    y_pred = model.apply(params, batch)
    loss = jnp.mean(cross_entropy(y, y_pred), axis=0)
    return loss

@jax.jit
def step(params: hk.Params, opt_state: optax.OptState, batch):
    loss_value, grads = loss(params, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss_value

@jax.jit
def accuracy(params: hk.Params, batch):
    pred = jnp.argmax(model.apply(params, batch), axis=-1)
    return jnp.mean(pred == batch["label"])


In [75]:
epochs = 5
log_every = 100
batch_size = 32

In [76]:
def batch_dset(batch_size: int, dset: datasets.Dataset):
    data = dset.shuffle()
    for i in range(0, len(dset), batch_size):
        yield data[i: i+batch_size]

In [80]:
key, key1 = jax.random.split(key)
params = model.init(key1, mnist_train[:batch_size])
opt_state = optimizer.init(params)

i = 0
for _ in range(epochs):
    for batch in batch_dset(batch_size, mnist_train):
        params, opt_state, loss_val = step(params, opt_state, batch)

        if i % log_every == 0:
            accs = []
            for test_batch in batch_dset(batch_size, mnist_test):
                acc = accuracy(params, test_batch)
                accs.append(jax.device_get(acc))
            acc = sum(accs) / len(accs)
            acc = sum(accs) / len(accs)
            print(f"Test accuracy: {acc}")
            print(f"Test loss: {loss_val}")
        i += 1





Test accuracy: 0.1484375
Test loss: 98.7048568725586
Test accuracy: 0.7197265625
Test loss: 9.346254348754883
Test accuracy: 0.814453125
Test loss: 14.27621078491211
Test accuracy: 0.7958984375
Test loss: 0.7828835248947144
Test accuracy: 0.8076171875
Test loss: 4.752987861633301
Test accuracy: 0.8310546875
Test loss: 6.615781784057617
Test accuracy: 0.828125
Test loss: 1.4483803510665894
Test accuracy: 0.8125
Test loss: 1.4546937942504883
Test accuracy: 0.802734375
Test loss: 2.5628137588500977
Test accuracy: 0.818359375
Test loss: 2.5403683185577393
