In [None]:
from ipynb_path import *

In [None]:
from torchvision.datasets import FashionMNIST
import jax_dataloader as jdl
import haiku as hk
import numpy as np
import jax
import jax.numpy as jnp
import optax
from functools import partial
import time
from datasets import load_dataset
from torch.utils.data import DataLoader

In [None]:
class FlattenAndCast(object):
  def __call__(self, pic):
    return np.array(pic, dtype=float)

In [None]:
train_ds_torch = FashionMNIST(
    '/tmp/mnist/', download=True, transform=FlattenAndCast(), train=True)
test_ds_torch = FashionMNIST(
    '/tmp/mnist/', download=True, transform=FlattenAndCast(), train=False)

train_ds_jax = jdl.ArrayDataset(
    train_ds_torch.data.numpy(), train_ds_torch.targets.numpy())
test_ds_jax = jdl.ArrayDataset(
    test_ds_torch.data.numpy(), test_ds_torch.targets.numpy())

train_ds_hf = load_dataset('fashion_mnist', split='train')
test_ds_hf = load_dataset('fashion_mnist', split='test')



In [None]:
def net_fn(imgs: jnp.ndarray):
    B, H, W = imgs.shape
    imgs = imgs.reshape(B, H, W, 1)
    x = imgs.astype(jnp.float32) / 255.
    cov = hk.Sequential([
        hk.Conv2D(32, 3, 2),
        jax.nn.relu,
        hk.Conv2D(64, 3, 2),
        jax.nn.relu,
        hk.Conv2D(128, 3, 2),
        jax.nn.relu,
        hk.Flatten(),
        hk.Linear(256),
        jax.nn.relu,
        hk.Linear(10),
    ])
    return cov(x)


optax_cross_entropy = optax.softmax_cross_entropy_with_integer_labels

def loss(
    params: hk.Params, 
    classifier: hk.Transformed, 
    imgs: jnp.ndarray, 
    labels: jnp.ndarray
):
    logits = classifier.apply(params, imgs)
    return jnp.mean(jax.vmap(optax_cross_entropy)(logits, labels=labels))

In [None]:
def init():
    classifier = hk.without_apply_rng(hk.transform(net_fn))
    opt = optax.adam(1e-3)
    params = classifier.init(jax.random.PRNGKey(42), jnp.ones((32, 28, 28)))
    opt_state = opt.init(params)
    return classifier, opt, params, opt_state

In [None]:
# @jax.jit
@partial(jax.jit, static_argnums=(2,3))
def update(
    params: hk.Params,
    opt_state: optax.OptState,
    classifier: hk.Transformed,
    opt: optax.GradientTransformation,
    imgs: jnp.ndarray,
    labels: jnp.ndarray
):
    grads = jax.grad(loss)(params, classifier, imgs, labels)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

In [None]:
def _numpy_collate(batch):
    if isinstance(batch[0], (np.ndarray, jax.Array)):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [_numpy_collate(samples) for samples in transposed]
    elif isinstance(batch[0], dict):
        return {key: _numpy_collate([d[key] for d in batch]) for key in batch[0]}
    else:
        return np.array(batch)


In [None]:
def get_img_labels(batch):
    if isinstance(batch, tuple) or isinstance(batch, list):
        # print(batch[0])
        if isinstance(batch[0], dict):
            imgs, labels = batch[0]['image'], batch[0]['label']
        else:
            imgs, labels = batch
    elif isinstance(batch, dict):
        imgs, labels = batch['image'], batch['label']
    else:
        raise ValueError(f'Unknown batch type: {type(batch)}', )
    return imgs, labels

def train(
    train_ds,
    backend: str,
    batch_size: int,
    shuffle: bool = True,
    n_epochs: int = 1
):
    train_dl = jdl.DataLoader(
        train_ds, backend=backend, batch_size=batch_size, shuffle=shuffle)
    imgs_list= []
    classifier, opt, params, opt_state = init()
    
    train_start_time = time.time()
    for i in range(n_epochs):
        epoch_start_time = time.time()
        for batch in train_dl:
            imgs, labels = get_img_labels(batch)

            params, opt_state = update(
                params, opt_state, classifier, opt, imgs, labels)
            imgs_list.append(imgs)
        
        epoch_time = time.time() - epoch_start_time
        print(f'Epoch {i} took {epoch_time: .3f} seconds')
        print(f'Per batch: {epoch_time / len(train_dl): .3f} seconds')
    
    train_time = time.time() - train_start_time
    print(f'Training took {train_time: .3f} seconds')

    imgs_list = jnp.concatenate(imgs_list)
    assert imgs_list.shape == (len(train_ds), 28, 28)
    return train_time

In [None]:
# train(train_ds_torch, 'pytorch', 128)
train(train_ds_jax, 'jax', 128)
# train(train_ds_jax, 'pytorch', 128)
# train(train_ds_hf, 'jax', 128)
# train(train_ds_hf.with_format('jax'), 'pytorch', 128)


Epoch 0 took  10.049 seconds
Per batch:  0.021 seconds
Training took  10.050 seconds
(60000, 28, 28)


10.049551725387573