In [1]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [2]:
!pip install --quiet dm-haiku

In [3]:
import dataclasses
import functools

import jax
from jax import numpy as jnp
import haiku as hk
import tensorflow_datasets as tfds
import tensorflow as tf

# Datasets setup

In [4]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
ds_info

tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.1,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)

In [5]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [6]:
x, y = next(iter(ds_train))
x.shape, type(x)

(TensorShape([128, 28, 28, 1]), tensorflow.python.framework.ops.EagerTensor)

In [7]:
x, y = next(iter(tfds.as_numpy(ds_train)))
x.shape, type(x)

((128, 28, 28, 1), numpy.ndarray)

# Haiku

In [8]:
@dataclasses.dataclass
class Tokenizer(hk.Module):
  embed_dim: int

  def __call__(self, img: jnp.array):
    emb = hk.Conv2D(self.embed_dim, kernel_shape=4, stride=4)(img)
    w, h, d = emb.shape
    return jnp.reshape(emb, (w * h, d))


@dataclasses.dataclass
class MLP(hk.Module):
  embed_dim: int
  expand_factor: int
  init_fn: hk.initializers.Initializer

  def __call__(self, tokens: jnp.array):
    tokens = hk.Linear(self.embed_dim * self.expand_factor, w_init=self.init_fn)(tokens)
    tokens = jax.nn.gelu(tokens)
    return hk.Linear(self.embed_dim, w_init=self.init_fn)(tokens)


