In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import time
from functools import partial
DEVICE_COUNT = 8
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(DEVICE_COUNT)

In [None]:
import jax
import jax.numpy as jnp
from jax import random, jit
import haiku as hk
from model import Linear, RMSNorm, DenseFF, RotaryEmbedding, MultiHeadAttention, MoEBlock, TransformerBlock, Embedding

In [None]:
batch_size = 16
seq_len = 200
emd_dim = 256
num_heads = 8
k_dim = emd_dim //num_heads
v_dim = emd_dim //num_heads
x = random.normal(random.key(1), (batch_size, seq_len, emd_dim))
heads = random.normal(random.key(1), (batch_size, seq_len, num_heads, k_dim))

In [None]:
def transform_module(mod, configs, loss_fn=None):
  def compute_loss(*args, **kwargs):
    m = mod(**configs)
    y = m(*args, **kwargs)

    if loss_fn is None:
      return jnp.mean(y[0])
    else:
      return loss_fn(y)

  def f(*args, **kwargs):
    m = mod(**configs)
    y = m(*args, **kwargs)

    g = jax.grad(compute_loss)(*args, **kwargs)
    return y, g
  
  f = hk.transform_with_state(f)
  return f

def test_jit(f, *args, **kwargs):
  key = random.PRNGKey(1)
  params, state = f.init(key, *args, **kwargs)
  start = time.perf_counter()
  (y1, g1), state = f.apply(params, state, key, *args, **kwargs)
  end = time.perf_counter()
  non_jitted_latency = end - start

  params, state = jit(f.init)(key, *args, **kwargs)
  jitted = jit(f.apply) 
  _ = jitted(params, state, key, *args, **kwargs)
  start = time.perf_counter()
  (y2, g2), state = jitted(params, state, key, *args, **kwargs)
  end = time.perf_counter()
  jitted_latency = end - start

  y1_leaves = jax.tree_util.tree_leaves(y1) 
  y2_leaves = jax.tree_util.tree_leaves(y2)
  g1_leaves = jax.tree_util.tree_leaves(g1)
  g2_leaves = jax.tree_util.tree_leaves(g2)

  for i in range(len(y1_leaves)):
    y_abs_err = jnp.abs(y1_leaves[i] - y2_leaves[i])
    y_rel_err = jnp.abs(2 * (y1_leaves[i] - y2_leaves[i]) / (y1_leaves[i] + y2_leaves[i]))
    print('y_abs_err: {} \t y_rel_err: {}'.format(
      jnp.mean(y_abs_err), jnp.mean(y_rel_err)
    ))

  for i in range(len(g1_leaves)):
    g_abs_err = jnp.abs(g1_leaves[i] - g2_leaves[i])
    g_rel_err = jnp.abs(2 * (g1_leaves[i] - g2_leaves[i]) / (g1_leaves[i] + g2_leaves[i]))
    print('g_abs_err: {} \t g_rel_err: {}'.format(
      jnp.mean(g_abs_err),  jnp.mean(g_rel_err)
    ))
  
  print('latency: {} \t jitted: {}'.format(non_jitted_latency, jitted_latency))

In [None]:
linear =  transform_module(Linear, {'in_dim': emd_dim, 'out_dim': emd_dim }) 
test_jit(linear, x)

In [None]:
rms_norm =  transform_module(RMSNorm, {}) 
test_jit(rms_norm, x)

In [None]:
ff = transform_module(DenseFF, {'emd_dim': emd_dim, 'activation': 'gelu', 'hidden_dim': emd_dim * 2})
test_jit(ff, x)
ff = transform_module(DenseFF, {'emd_dim': emd_dim, 'activation': 'silu', 'hidden_dim': emd_dim * 2})
test_jit(ff, x)
ff = transform_module(DenseFF, {'emd_dim': emd_dim, 'activation': 'relu', 'hidden_dim': emd_dim * 2})
test_jit(ff, x)

In [None]:
rote = transform_module(RotaryEmbedding, {'dim': k_dim})
test_jit(rote, heads, offset=10)

In [None]:
from model import MultiHeadAttention
mha = transform_module(MultiHeadAttention, {
          'emd_dim': emd_dim,
          'num_q_heads': num_heads,
          'num_kv_heads': num_heads,
          'v_dim': v_dim,
          'k_dim': k_dim,
})
test_jit(mha, x, x, x)



In [None]:
initializer = hk.initializers.TruncatedNormal(stddev=1)
mha2 = transform_module(hk.MultiHeadAttention, {
          'num_heads': num_heads,
          'key_size': k_dim,
          'model_size': emd_dim,
          'w_init': initializer
})
test_jit(mha2, x, x, x)

