In [4]:
!pip install dm-haiku

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dm-haiku
  Downloading dm_haiku-0.0.7-py3-none-any.whl (342 kB)
[K     |████████████████████████████████| 342 kB 6.7 MB/s 
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.7 jmp-0.0.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 6.6 MB/s 
Collecting chex>=0.0.4
  Downloading chex-0.1.3-py3-none-any.whl (72 kB)
[K     |████████████████████████████████| 72 kB 806 kB/s 
Installing collected packages: chex, optax
Successfully installed chex-0.1.3 optax-0.1.3


In [5]:
import dataclasses

import jax
from jax import numpy as jnp
import haiku as hk
import tensorflow_datasets as tfds
import tensorflow as tf

# Datasets setup

In [6]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
ds_info

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.1,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)

In [7]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [8]:
x, y = next(iter(ds_train))
x.shape, type(x)

(TensorShape([128, 28, 28, 1]), tensorflow.python.framework.ops.EagerTensor)

In [9]:
x, y = next(iter(tfds.as_numpy(ds_train)))
x.shape, type(x)

((128, 28, 28, 1), numpy.ndarray)

# Haiku

In [48]:
@dataclasses.dataclass
class Tokenizer(hk.Module):
  embed_dim: int

  def __call__(self, img: jnp.array):
    emb = hk.Conv2D(self.embed_dim, kernel_shape=4, stride=4)(img)
    w, h, d = emb.shape
    return jnp.reshape(emb, (w * h, d))


@dataclasses.dataclass
class MLP(hk.Module):
  embed_dim: int
  expand_factor: int
  init_fn: hk.initializers.Initializer

  def __call__(self, tokens: jnp.array):
    tokens = hk.Linear(self.embed_dim * self.expand_factor, w_init=self.init_fn)(tokens)
    tokens = jax.nn.gelu(tokens)
    return hk.Linear(self.embed_dim, w_init=self.init_fn)(tokens)


@dataclasses.dataclass
class MultiHeadsSelfAtt(hk.Module):
  embed_dim: int
  nb_heads: int
  init_fn: hk.initializers.Initializer

  def __call__(self, tokens: jnp.array):
    qkv = hk.Linear(3 * self.embed_dim, w_init=self.init_fn)(tokens)
    q, k, v = jnp.split(qkv, indices_or_sections=3, axis=-1)

    q = jnp.reshape(q, (-1, self.nb_heads, self.embed_dim // self.nb_heads))
    k = jnp.reshape(k, (-1, self.nb_heads, self.embed_dim // self.nb_heads))
    v = jnp.reshape(v, (-1, self.nb_heads, self.embed_dim // self.nb_heads))

    o = _mhsa(q, k, v)

    return hk.Linear(self.embed_dim, w_init=self.init_fn)(o)


@dataclasses.dataclass
class MultiHeadsClassAtt(hk.Module):
  embed_dim: int
  nb_heads: int
  init_fn: hk.initializers.Initializer

  def __call__(self, tokens: jnp.array):
    qkv = hk.Linear(3 * self.embed_dim, w_init=self.init_fn)(tokens)
    q, k, v = jnp.split(qkv, indices_or_sections=3, axis=-1)

    q = jnp.reshape(q[0], (1, self.nb_heads, self.embed_dim // self.nb_heads))
    k = jnp.reshape(k, (-1, self.nb_heads, self.embed_dim // self.nb_heads))
    v = jnp.reshape(v, (-1, self.nb_heads, self.embed_dim // self.nb_heads))

    o = _mhsa(q, k, v)

    return hk.Linear(self.embed_dim, w_init=self.init_fn)(o)

@jax.jit
def _mhsa(q, k, v) -> jnp.array:
    embed_dim = q.shape[-1] * q.shape[-2]

    att_logits = jnp.einsum('...thd,...Thd->...htT', q, k)
    # eq. to jnp.matmul(x.transpose(1,0,2), y.transpose(1, 2, 0))
    scale = 1 / jnp.sqrt(embed_dim)
    att = jax.nn.softmax(att_logits * scale)

    o = jnp.einsum("...htT,...Thd->...thd", att, v)
    # eq. to jnp.matmul(x, y.transpose(1, 0, 2)).transpose(1, 0, 2)
    o = jnp.reshape(o, (-1, embed_dim))

    return o


@dataclasses.dataclass
class Block(hk.Module):
    embed_dim: int
    nb_heads: int
    init_fn: hk.initializers.Initializer
    att_fn: hk.Module
    expand_factor: int

    def __call__(self, tokens):
      tokens_ = hk.LayerNorm(-1, create_scale=True, create_offset=True)(tokens)
      tokens_ = self.att_fn(self.embed_dim, self.nb_heads, self.init_fn)(tokens_)

      tokens = tokens + tokens_

      tokens_ = hk.LayerNorm(-1, create_scale=True, create_offset=True)(tokens)
      tokens_ = MLP(self.embed_dim, self.expand_factor, self.init_fn)(tokens_)

      return tokens + tokens_


@dataclasses.dataclass
class PosEmbedding(hk.Module):
  def __call__(self, tokens: jnp.array):
    init = hk.initializers.TruncatedNormal(stddev=0.02)
    pos_emb = hk.get_parameter("pos_emb", tokens.shape, init=init)
    return tokens + pos_emb


@dataclasses.dataclass
class ViT(hk.Module):
  embed_dim: int
  expand_factor: int
  nb_layers: int
  nb_heads: int
  nb_classes: int
  nb_ca: int = 1
  cait: bool = False

  def __call__(self, img):
    tokens = Tokenizer(self.embed_dim)(img)

    init_token = hk.initializers.TruncatedNormal(stddev=0.02)
    init_var = hk.initializers.VarianceScaling(2 / self.nb_layers)

    cls_token = hk.get_parameter("cls_token", (1, self.embed_dim), init=init_token)

    if self.cait:
      tokens = PosEmbedding()(tokens)

      for _ in range(self.nb_layers - self.nb_ca):
        tokens = Block(
            self.embed_dim, self.nb_heads, init_var, MultiHeadsSelfAtt, self.expand_factor
        )(tokens)

      for _ in range(self.nb_ca):
        tokens = jnp.concatenate((cls_token, tokens))
        cls_token = Block(
            self.embed_dim, self.nb_heads, init_var, MultiHeadsClassAtt, self.expand_factor
        )(tokens)

      final_emb = cls_token[0]
    else:
      tokens = jnp.concatenate((cls_token, tokens))
      tokens = PosEmbedding()(tokens)

      for _ in range(self.nb_layers):
        tokens = Block(
            self.embed_dim, self.nb_heads, init_var, MultiHeadsSelfAtt, self.expand_factor
        )(tokens)

      final_emb = tokens[0]

    return hk.Linear(
        self.nb_classes, w_init=init_var
    )(hk.LayerNorm(-1, create_scale=True, create_offset=True)(final_emb))


rng = jax.random.PRNGKey(1)

def _vit(x):
  return ViT(
      embed_dim=10, 
      expand_factor=4,
      nb_layers=3,
      nb_heads=2, 
      nb_classes=10,
      nb_ca=1,
      cait=True
  )(x)

vit = hk.without_apply_rng(hk.transform(_vit))
params = vit.init(rng, jnp.ones((28, 28, 1)))

In [11]:
vit.apply(params, jnp.ones((28, 28, 1))).shape

(10,)

In [12]:
batch_apply = jax.vmap(vit.apply, in_axes=[None, 0])

batch_apply(params, jnp.ones((32, 28, 28, 1))).shape

(32, 10)

In [13]:
x = jnp.ones((4096, 28, 28, 1))

%timeit batch_apply(params, x).block_until_ready()

The slowest run took 65.69 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 83.5 ms per loop


In [14]:
j_batch_apply = jax.jit(batch_apply)

%timeit j_batch_apply(params, x).block_until_ready()

The slowest run took 152.18 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 21.4 ms per loop


# Training

In [15]:
loader_train = tfds.as_numpy(ds_train)
loader_test = tfds.as_numpy(ds_test)

In [16]:
jnp.sum(jax.nn.one_hot([0, 1], 10) * jnp.ones((2, 10)), -1).shape

(2,)

In [50]:
def loss_fn(params, x, y):
  logits = j_batch_apply(params, x)
  labels = jax.nn.one_hot(y, 10)

  log_likelihood = jnp.mean(jnp.sum(labels * jax.nn.log_softmax(logits), -1))
  return -log_likelihood


@jax.jit
def update(params, x, y):
  loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
  params = jax.tree_map(
      lambda p, g: p - 0.05 * g,
      params, grads
  )
  return params, loss


@jax.jit
def _eval(params, x, y, result, count):
    yhat = jnp.argmax(j_batch_apply(params, x), -1)
    result += jnp.sum(yhat == y)
    count += len(y)
    return result, count

def eval(params, loader):
  result, count = jnp.array(0.), jnp.array(0.)

  for x, y in loader:
    result, count = _eval(params, x, y, result, count)

  return 100 * result / count

In [22]:
x = jnp.ones((128, 28, 28, 1))
y = jnp.ones((128,))

In [30]:
%timeit update(params, x, y)

100 loops, best of 5: 1.83 ms per loop


In [38]:
%timeit eval(params, loader_test)

10 loops, best of 5: 100 ms per loop


In [51]:
params = vit.init(rng, jnp.ones((28, 28, 1)))

In [52]:
EPOCHS = 10

for epoch in range(EPOCHS):
  mean_loss, c = jnp.array(0.), 0
  for x, y in loader_train:
    params, loss = update(params, x, y)
    mean_loss += loss
    c += 1

  print(f"[{epoch}] Mean train loss: {round(mean_loss / c, 5)}")
  print(f"[{epoch}] Train accuracy: {round(float(eval(params, loader_train)), 2)}")

print(f"[{epoch}] Test accuracy: {round(float(eval(params, loader_test)), 2)}")

[0] Mean train loss: 1.9896999597549438
[0] Train accuracy: 41.35
[1] Mean train loss: 1.2107899188995361
[1] Train accuracy: 64.65
[2] Mean train loss: 0.7491499781608582
[2] Train accuracy: 80.72
[3] Mean train loss: 0.5136500000953674
[3] Train accuracy: 87.49
[4] Mean train loss: 0.40366998314857483
[4] Train accuracy: 90.06
[5] Mean train loss: 0.343860000371933
[5] Train accuracy: 91.16
[6] Mean train loss: 0.29712000489234924
[6] Train accuracy: 91.75
[7] Mean train loss: 0.27177000045776367
[7] Train accuracy: 92.84
[8] Mean train loss: 0.2602999806404114
[8] Train accuracy: 92.23
[9] Mean train loss: 0.23503999412059784
[9] Train accuracy: 93.26
[9] Test accuracy: 93.25
