### Convolutional Neural Network on CIFAR-10

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
import optax

#### import CIFAR-10

In [None]:
import torch
import torchvision

PATH = 'data'
BATCH_SIZE = 16



def custom_transform(x):
    return (np.array(x, dtype=np.float32)/255.0 -0.5) * 2

def custom_collate_fn(batch):
    """gets list of tuples and returns seperated images and labels as ndarrays"""
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return imgs, labels


train_data = torchvision.datasets.CIFAR10(root=PATH, train=True, transform=custom_transform, download=True)
test_data = torchvision.datasets.CIFAR10(root=PATH, train=False, transform=custom_transform, download=True)


train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#### check data

In [None]:
import matplotlib.pyplot as plt

x, y = next(iter(train_loader))

print('x[0].shape: ', x[0].shape)

img = x[0] * 0.5 + 0.5

print(f'label: {classes[y[0]]}, {y[0]}')
plt.imshow(img)
plt.axis("off")

#### define model

In [None]:
class CNN(hk.Module):
    """Convolutional Neural Network"""
    def __init__(self):
        super().__init__()
        
    def __call__(self, x_in):
        x = hk.Conv2D(output_channels=32, kernel_shape=(3,3), padding="SAME")(x_in)
        x = jax.nn.relu(x)
        x = hk.Conv2D(output_channels=16, kernel_shape=(3,3), padding="SAME")(x)
        x = jax.nn.relu(x)
        x = hk.Flatten()(x)
        x = hk.Linear(len(classes))(x)
        # softmax is in the optax.softmax_cross_entropy loss
        return x

def _conv_net(x):
    cnn = CNN()
    return cnn(x)

conv_net = hk.transform(_conv_net)
conv_net = hk.without_apply_rng(conv_net)

#### init model

In [None]:
rng = jax.random.PRNGKey(42)

x = jnp.ones((1, 32, 32, 3))
params = conv_net.init(rng, x)

jax.tree_map(lambda x: x.shape, params)

#### check forward pass

In [None]:
x, _ = next(iter(train_loader))
print('x.shape: ', x.shape)

preds = conv_net.apply(params, x)
print('preds.shape: ', preds.shape)

#### define loss

In [None]:
def loss(params, x, y): 
  y_onehot =  jax.nn.one_hot(y, num_classes=10)
  y_hat = conv_net.apply(params, x)
  return jnp.mean(optax.softmax_cross_entropy(y_hat, y_onehot))


#### define optimizer and update function

In [None]:
optimizer = optax.adam(learning_rate=1e-3) # learning rate magnitude really mattered here, with 1e-2 the learning did not go forward at all..
opt_state = optimizer.init(params)

@jax.jit
def update(params, opt_state, x, y):
  grad = jax.grad(loss)(params, x, y)
  updates, opt_state = optimizer.update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state


#### train

In [None]:
EPOCHS = 1
for epoch in range(EPOCHS): 
    for xs, ys in train_loader: 
        params, opt_state = update(params, opt_state, xs, ys)

#### test

In [None]:
imgs, labels = next(iter(train_loader))
img = imgs[0]
im = np.expand_dims(img, axis=0)

print(img.shape)
print(im.shape)

im = np.expand_dims(img, axis=0)

prediction = conv_net.apply(params, im)

pred = classes[np.argmax(prediction)]
label = classes[labels[0]]

print(f'predicted: {pred}, label: {label}')

import matplotlib.pyplot as plt
plt.imshow(img*0.5 + 0.5)
plt.axis('off')