In [None]:
moe_block = transform_module(MoEBlock, {
    'emd_dim': emd_dim,
    'hidden_dim': emd_dim * 4,
    'num_experts': 8,
    'active_experts': 1,
    'multi_device': True

})
test_jit(moe_block, x)

moe_block = transform_module(MoEBlock, {
    'emd_dim': emd_dim,
    'hidden_dim': emd_dim * 4,
    'num_experts': 1,
    'active_experts': 1,
    'multi_device': True

})
test_jit(moe_block, x)

In [None]:
def transform_embedding(mod, configs):
  def encode(*args, **kwargs):
    m = mod(**configs)
    y = m.encode(*args, **kwargs)
    g = jax.grad(
      lambda *args, **kwargs: jnp.mean(m.encode(*args, **kwargs))
    )(*args, **kwargs)

    return y, g
  
  def decode(*args, **kwargs):
    m = mod(**configs)
    y = m.decode(*args, **kwargs)
    g = jax.grad(
      lambda *args, **kwargs: jnp.mean(m.decode(*args, **kwargs))
    )(*args, **kwargs)

    return y, g
  
  encode = hk.transform_with_state(encode)
  decode = hk.transform_with_state(decode)
  return encode, decode

emd_encode, emd_decode = transform_embedding(Embedding, {
    'emd_dim': emd_dim,
    'n_vocab': emd_dim,
})
test_jit(emd_encode, x)
test_jit(emd_decode, x)



In [None]:
from model import TransformerBlock
transformer_block =  transform_module(TransformerBlock, {
    'emd_dim': emd_dim,
    'num_q_heads': num_heads,
    'num_kv_heads': num_heads,
    'v_dim': v_dim,
    'k_dim': k_dim,
    'hidden_dim': emd_dim * 4, 
    'num_experts': 1,
    'active_experts': 1,
    'expert_capacity': 1.0
})
test_jit(transformer_block, x)

In [None]:
from model import MoeTransformer

transformer = transform_module(MoeTransformer, {
    'depth': 5,
    'n_vocab': emd_dim,
    'emd_dim': emd_dim,
    'num_q_heads': num_heads,
    'num_kv_heads': num_heads,
    'v_dim': v_dim,
    'k_dim': k_dim,
    'hidden_dim': emd_dim * 4, 
    'num_experts': 1,
    'active_experts': 1,
    'expert_capacity': 1.0
})
test_jit(transformer, x)

In [None]:
import dataclasses
from typing import Optional

import numpy as np


def _layer_norm(x: jax.Array) -> jax.Array:
  """Applies a unique LayerNorm to `x` with default settings."""
  ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
  return ln(x)

  
initializer = hk.initializers.TruncatedNormal(stddev=1)
def transformer(h):
    for _ in range(5):
      # First the attention block.
      attn_block = hk.MultiHeadAttention(
          num_heads=num_heads,
          key_size=k_dim,
          model_size=emd_dim,
          w_init=initializer,
      )
      h_norm = h
      h_attn = attn_block(h_norm, h_norm, h_norm, None)
      # h_attn = hk.dropout(hk.next_rng_key(), 0.1, h_attn)
      h = h + h_attn

      # Then the dense block.
      dense_block = hk.Sequential([
          hk.Linear(4 * emd_dim, w_init=initializer),
          jax.nn.gelu,
          hk.Linear(emd_dim, w_init=initializer),
      ])
      h_norm = _layer_norm(h)
      h_dense = dense_block(h_norm)
      h_dense = hk.dropout(hk.next_rng_key(), 0.2, h_dense)
      h = h + h_dense

    return jnp.mean(h)
    
transformer = hk.transform(transformer)
params = transformer.init(random.PRNGKey(1), x)

jitted = jit(transformer.apply)
_  = jitted(params, random.PRNGKey(1), x)

jitted_grad = jit(jax.grad(transformer.apply))
_ = jitted_grad(params, random.PRNGKey(1), x)

start = time.perf_counter()
y = jitted(params, random.PRNGKey(1), x)
end = time.perf_counter()
jitted_latency = end - start
print(jitted_latency)

start = time.perf_counter()
grads = jitted_grad(params, random.PRNGKey(1), x)
end = time.perf_counter()
grad_latency = end - start
print(grad_latency)




In [None]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
devices = mesh_utils.create_device_mesh((8,))
sharding = PositionalSharding(devices)

In [None]:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4,2))
jax.debug.visualize_array_sharding(y)

In [None]:
sharding.reshape(4,2)