In [None]:
from torchvision.datasets import FashionMNIST
import jax_dataloader.core as jdl
import haiku as hk
import numpy as np
import jax
import jax.numpy as jnp
import optax
from functools import partial
import time

In [None]:
class FlattenAndCast(object):
  def __call__(self, pic):
    return np.array(pic, dtype=float)

In [None]:
# TODO: Add normalization
train_ds = FashionMNIST('/tmp/mnist/', download=True, transform=FlattenAndCast(), train=True)
test_ds = FashionMNIST('/tmp/mnist/', download=True, transform=FlattenAndCast(), train=False)

In [None]:
def net_fn(imgs: jnp.ndarray):
    B, H, W = imgs.shape
    imgs = imgs.reshape(B, H, W, 1)
    x = imgs.astype(jnp.float32) / 255.
    cov = hk.Sequential([
        hk.Conv2D(32, 3, 2),
        jax.nn.relu,
        hk.Conv2D(64, 3, 2),
        jax.nn.relu,
        hk.Conv2D(128, 3, 2),
        jax.nn.relu,
        hk.Flatten(),
        hk.Linear(256),
        jax.nn.relu,
        hk.Linear(10),
    ])
    return cov(x)

    # mlp = hk.Sequential([
    #     hk.Flatten(),
    #     hk.Linear(512),
    #     jax.nn.relu,
    #     hk.Linear(256),
    #     jax.nn.relu,
    #     hk.Linear(10),
    # ])
    # return mlp(x)

In [None]:
def loss(
    params: hk.Params, 
    classifier: hk.Transformed, 
    imgs: jnp.ndarray, 
    labels: jnp.ndarray
):
    logits = classifier.apply(params, imgs)
    # return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(logits, labels))
    return jnp.mean(jax.vmap(optax.softmax_cross_entropy_with_integer_labels)(logits, labels=labels))

In [None]:
# TODO: Jax dataloader does not work
train_dl = jdl.DataLoader(train_ds, 'pytorch', batch_size=128, shuffle=True)
# test_dl = jdl.DataLoader(test_ds, 'jax', batch_size=32, shuffle=False)

classifier = hk.without_apply_rng(hk.transform(net_fn))
opt = optax.adam(1e-3)
img, label = next(iter(train_dl))
params = classifier.init(jax.random.PRNGKey(42), jnp.ones((32, 28, 28)))
opt_state = opt.init(params)



In [None]:
@jax.jit
def update(
    params: hk.Params,
    opt_state: optax.OptState,
    # classifier: hk.Transformed,
    # opt: optax.GradientTransformation,
    imgs: jnp.ndarray,
    labels: jnp.ndarray
):
    grads = jax.grad(loss)(params, classifier, imgs, labels)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

In [None]:
for i in range(1):
    start_time = time.time()
    for img, label in train_dl:
        params, opt_state = update(params, opt_state, img, label)
    epoch_time = time.time() - start_time
    print(f'Epoch {i} took {epoch_time: .3f} seconds')
    print(f'Per batch: {epoch_time / len(train_dl): .3f} seconds')

Epoch 0 took  12.626 seconds
Per batch:  0.027 seconds
