### simple (not actually useful) implementation to understand haiku.BatchNorm()

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import optax

#### import CIFAR-10

In [None]:
import torch
import torchvision

PATH = 'data'
BATCH_SIZE = 16



def custom_transform(x):
    return (np.array(x, dtype=np.float32)/255.0 -0.5) * 2

def custom_collate_fn(batch):
    """ gets list of tuples and returns seperated images and labels as ndarrays """
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return imgs, labels


train_data = torchvision.datasets.CIFAR10(root=PATH, train=True, transform=custom_transform, download=True)
test_data = torchvision.datasets.CIFAR10(root=PATH, train=False, transform=custom_transform, download=True)


train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#### check data

In [None]:
import matplotlib.pyplot as plt

x, y = next(iter(train_loader))

print('x[0].shape, x[0].dtype: ', x[0].shape, x[0].dtype)

img = x[0] * 0.5 + 0.5

print(f'label: {classes[y[0]]}, {y[0]}')
plt.imshow(img)
plt.axis("off")

#### define model

In [None]:
def _batchnet(x, is_training):
    batch_norm = hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)
    flatten = hk.Flatten()
    linear = hk.Linear(len(classes))
    return linear(flatten(batch_norm(x, is_training)))

batchnet = hk.transform_with_state(_batchnet)
batchnet = hk.without_apply_rng(batchnet)

#### init model

In [None]:
rng = jax.random.PRNGKey(42)

x = jnp.ones((1, 32, 32, 3))

# params, batch_state = batchnet.init(rng, x)
# or in case no default in _batchnet: 
params, batch_state = batchnet.init(rng, x, is_training=True)

print(jax.tree_map(lambda x: x.shape, params))

# print('batch_state: ', batch_state)
jax.tree_map(lambda x: x.shape, batch_state)

#### check forward pass

In [None]:
x, y = next(iter(train_loader))

print('x.shape: ', x.shape)

preds, batch_state = batchnet.apply(params, batch_state, x, is_training=False)

print('preds.shape: ', preds.shape)
jax.tree_map(lambda x: x.shape, batch_state)

#### define loss

In [None]:
def loss(params, batch_state, x, y): 
  y_onehot =  jax.nn.one_hot(y, num_classes=10)
  y_hat, batch_state = batchnet.apply(params, batch_state, x, is_training=True)
  return jnp.mean(optax.softmax_cross_entropy(y_hat, y_onehot)), batch_state


#### define optimizer and update function

In [None]:
optimizer = optax.adam(learning_rate=1e-3) 
opt_state = optimizer.init(params)

@jax.jit
def update(params, batch_state, opt_state, x, y):
  grad, batch_state = jax.grad(loss, has_aux=True)(params, batch_state, x, y)
  updates, opt_state = optimizer.update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, batch_state, opt_state


#### train

In [None]:

EPOCHS = 1
for epoch in range(EPOCHS): 
    for xs, ys in train_loader: 
        params, batch_state, opt_state = update(params, batch_state, opt_state, xs, ys)

#### test

In [None]:
imgs, labels = next(iter(train_loader))
img = imgs[0]
im = np.expand_dims(img, axis=0)

print(img.shape)
print(im.shape)

im = np.expand_dims(img, axis=0)

prediction, batch_state = batchnet.apply(params, batch_state, x=im, is_training=False)

pred = classes[np.argmax(prediction)]
label = classes[labels[0]]

print(f'predicted: {pred}, label: {label}')

import matplotlib.pyplot as plt
plt.imshow(img*0.5 + 0.5)
plt.axis('off')