## Haiku tutorial

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import optax


#### import MNIST dataset

In [None]:

import torch
import torchvision

PATH = 'data'
BATCHSIZE = 128


def custom_transform(x):
    return np.reshape(np.array(x, dtype=np.float32), (28*28,)) / 255.

def custom_collate_fn(batch):
    """ ... """
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return imgs, labels


train_data = torchvision.datasets.MNIST(root=PATH, train=True, transform=custom_transform, download=True)
test_data = torchvision.datasets.MNIST(root=PATH, train=False, transform=custom_transform, download=True)


train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCHSIZE, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

#### define forward/model with haiku, init params, define loss, optimizer and update

In [None]:
# define forward and wrap in transform
def forward(x):
  mlp = hk.nets.MLP([784, 100, 10])
  return mlp(x)

forward = hk.transform(forward)

# init params
rng = jax.random.PRNGKey(42)
x = jnp.ones([28 * 28])
params = forward.init(rng, x)
# logits = forward.apply(params, rng, x)  # rng = jax.random.PRNGKey(42) or None

forward = hk.without_apply_rng(forward) # random key not needed/ to get rid of the None

def loss(params, x, y): 
  y_onehot =  jax.nn.one_hot(y, num_classes=10)
  y_hat = forward.apply(params, x)
  return jnp.sum(optax.l2_loss(y_hat, y_onehot))  # using loss from optax

# optimizer from optax
optimizer = optax.adam(learning_rate=1e-2)

@jax.jit
def update(params, opt_state, x, y):
  grad = jax.grad(loss)(params, x, y)
  updates, opt_state = optimizer.update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state

jax.tree_map(lambda x: x.shape, params)


#### train

In [None]:
opt_state = optimizer.init(params)

EPOCHS = 3
for epoch in range(EPOCHS): 
    for xs, ys in train_loader: 
        params, opt_state = update(params, opt_state, xs, ys)


#### test

In [None]:
imgs, labels = next(iter(test_loader))
r_int = np.random.randint(0, BATCHSIZE-1)
img = imgs[r_int]

print(img.shape)

prediction = forward.apply(params, img)

print(f'predicted: {np.argmax(prediction)}, label: {labels[r_int]}')

import matplotlib.pyplot as plt
plt.imshow(np.reshape(img, (28, -1)))