@dataclasses.dataclass
class MultiHeadsSelfAtt(hk.Module):
  embed_dim: int
  nb_heads: int
  init_fn: hk.initializers.Initializer

  def __call__(self, tokens: jnp.array):
    qkv = hk.Linear(3 * self.embed_dim, w_init=self.init_fn)(tokens)
    q, k, v = jnp.split(qkv, indices_or_sections=3, axis=-1)

    q = jnp.reshape(q, (-1, self.nb_heads, self.embed_dim // self.nb_heads))
    k = jnp.reshape(k, (-1, self.nb_heads, self.embed_dim // self.nb_heads))
    v = jnp.reshape(v, (-1, self.nb_heads, self.embed_dim // self.nb_heads))

    o = _mhsa(q, k, v)

    return hk.Linear(self.embed_dim, w_init=self.init_fn)(o)


@dataclasses.dataclass
class MultiHeadsClassAtt(hk.Module):
  embed_dim: int
  nb_heads: int
  init_fn: hk.initializers.Initializer

  def __call__(self, tokens: jnp.array):
    qkv = hk.Linear(3 * self.embed_dim, w_init=self.init_fn)(tokens)
    q, k, v = jnp.split(qkv, indices_or_sections=3, axis=-1)

    q = jnp.reshape(q[0], (1, self.nb_heads, self.embed_dim // self.nb_heads))
    k = jnp.reshape(k, (-1, self.nb_heads, self.embed_dim // self.nb_heads))
    v = jnp.reshape(v, (-1, self.nb_heads, self.embed_dim // self.nb_heads))

    o = _mhsa(q, k, v)

    return hk.Linear(self.embed_dim, w_init=self.init_fn)(o)

@jax.jit
def _mhsa(q, k, v) -> jnp.array:
    embed_dim = q.shape[-1] * q.shape[-2]

    att_logits = jnp.einsum('...thd,...Thd->...htT', q, k)
    # eq. to jnp.matmul(x.transpose(1,0,2), y.transpose(1, 2, 0))
    scale = 1 / jnp.sqrt(embed_dim)
    att = jax.nn.softmax(att_logits * scale)

    o = jnp.einsum("...htT,...Thd->...thd", att, v)
    # eq. to jnp.matmul(x, y.transpose(1, 0, 2)).transpose(1, 0, 2)
    o = jnp.reshape(o, (-1, embed_dim))

    return o


@dataclasses.dataclass
class Block(hk.Module):
    embed_dim: int
    nb_heads: int
    init_fn: hk.initializers.Initializer
    att_fn: hk.Module
    expand_factor: int

    def __call__(self, tokens):
      tokens_ = hk.LayerNorm(-1, create_scale=True, create_offset=True)(tokens)
      tokens_ = self.att_fn(self.embed_dim, self.nb_heads, self.init_fn)(tokens_)

      tokens = tokens + tokens_

      tokens_ = hk.LayerNorm(-1, create_scale=True, create_offset=True)(tokens)
      tokens_ = MLP(self.embed_dim, self.expand_factor, self.init_fn)(tokens_)

      return tokens + tokens_


@dataclasses.dataclass
class PosEmbedding(hk.Module):
  def __call__(self, tokens: jnp.array):
    init = hk.initializers.TruncatedNormal(stddev=0.02)
    pos_emb = hk.get_parameter("pos_emb", tokens.shape, init=init)
    return tokens + pos_emb


@dataclasses.dataclass
class ViT(hk.Module):
  embed_dim: int
  expand_factor: int
  nb_layers: int
  nb_heads: int
  nb_classes: int
  nb_ca: int = 1
  cait: bool = False

  def __call__(self, img):
    tokens = Tokenizer(self.embed_dim)(img)

    init_token = hk.initializers.TruncatedNormal(stddev=0.02)
    init_var = hk.initializers.VarianceScaling(2 / self.nb_layers)

    cls_token = hk.get_parameter("cls_token", (1, self.embed_dim), init=init_token)

    if self.cait:
      tokens = PosEmbedding()(tokens)

      for _ in range(self.nb_layers - self.nb_ca):
        tokens = Block(
            self.embed_dim, self.nb_heads, init_var, MultiHeadsSelfAtt, self.expand_factor
        )(tokens)

      for _ in range(self.nb_ca):
        tokens = jnp.concatenate((cls_token, tokens))
        cls_token = Block(
            self.embed_dim, self.nb_heads, init_var, MultiHeadsClassAtt, self.expand_factor
        )(tokens)

      final_emb = cls_token[0]
    else:
      tokens = jnp.concatenate((cls_token, tokens))
      tokens = PosEmbedding()(tokens)

      for _ in range(self.nb_layers):
        tokens = Block(
            self.embed_dim, self.nb_heads, init_var, MultiHeadsSelfAtt, self.expand_factor
        )(tokens)

      final_emb = tokens[0]

    return hk.Linear(
        self.nb_classes, w_init=init_var
    )(hk.LayerNorm(-1, create_scale=True, create_offset=True)(final_emb))


def _vit(x):
  return ViT(
      embed_dim=10, 
      expand_factor=4,
      nb_layers=3,
      nb_heads=2, 
      nb_classes=10,
      nb_ca=1,
      cait=True
  )(x)

In [9]:
nb_devices = len(jax.devices())
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [10]:
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, len(jax.devices()))

vit = hk.without_apply_rng(hk.transform(_vit))

dummy_x = jnp.ones((28, 28, 1))
params = vit.init(rng, dummy_x)

In [11]:
dist_params = jax.tree_map(lambda x: jnp.array([x] * nb_devices), params)

In [12]:
batched = jax.jit(jax.vmap(vit.apply, axis_name="batch", in_axes=(None, 0)))
dist_batched = jax.pmap(batched, axis_name="device")
dist_batched_not_replicated = jax.pmap(batched, in_axes=(None, 0))

dist_batched(dist_params, jnp.ones((8, 32, 28, 28, 1))).shape

(8, 32, 10)

In [13]:
x = jnp.ones((8, 4096, 28, 28, 1))

In [14]:
%timeit dist_batched(dist_params, x).block_until_ready()

The slowest run took 16.66 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 204 ms per loop


In [15]:
%timeit dist_batched_not_replicated(params, x).block_until_ready()

The slowest run took 17.06 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 207 ms per loop


In [16]:
j_dist_batched = jax.jit(dist_batched)

%timeit j_dist_batched(dist_params, x).block_until_ready()

  f"The jitted function {name} includes a pmap. Using "


The slowest run took 13.70 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 273 ms per loop


# Training

In [17]:
loader_train = tfds.as_numpy(ds_train)
loader_test = tfds.as_numpy(ds_test)

In [18]:
def to_sharded(x):
  b, *shp = x.shape
  d = len(jax.devices())
  return jnp.reshape(x, (d, b // d, *shp))


def loss_fn(params, xs, ys):
  logits = batched(params, xs)
  labels = jax.nn.one_hot(ys, 10)
  log_likelihood = jnp.mean(jnp.sum(labels * jax.nn.log_softmax(logits), -1))
  return -log_likelihood


@functools.partial(jax.pmap, axis_name="nb_devices")
def update(params, xs, ys):
  loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)

  loss = jax.lax.pmean(loss, axis_name="nb_devices")
  grads = jax.lax.pmean(grads, axis_name="nb_devices")

  new_params = jax.tree_map(
      lambda p, g: p - 0.05 * g,
      params, grads
  )
  return new_params, loss


@jax.jit
def _eval(params, xs, ys, result, count):
    yhat = jnp.argmax(batched(params, xs), -1)
    result += jnp.sum(yhat == ys)
    count += len(ys)
    return result, count


def eval(params, loader):
  result, count = jnp.array(0.), jnp.array(0.)

  for xs, ys in loader:
    result, count = _eval(params, xs, ys, result, count)

  return 100 * result / count

In [19]:
params = vit.init(rng, jnp.ones((28, 28, 1)))
dist_params = jax.tree_map(lambda x: jnp.array([x] * nb_devices), params)

In [20]:
EPOCHS = 10

for epoch in range(EPOCHS):
  mean_loss, c = jnp.array(0.), 0
  for xs, ys in loader_train:
    dist_params, loss = update(dist_params, to_sharded(xs), to_sharded(ys))
    mean_loss += loss[0]
    c += 1

  params = jax.device_get(jax.tree_map(lambda x: x[0], dist_params))
  print(f"[{epoch}] Mean train loss: {round(mean_loss / c, 5)}")
  print(f"[{epoch}] Train accuracy: {round(float(eval(params, loader_train)), 2)}")

print(f"[{epoch}] Test accuracy: {round(float(eval(params, loader_test)), 2)}")

[0] Mean train loss: 2.1996798515319824
[0] Train accuracy: 24.71
[1] Mean train loss: 1.5767399072647095
[1] Train accuracy: 63.02
[2] Mean train loss: 0.9171499609947205
[2] Train accuracy: 72.03
[3] Mean train loss: 0.6782900094985962
[3] Train accuracy: 83.51
[4] Mean train loss: 0.5320799946784973
[4] Train accuracy: 85.68
[5] Mean train loss: 0.4148799777030945
[5] Train accuracy: 88.36
[6] Mean train loss: 0.3565099835395813
[6] Train accuracy: 88.4
[7] Mean train loss: 0.305869996547699
[7] Train accuracy: 90.83
[8] Mean train loss: 0.27741000056266785
[8] Train accuracy: 92.3
[9] Mean train loss: 0.24948999285697937
[9] Train accuracy: 92.51
[9] Test accuracy: 92.62


In [21]:
params_0 = jax.device_get(jax.tree_map(lambda x: x[0], dist_params))
print(f"[{epoch}] Train accuracy: {round(float(eval(params_0, loader_train)), 2)}")

[9] Train accuracy: 92.52


In [22]:
params_1 = jax.device_get(jax.tree_map(lambda x: x[1], dist_params))
print(f"[{epoch}] Train accuracy: {round(float(eval(params_1, loader_train)), 2)}")

[9] Train accuracy: 92.51


In [23]:
jnp.allclose(params_0["vi_t/linear"]["w"], params_1["vi_t/linear"]["w"])

DeviceArray(True, dtype=bool)