In [None]:
# Simple script to benchmark flash and dot product attention kernels
%load_ext autoreload
%autoreload 2
from MaxText.layers.attentions import AttentionOp
from MaxText import pyconfig
from MaxText import maxtext_utils
from MaxText.layers import quantizations

import jax
import jax.numpy as jnp

model_args = ["", "MaxText/configs/base.yml", "weight_dtype=bfloat16", "quantization=int8", "quantize_kvcache=True", "checkpoint_is_quantized=True", "ici_tensor_parallelism=1", "ici_fsdp_parallelism=1", "ici_context_parallelism=8", "context_parallel_load_balance=False"]

config = pyconfig.initialize(model_args)

devices_array = maxtext_utils.create_device_mesh(config=config)
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
quant = quantizations.configure_quantization(config)

In [6]:
%load_ext autoreload
%autoreload 2
kernel = "flash" # For now, this won't matter because we'll call the actual kernel manually (i.e. `AttentionOp.apply_dot_product` / `AttentionOp.apply_flash_attention`)
MAX_TARGET_LEN = 8192
NUM_QUERY_HEADS = 64
NUM_KV_HEADS = 8
HEAD_DIM = 128
attention_op = AttentionOp(
  config, mesh, kernel, MAX_TARGET_LEN, NUM_QUERY_HEADS, NUM_KV_HEADS, compute_axis_order=(0,1,2,3), quant=quant, kv_quant=quantizations.configure_kv_quant(config)
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
# reference: https://source.corp.google.com/piper///depot/google3/experimental/users/tohaowu/pallas/flash_tuning.py
original_key = jax.random.PRNGKey(0)
qkv_keys = jax.random.split(original_key, 1)
qkv_key_idx = 0
qkv_key = qkv_keys[qkv_key_idx]
qkv_key_idx += 1
batch_size = 1
q_seq_len = MAX_TARGET_LEN
kv_seq_len = MAX_TARGET_LEN
model_mode = "autoregressive"


q = jax.random.normal(qkv_key, (batch_size, q_seq_len, NUM_QUERY_HEADS, HEAD_DIM), dtype=jnp.bfloat16)
k = jax.random.normal(qkv_key, (batch_size, kv_seq_len, NUM_KV_HEADS, HEAD_DIM), dtype=jnp.bfloat16)
v = jax.random.normal(qkv_key, (batch_size, kv_seq_len, NUM_KV_HEADS, HEAD_DIM), dtype=jnp.bfloat16)

q = jax.device_put(q)
k = jax.device_put(k)
v = jax.device_put(v)

decoder_segment_ids = jnp.zeros((batch_size, kv_seq_len), dtype=jnp.int32)

In [8]:
# import timeit

# WARMUP_ITERS = 10
# TIME_IT_NUMBER = 10
# TIME_IT_REPEAT = 5


# print("Starting warmup...")
# global_block_q = 512
# global_block_kv = 512
# global_block_kv_compute = 512
# global_block_q_dkv = 512
# global_block_kv_dkv = 512
# global_block_kv_dkv_compute = 512
# global_block_q_dq = 512
# global_block_kv_dq = 512
# global_q_layout = "qkv"
# global_k_layout = "qkv"
# global_v_layout = "qkv" # HEAD_DIM_MINOR, SEQ_MINOR

# for _ in range(WARMUP_ITERS):
#   jax.block_until_ready(attention_op.apply_attention_dot(q,k,v,decoder_segment_ids, model_mode))
#   jax.block_until_ready(attention_op.tpu_flash_attention(q,k,v,decoder_segment_ids, model_mode=model_mode, global_block_q=global_block_q,
#       global_block_kv=global_block_kv,
#       global_block_kv_compute=  global_block_kv_compute,
#       global_block_q_dkv=global_block_q_dkv,
#       global_block_kv_dkv=global_block_kv_dkv,
#       global_block_kv_dkv_compute=global_block_kv_dkv_compute,
#       global_block_q_dq=global_block_q_dq,
#       global_block_kv_dq=global_block_kv_dq,
#       global_q_layout=global_q_layout,
#       global_k_layout=global_k_layout,
#       global_v_layout=global_v_layout))


# # # with jax.profiler.StepTraceAnnotation("gqa_pallas", step_num=k_seq_len):
# print("Starting benchmark for dot product attention...")
# times_dot_product = timeit.repeat(lambda: jax.block_until_ready(attention_op.apply_attention_dot(q,k,v,decoder_segment_ids, model_mode)), repeat=TIME_IT_REPEAT, number=TIME_IT_NUMBER)
# times_dot_product = min(times_dot_product) / TIME_IT_NUMBER

# # with jax.profiler.StepTraceAnnotation("gqa_reference", step_num=k_seq_len):
# print("Starting benchmark for Splash attention...")
# times_flash = timeit.repeat(lambda: jax.block_until_ready(attention_op.tpu_flash_attention(q,k,v,decoder_segment_ids, model_mode=model_mode)), repeat=TIME_IT_REPEAT, number=TIME_IT_NUMBER)
# times_flash = min(times_flash) / TIME_IT_NUMBER


# print(f"    Dot Product Attention:        {times_dot_product:.6f} s")
# print(f"    Flash Attention: {times_flash:.6f} s")
# print(f"    Speedup: {times_dot_product / times_flash if times_dot_product > 0 else float('inf'):.2f}x")

In [9]:
import timeit

WARMUP_ITERS = 20
TIME_IT_NUMBER = 15
TIME_IT_REPEAT = 10

for block_size in [256, 512, 1024]:
  for layout in ["HEAD_DIM_MINOR", "SEQ_MINOR"]:
    print(f"Trying block size {block_size} and layout {layout}")
    print("Starting warmup...")
    global_block_q = block_size
    global_block_kv = block_size
    global_block_kv_compute = block_size
    global_block_q_dkv = block_size
    global_block_kv_dkv = block_size
    global_block_kv_dkv_compute = block_size
    global_block_q_dq = block_size
    global_block_kv_dq = block_size
    global_q_layout = layout
    global_k_layout = layout
    global_v_layout = layout

    for _ in range(WARMUP_ITERS):
      jax.block_until_ready(attention_op.apply_attention_dot(q,k,v,decoder_segment_ids, model_mode))
      jax.block_until_ready(attention_op.tpu_flash_attention(q,k,v,decoder_segment_ids, model_mode=model_mode, global_block_q=global_block_q,
          global_block_kv=global_block_kv,
          global_block_kv_compute=  global_block_kv_compute,
          global_block_q_dkv=global_block_q_dkv,
          global_block_kv_dkv=global_block_kv_dkv,
          global_block_kv_dkv_compute=global_block_kv_dkv_compute,
          global_block_q_dq=global_block_q_dq,
          global_block_kv_dq=global_block_kv_dq,
          global_q_layout=global_q_layout,
          global_k_layout=global_k_layout,
          global_v_layout=global_v_layout))


    # with jax.profiler.StepTraceAnnotation("gqa_reference", step_num=k_seq_len):
    print("Starting benchmark for Splash attention...")
    times_flash = timeit.repeat(lambda: jax.block_until_ready(attention_op.tpu_flash_attention(q,k,v,decoder_segment_ids, model_mode=model_mode, global_block_q=global_block_q,
          global_block_kv=global_block_kv,
          global_block_kv_compute=  global_block_kv_compute,
          global_block_q_dkv=global_block_q_dkv,
          global_block_kv_dkv=global_block_kv_dkv,
          global_block_kv_dkv_compute=global_block_kv_dkv_compute,
          global_block_q_dq=global_block_q_dq,
          global_block_kv_dq=global_block_kv_dq,
          global_q_layout=global_q_layout,
          global_k_layout=global_k_layout,
          global_v_layout=global_v_layout)), repeat=TIME_IT_REPEAT, number=TIME_IT_NUMBER)
    times_flash = min(times_flash) / TIME_IT_NUMBER
    print(f"    Flash Attention: {times_flash:.6f} s")


Trying block size 256 and layout HEAD_DIM_MINOR
Starting warmup...
Starting benchmark for Splash attention...
    Flash Attention: 0.067719 s
Trying block size 256 and layout SEQ_MINOR
Starting warmup...
Starting benchmark for Splash attention...
    Flash Attention: 0.065657 s
Trying block size 512 and layout HEAD_DIM_MINOR
Starting warmup...
Starting benchmark for Splash attention...
    Flash Attention: 0.047557 s
Trying block size 512 and layout SEQ_MINOR
Starting warmup...
Starting benchmark for Splash attention...
    Flash Attention: 0.047107 s
Trying block size 1024 and layout HEAD_DIM_MINOR
Starting warmup...
Starting benchmark for Splash attention...
    Flash Attention: 0.044642 s
Trying block size 1024 and layout SEQ_MINOR
Starting warmup...
Starting benchmark for Splash attention...
    Flash Attention: 0.044642 s
