From cf45d7fa18cfc7cec3ff74b438f030354a13b496 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 16 May 2024 23:41:16 +0000 Subject: [PATCH 01/41] Stable version of ragged attention. --- jetstream_pt/engine.py | 68 +++- jetstream_pt/environment.py | 20 +- jetstream_pt/layers.py | 385 ++++++++++++++++-- .../third_party/llama/model_exportable.py | 14 +- run_interactive_multiple_host.py | 1 + 5 files changed, 431 insertions(+), 57 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index defa3e94..91626664 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -64,6 +64,7 @@ class DecodeState: ] # only present in quantized kv current_position: int lens: jax.Array # [batch_size, 1] + start: jax.Array # [batch_size, 1], the starting pos for each slot input_pos: jax.Array # [batch_size, 1] input pos for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid @@ -126,8 +127,9 @@ def init_decode_state( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, scalers, - self.env.max_input_sequence_length, + self.env.max_input_sequence_length / 4, jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), + jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # start pos jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # input pos jnp.full( (self.env.batch_size, self.env.cache_sequence_length), @@ -145,7 +147,10 @@ def _call_model_generate( caches, cache_scales, mask, + start, input_pos, + pre_batch, + pre_block, ): if self.env.quant_config.enable_kv_quantization: caches_obj = [ @@ -163,7 +168,7 @@ def _call_model_generate( ] mask = jnp.expand_dims(mask, (1, 2)) - args = (tokens, input_pos, caches_obj, mask) + args = (tokens, caches_obj, mask, start, input_pos, pre_batch, pre_block) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: with torchjax.jax_mode: @@ -193,7 +198,7 @@ def _call_model_prefill(self, weights, tokens, input_indexes): dtype=self.default_dtype, ) mask = jnp.triu(mask, k=1) - args = (tokens, input_indexes, caches, mask) + args = (tokens, caches, mask, None, input_indexes, None, None) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: @@ -272,6 +277,7 @@ def _insert_no_wrap( cond = jnp.logical_and(x <= decode_state.current_position, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) + start = decode_state.start_pos.at[slot].set(pos % self.env.cache_sequence_length) input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) if not self.env.quant_config.enable_kv_quantization: @@ -328,6 +334,7 @@ def insert(cache, scaler, new_entry): scales, decode_state.current_position, lens, + start, input_pos, mask, ) @@ -366,6 +373,7 @@ def _insert_wrap( mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) + start = decode_state.start_pos.at[slot].set(start_insert) input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) old_caches = decode_state.caches @@ -420,6 +428,7 @@ def insert(cache, scaler, new_entry): scales, decode_state.current_position, lens, + start, input_pos, mask, ) @@ -448,6 +457,45 @@ def insert( slot, ) + def precompute_ragged_block_indices(self, decode_state: DecodeState): + start = decode_state.start + end = (start + decode_state.input_pos) % self.env.cache_len + batch_size = start.shape[0] + bk = self.env.block_size + b = jnp.arange(batch_size).reshape((batch_size, 1)) + num_bk = self.env.cache_len // self.env.block_size + i = jnp.arange(num_bk).reshape((1, num_bk)) + i = jnp.broadcast_to(i, (batch_size, num_bk)) + + start = start.reshape((batch_size, 1)) + end = end.reshape((batch_size, 1)) + + am_last_batch = b == batch_size - 1 + last_good_block = jnp.where(start < end, jnp.div(end - 1, bk), jnp.div(self.env.cache_len -1, bk)) + + next_b = jnp.where(am_last_batch, b, b + 1) + next_i = jnp.where(am_last_batch, last_good_block, 0) + + # start < end + def true_comp(b, i, bk, start, end, next_b, next_i): + b_next = jnp.where(i * bk >= end, next_b, b) + i_next = jnp.where(i * bk >= end, next_i, i) + i_next = jnp.where((i + 1) * bk <= start, jnp.div(start, bk), i_next) + return b_next, i_next + + # start > end + def false_comp(b, i, bk, start, end): + b_next = b + i_next = jnp.where(jnp.logical_and(i * bk >= end, (i + 1) * bk <= start), jnp.div(start, bk), i) + return b_next, i_next + + true_comp_b, true_comp_i = true_comp(b, i, bk, start, end, next_b, next_i) + false_comp_b, false_comp_i = false_comp(b, i, bk, start, end) + + b_next = jnp.where(start < end, true_comp_b, jnp.where(start == end, next_b, false_comp_b)) + i_next = jnp.where(start < end, true_comp_i, jnp.where(start == end, next_i, false_comp_i)) + return b_next, i_next + def generate( self, params: Any, decode_state: DecodeState ) -> tuple[DecodeState, engine_api.ResultTokens]: @@ -457,6 +505,8 @@ def generate( # fill mask first mask = decode_state.mask.at[:, decode_state.current_position].set(0) + pre_batch, pre_block = self.precompute_ragged_block_indices(decode_state) + logits, new_caches, new_scales = self._call_model_generate( params, decode_state.tokens, @@ -464,8 +514,12 @@ def generate( decode_state.caches, decode_state.cache_scales, mask, + decode_state.start, decode_state.input_pos, + pre_batch, + pre_block, ) + next_token = self._sampling(logits, self.env.batch_size) lens = decode_state.lens + 1 data = jnp.concatenate( @@ -493,7 +547,9 @@ def generate( new_scales, (decode_state.current_position + 1) % self.env.cache_sequence_length, lens, - decode_state.input_pos + 1, + decode_state.start, + # Stop the input_pos from increasing if it's 0, for better ragged attention performance + jnp.where(decode_state.input_pos == 0, 0, decode_state.input_pos + 1), mask, ) print( @@ -619,6 +675,7 @@ def get_decode_state_sharding(self) -> DecodeState: self.replicated, self.replicated, self.replicated, + self.replicated, ) def get_prefix_sequence_ddim(self) -> Any: @@ -666,6 +723,7 @@ def create_pytorch_engine( max_cache_length=1024, sharding_config=None, shard_on_batch=False, + ragged_mha=False, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -724,6 +782,7 @@ def create_pytorch_engine( bf16_enable=bf16_enable, sharding_config_path=sharding_config, shard_on_batch=shard_on_batch, + ragged_mha=ragged_mha, ) if shard_on_batch and sharding_config: @@ -756,6 +815,7 @@ def create_pytorch_engine( env_data.model_type = model_name + "-" + param_size env_data.num_layers = args.num_hidden_layers env = JetEngineEnvironment(env_data) + print(f"Enviroment variables: {vars(env)}") pt_model = gemma_model.GemmaModel(args, env) else: raise RuntimeError(f"Model with name {model_name} not found") diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index b4df9980..3e9e4be6 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -87,6 +87,11 @@ class JetEngineEnvironmentData: # Whether to shard on batch dimension. i.e. data parallel. shard_on_batch: bool = False + # Whether to enable ragged multi head attention. + ragged_mha: bool = False + + # The block size for the ragged attention. + block_size: int = 512 # pylint: disable-next=all class JetEngineEnvironment: @@ -95,6 +100,9 @@ def __init__(self, data: JetEngineEnvironmentData): self._data = data self.seq_len = self._data.max_input_sequence_length + self.cache_len = self._data.cache_sequence_length + self.ragged_mha = self._data.ragged_mha + self.block_size = self._data.block_size P = jax.sharding.PartitionSpec @@ -144,17 +152,19 @@ def apply_sharding(self, tensor, *, axis: int | None): # pylint: disable-next=all tensor._elem = jax.lax.with_sharding_constraint(tensor._elem, sharding_spec) - def sharding_by_axis(self, axis): + def partition_by_axis(self, axis): """return sharding partition spc by axis, options are x, y, -1 or Noe""" if axis == -1 or axis is None: - return jsharding.NamedSharding(self._mesh, jax.sharding.PartitionSpec()) + return jax.sharding.PartitionSpec() sharding = [None] * (axis + 1) sharding[axis] = "x" - sharding_spec = jsharding.NamedSharding( - self._mesh, jax.sharding.PartitionSpec(*sharding) - ) + sharding_spec = jax.sharding.PartitionSpec(*sharding) return sharding_spec + def sharding_by_axis(self, axis): + """return sharding partition spc by axis, options are x, y, -1 or Noe""" + return jsharding.NamedSharding(self._mesh, self.partition_by_axis(axis)) + def make_caches_prefill(self): """Create kv caches for inference prefill""" caches = [] diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index a44c1f3e..2112b213 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -17,9 +17,13 @@ import math from typing import Optional, Tuple +import functools import jax import jax.numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.shard_map import shard_map import torch import torch.nn.functional as F import torch_xla2 @@ -38,6 +42,7 @@ def _calc_cosine_dist(x, y): y = y.flatten().to(torch.float32) return (torch.dot(x, y) / (x.norm() * y.norm())).item() +import numpy as np class Int8Embedding(torch.nn.Module): @@ -395,12 +400,314 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: ) +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + +def ragged_flash_attention_kernel( + start_ref, + end_ref, + line_end_ref, + pre_b_ref, + pre_i_ref, + q_ref, + k_ref, + v_ref, + k_scaler_ref, + v_scaler_ref, + o_ref, + m_ref, + l_ref, + bk: int, + mask_value: float, + normalize_var: bool, + quantized: bool, +): + """Pallas kernel for flash attention.""" + with jax.named_scope("attention_kernel"): + b, i = pl.program_id(0), pl.program_id(1) + + @pl.when(i == 0) + def init(): + with jax.named_scope("init"): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + length = line_end_ref[b] + start = start_ref[b] + end = end_ref[b] + + @pl.when(jnp.logical_and(i * bk < length, start != end)) + def run(): + with jax.named_scope("run_qk"): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + if normalize_var: + qk = qk / jnp.sqrt(k.shape[-1]) + if quantized: + qk = qk * k_scaler_ref[...] + with jax.named_scope("run_mask"): + start = start_ref[b] + end = end_ref[b] + iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) + mask_start_lt_end = jnp.logical_and(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) + mask_start_gt_end = jnp.logical_or(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) + #mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) + mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) + + qk = qk + jnp.where(mask, 0.0, mask_value) + + with jax.named_scope("run_softmax"): + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + if quantized: + s_curr = s_curr * v_scaler_ref[...] + o_curr_times_l_curr = jnp.dot(s_curr, v) + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + +@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]) +def ragged_mqa( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + pre_batch = None, + pre_block = None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi query attention.""" + with jax.named_scope("ragged_mqa"): + batch_size, num_heads, head_dim = q.shape + seq_len = k.shape[1] + + def kv_index_map(b, i, start_ref, end_ref, line_end_ref, pre_batch_ref, pre_block_ref): + index = b * (seq_len // bk) + i + return pre_batch_ref[index], pre_block_ref[index], 0 + + def q_index_map(b, i, start_ref, end_ref, line_end_ref, pre_batch_ref, pre_block_ref): + index = b * (seq_len // bk) + i + return pre_batch_ref[index], 0, 0 + + def scaler_index_map(b, i, *_): + return b, 0, i + + line_end = jnp.where(start < end, end, seq_len - 1) + + + if k_scaler is not None: + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=False, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + ], + out_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + ], + )(start, end, line_end, pre_batch, pre_block, q, k, v, k_scaler, v_scaler) + else: + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=True, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + ], + out_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + ], + )(start, end, line_end, pre_batch, pre_block, q, k, v) + return out, (m[..., 0], l[..., 0]) + + +@functools.partial(jax.jit, static_argnames=['bk', 'mask', 'normalize_var', 'shard_axis']) +def ragged_mha( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + pre_batch: jax.Array, + pre_block: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + bk: int = 512, + mask_value : float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, + shard_axis: int = 1 +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi head attention. + Args: + q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. + k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + start: A i32[batch_size] jax.Array + end: A i32[batch_size] jax.Array + bk: An integer that is the sequence block size. + logit_cap: An optional float that caps logits via tanh. By default there is + no logit capping. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + out_dtype: An optional dtype for the output. If not provided, the output + dtype will be q's dtype. + Returns: + The output of attention([batch_size, num_heads, compute_dim, head_dim]), + along with the max logit ([batch_size, num_heads, compute_dim, 1]) and + softmax denominator ([batch_size, num_heads, compute_dim, 1]). + """ + mask_value = DEFAULT_MASK_VALUE + seqlen = q.shape[-2] + if k_scaler is None: + replicated_in_axes = 4 + replicated_inputs = (pre_batch, pre_block) + else: + replicated_in_axes = 6 + replicated_inputs = (k_scaler, v_scaler, pre_batch, pre_block) + + with jax.named_scope("ragged_mha_vmap"): + out, (m, l) = jax.vmap( + functools.partial( + ragged_mqa, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + #out_dtype=out_dtype, + ), + in_axes=(shard_axis, shard_axis, shard_axis, *([None]*replicated_in_axes)), + out_axes=shard_axis, + )(q, k, v, start, end, *replicated_inputs) + return out, (m, l) + + +def dense_attention(xq, keys, values, mask): + head_dim = xq.shape[-1] + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if mask is not None: + # if mask.shape != (1,1,16,16): + # breakpoint() + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) + + +def dense_attention_quantized( + xq: jax.Array, + keys: jax.Array, + values: jax.Array, + k_scaler = None, + v_scaler = None, + mask = None, +): + bsz, _, _, head_dim = xq.shape + + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = ( + torch.einsum("ikjl,ikml->ikjm", xq, keys) + / math.sqrt(head_dim) + * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + ) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) + + return output + + class AttentionKernel: def __init__(self, env): self.env = env - - def __call__(self, xq, xk, xv, mask, cache): + self.shard_axis = 0 if self.env.shard_on_batch else 1 + qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + others_pspec = self.env.partition_by_axis() + self.binded_ragged_mha = functools.partial(ragged_mha, bk=self.env.block_size, shard_axis=self.shard_axis) + self.binded_ragged_mha = shard_map(ragged_mha, in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), out_specs=(others_pspec, (others_pspec, others_pspec)), check_rep=False) + self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) + + def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -419,28 +726,18 @@ def __call__(self, xq, xk, xv, mask, cache): keys, values = cache.update(xk, xv) keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) + + with jax.named_scope("attn_qkv"): + if self.env.ragged_mha and seqlen == 1: + output, _ = torch_xla2.extra.call_jax(self.binded_ragged_mha, xq, keys, values, start, end, pre_batch, pre_block) + else: + output = dense_attention(xq, keys, values, mask) + if seqlen == 1: output = output[:, :, 0:1, :] # For XLA matmul performance boost # output = torch.matmul(scores, values) - shard_axis = 0 if self.env.shard_on_batch else 1 - self.env.apply_sharding(output, axis=shard_axis) + self.env.apply_sharding(output, axis=self.shard_axis) return output @@ -448,8 +745,14 @@ class Int8KVAttentionKernel: def __init__(self, env): self.env = env - - def __call__(self, xq, xk, xv, mask, cache): + self.shard_axis = 0 if self.env.shard_on_batch else 1 + qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + others_pspec = self.env.partition_by_axis() + self.binded_ragged_mha_quantized = functools.partial(ragged_mha, bk=self.env.block_size, shard_axis=self.shard_axis) + self.binded_ragged_mha_quantized = shard_map(self.binded_ragged_mha_quantized, in_specs=(*([qkv_pspec] * 3), *([others_pspec]*6)), out_specs=(others_pspec, (others_pspec, others_pspec)), check_rep=False) + self.binded_ragged_mha_quantized = jax.jit(self.binded_ragged_mha_quantized) + + def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -461,6 +764,7 @@ def __call__(self, xq, xk, xv, mask, cache): bsz, num_heads, seqlen, head_dim = xq.shape _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads + if seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) @@ -468,30 +772,17 @@ def __call__(self, xq, xk, xv, mask, cache): keys, values, k_scaler, v_scaler = cache.update(xk, xv) keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = ( - torch.einsum("ikjl,ikml->ikjm", xq, keys) - / math.sqrt(head_dim) - * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - ) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) + + with jax.named_scope("attn_qkv"): + if self.env.ragged_mha and seqlen == 1: + output = torch_xla2.extra.call_jax(self.binded_ragged_mha_quantized, xq, keys, values, start, end, pre_batch, pre_block, k_scaler, v_scaler) + else: + output, _ = dense_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask) + if seqlen == 1: output = output[:, :, 0:1, :] - # output = torch.matmul(scores, values) - shard_axis = 0 if self.env.shard_on_batch else 1 - self.env.apply_sharding(output, axis=shard_axis) + + self.env.apply_sharding(output, axis=self.shard_axis) return output @@ -566,6 +857,10 @@ def forward( freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], cache, + start, + end, + pre_batch, + pre_block, ): with jax.named_scope("attn_linear_before_cache"): bsz, seqlen = x.shape[0], x.shape[-2] @@ -593,6 +888,6 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - output = self.attention_kernel(xq, xk, xv, mask, cache) + output = self.attention_kernel(xq, xk, xv, mask, cache, start, end, pre_batch, pre_block) output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 7c692b22..1838bef6 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -109,10 +109,14 @@ def forward( freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], cache, + start, + end, + pre_batch, + pre_block, ): with jax.named_scope("Attention"): attn = self.attention.forward( - self.attention_norm(x), freqs_cis, mask, cache + self.attention_norm(x), freqs_cis, mask, cache, start, end, pre_batch, pre_block ) with jax.named_scope("ffn_norm"): h = x + attn @@ -180,9 +184,12 @@ def __init__( def forward( self, tokens: torch.Tensor, - input_pos: torch.Tensor, caches: List[Any], mask, + start, + input_pos, + pre_batch, + pre_block, ): with jax.named_scope("transformer_tok"): seqlen = tokens.shape[-1] @@ -196,9 +203,10 @@ def forward( assert len(caches) == len( self.layers ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" + end = None if start is None else (start + input_pos) % self.env.cache_len for layer, cache in zip(self.layers, caches): with jax.named_scope("TransformerBlock"): - h = layer(h, freqs_cis, mask, cache) + h = layer(h, freqs_cis, mask, cache, start, end, pre_batch, pre_block) with jax.named_scope("transformer_norm"): h = self.norm(h) diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 66695bf5..dca59f6c 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -44,6 +44,7 @@ def create_engine(): max_cache_length=FLAGS.max_cache_length, sharding_config=FLAGS.sharding_config, shard_on_batch=FLAGS.shard_on_batch, + ragged_mha=FLAGS.ragged_mha ) print("Initialize engine", time.perf_counter() - start) From d2bb5143120fc35811897f7546d4170b5202bf65 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 16 May 2024 23:47:35 +0000 Subject: [PATCH 02/41] Converts the attention output types the same as q. --- jetstream_pt/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 2112b213..0af3e42e 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -888,6 +888,6 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - output = self.attention_kernel(xq, xk, xv, mask, cache, start, end, pre_batch, pre_block) + output = self.attention_kernel(xq, xk, xv, mask, cache, start, end, pre_batch, pre_block).type_as(xq) output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) From 8482117abb94bfa6ee887ef54ca94758c1f71a6b Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 00:38:36 +0000 Subject: [PATCH 03/41] Fixes the typo for the ragged attention. --- jetstream_pt/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 0af3e42e..27bd44b7 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -582,7 +582,7 @@ def scaler_index_map(b, i, *_): return out, (m[..., 0], l[..., 0]) -@functools.partial(jax.jit, static_argnames=['bk', 'mask', 'normalize_var', 'shard_axis']) +@functools.partial(jax.jit, static_argnames=['bk', 'mask_value', 'normalize_var', 'shard_axis']) def ragged_mha( q: jax.Array, k: jax.Array, From 4585ab4fb8cf35178984e1f27c1d51ce48aecadd Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 00:41:58 +0000 Subject: [PATCH 04/41] Provides the default value for partition_by_axis. --- jetstream_pt/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 3e9e4be6..39631603 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -152,7 +152,7 @@ def apply_sharding(self, tensor, *, axis: int | None): # pylint: disable-next=all tensor._elem = jax.lax.with_sharding_constraint(tensor._elem, sharding_spec) - def partition_by_axis(self, axis): + def partition_by_axis(self, axis=None): """return sharding partition spc by axis, options are x, y, -1 or Noe""" if axis == -1 or axis is None: return jax.sharding.PartitionSpec() From 1498ba91a3500651b007f3eae4fb491e3be05680 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 00:46:49 +0000 Subject: [PATCH 05/41] Provides mesh to the shard_map. --- jetstream_pt/environment.py | 10 +++++----- jetstream_pt/layers.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 39631603..3a261012 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -108,14 +108,14 @@ def __init__(self, data: JetEngineEnvironmentData): num_of_partitions = jax.device_count() # make mesh etc. - self._mesh = jsharding.Mesh( + self.mesh = jsharding.Mesh( mesh_utils.create_device_mesh((num_of_partitions, 1)), axis_names=("x", "y"), ) - self.y_sharding = jsharding.NamedSharding(self._mesh, P(None, "x")) - self.x_sharding = jsharding.NamedSharding(self._mesh, P("x")) - self.replicated = jsharding.NamedSharding(self._mesh, P()) + self.y_sharding = jsharding.NamedSharding(self.mesh, P(None, "x")) + self.x_sharding = jsharding.NamedSharding(self.mesh, P("x")) + self.replicated = jsharding.NamedSharding(self.mesh, P()) if data.shard_on_batch: cache_sharding_axis = 0 @@ -163,7 +163,7 @@ def partition_by_axis(self, axis=None): def sharding_by_axis(self, axis): """return sharding partition spc by axis, options are x, y, -1 or Noe""" - return jsharding.NamedSharding(self._mesh, self.partition_by_axis(axis)) + return jsharding.NamedSharding(self.mesh, self.partition_by_axis(axis)) def make_caches_prefill(self): """Create kv caches for inference prefill""" diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 27bd44b7..d7fd7344 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -704,7 +704,7 @@ def __init__(self, env): qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.binded_ragged_mha = functools.partial(ragged_mha, bk=self.env.block_size, shard_axis=self.shard_axis) - self.binded_ragged_mha = shard_map(ragged_mha, in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), out_specs=(others_pspec, (others_pspec, others_pspec)), check_rep=False) + self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), out_specs=(others_pspec, (others_pspec, others_pspec)), check_rep=False) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): @@ -749,7 +749,7 @@ def __init__(self, env): qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.binded_ragged_mha_quantized = functools.partial(ragged_mha, bk=self.env.block_size, shard_axis=self.shard_axis) - self.binded_ragged_mha_quantized = shard_map(self.binded_ragged_mha_quantized, in_specs=(*([qkv_pspec] * 3), *([others_pspec]*6)), out_specs=(others_pspec, (others_pspec, others_pspec)), check_rep=False) + self.binded_ragged_mha_quantized = shard_map(self.binded_ragged_mha_quantized, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec]*6)), out_specs=(others_pspec, (others_pspec, others_pspec)), check_rep=False) self.binded_ragged_mha_quantized = jax.jit(self.binded_ragged_mha_quantized) def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): From 81bfaa68e301ba41a520ca1bd06d37be44a3c6d3 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 00:48:15 +0000 Subject: [PATCH 06/41] Fixes typo. --- jetstream_pt/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index d7fd7344..3034c891 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -775,9 +775,9 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output = torch_xla2.extra.call_jax(self.binded_ragged_mha_quantized, xq, keys, values, start, end, pre_batch, pre_block, k_scaler, v_scaler) + output, _ = torch_xla2.extra.call_jax(self.binded_ragged_mha_quantized, xq, keys, values, start, end, pre_batch, pre_block, k_scaler, v_scaler) else: - output, _ = dense_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask) + output= dense_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask) if seqlen == 1: output = output[:, :, 0:1, :] From 01d2eef0640d72ee58c22905867f338a41aeb93c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 00:53:22 +0000 Subject: [PATCH 07/41] Fixes typo, should be start instead of start_pos. --- jetstream_pt/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 91626664..c501e7de 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -277,7 +277,7 @@ def _insert_no_wrap( cond = jnp.logical_and(x <= decode_state.current_position, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) - start = decode_state.start_pos.at[slot].set(pos % self.env.cache_sequence_length) + start = decode_state.start.at[slot].set(pos % self.env.cache_sequence_length) input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) if not self.env.quant_config.enable_kv_quantization: @@ -373,7 +373,7 @@ def _insert_wrap( mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) - start = decode_state.start_pos.at[slot].set(start_insert) + start = decode_state.start.at[slot].set(start_insert) input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) old_caches = decode_state.caches From 560387916c562d50da29c56dcc5a99a848119014 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 01:21:36 +0000 Subject: [PATCH 08/41] Should use "//" instead of "/" to get int results. --- jetstream_pt/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index c501e7de..182f351d 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -127,8 +127,8 @@ def init_decode_state( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, scalers, - self.env.max_input_sequence_length / 4, - jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), + self.env.max_input_sequence_length // 4, + jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), # lens jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # start pos jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # input pos jnp.full( From 24882977b587c18dc6c9b61da16d19ad5ef937df Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 01:25:24 +0000 Subject: [PATCH 09/41] Use block size // 2 as the starting current position for better initial performance. Fix the typo that should use jax.lax.div instead of jnp.div --- jetstream_pt/engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 182f351d..4a8e11ee 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -127,7 +127,7 @@ def init_decode_state( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, scalers, - self.env.max_input_sequence_length // 4, + self.env.block_size // 2, jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), # lens jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # start pos jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # input pos @@ -471,7 +471,7 @@ def precompute_ragged_block_indices(self, decode_state: DecodeState): end = end.reshape((batch_size, 1)) am_last_batch = b == batch_size - 1 - last_good_block = jnp.where(start < end, jnp.div(end - 1, bk), jnp.div(self.env.cache_len -1, bk)) + last_good_block = jnp.where(start < end, jax.lax.div(end - 1, bk), jax.lax.div(self.env.cache_len -1, bk)) next_b = jnp.where(am_last_batch, b, b + 1) next_i = jnp.where(am_last_batch, last_good_block, 0) @@ -480,13 +480,13 @@ def precompute_ragged_block_indices(self, decode_state: DecodeState): def true_comp(b, i, bk, start, end, next_b, next_i): b_next = jnp.where(i * bk >= end, next_b, b) i_next = jnp.where(i * bk >= end, next_i, i) - i_next = jnp.where((i + 1) * bk <= start, jnp.div(start, bk), i_next) + i_next = jnp.where((i + 1) * bk <= start, jax.lax.div(start, bk), i_next) return b_next, i_next # start > end def false_comp(b, i, bk, start, end): b_next = b - i_next = jnp.where(jnp.logical_and(i * bk >= end, (i + 1) * bk <= start), jnp.div(start, bk), i) + i_next = jnp.where(jnp.logical_and(i * bk >= end, (i + 1) * bk <= start), jax.lax.div(start, bk), i) return b_next, i_next true_comp_b, true_comp_i = true_comp(b, i, bk, start, end, next_b, next_i) From f04b20ab69f798f5a02ed93ef3e19debf038aba6 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 01:44:44 +0000 Subject: [PATCH 10/41] Updates the run_interactive script to use the correct result token processing API from JetStream. --- run_interactive.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/run_interactive.py b/run_interactive.py index 77b3a702..c5811107 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -69,13 +69,8 @@ def main(argv): complete = np.zeros((1,), dtype=np.bool_) while True: decode_state, result_tokens = engine.generate(params, decode_state) - result_tokens = result_tokens.convert_to_numpy() - output, complete = token_utils.process_result_tokens( - tokenizer=tokenizer, - slot=slot, - slot_max_length=max_output_length, - result_tokens=result_tokens, - complete=complete, + output, complete = tokenizer.process_result_token( + tokenizer, slot, max_output_length, result_tokens, complete ) if complete[0]: break From 6aaf6d9e17864abdcd9497a738a7bbf2370ce5e3 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 01:47:36 +0000 Subject: [PATCH 11/41] Fix typo, should use token_utils.process_result_token. --- run_interactive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_interactive.py b/run_interactive.py index c5811107..4d8693c0 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -69,7 +69,7 @@ def main(argv): complete = np.zeros((1,), dtype=np.bool_) while True: decode_state, result_tokens = engine.generate(params, decode_state) - output, complete = tokenizer.process_result_token( + output, complete = token_utils.process_result_token( tokenizer, slot, max_output_length, result_tokens, complete ) if complete[0]: From cd84291ef5a65da16ef3b5243b8bdab386eb2ffd Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 01:49:25 +0000 Subject: [PATCH 12/41] Fix typo. --- run_interactive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_interactive.py b/run_interactive.py index 4d8693c0..34eded54 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -69,7 +69,7 @@ def main(argv): complete = np.zeros((1,), dtype=np.bool_) while True: decode_state, result_tokens = engine.generate(params, decode_state) - output, complete = token_utils.process_result_token( + output, complete = token_utils.process_result_tokens( tokenizer, slot, max_output_length, result_tokens, complete ) if complete[0]: From 53240bc92652c107c50e5f8dd33f042ec30e4979 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 01:53:44 +0000 Subject: [PATCH 13/41] Fixes the sampled tokens list. --- run_interactive.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/run_interactive.py b/run_interactive.py index 34eded54..d265995b 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -74,8 +74,12 @@ def main(argv): ) if complete[0]: break - token_ids = output[0].token_ids - sampled_tokens_list.extend(token_ids) + sampled_tokens_list = output[0] + # output_str = tokenizer.decode_str([token_id]) + # print(Fore.GREEN + output_str, end="", flush=True) + + # print(Style.RESET_ALL + "\n") + # print("---- Streaming decode finished.") print("---- All output tokens.") print(sampled_tokens_list) From ed368b590af57f4c56adac76858dde4b2e9de9d8 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 02:00:28 +0000 Subject: [PATCH 14/41] Use text_tokens_to_str to convert the output tokens. --- run_interactive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_interactive.py b/run_interactive.py index d265995b..517a439c 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -84,7 +84,7 @@ def main(argv): print("---- All output tokens.") print(sampled_tokens_list) print("---- All output text.") - print(tokenizer.decode(sampled_tokens_list)) + print(token_utils.text_tokens_to_str(sampled_tokens_list)) if profiling_output: jax.profiler.stop_trace() From 5264f11d0ee579508c0aebd4097ecb4c2b285707 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 23:12:18 +0000 Subject: [PATCH 15/41] Reshape the precomputed grid indices to 1D. Removes the dense_attention_quantized and use option to control if it's quantization or not. Use the new torch_xla2 API. --- jetstream_pt/engine.py | 1 + jetstream_pt/layers.py | 53 ++++++++++-------------------------------- 2 files changed, 13 insertions(+), 41 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 4a8e11ee..88621bcd 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -506,6 +506,7 @@ def generate( # fill mask first mask = decode_state.mask.at[:, decode_state.current_position].set(0) pre_batch, pre_block = self.precompute_ragged_block_indices(decode_state) + pre_batch, pre_block = pre_batch.reshape((-1)), pre_block.reshape((-1)) logits, new_caches, new_scales = self._call_model_generate( params, diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 3034c891..99935f22 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -643,18 +643,22 @@ def ragged_mha( return out, (m, l) -def dense_attention(xq, keys, values, mask): - head_dim = xq.shape[-1] +def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): + bsz, _, _, head_dim = xq.shape with jax.named_scope("attn_mat1"): ## Attention start # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if k_scaler: + scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) if mask is not None: # if mask.shape != (1,1,16,16): # breakpoint() scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) with jax.named_scope("attn_soft"): scores = F.softmax(scores.float(), dim=-1).type_as(xq) + if v_scaler: + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) with jax.named_scope("attn_mat2"): # output = torch.einsum( @@ -663,39 +667,6 @@ def dense_attention(xq, keys, values, mask): output = torch.einsum("ikjm,ikml->ikjl", scores, values) -def dense_attention_quantized( - xq: jax.Array, - keys: jax.Array, - values: jax.Array, - k_scaler = None, - v_scaler = None, - mask = None, -): - bsz, _, _, head_dim = xq.shape - - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = ( - torch.einsum("ikjl,ikml->ikjm", xq, keys) - / math.sqrt(head_dim) - * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - ) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - - return output - - class AttentionKernel: def __init__(self, env): @@ -704,7 +675,7 @@ def __init__(self, env): qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.binded_ragged_mha = functools.partial(ragged_mha, bk=self.env.block_size, shard_axis=self.shard_axis) - self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), out_specs=(others_pspec, (others_pspec, others_pspec)), check_rep=False) + self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), out_specs=(qkv_pspec, (others_pspec, others_pspec)), check_rep=False) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): @@ -729,9 +700,9 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.extra.call_jax(self.binded_ragged_mha, xq, keys, values, start, end, pre_batch, pre_block) + output, _ = torch_xla2.interop.call_jax(self.binded_ragged_mha, xq, keys, values, start, end, pre_batch, pre_block) else: - output = dense_attention(xq, keys, values, mask) + output = dense_attention(xq, keys, values, None, None, mask) if seqlen == 1: output = output[:, :, 0:1, :] @@ -749,7 +720,7 @@ def __init__(self, env): qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.binded_ragged_mha_quantized = functools.partial(ragged_mha, bk=self.env.block_size, shard_axis=self.shard_axis) - self.binded_ragged_mha_quantized = shard_map(self.binded_ragged_mha_quantized, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec]*6)), out_specs=(others_pspec, (others_pspec, others_pspec)), check_rep=False) + self.binded_ragged_mha_quantized = shard_map(self.binded_ragged_mha_quantized, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec]*6)), out_specs=(qkv_pspec, (others_pspec, others_pspec)), check_rep=False) self.binded_ragged_mha_quantized = jax.jit(self.binded_ragged_mha_quantized) def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): @@ -775,9 +746,9 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.extra.call_jax(self.binded_ragged_mha_quantized, xq, keys, values, start, end, pre_batch, pre_block, k_scaler, v_scaler) + output, _ = torch_xla2.interop.call_jax(self.binded_ragged_mha_quantized, xq, keys, values, start, end, pre_batch, pre_block, k_scaler, v_scaler) else: - output= dense_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask) + output= dense_attention(xq, keys, values, k_scaler, v_scaler, mask) if seqlen == 1: output = output[:, :, 0:1, :] From a4241d9bd579b338022f55b0b5d4a69ef2380e8c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 23:19:04 +0000 Subject: [PATCH 16/41] Should check if X is None instead of if X --- jetstream_pt/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 99935f22..576cf1dd 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -649,7 +649,7 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): ## Attention start # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if k_scaler: + if k_scaler is not None: scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) if mask is not None: # if mask.shape != (1,1,16,16): @@ -657,7 +657,7 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) with jax.named_scope("attn_soft"): scores = F.softmax(scores.float(), dim=-1).type_as(xq) - if v_scaler: + if v_scaler is not None: scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) with jax.named_scope("attn_mat2"): From 00a8fa0ff0b723ba56fc8ca2ea6ea59764e7481d Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 23:31:56 +0000 Subject: [PATCH 17/41] Fix the dense_attention not returning data. --- jetstream_pt/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 576cf1dd..2e9f30c5 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -665,7 +665,7 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): # "ikjm,ikml->ikjl", scores, values # ) # (bs, n_local_heads, seqlen, head_dim) output = torch.einsum("ikjm,ikml->ikjl", scores, values) - + return output class AttentionKernel: From 4a26aed4207e1e6606744cf0c7515df430e63e4c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 17 May 2024 23:54:23 +0000 Subject: [PATCH 18/41] Reshape the kv scaler to 3 dim for ragged attention. --- jetstream_pt/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 2e9f30c5..d7f82ac6 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -626,7 +626,7 @@ def ragged_mha( replicated_inputs = (pre_batch, pre_block) else: replicated_in_axes = 6 - replicated_inputs = (k_scaler, v_scaler, pre_batch, pre_block) + replicated_inputs = (jnp.squeeze(k_scaler, -1), jnp.squeeze(v_scaler, -1), pre_batch, pre_block) with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( From 7fdf340e3fa42818da0208d84cb07ced84984f81 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 20 May 2024 17:27:00 +0000 Subject: [PATCH 19/41] Cannot stop the input_pos counter from increasing since we are using a ring buffer. Will cause error. --- jetstream_pt/engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 88621bcd..806f2a17 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -549,8 +549,7 @@ def generate( (decode_state.current_position + 1) % self.env.cache_sequence_length, lens, decode_state.start, - # Stop the input_pos from increasing if it's 0, for better ragged attention performance - jnp.where(decode_state.input_pos == 0, 0, decode_state.input_pos + 1), + decode_state.input_pos + 1, mask, ) print( From 072164637f2a4d8b510c6a17cd2a3a788e60afd0 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 20 May 2024 21:25:33 +0000 Subject: [PATCH 20/41] Adds starting_position and profiling_prefill for better testing and benchmarking. --- jetstream_pt/engine.py | 4 +++- jetstream_pt/environment.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 806f2a17..d80243be 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -127,7 +127,7 @@ def init_decode_state( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, scalers, - self.env.block_size // 2, + self.env.starting_position, jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), # lens jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # start pos jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # input pos @@ -724,6 +724,7 @@ def create_pytorch_engine( sharding_config=None, shard_on_batch=False, ragged_mha=False, + starting_position=512, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -783,6 +784,7 @@ def create_pytorch_engine( sharding_config_path=sharding_config, shard_on_batch=shard_on_batch, ragged_mha=ragged_mha, + starting_position=starting_position, ) if shard_on_batch and sharding_config: diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 3a261012..9573e503 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -93,6 +93,8 @@ class JetEngineEnvironmentData: # The block size for the ragged attention. block_size: int = 512 + # Starting position + starting_position: int = 512 # pylint: disable-next=all class JetEngineEnvironment: @@ -103,7 +105,7 @@ def __init__(self, data: JetEngineEnvironmentData): self.cache_len = self._data.cache_sequence_length self.ragged_mha = self._data.ragged_mha self.block_size = self._data.block_size - + self.starting_position = self._data.starting_position P = jax.sharding.PartitionSpec num_of_partitions = jax.device_count() From 930eaa0288aaacf01ab76152015135ccb553ee50 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 20 May 2024 09:15:03 -0700 Subject: [PATCH 21/41] Move flags in scripts to a common function (#92) * refactor flags * clean up: * fix run_server * move common flags to global * format * update * udpate readme * update run_interactive --- run_interactive_multiple_host.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index dca59f6c..f59674fc 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -54,7 +54,7 @@ def create_engine(): # pylint: disable-next=all def main(argv): - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() engine.load_params() From 97c6435c49ff8c9b7dae6b3f87193a4e8f58b952 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 16 May 2024 23:41:16 +0000 Subject: [PATCH 22/41] Stable version of ragged attention. --- jetstream_pt/layers.py | 296 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 296 insertions(+) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index d7f82ac6..692fb133 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -582,6 +582,302 @@ def scaler_index_map(b, i, *_): return out, (m[..., 0], l[..., 0]) +@functools.partial(jax.jit, static_argnames=['bk', 'mask', 'normalize_var', 'shard_axis']) +def ragged_mha( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + pre_batch: jax.Array, + pre_block: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + bk: int = 512, + mask_value : float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, + shard_axis: int = 1 +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi head attention. + Args: + q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. + k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + start: A i32[batch_size] jax.Array + end: A i32[batch_size] jax.Array + bk: An integer that is the sequence block size. + logit_cap: An optional float that caps logits via tanh. By default there is + no logit capping. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + out_dtype: An optional dtype for the output. If not provided, the output + dtype will be q's dtype. + Returns: + The output of attention([batch_size, num_heads, compute_dim, head_dim]), + along with the max logit ([batch_size, num_heads, compute_dim, 1]) and + softmax denominator ([batch_size, num_heads, compute_dim, 1]). + """ + mask_value = DEFAULT_MASK_VALUE + seqlen = q.shape[-2] + if k_scaler is None: + replicated_in_axes = 4 + replicated_inputs = (pre_batch, pre_block) + else: + replicated_in_axes = 6 + replicated_inputs = (k_scaler, v_scaler, pre_batch, pre_block) + + with jax.named_scope("ragged_mha_vmap"): + out, (m, l) = jax.vmap( + functools.partial( + ragged_mqa, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + #out_dtype=out_dtype, + ), + in_axes=(shard_axis, shard_axis, shard_axis, *([None]*replicated_in_axes)), + out_axes=shard_axis, + )(q, k, v, start, end, *replicated_inputs) + return out, (m, l) + + +def dense_attention(xq, keys, values, mask): + head_dim = xq.shape[-1] + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if mask is not None: + # if mask.shape != (1,1,16,16): + # breakpoint() + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) + + +def dense_attention_quantized( + xq: jax.Array, + keys: jax.Array, + values: jax.Array, + k_scaler = None, + v_scaler = None, + mask = None, +): + bsz, _, _, head_dim = xq.shape + + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = ( + torch.einsum("ikjl,ikml->ikjm", xq, keys) + / math.sqrt(head_dim) + * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + ) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) + + return output + + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + +def ragged_flash_attention_kernel( + start_ref, + end_ref, + line_end_ref, + pre_b_ref, + pre_i_ref, + q_ref, + k_ref, + v_ref, + k_scaler_ref, + v_scaler_ref, + o_ref, + m_ref, + l_ref, + bk: int, + mask_value: float, + normalize_var: bool, + quantized: bool, +): + """Pallas kernel for flash attention.""" + with jax.named_scope("attention_kernel"): + b, i = pl.program_id(0), pl.program_id(1) + + @pl.when(i == 0) + def init(): + with jax.named_scope("init"): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + length = line_end_ref[b] + start = start_ref[b] + end = end_ref[b] + + @pl.when(jnp.logical_and(i * bk < length, start != end)) + def run(): + with jax.named_scope("run_qk"): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + if normalize_var: + qk = qk / jnp.sqrt(k.shape[-1]) + if quantized: + qk = qk * k_scaler_ref[...] + with jax.named_scope("run_mask"): + start = start_ref[b] + end = end_ref[b] + iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) + mask_start_lt_end = jnp.logical_and(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) + mask_start_gt_end = jnp.logical_or(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) + #mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) + mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) + + qk = qk + jnp.where(mask, 0.0, mask_value) + + with jax.named_scope("run_softmax"): + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + if quantized: + s_curr = s_curr * v_scaler_ref[...] + o_curr_times_l_curr = jnp.dot(s_curr, v) + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + +@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]) +def ragged_mqa( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + pre_batch = None, + pre_block = None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi query attention.""" + with jax.named_scope("ragged_mqa"): + batch_size, num_heads, head_dim = q.shape + seq_len = k.shape[1] + + def kv_index_map(b, i, start_ref, end_ref, line_end_ref, pre_batch_ref, pre_block_ref): + index = b * (seq_len // bk) + i + return pre_batch_ref[index], pre_block_ref[index], 0 + + def q_index_map(b, i, start_ref, end_ref, line_end_ref, pre_batch_ref, pre_block_ref): + index = b * (seq_len // bk) + i + return pre_batch_ref[index], 0, 0 + + def scaler_index_map(b, i, *_): + return b, 0, i + + line_end = jnp.where(start < end, end, seq_len - 1) + + + if k_scaler is not None: + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=False, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + ], + out_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + ], + )(start, end, line_end, pre_batch, pre_block, q, k, v, k_scaler, v_scaler) + else: + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=True, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + ], + out_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + ], + )(start, end, line_end, pre_batch, pre_block, q, k, v) + return out, (m[..., 0], l[..., 0]) + + @functools.partial(jax.jit, static_argnames=['bk', 'mask_value', 'normalize_var', 'shard_axis']) def ragged_mha( q: jax.Array, From 6be5ec322c3c0f8e09b3decc8fa2b00400e40153 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 21 May 2024 01:00:37 +0000 Subject: [PATCH 23/41] Fix the merge conflicts --- jetstream_pt/layers.py | 297 ----------------------------------------- 1 file changed, 297 deletions(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 692fb133..a67c997a 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -399,303 +399,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: .reshape(bs, n_kv_heads * n_rep, slen, head_dim) ) - -DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) - -def ragged_flash_attention_kernel( - start_ref, - end_ref, - line_end_ref, - pre_b_ref, - pre_i_ref, - q_ref, - k_ref, - v_ref, - k_scaler_ref, - v_scaler_ref, - o_ref, - m_ref, - l_ref, - bk: int, - mask_value: float, - normalize_var: bool, - quantized: bool, -): - """Pallas kernel for flash attention.""" - with jax.named_scope("attention_kernel"): - b, i = pl.program_id(0), pl.program_id(1) - - @pl.when(i == 0) - def init(): - with jax.named_scope("init"): - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - length = line_end_ref[b] - start = start_ref[b] - end = end_ref[b] - - @pl.when(jnp.logical_and(i * bk < length, start != end)) - def run(): - with jax.named_scope("run_qk"): - q = q_ref[...].astype(jnp.float32) - k = k_ref[...].astype(jnp.float32) - v = v_ref[...].astype(jnp.float32) - m_prev, l_prev = m_ref[...], l_ref[...] - - qk = jax.lax.dot_general( - q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 - ) - if normalize_var: - qk = qk / jnp.sqrt(k.shape[-1]) - if quantized: - qk = qk * k_scaler_ref[...] - with jax.named_scope("run_mask"): - start = start_ref[b] - end = end_ref[b] - iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) - mask_start_lt_end = jnp.logical_and(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) - mask_start_gt_end = jnp.logical_or(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) - #mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) - mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) - - qk = qk + jnp.where(mask, 0.0, mask_value) - - with jax.named_scope("run_softmax"): - m_curr = qk.max(axis=-1) - - s_curr = jnp.exp(qk - m_curr[..., None]) - - l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) - if quantized: - s_curr = s_curr * v_scaler_ref[...] - o_curr_times_l_curr = jnp.dot(s_curr, v) - m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) - m_next = jnp.maximum(m_prev, m_curr) - alpha = jnp.exp(m_prev - m_next) - beta = jnp.exp(m_curr - m_next) - l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) - - m_ref[...], l_ref[...] = m_next, l_next_safe - o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe - ).astype(o_ref.dtype) - -@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]) -def ragged_mqa( - q: jax.Array, - k: jax.Array, - v: jax.Array, - start: jax.Array, - end: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, - pre_batch = None, - pre_block = None, - bk: int = 512, - mask_value: float = DEFAULT_MASK_VALUE, - normalize_var: bool = True, -) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - """Ragged multi query attention.""" - with jax.named_scope("ragged_mqa"): - batch_size, num_heads, head_dim = q.shape - seq_len = k.shape[1] - - def kv_index_map(b, i, start_ref, end_ref, line_end_ref, pre_batch_ref, pre_block_ref): - index = b * (seq_len // bk) + i - return pre_batch_ref[index], pre_block_ref[index], 0 - - def q_index_map(b, i, start_ref, end_ref, line_end_ref, pre_batch_ref, pre_block_ref): - index = b * (seq_len // bk) + i - return pre_batch_ref[index], 0, 0 - - def scaler_index_map(b, i, *_): - return b, 0, i - - line_end = jnp.where(start < end, end, seq_len - 1) - - - if k_scaler is not None: - out, m, l = pl.pallas_call( - functools.partial( - ragged_flash_attention_kernel, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - quantized=False, - ), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=5, - in_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - ], - out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - ], - grid=(batch_size, seq_len // bk), - ), - compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), - out_shape=[ - q, - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - ], - )(start, end, line_end, pre_batch, pre_block, q, k, v, k_scaler, v_scaler) - else: - out, m, l = pl.pallas_call( - functools.partial( - ragged_flash_attention_kernel, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - quantized=True, - ), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=5, - in_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - ], - out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - ], - grid=(batch_size, seq_len // bk), - ), - compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), - out_shape=[ - q, - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - ], - )(start, end, line_end, pre_batch, pre_block, q, k, v) - return out, (m[..., 0], l[..., 0]) - - -@functools.partial(jax.jit, static_argnames=['bk', 'mask', 'normalize_var', 'shard_axis']) -def ragged_mha( - q: jax.Array, - k: jax.Array, - v: jax.Array, - start: jax.Array, - end: jax.Array, - pre_batch: jax.Array, - pre_block: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, - bk: int = 512, - mask_value : float = DEFAULT_MASK_VALUE, - normalize_var: bool = True, - shard_axis: int = 1 -) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - """Ragged multi head attention. - Args: - q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. - k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or - PartitionQuantizedTensor. - v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or - PartitionQuantizedTensor. - start: A i32[batch_size] jax.Array - end: A i32[batch_size] jax.Array - bk: An integer that is the sequence block size. - logit_cap: An optional float that caps logits via tanh. By default there is - no logit capping. - mask_value: The value used for padding in attention. By default it is a very - negative floating point number. - out_dtype: An optional dtype for the output. If not provided, the output - dtype will be q's dtype. - Returns: - The output of attention([batch_size, num_heads, compute_dim, head_dim]), - along with the max logit ([batch_size, num_heads, compute_dim, 1]) and - softmax denominator ([batch_size, num_heads, compute_dim, 1]). - """ - mask_value = DEFAULT_MASK_VALUE - seqlen = q.shape[-2] - if k_scaler is None: - replicated_in_axes = 4 - replicated_inputs = (pre_batch, pre_block) - else: - replicated_in_axes = 6 - replicated_inputs = (k_scaler, v_scaler, pre_batch, pre_block) - - with jax.named_scope("ragged_mha_vmap"): - out, (m, l) = jax.vmap( - functools.partial( - ragged_mqa, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - #out_dtype=out_dtype, - ), - in_axes=(shard_axis, shard_axis, shard_axis, *([None]*replicated_in_axes)), - out_axes=shard_axis, - )(q, k, v, start, end, *replicated_inputs) - return out, (m, l) - - -def dense_attention(xq, keys, values, mask): - head_dim = xq.shape[-1] - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - - -def dense_attention_quantized( - xq: jax.Array, - keys: jax.Array, - values: jax.Array, - k_scaler = None, - v_scaler = None, - mask = None, -): - bsz, _, _, head_dim = xq.shape - - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = ( - torch.einsum("ikjl,ikml->ikjm", xq, keys) - / math.sqrt(head_dim) - * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - ) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - - return output - - DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) def ragged_flash_attention_kernel( From 6ae0f9d1f7cb451a6c08c8b7109bd2797ea1ae49 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 21 May 2024 01:16:34 +0000 Subject: [PATCH 24/41] Fixes the missing pieces after merging conflicts. Adds couple of new flags for debugging and performance tuning. --- jetstream_pt/config.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index a4e391c3..ae3dfe75 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -58,7 +58,26 @@ lambda value: value in _VALID_QUANTIZATION_TYPE, f"quantize_type is invalid, supported quantization types are {_VALID_QUANTIZATION_TYPE}", ) - +flags.DEFINE_bool( + "profiling_prefill", + False, + "Whether to profile the prefill, " + "if set to false, profile generate function only", + required=False +) +flags.DEFINE_bool( + "ragged_mha", + False, + "Whether to enable Ragged multi head attention", + required=False +) +flags.DEFINE_integer( + "starting_position", + 512, + "The starting position of decoding, " + "for performance tuning and debugging only", + required=False +) def create_quantization_config_from_flags(): """Create Quantization Config from cmd flags""" @@ -112,6 +131,8 @@ def create_engine_from_config_flags(): max_cache_length=FLAGS.max_cache_length, sharding_config=sharding_file_name, shard_on_batch=FLAGS.shard_on_batch, + ragged_mha=FLAGS.ragged_mha, + starting_position=FLAGS.starting_position, ) print("Initialize engine", time.perf_counter() - start) From 212aa8e130e1cca16e0d7c1ba5cd7b0cb8aca30a Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 21 May 2024 04:08:26 +0000 Subject: [PATCH 25/41] Integrates ragged attention to Gemma too. --- jetstream_pt/engine.py | 14 +++---- jetstream_pt/layers.py | 38 +++++++++---------- jetstream_pt/third_party/gemma/model.py | 35 ++++++++++++++++- .../third_party/llama/model_exportable.py | 22 ++++++++--- 4 files changed, 75 insertions(+), 34 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index d80243be..648a0553 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -149,8 +149,8 @@ def _call_model_generate( mask, start, input_pos, - pre_batch, - pre_block, + ragged_batch_index, + ragged_block_index, ): if self.env.quant_config.enable_kv_quantization: caches_obj = [ @@ -168,7 +168,7 @@ def _call_model_generate( ] mask = jnp.expand_dims(mask, (1, 2)) - args = (tokens, caches_obj, mask, start, input_pos, pre_batch, pre_block) + args = (tokens, caches_obj, mask, start, input_pos, ragged_batch_index, ragged_block_index) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: with torchjax.jax_mode: @@ -505,8 +505,8 @@ def generate( # fill mask first mask = decode_state.mask.at[:, decode_state.current_position].set(0) - pre_batch, pre_block = self.precompute_ragged_block_indices(decode_state) - pre_batch, pre_block = pre_batch.reshape((-1)), pre_block.reshape((-1)) + ragged_batch_index, ragged_block_index = self.precompute_ragged_block_indices(decode_state) + ragged_batch_index, ragged_block_index = ragged_batch_index.reshape((-1)), ragged_block_index.reshape((-1)) logits, new_caches, new_scales = self._call_model_generate( params, @@ -517,8 +517,8 @@ def generate( mask, decode_state.start, decode_state.input_pos, - pre_batch, - pre_block, + ragged_batch_index, + ragged_block_index, ) next_token = self._sampling(logits, self.env.batch_size) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index a67c997a..acf9a25c 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -491,8 +491,8 @@ def ragged_mqa( end: jax.Array, k_scaler: jax.Array | None = None, v_scaler: jax.Array | None = None, - pre_batch = None, - pre_block = None, + ragged_batch_index = None, + ragged_block_index = None, bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, @@ -502,13 +502,13 @@ def ragged_mqa( batch_size, num_heads, head_dim = q.shape seq_len = k.shape[1] - def kv_index_map(b, i, start_ref, end_ref, line_end_ref, pre_batch_ref, pre_block_ref): + def kv_index_map(b, i, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): index = b * (seq_len // bk) + i - return pre_batch_ref[index], pre_block_ref[index], 0 + return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 - def q_index_map(b, i, start_ref, end_ref, line_end_ref, pre_batch_ref, pre_block_ref): + def q_index_map(b, i, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): index = b * (seq_len // bk) + i - return pre_batch_ref[index], 0, 0 + return ragged_batch_index_ref[index], 0, 0 def scaler_index_map(b, i, *_): return b, 0, i @@ -547,7 +547,7 @@ def scaler_index_map(b, i, *_): jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), ], - )(start, end, line_end, pre_batch, pre_block, q, k, v, k_scaler, v_scaler) + )(start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v, k_scaler, v_scaler) else: out, m, l = pl.pallas_call( functools.partial( @@ -577,7 +577,7 @@ def scaler_index_map(b, i, *_): jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), ], - )(start, end, line_end, pre_batch, pre_block, q, k, v) + )(start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v) return out, (m[..., 0], l[..., 0]) @@ -588,8 +588,8 @@ def ragged_mha( v: jax.Array, start: jax.Array, end: jax.Array, - pre_batch: jax.Array, - pre_block: jax.Array, + ragged_batch_index: jax.Array, + ragged_block_index: jax.Array, k_scaler: jax.Array | None = None, v_scaler: jax.Array | None = None, bk: int = 512, @@ -622,10 +622,10 @@ def ragged_mha( seqlen = q.shape[-2] if k_scaler is None: replicated_in_axes = 4 - replicated_inputs = (pre_batch, pre_block) + replicated_inputs = (ragged_batch_index, ragged_block_index) else: replicated_in_axes = 6 - replicated_inputs = (jnp.squeeze(k_scaler, -1), jnp.squeeze(v_scaler, -1), pre_batch, pre_block) + replicated_inputs = (jnp.squeeze(k_scaler, -1), jnp.squeeze(v_scaler, -1), ragged_batch_index, ragged_block_index) with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( @@ -677,7 +677,7 @@ def __init__(self, env): self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), out_specs=(qkv_pspec, (others_pspec, others_pspec)), check_rep=False) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) - def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): + def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -699,7 +699,7 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax(self.binded_ragged_mha, xq, keys, values, start, end, pre_batch, pre_block) + output, _ = torch_xla2.interop.call_jax(self.binded_ragged_mha, xq, keys, values, start, end, ragged_batch_index, ragged_block_index) else: output = dense_attention(xq, keys, values, None, None, mask) @@ -722,7 +722,7 @@ def __init__(self, env): self.binded_ragged_mha_quantized = shard_map(self.binded_ragged_mha_quantized, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec]*6)), out_specs=(qkv_pspec, (others_pspec, others_pspec)), check_rep=False) self.binded_ragged_mha_quantized = jax.jit(self.binded_ragged_mha_quantized) - def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): + def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -745,7 +745,7 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, pre_batch, pre_block): with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax(self.binded_ragged_mha_quantized, xq, keys, values, start, end, pre_batch, pre_block, k_scaler, v_scaler) + output, _ = torch_xla2.interop.call_jax(self.binded_ragged_mha_quantized, xq, keys, values, start, end, ragged_batch_index, ragged_block_index, k_scaler, v_scaler) else: output= dense_attention(xq, keys, values, k_scaler, v_scaler, mask) @@ -829,8 +829,8 @@ def forward( cache, start, end, - pre_batch, - pre_block, + ragged_batch_index, + ragged_block_index, ): with jax.named_scope("attn_linear_before_cache"): bsz, seqlen = x.shape[0], x.shape[-2] @@ -858,6 +858,6 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - output = self.attention_kernel(xq, xk, xv, mask, cache, start, end, pre_batch, pre_block).type_as(xq) + output = self.attention_kernel(xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index).type_as(xq) output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 8cd87b13..145053f3 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -135,6 +135,10 @@ def forward( freqs_cis, mask, cache, + start, + end, + ragged_batch_index, + ragged_block_index, ) -> torch.Tensor: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 @@ -164,7 +168,7 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - output = self.attention_kernel(xq, xk, xv, mask, cache) + output = self.attention_kernel(xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index) # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) @@ -264,6 +268,10 @@ def forward( freqs_cis: torch.Tensor, cache: Any, mask: torch.Tensor, + start: torch.Tensor, + end: torch.Tensor, + ragged_batch_index: torch.Tensor, + ragged_block_index: torch.Tensor, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -273,6 +281,10 @@ def forward( freqs_cis=freqs_cis, mask=mask, cache=cache, + start=start, + end=end, + ragged_batch_index=ragged_batch_index, + ragged_block_index=ragged_block_index, ) hidden_states = residual + hidden_states @@ -318,10 +330,23 @@ def __init__(self, config: gemma_config.GemmaConfig, env): def forward( self, tokens: torch.Tensor, - input_pos: torch.Tensor, caches: List[Any], mask, + start, + input_pos: torch.Tensor, + ragged_batch_index, + ragged_block_index, ): + """ + tokens: the input token for decoding + caches: kv caches + mask: causal mask to filter the attention results + start: the starting position for each slot + input_pos: the decoding position relative to the start, which is the length of the decoding results + ragged_batch_index: precomputed batch index for ragged attention + ragged_block_index: precomputed block index for ragged attention + """ + with jax.named_scope("transformer_freq"): bsz, seqlen = tokens.shape freqs_cis = self.freqs_cis[input_pos] @@ -330,6 +355,8 @@ def forward( hidden_states = self.embedder(tokens) hidden_states = hidden_states * (self.config.hidden_size**0.5) + end = None if start is None else (start + input_pos) % self.env.cache_len + for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer( @@ -337,6 +364,10 @@ def forward( freqs_cis=freqs_cis, cache=caches[i], mask=mask, + start=start, + end=end, + ragged_batch_index=ragged_batch_index, + ragged_block_index=ragged_block_index, ) hidden_states = self.norm(hidden_states) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 1838bef6..2f6cfdd8 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -111,12 +111,12 @@ def forward( cache, start, end, - pre_batch, - pre_block, + ragged_batch_index, + ragged_block_index, ): with jax.named_scope("Attention"): attn = self.attention.forward( - self.attention_norm(x), freqs_cis, mask, cache, start, end, pre_batch, pre_block + self.attention_norm(x), freqs_cis, mask, cache, start, end, ragged_batch_index, ragged_block_index ) with jax.named_scope("ffn_norm"): h = x + attn @@ -188,9 +188,19 @@ def forward( mask, start, input_pos, - pre_batch, - pre_block, + ragged_batch_index, + ragged_block_index, ): + """ + tokens: the input token for decoding + caches: kv caches + mask: causal mask to filter the attention results + start: the starting position for each slot + input_pos: the decoding position relative to the start, which is the length of the decoding results + ragged_batch_index: precomputed batch index for ragged attention + ragged_block_index: precomputed block index for ragged attention + """ + with jax.named_scope("transformer_tok"): seqlen = tokens.shape[-1] h = self.tok_embeddings(tokens) @@ -206,7 +216,7 @@ def forward( end = None if start is None else (start + input_pos) % self.env.cache_len for layer, cache in zip(self.layers, caches): with jax.named_scope("TransformerBlock"): - h = layer(h, freqs_cis, mask, cache, start, end, pre_batch, pre_block) + h = layer(h, freqs_cis, mask, cache, start, end, ragged_batch_index, ragged_block_index) with jax.named_scope("transformer_norm"): h = self.norm(h) From ddb32e084755fc3c2d9d9e241e824f7cbd723dba Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 21 May 2024 04:20:49 +0000 Subject: [PATCH 26/41] Somehow have some local changes to run_interactive, reverting them to align with main. --- run_interactive.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/run_interactive.py b/run_interactive.py index 517a439c..ebe97c15 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -69,22 +69,23 @@ def main(argv): complete = np.zeros((1,), dtype=np.bool_) while True: decode_state, result_tokens = engine.generate(params, decode_state) + result_tokens = result_tokens.convert_to_numpy() output, complete = token_utils.process_result_tokens( - tokenizer, slot, max_output_length, result_tokens, complete + tokenizer=tokenizer, + slot=slot, + slot_max_length=max_output_length, + result_tokens=result_tokens, + complete=complete, ) if complete[0]: break - sampled_tokens_list = output[0] - # output_str = tokenizer.decode_str([token_id]) - # print(Fore.GREEN + output_str, end="", flush=True) - - # print(Style.RESET_ALL + "\n") - # print("---- Streaming decode finished.") + token_ids = output[0].token_ids + sampled_tokens_list.extend(token_ids) print("---- All output tokens.") print(sampled_tokens_list) print("---- All output text.") - print(token_utils.text_tokens_to_str(sampled_tokens_list)) + print(tokenizer.decode(sampled_tokens_list)) if profiling_output: jax.profiler.stop_trace() @@ -92,4 +93,4 @@ def main(argv): if __name__ == "__main__": os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - app.run(main) + app.run(main) \ No newline at end of file From fb680259b16a6b58537567efef84c49b9ece8f49 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 21 May 2024 06:39:10 +0000 Subject: [PATCH 27/41] Set the default value for the newly added parameters. --- jetstream_pt/third_party/gemma/model.py | 24 +++++++++---------- .../third_party/llama/model_exportable.py | 16 ++++++------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 145053f3..d58805d7 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -135,10 +135,10 @@ def forward( freqs_cis, mask, cache, - start, - end, - ragged_batch_index, - ragged_block_index, + start = None, + end = None, + ragged_batch_index = None, + ragged_block_index = None, ) -> torch.Tensor: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 @@ -268,10 +268,10 @@ def forward( freqs_cis: torch.Tensor, cache: Any, mask: torch.Tensor, - start: torch.Tensor, - end: torch.Tensor, - ragged_batch_index: torch.Tensor, - ragged_block_index: torch.Tensor, + start: torch.Tensor | None = None, + end: torch.Tensor | None = None, + ragged_batch_index: torch.Tensor | None = None, + ragged_block_index: torch.Tensor | None = None, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -332,10 +332,10 @@ def forward( tokens: torch.Tensor, caches: List[Any], mask, - start, - input_pos: torch.Tensor, - ragged_batch_index, - ragged_block_index, + start = None, + input_pos = None, + ragged_batch_index = None, + ragged_block_index = None, ): """ tokens: the input token for decoding diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 2f6cfdd8..5a430fec 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -109,10 +109,10 @@ def forward( freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], cache, - start, - end, - ragged_batch_index, - ragged_block_index, + start = None, + end = None, + ragged_batch_index = None, + ragged_block_index = None, ): with jax.named_scope("Attention"): attn = self.attention.forward( @@ -186,10 +186,10 @@ def forward( tokens: torch.Tensor, caches: List[Any], mask, - start, - input_pos, - ragged_batch_index, - ragged_block_index, + start = None, + input_pos = None, + ragged_batch_index = None, + ragged_block_index = None, ): """ tokens: the input token for decoding From 2def37c03656f41121d26b66ee0e75bf322b3f69 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 22 May 2024 00:33:49 +0000 Subject: [PATCH 28/41] Adds more descriptions to the ragged attention index precompuation function. --- jetstream_pt/engine.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 648a0553..a8b2d095 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -458,12 +458,18 @@ def insert( ) def precompute_ragged_block_indices(self, decode_state: DecodeState): + """Precompute the ragged attention block indices. Ragged attention iterates the grid + and relies on the computed grid index to skip the unnecessary blocks. The basic idea + is to use input_pos, which is the length of each slot to determine if we should + work on the next block of the slot or move to the next slot. """ start = decode_state.start end = (start + decode_state.input_pos) % self.env.cache_len batch_size = start.shape[0] bk = self.env.block_size + # The batch index b = jnp.arange(batch_size).reshape((batch_size, 1)) num_bk = self.env.cache_len // self.env.block_size + # The block index i = jnp.arange(num_bk).reshape((1, num_bk)) i = jnp.broadcast_to(i, (batch_size, num_bk)) @@ -476,14 +482,14 @@ def precompute_ragged_block_indices(self, decode_state: DecodeState): next_b = jnp.where(am_last_batch, b, b + 1) next_i = jnp.where(am_last_batch, last_good_block, 0) - # start < end + # start < end, continue work on the block is there is overlap with the [start, end) def true_comp(b, i, bk, start, end, next_b, next_i): b_next = jnp.where(i * bk >= end, next_b, b) i_next = jnp.where(i * bk >= end, next_i, i) i_next = jnp.where((i + 1) * bk <= start, jax.lax.div(start, bk), i_next) return b_next, i_next - # start > end + # start > end, continue work on the block is there is no overlap with [end, start) def false_comp(b, i, bk, start, end): b_next = b i_next = jnp.where(jnp.logical_and(i * bk >= end, (i + 1) * bk <= start), jax.lax.div(start, bk), i) From 268c407fe69f593284873c5a102a864df66345ef Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 22 May 2024 00:41:04 +0000 Subject: [PATCH 29/41] Merges the quantized ragged attention kernel with the non quantized version. --- jetstream_pt/layers.py | 65 ++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index acf9a25c..25c3b532 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -515,59 +515,36 @@ def scaler_index_map(b, i, *_): line_end = jnp.where(start < end, end, seq_len - 1) - + in_specs = [ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + ] + inputs = (start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v) + quantized = False if k_scaler is not None: - out, m, l = pl.pallas_call( - functools.partial( - ragged_flash_attention_kernel, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - quantized=False, - ), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=5, - in_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - ], - out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - ], - grid=(batch_size, seq_len // bk), - ), - compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), - out_shape=[ - q, - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - ], - )(start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v, k_scaler, v_scaler) - else: - out, m, l = pl.pallas_call( + in_specs = in_specs + [ + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + ] + inputs = inputs + (k_scaler, v_scaler) + quantized = True + + out, m, l = pl.pallas_call( functools.partial( ragged_flash_attention_kernel, bk=bk, mask_value=mask_value, normalize_var=normalize_var, - quantized=True, + quantized=quantized, ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=5, - in_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - ], + in_specs=in_specs, out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), ], grid=(batch_size, seq_len // bk), ), @@ -577,7 +554,7 @@ def scaler_index_map(b, i, *_): jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), ], - )(start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v) + )(*inputs) return out, (m[..., 0], l[..., 0]) From 8fa8fcb94d1989b33d9d486e1de5f155a41eaaf3 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 22 May 2024 01:02:24 +0000 Subject: [PATCH 30/41] Moves the attention calculation to attention.py for better code structure. --- jetstream_pt/attention_kernel.py | 261 ++++++++++++++++++++++++++++ jetstream_pt/layers.py | 284 +++---------------------------- 2 files changed, 285 insertions(+), 260 deletions(-) create mode 100644 jetstream_pt/attention_kernel.py diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py new file mode 100644 index 00000000..d81789c7 --- /dev/null +++ b/jetstream_pt/attention_kernel.py @@ -0,0 +1,261 @@ +import jax +import jax.numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.shard_map import shard_map + +import numpy as np + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + +def ragged_flash_attention_kernel( + start_ref, + end_ref, + line_end_ref, + pre_b_ref, + pre_i_ref, + q_ref, + k_ref, + v_ref, + k_scaler_ref, + v_scaler_ref, + o_ref, + m_ref, + l_ref, + bk: int, + mask_value: float, + normalize_var: bool, + quantized: bool, +): + """Pallas kernel for flash attention.""" + with jax.named_scope("attention_kernel"): + b, i = pl.program_id(0), pl.program_id(1) + + @pl.when(i == 0) + def init(): + with jax.named_scope("init"): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + length = line_end_ref[b] + start = start_ref[b] + end = end_ref[b] + + @pl.when(jnp.logical_and(i * bk < length, start != end)) + def run(): + with jax.named_scope("run_qk"): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + if normalize_var: + qk = qk / jnp.sqrt(k.shape[-1]) + if quantized: + qk = qk * k_scaler_ref[...] + with jax.named_scope("run_mask"): + start = start_ref[b] + end = end_ref[b] + iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) + mask_start_lt_end = jnp.logical_and(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) + mask_start_gt_end = jnp.logical_or(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) + #mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) + mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) + + qk = qk + jnp.where(mask, 0.0, mask_value) + + with jax.named_scope("run_softmax"): + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + if quantized: + s_curr = s_curr * v_scaler_ref[...] + o_curr_times_l_curr = jnp.dot(s_curr, v) + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + +@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]) +def ragged_mqa( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + ragged_batch_index = None, + ragged_block_index = None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi query attention.""" + with jax.named_scope("ragged_mqa"): + batch_size, num_heads, head_dim = q.shape + seq_len = k.shape[1] + + def kv_index_map(b, i, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): + index = b * (seq_len // bk) + i + return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 + + def q_index_map(b, i, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): + index = b * (seq_len // bk) + i + return ragged_batch_index_ref[index], 0, 0 + + def scaler_index_map(b, i, *_): + return b, 0, i + + line_end = jnp.where(start < end, end, seq_len - 1) + + in_specs = [ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + ] + inputs = (start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v) + quantized = False + if k_scaler is not None: + in_specs = in_specs + [ + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + ] + inputs = inputs + (k_scaler, v_scaler) + quantized = True + + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=quantized, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + ], + )(*inputs) + return out, (m[..., 0], l[..., 0]) + + +@functools.partial(jax.jit, static_argnames=['bk', 'mask_value', 'normalize_var', 'shard_axis']) +def ragged_mha( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + ragged_batch_index: jax.Array, + ragged_block_index: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + bk: int = 512, + mask_value : float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, + shard_axis: int = 1 +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi head attention. + Args: + q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. + k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + start: A i32[batch_size] jax.Array + end: A i32[batch_size] jax.Array + bk: An integer that is the sequence block size. + logit_cap: An optional float that caps logits via tanh. By default there is + no logit capping. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + out_dtype: An optional dtype for the output. If not provided, the output + dtype will be q's dtype. + Returns: + The output of attention([batch_size, num_heads, compute_dim, head_dim]), + along with the max logit ([batch_size, num_heads, compute_dim, 1]) and + softmax denominator ([batch_size, num_heads, compute_dim, 1]). + """ + mask_value = DEFAULT_MASK_VALUE + seqlen = q.shape[-2] + if k_scaler is None: + replicated_in_axes = 4 + replicated_inputs = (ragged_batch_index, ragged_block_index) + else: + replicated_in_axes = 6 + replicated_inputs = (jnp.squeeze(k_scaler, -1), jnp.squeeze(v_scaler, -1), ragged_batch_index, ragged_block_index) + + with jax.named_scope("ragged_mha_vmap"): + out, (m, l) = jax.vmap( + functools.partial( + ragged_mqa, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + #out_dtype=out_dtype, + ), + in_axes=(shard_axis, shard_axis, shard_axis, *([None]*replicated_in_axes)), + out_axes=shard_axis, + )(q, k, v, start, end, *replicated_inputs) + return out, (m, l) + + +def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): + bsz, _, _, head_dim = xq.shape + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if k_scaler is not None: + scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + if mask is not None: + # if mask.shape != (1,1,16,16): + # breakpoint() + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + if v_scaler is not None: + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) + return output + +class RaggedAttentionKernel: + + def __init(self, env, input_specs, output_specs, sharding_axis): + self.binded_ragged_mha = functools.partial(ragged_mha, bk=env.block_size, shard_axis=sharding_axis) + self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, input_specs, output_specs, check_rep=False) + self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) + + def __call__(self, xq, keys, values, start, end, ragged_batch_index, ragged_block_index): + return self.binded_ragged_mha(xq, keys, values, start, end, ragged_batch_index, ragged_block_index) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 25c3b532..274e3011 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -20,10 +20,8 @@ import functools import jax +from . import attention_kernel as ak import jax.numpy as jnp -from jax.experimental import pallas as pl -from jax.experimental.pallas import tpu as pltpu -from jax.experimental.shard_map import shard_map import torch import torch.nn.functional as F import torch_xla2 @@ -35,6 +33,7 @@ quantize_tensor, ) from torch import nn +from . import attention_kernel as ak def _calc_cosine_dist(x, y): @@ -399,249 +398,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: .reshape(bs, n_kv_heads * n_rep, slen, head_dim) ) -DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) - -def ragged_flash_attention_kernel( - start_ref, - end_ref, - line_end_ref, - pre_b_ref, - pre_i_ref, - q_ref, - k_ref, - v_ref, - k_scaler_ref, - v_scaler_ref, - o_ref, - m_ref, - l_ref, - bk: int, - mask_value: float, - normalize_var: bool, - quantized: bool, -): - """Pallas kernel for flash attention.""" - with jax.named_scope("attention_kernel"): - b, i = pl.program_id(0), pl.program_id(1) - - @pl.when(i == 0) - def init(): - with jax.named_scope("init"): - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - length = line_end_ref[b] - start = start_ref[b] - end = end_ref[b] - - @pl.when(jnp.logical_and(i * bk < length, start != end)) - def run(): - with jax.named_scope("run_qk"): - q = q_ref[...].astype(jnp.float32) - k = k_ref[...].astype(jnp.float32) - v = v_ref[...].astype(jnp.float32) - m_prev, l_prev = m_ref[...], l_ref[...] - - qk = jax.lax.dot_general( - q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 - ) - if normalize_var: - qk = qk / jnp.sqrt(k.shape[-1]) - if quantized: - qk = qk * k_scaler_ref[...] - with jax.named_scope("run_mask"): - start = start_ref[b] - end = end_ref[b] - iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) - mask_start_lt_end = jnp.logical_and(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) - mask_start_gt_end = jnp.logical_or(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) - #mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) - mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) - - qk = qk + jnp.where(mask, 0.0, mask_value) - - with jax.named_scope("run_softmax"): - m_curr = qk.max(axis=-1) - - s_curr = jnp.exp(qk - m_curr[..., None]) - - l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) - if quantized: - s_curr = s_curr * v_scaler_ref[...] - o_curr_times_l_curr = jnp.dot(s_curr, v) - m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) - m_next = jnp.maximum(m_prev, m_curr) - alpha = jnp.exp(m_prev - m_next) - beta = jnp.exp(m_curr - m_next) - l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) - - m_ref[...], l_ref[...] = m_next, l_next_safe - o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe - ).astype(o_ref.dtype) - -@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]) -def ragged_mqa( - q: jax.Array, - k: jax.Array, - v: jax.Array, - start: jax.Array, - end: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, - ragged_batch_index = None, - ragged_block_index = None, - bk: int = 512, - mask_value: float = DEFAULT_MASK_VALUE, - normalize_var: bool = True, -) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - """Ragged multi query attention.""" - with jax.named_scope("ragged_mqa"): - batch_size, num_heads, head_dim = q.shape - seq_len = k.shape[1] - - def kv_index_map(b, i, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): - index = b * (seq_len // bk) + i - return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 - - def q_index_map(b, i, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): - index = b * (seq_len // bk) + i - return ragged_batch_index_ref[index], 0, 0 - - def scaler_index_map(b, i, *_): - return b, 0, i - - line_end = jnp.where(start < end, end, seq_len - 1) - - in_specs = [ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - ] - inputs = (start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v) - quantized = False - if k_scaler is not None: - in_specs = in_specs + [ - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - ] - inputs = inputs + (k_scaler, v_scaler) - quantized = True - - out, m, l = pl.pallas_call( - functools.partial( - ragged_flash_attention_kernel, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - quantized=quantized, - ), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=5, - in_specs=in_specs, - out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - ], - grid=(batch_size, seq_len // bk), - ), - compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), - out_shape=[ - q, - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - ], - )(*inputs) - return out, (m[..., 0], l[..., 0]) - - -@functools.partial(jax.jit, static_argnames=['bk', 'mask_value', 'normalize_var', 'shard_axis']) -def ragged_mha( - q: jax.Array, - k: jax.Array, - v: jax.Array, - start: jax.Array, - end: jax.Array, - ragged_batch_index: jax.Array, - ragged_block_index: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, - bk: int = 512, - mask_value : float = DEFAULT_MASK_VALUE, - normalize_var: bool = True, - shard_axis: int = 1 -) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - """Ragged multi head attention. - Args: - q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. - k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or - PartitionQuantizedTensor. - v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or - PartitionQuantizedTensor. - start: A i32[batch_size] jax.Array - end: A i32[batch_size] jax.Array - bk: An integer that is the sequence block size. - logit_cap: An optional float that caps logits via tanh. By default there is - no logit capping. - mask_value: The value used for padding in attention. By default it is a very - negative floating point number. - out_dtype: An optional dtype for the output. If not provided, the output - dtype will be q's dtype. - Returns: - The output of attention([batch_size, num_heads, compute_dim, head_dim]), - along with the max logit ([batch_size, num_heads, compute_dim, 1]) and - softmax denominator ([batch_size, num_heads, compute_dim, 1]). - """ - mask_value = DEFAULT_MASK_VALUE - seqlen = q.shape[-2] - if k_scaler is None: - replicated_in_axes = 4 - replicated_inputs = (ragged_batch_index, ragged_block_index) - else: - replicated_in_axes = 6 - replicated_inputs = (jnp.squeeze(k_scaler, -1), jnp.squeeze(v_scaler, -1), ragged_batch_index, ragged_block_index) - - with jax.named_scope("ragged_mha_vmap"): - out, (m, l) = jax.vmap( - functools.partial( - ragged_mqa, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - #out_dtype=out_dtype, - ), - in_axes=(shard_axis, shard_axis, shard_axis, *([None]*replicated_in_axes)), - out_axes=shard_axis, - )(q, k, v, start, end, *replicated_inputs) - return out, (m, l) - - -def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): - bsz, _, _, head_dim = xq.shape - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if k_scaler is not None: - scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - if v_scaler is not None: - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - return output class AttentionKernel: @@ -650,9 +406,13 @@ def __init__(self, env): self.shard_axis = 0 if self.env.shard_on_batch else 1 qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() - self.binded_ragged_mha = functools.partial(ragged_mha, bk=self.env.block_size, shard_axis=self.shard_axis) - self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), out_specs=(qkv_pspec, (others_pspec, others_pspec)), check_rep=False) - self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) + self.dense_attention = ak.dense_attention + self.ragged_attention = ak.RaggedAttentionKernel( + env, + in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), + out_specs=(qkv_pspec, (others_pspec, others_pspec)), + sharding_axis=self.shard_axis + ) def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index): """ @@ -666,7 +426,7 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragg bsz, num_heads, seqlen, head_dim = xq.shape _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads - if seqlen == 1: + if not self.env.ragged_mha and seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) with jax.named_scope("attn_insert_cache"): @@ -676,11 +436,11 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragg with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax(self.binded_ragged_mha, xq, keys, values, start, end, ragged_batch_index, ragged_block_index) + output, _ = torch_xla2.interop.call_jax(self.ragged_attention, xq, keys, values, start, end, ragged_batch_index, ragged_block_index) else: - output = dense_attention(xq, keys, values, None, None, mask) + output = self.dense_attention(xq, keys, values, None, None, mask) - if seqlen == 1: + if not self.env.ragged_mha and seqlen == 1: output = output[:, :, 0:1, :] # For XLA matmul performance boost # output = torch.matmul(scores, values) @@ -695,9 +455,13 @@ def __init__(self, env): self.shard_axis = 0 if self.env.shard_on_batch else 1 qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() - self.binded_ragged_mha_quantized = functools.partial(ragged_mha, bk=self.env.block_size, shard_axis=self.shard_axis) - self.binded_ragged_mha_quantized = shard_map(self.binded_ragged_mha_quantized, env.mesh, in_specs=(*([qkv_pspec] * 3), *([others_pspec]*6)), out_specs=(qkv_pspec, (others_pspec, others_pspec)), check_rep=False) - self.binded_ragged_mha_quantized = jax.jit(self.binded_ragged_mha_quantized) + self.dense_attention = ak.dense_attention + self.ragged_attention = ak.RaggedAttentionKernel( + env, + in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), + out_specs=(qkv_pspec, (others_pspec, others_pspec)), + sharding_axis=self.shard_axis + ) def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index): """ @@ -712,7 +476,7 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragg _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads - if seqlen == 1: + if not self.env.ragged_mha and seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) with jax.named_scope("attn_insert_cache"): @@ -722,11 +486,11 @@ def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragg with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax(self.binded_ragged_mha_quantized, xq, keys, values, start, end, ragged_batch_index, ragged_block_index, k_scaler, v_scaler) + output, _ = torch_xla2.interop.call_jax(self.ragged_attention, xq, keys, values, start, end, ragged_batch_index, ragged_block_index, k_scaler, v_scaler) else: - output= dense_attention(xq, keys, values, k_scaler, v_scaler, mask) + output= self.dense_attention(xq, keys, values, k_scaler, v_scaler, mask) - if seqlen == 1: + if not self.env.ragged_mha and seqlen == 1: output = output[:, :, 0:1, :] self.env.apply_sharding(output, axis=self.shard_axis) From 98856f657b25e78f95663b52d457e997e1ac9133 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 22 May 2024 04:51:52 +0000 Subject: [PATCH 31/41] Fix run issues refactoring. --- jetstream_pt/attention_kernel.py | 8 +++++++- jetstream_pt/layers.py | 10 ++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index d81789c7..165b04f2 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -3,9 +3,15 @@ from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu from jax.experimental.shard_map import shard_map +import functools + +import torch import numpy as np +import math +import torch.nn.functional as F + DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) def ragged_flash_attention_kernel( @@ -252,7 +258,7 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): class RaggedAttentionKernel: - def __init(self, env, input_specs, output_specs, sharding_axis): + def __init__(self, env, input_specs, output_specs, sharding_axis): self.binded_ragged_mha = functools.partial(ragged_mha, bk=env.block_size, shard_axis=sharding_axis) self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, input_specs, output_specs, check_rep=False) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 274e3011..6fa93853 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -15,9 +15,7 @@ # pylint: disable-all """This version contains modification to make it easier to trace and support batch.""" -import math from typing import Optional, Tuple -import functools import jax from . import attention_kernel as ak @@ -409,8 +407,8 @@ def __init__(self, env): self.dense_attention = ak.dense_attention self.ragged_attention = ak.RaggedAttentionKernel( env, - in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), - out_specs=(qkv_pspec, (others_pspec, others_pspec)), + input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), + output_specs=(qkv_pspec, (others_pspec, others_pspec)), sharding_axis=self.shard_axis ) @@ -458,8 +456,8 @@ def __init__(self, env): self.dense_attention = ak.dense_attention self.ragged_attention = ak.RaggedAttentionKernel( env, - in_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), - out_specs=(qkv_pspec, (others_pspec, others_pspec)), + input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), + output_specs=(qkv_pspec, (others_pspec, others_pspec)), sharding_axis=self.shard_axis ) From ab387262a2cd3ac8cb1efe0eeb79a1f5e98c3f81 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 22 May 2024 17:47:54 +0000 Subject: [PATCH 32/41] Fix the quantized version for ragged attention. --- jetstream_pt/attention_kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 165b04f2..92f6d7ee 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -263,5 +263,5 @@ def __init__(self, env, input_specs, output_specs, sharding_axis): self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, input_specs, output_specs, check_rep=False) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) - def __call__(self, xq, keys, values, start, end, ragged_batch_index, ragged_block_index): - return self.binded_ragged_mha(xq, keys, values, start, end, ragged_batch_index, ragged_block_index) + def __call__(self, xq, keys, values, start, end, ragged_batch_index, ragged_block_index, k_scaler=None, v_scaler=None): + return self.binded_ragged_mha(xq, keys, values, start, end, ragged_batch_index, ragged_block_index, k_scaler, v_scaler) From a712862e45786f77fcd3af04a90378a3d4b0bda5 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 19:11:40 +0000 Subject: [PATCH 33/41] Fix test_attention by adding default value for the newly added arguments. The error message is missing positional arguments. --- jetstream_pt/layers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 6fa93853..a42b57a2 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -412,7 +412,7 @@ def __init__(self, env): sharding_axis=self.shard_axis ) - def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index): + def __call__(self, xq, xk, xv, mask, cache, start=None, end=None, ragged_batch_index=None, ragged_block_index=None): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -566,10 +566,10 @@ def forward( freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], cache, - start, - end, - ragged_batch_index, - ragged_block_index, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, ): with jax.named_scope("attn_linear_before_cache"): bsz, seqlen = x.shape[0], x.shape[-2] From c50aba6af83587a50b5f4c50be1280c78e9d8069 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 19:42:07 +0000 Subject: [PATCH 34/41] Fixes unit tests, changes the Transformer model call argument order(input_pos) back to original to avoid unnecessary issues. --- jetstream_pt/engine.py | 4 ++-- jetstream_pt/third_party/gemma/model.py | 2 +- jetstream_pt/third_party/llama/model_exportable.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index a8b2d095..832019d1 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -168,7 +168,7 @@ def _call_model_generate( ] mask = jnp.expand_dims(mask, (1, 2)) - args = (tokens, caches_obj, mask, start, input_pos, ragged_batch_index, ragged_block_index) + args = (tokens, input_pos, caches_obj, mask, start, ragged_batch_index, ragged_block_index) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: with torchjax.jax_mode: @@ -198,7 +198,7 @@ def _call_model_prefill(self, weights, tokens, input_indexes): dtype=self.default_dtype, ) mask = jnp.triu(mask, k=1) - args = (tokens, caches, mask, None, input_indexes, None, None) + args = (tokens, input_indexes, caches, mask) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index d58805d7..72607b1d 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -330,10 +330,10 @@ def __init__(self, config: gemma_config.GemmaConfig, env): def forward( self, tokens: torch.Tensor, + input_pos: torch.Tensor, caches: List[Any], mask, start = None, - input_pos = None, ragged_batch_index = None, ragged_block_index = None, ): diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 5a430fec..ece6bcb3 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -184,10 +184,10 @@ def __init__( def forward( self, tokens: torch.Tensor, + input_pos: torch.Tensor, caches: List[Any], mask, start = None, - input_pos = None, ragged_batch_index = None, ragged_block_index = None, ): From d4318038a16bed3db42cfbd0b209c90c2e8d2c37 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 19:42:59 +0000 Subject: [PATCH 35/41] Format attention_kernel.py --- jetstream_pt/attention_kernel.py | 457 ++++++++++++++++++------------- 1 file changed, 260 insertions(+), 197 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 92f6d7ee..e2f0a553 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -14,6 +14,7 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + def ragged_flash_attention_kernel( start_ref, end_ref, @@ -33,67 +34,75 @@ def ragged_flash_attention_kernel( normalize_var: bool, quantized: bool, ): - """Pallas kernel for flash attention.""" - with jax.named_scope("attention_kernel"): - b, i = pl.program_id(0), pl.program_id(1) - - @pl.when(i == 0) - def init(): - with jax.named_scope("init"): - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - length = line_end_ref[b] - start = start_ref[b] - end = end_ref[b] - - @pl.when(jnp.logical_and(i * bk < length, start != end)) - def run(): - with jax.named_scope("run_qk"): - q = q_ref[...].astype(jnp.float32) - k = k_ref[...].astype(jnp.float32) - v = v_ref[...].astype(jnp.float32) - m_prev, l_prev = m_ref[...], l_ref[...] - - qk = jax.lax.dot_general( - q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 - ) - if normalize_var: - qk = qk / jnp.sqrt(k.shape[-1]) - if quantized: - qk = qk * k_scaler_ref[...] - with jax.named_scope("run_mask"): - start = start_ref[b] - end = end_ref[b] - iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) - mask_start_lt_end = jnp.logical_and(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) - mask_start_gt_end = jnp.logical_or(i * bk + iota >= start, i * bk + iota < end).astype(jnp.int32) - #mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) - mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) - - qk = qk + jnp.where(mask, 0.0, mask_value) - - with jax.named_scope("run_softmax"): - m_curr = qk.max(axis=-1) - - s_curr = jnp.exp(qk - m_curr[..., None]) - - l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) - if quantized: - s_curr = s_curr * v_scaler_ref[...] - o_curr_times_l_curr = jnp.dot(s_curr, v) - m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) - m_next = jnp.maximum(m_prev, m_curr) - alpha = jnp.exp(m_prev - m_next) - beta = jnp.exp(m_curr - m_next) - l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) - - m_ref[...], l_ref[...] = m_next, l_next_safe - o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe - ).astype(o_ref.dtype) + """Pallas kernel for flash attention.""" + with jax.named_scope("attention_kernel"): + b, i = pl.program_id(0), pl.program_id(1) + + @pl.when(i == 0) + def init(): + with jax.named_scope("init"): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + length = line_end_ref[b] + start = start_ref[b] + end = end_ref[b] + + @pl.when(jnp.logical_and(i * bk < length, start != end)) + def run(): + with jax.named_scope("run_qk"): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + if normalize_var: + qk = qk / jnp.sqrt(k.shape[-1]) + if quantized: + qk = qk * k_scaler_ref[...] + with jax.named_scope("run_mask"): + start = start_ref[b] + end = end_ref[b] + iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) + mask_start_lt_end = jnp.logical_and( + i * bk + iota >= start, i * bk + iota < end + ).astype(jnp.int32) + mask_start_gt_end = jnp.logical_or( + i * bk + iota >= start, i * bk + iota < end + ).astype(jnp.int32) + # mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) + mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) + + qk = qk + jnp.where(mask, 0.0, mask_value) + + with jax.named_scope("run_softmax"): + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + l_curr = jax.lax.broadcast_in_dim( + s_curr.sum(axis=-1), l_prev.shape, (0,) + ) + if quantized: + s_curr = s_curr * v_scaler_ref[...] + o_curr_times_l_curr = jnp.dot(s_curr, v) + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) + / l_next_safe + ).astype(o_ref.dtype) + @functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]) def ragged_mqa( @@ -104,74 +113,92 @@ def ragged_mqa( end: jax.Array, k_scaler: jax.Array | None = None, v_scaler: jax.Array | None = None, - ragged_batch_index = None, - ragged_block_index = None, + ragged_batch_index=None, + ragged_block_index=None, bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - """Ragged multi query attention.""" - with jax.named_scope("ragged_mqa"): - batch_size, num_heads, head_dim = q.shape - seq_len = k.shape[1] - - def kv_index_map(b, i, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): - index = b * (seq_len // bk) + i - return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 - - def q_index_map(b, i, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): - index = b * (seq_len // bk) + i - return ragged_batch_index_ref[index], 0, 0 - - def scaler_index_map(b, i, *_): - return b, 0, i - - line_end = jnp.where(start < end, end, seq_len - 1) - - in_specs = [ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - ] - inputs = (start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v) - quantized = False - if k_scaler is not None: - in_specs = in_specs + [ - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), + """Ragged multi query attention.""" + with jax.named_scope("ragged_mqa"): + batch_size, num_heads, head_dim = q.shape + seq_len = k.shape[1] + + def kv_index_map( + b, + i, + start_ref, + end_ref, + line_end_ref, + ragged_batch_index_ref, + ragged_block_index_ref, + ): + index = b * (seq_len // bk) + i + return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 + + def q_index_map( + b, + i, + start_ref, + end_ref, + line_end_ref, + ragged_batch_index_ref, + ragged_block_index_ref, + ): + index = b * (seq_len // bk) + i + return ragged_batch_index_ref[index], 0, 0 + + def scaler_index_map(b, i, *_): + return b, 0, i + + line_end = jnp.where(start < end, end, seq_len - 1) + + in_specs = [ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), ] - inputs = inputs + (k_scaler, v_scaler) - quantized = True - - out, m, l = pl.pallas_call( - functools.partial( - ragged_flash_attention_kernel, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - quantized=quantized, - ), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=5, - in_specs=in_specs, - out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - ], - grid=(batch_size, seq_len // bk), - ), - compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), - out_shape=[ - q, - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - ], - )(*inputs) - return out, (m[..., 0], l[..., 0]) - - -@functools.partial(jax.jit, static_argnames=['bk', 'mask_value', 'normalize_var', 'shard_axis']) + inputs = (start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v) + quantized = False + if k_scaler is not None: + in_specs = in_specs + [ + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + ] + inputs = inputs + (k_scaler, v_scaler) + quantized = True + + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=quantized, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), + ], + )(*inputs) + return out, (m[..., 0], l[..., 0]) + + +@functools.partial( + jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"] +) def ragged_mha( q: jax.Array, k: jax.Array, @@ -183,85 +210,121 @@ def ragged_mha( k_scaler: jax.Array | None = None, v_scaler: jax.Array | None = None, bk: int = 512, - mask_value : float = DEFAULT_MASK_VALUE, + mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, - shard_axis: int = 1 + shard_axis: int = 1, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - """Ragged multi head attention. - Args: - q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. - k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or - PartitionQuantizedTensor. - v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or - PartitionQuantizedTensor. - start: A i32[batch_size] jax.Array - end: A i32[batch_size] jax.Array - bk: An integer that is the sequence block size. - logit_cap: An optional float that caps logits via tanh. By default there is - no logit capping. - mask_value: The value used for padding in attention. By default it is a very - negative floating point number. - out_dtype: An optional dtype for the output. If not provided, the output - dtype will be q's dtype. - Returns: - The output of attention([batch_size, num_heads, compute_dim, head_dim]), - along with the max logit ([batch_size, num_heads, compute_dim, 1]) and - softmax denominator ([batch_size, num_heads, compute_dim, 1]). - """ - mask_value = DEFAULT_MASK_VALUE - seqlen = q.shape[-2] - if k_scaler is None: - replicated_in_axes = 4 - replicated_inputs = (ragged_batch_index, ragged_block_index) - else: - replicated_in_axes = 6 - replicated_inputs = (jnp.squeeze(k_scaler, -1), jnp.squeeze(v_scaler, -1), ragged_batch_index, ragged_block_index) - - with jax.named_scope("ragged_mha_vmap"): - out, (m, l) = jax.vmap( - functools.partial( - ragged_mqa, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - #out_dtype=out_dtype, - ), - in_axes=(shard_axis, shard_axis, shard_axis, *([None]*replicated_in_axes)), - out_axes=shard_axis, - )(q, k, v, start, end, *replicated_inputs) - return out, (m, l) + """Ragged multi head attention. + Args: + q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. + k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + start: A i32[batch_size] jax.Array + end: A i32[batch_size] jax.Array + bk: An integer that is the sequence block size. + logit_cap: An optional float that caps logits via tanh. By default there is + no logit capping. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + out_dtype: An optional dtype for the output. If not provided, the output + dtype will be q's dtype. + Returns: + The output of attention([batch_size, num_heads, compute_dim, head_dim]), + along with the max logit ([batch_size, num_heads, compute_dim, 1]) and + softmax denominator ([batch_size, num_heads, compute_dim, 1]). + """ + mask_value = DEFAULT_MASK_VALUE + seqlen = q.shape[-2] + if k_scaler is None: + replicated_in_axes = 4 + replicated_inputs = (ragged_batch_index, ragged_block_index) + else: + replicated_in_axes = 6 + replicated_inputs = ( + jnp.squeeze(k_scaler, -1), + jnp.squeeze(v_scaler, -1), + ragged_batch_index, + ragged_block_index, + ) + + with jax.named_scope("ragged_mha_vmap"): + out, (m, l) = jax.vmap( + functools.partial( + ragged_mqa, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + # out_dtype=out_dtype, + ), + in_axes=( + shard_axis, + shard_axis, + shard_axis, + *([None] * replicated_in_axes), + ), + out_axes=shard_axis, + )(q, k, v, start, end, *replicated_inputs) + return out, (m, l) def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): - bsz, _, _, head_dim = xq.shape - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if k_scaler is not None: - scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - if v_scaler is not None: - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - return output + bsz, _, _, head_dim = xq.shape + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if k_scaler is not None: + scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + if mask is not None: + # if mask.shape != (1,1,16,16): + # breakpoint() + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + if v_scaler is not None: + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) + return output + class RaggedAttentionKernel: - def __init__(self, env, input_specs, output_specs, sharding_axis): - self.binded_ragged_mha = functools.partial(ragged_mha, bk=env.block_size, shard_axis=sharding_axis) - self.binded_ragged_mha = shard_map(ragged_mha, env.mesh, input_specs, output_specs, check_rep=False) - self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) + def __init__(self, env, input_specs, output_specs, sharding_axis): + self.binded_ragged_mha = functools.partial( + ragged_mha, bk=env.block_size, shard_axis=sharding_axis + ) + self.binded_ragged_mha = shard_map( + ragged_mha, env.mesh, input_specs, output_specs, check_rep=False + ) + self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) - def __call__(self, xq, keys, values, start, end, ragged_batch_index, ragged_block_index, k_scaler=None, v_scaler=None): - return self.binded_ragged_mha(xq, keys, values, start, end, ragged_batch_index, ragged_block_index, k_scaler, v_scaler) + def __call__( + self, + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + k_scaler=None, + v_scaler=None, + ): + return self.binded_ragged_mha( + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + k_scaler, + v_scaler, + ) From 9061fac7e44e7d2c86833684f5bea6b1e8a70815 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 19:51:46 +0000 Subject: [PATCH 36/41] Add descrpitions to ragged attention outputs. --- jetstream_pt/attention_kernel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index e2f0a553..7bf18900 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -26,9 +26,9 @@ def ragged_flash_attention_kernel( v_ref, k_scaler_ref, v_scaler_ref, - o_ref, - m_ref, - l_ref, + o_ref, # outputs + m_ref, # row max + l_ref, # propogation coefficient bk: int, mask_value: float, normalize_var: bool, From a6059e9e91f99b00d889aa941ffc079b80407b0b Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 19:58:28 +0000 Subject: [PATCH 37/41] Fix quantization tests by adding default value to quantization kernel class. --- jetstream_pt/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index a42b57a2..89936f84 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -461,7 +461,7 @@ def __init__(self, env): sharding_axis=self.shard_axis ) - def __call__(self, xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index): + def __call__(self, xq, xk, xv, mask, cache, start=None, end=None, ragged_batch_index=None, ragged_block_index=None): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) From a73658d1227011ddd36c49ab8249033f9db53690 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 20:13:14 +0000 Subject: [PATCH 38/41] Reformat attention_kernel.py. Format with black doesn't comply with the pylink rules. --- jetstream_pt/attention_kernel.py | 503 ++++++++++++++++--------------- 1 file changed, 259 insertions(+), 244 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 7bf18900..0e835261 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -26,85 +26,87 @@ def ragged_flash_attention_kernel( v_ref, k_scaler_ref, v_scaler_ref, - o_ref, # outputs - m_ref, # row max - l_ref, # propogation coefficient + o_ref, # outputs + m_ref, # row max + l_ref, # propogation coefficient bk: int, mask_value: float, normalize_var: bool, quantized: bool, ): - """Pallas kernel for flash attention.""" - with jax.named_scope("attention_kernel"): - b, i = pl.program_id(0), pl.program_id(1) - - @pl.when(i == 0) - def init(): - with jax.named_scope("init"): - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - length = line_end_ref[b] + """Pallas kernel for flash attention.""" + with jax.named_scope("attention_kernel"): + b, i = pl.program_id(0), pl.program_id(1) + + @pl.when(i == 0) + def init(): + with jax.named_scope("init"): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + length = line_end_ref[b] + start = start_ref[b] + end = end_ref[b] + + @pl.when(jnp.logical_and(i * bk < length, start != end)) + def run(): + with jax.named_scope("run_qk"): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + if normalize_var: + qk = qk / jnp.sqrt(k.shape[-1]) + if quantized: + qk = qk * k_scaler_ref[...] + with jax.named_scope("run_mask"): start = start_ref[b] end = end_ref[b] + iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) + mask_start_lt_end = jnp.logical_and( + i * bk + iota >= start, i * bk + iota < end + ).astype(jnp.int32) + mask_start_gt_end = jnp.logical_or( + i * bk + iota >= start, i * bk + iota < end + ).astype(jnp.int32) + # mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) + mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) + + qk = qk + jnp.where(mask, 0.0, mask_value) + + with jax.named_scope("run_softmax"): + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) - @pl.when(jnp.logical_and(i * bk < length, start != end)) - def run(): - with jax.named_scope("run_qk"): - q = q_ref[...].astype(jnp.float32) - k = k_ref[...].astype(jnp.float32) - v = v_ref[...].astype(jnp.float32) - m_prev, l_prev = m_ref[...], l_ref[...] - - qk = jax.lax.dot_general( - q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 - ) - if normalize_var: - qk = qk / jnp.sqrt(k.shape[-1]) - if quantized: - qk = qk * k_scaler_ref[...] - with jax.named_scope("run_mask"): - start = start_ref[b] - end = end_ref[b] - iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) - mask_start_lt_end = jnp.logical_and( - i * bk + iota >= start, i * bk + iota < end - ).astype(jnp.int32) - mask_start_gt_end = jnp.logical_or( - i * bk + iota >= start, i * bk + iota < end - ).astype(jnp.int32) - # mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) - mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) - - qk = qk + jnp.where(mask, 0.0, mask_value) - - with jax.named_scope("run_softmax"): - m_curr = qk.max(axis=-1) - - s_curr = jnp.exp(qk - m_curr[..., None]) - - l_curr = jax.lax.broadcast_in_dim( - s_curr.sum(axis=-1), l_prev.shape, (0,) - ) - if quantized: - s_curr = s_curr * v_scaler_ref[...] - o_curr_times_l_curr = jnp.dot(s_curr, v) - m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) - m_next = jnp.maximum(m_prev, m_curr) - alpha = jnp.exp(m_prev - m_next) - beta = jnp.exp(m_curr - m_next) - l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) - - m_ref[...], l_ref[...] = m_next, l_next_safe - o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) - / l_next_safe - ).astype(o_ref.dtype) - - -@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]) + l_curr = jax.lax.broadcast_in_dim( + s_curr.sum(axis=-1), l_prev.shape, (0,) + ) + if quantized: + s_curr = s_curr * v_scaler_ref[...] + o_curr_times_l_curr = jnp.dot(s_curr, v) + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) + / l_next_safe + ).astype(o_ref.dtype) + + +@functools.partial( + jax.jit, static_argnames=["bk", "mask_value", "normalize_var"] +) def ragged_mqa( q: jax.Array, k: jax.Array, @@ -119,81 +121,94 @@ def ragged_mqa( mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - """Ragged multi query attention.""" - with jax.named_scope("ragged_mqa"): - batch_size, num_heads, head_dim = q.shape - seq_len = k.shape[1] - - def kv_index_map( - b, - i, - start_ref, - end_ref, - line_end_ref, - ragged_batch_index_ref, - ragged_block_index_ref, - ): - index = b * (seq_len // bk) + i - return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 - - def q_index_map( - b, - i, - start_ref, - end_ref, - line_end_ref, - ragged_batch_index_ref, - ragged_block_index_ref, - ): - index = b * (seq_len // bk) + i - return ragged_batch_index_ref[index], 0, 0 - - def scaler_index_map(b, i, *_): - return b, 0, i - - line_end = jnp.where(start < end, end, seq_len - 1) - - in_specs = [ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - ] - inputs = (start, end, line_end, ragged_batch_index, ragged_block_index, q, k, v) - quantized = False - if k_scaler is not None: - in_specs = in_specs + [ - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - ] - inputs = inputs + (k_scaler, v_scaler) - quantized = True - - out, m, l = pl.pallas_call( - functools.partial( - ragged_flash_attention_kernel, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - quantized=quantized, + """Ragged multi query attention.""" + with jax.named_scope("ragged_mqa"): + batch_size, num_heads, head_dim = q.shape + seq_len = k.shape[1] + + def kv_index_map( + b, + i, + start_ref, + end_ref, + line_end_ref, + ragged_batch_index_ref, + ragged_block_index_ref, + ): + index = b * (seq_len // bk) + i + return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 + + def q_index_map( + b, + i, + start_ref, + end_ref, + line_end_ref, + ragged_batch_index_ref, + ragged_block_index_ref, + ): + index = b * (seq_len // bk) + i + return ragged_batch_index_ref[index], 0, 0 + + def scaler_index_map(b, i, *_): + return b, 0, i + + line_end = jnp.where(start < end, end, seq_len - 1) + + in_specs = [ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + ] + inputs = ( + start, + end, + line_end, + ragged_batch_index, + ragged_block_index, + q, + k, + v, + ) + quantized = False + if k_scaler is not None: + in_specs = in_specs + [ + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + ] + inputs = inputs + (k_scaler, v_scaler) + quantized = True + + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=quantized, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), + out_shape=[ + q, + jax.ShapeDtypeStruct( + (batch_size, num_heads, head_dim), jnp.float32 ), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=5, - in_specs=in_specs, - out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - ], - grid=(batch_size, seq_len // bk), + jax.ShapeDtypeStruct( + (batch_size, num_heads, head_dim), jnp.float32 ), - compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), - out_shape=[ - q, - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), - ], - )(*inputs) - return out, (m[..., 0], l[..., 0]) + ], + )(*inputs) + return out, (m[..., 0], l[..., 0]) @functools.partial( @@ -214,99 +229,110 @@ def ragged_mha( normalize_var: bool = True, shard_axis: int = 1, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - """Ragged multi head attention. - Args: - q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. - k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or - PartitionQuantizedTensor. - v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or - PartitionQuantizedTensor. - start: A i32[batch_size] jax.Array - end: A i32[batch_size] jax.Array - bk: An integer that is the sequence block size. - logit_cap: An optional float that caps logits via tanh. By default there is - no logit capping. - mask_value: The value used for padding in attention. By default it is a very - negative floating point number. - out_dtype: An optional dtype for the output. If not provided, the output - dtype will be q's dtype. - Returns: - The output of attention([batch_size, num_heads, compute_dim, head_dim]), - along with the max logit ([batch_size, num_heads, compute_dim, 1]) and - softmax denominator ([batch_size, num_heads, compute_dim, 1]). - """ - mask_value = DEFAULT_MASK_VALUE - seqlen = q.shape[-2] - if k_scaler is None: - replicated_in_axes = 4 - replicated_inputs = (ragged_batch_index, ragged_block_index) - else: - replicated_in_axes = 6 - replicated_inputs = ( - jnp.squeeze(k_scaler, -1), - jnp.squeeze(v_scaler, -1), - ragged_batch_index, - ragged_block_index, - ) - - with jax.named_scope("ragged_mha_vmap"): - out, (m, l) = jax.vmap( - functools.partial( - ragged_mqa, - bk=bk, - mask_value=mask_value, - normalize_var=normalize_var, - # out_dtype=out_dtype, - ), - in_axes=( - shard_axis, - shard_axis, - shard_axis, - *([None] * replicated_in_axes), - ), - out_axes=shard_axis, - )(q, k, v, start, end, *replicated_inputs) - return out, (m, l) + """Ragged multi head attention. + Args: + q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. + k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + start: A i32[batch_size] jax.Array + end: A i32[batch_size] jax.Array + bk: An integer that is the sequence block size. + logit_cap: An optional float that caps logits via tanh. By default there is + no logit capping. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + out_dtype: An optional dtype for the output. If not provided, the output + dtype will be q's dtype. + Returns: + The output of attention([batch_size, num_heads, compute_dim, head_dim]), + along with the max logit ([batch_size, num_heads, compute_dim, 1]) and + softmax denominator ([batch_size, num_heads, compute_dim, 1]). + """ + mask_value = DEFAULT_MASK_VALUE + seqlen = q.shape[-2] + if k_scaler is None: + replicated_in_axes = 4 + replicated_inputs = (ragged_batch_index, ragged_block_index) + else: + replicated_in_axes = 6 + replicated_inputs = ( + jnp.squeeze(k_scaler, -1), + jnp.squeeze(v_scaler, -1), + ragged_batch_index, + ragged_block_index, + ) + + with jax.named_scope("ragged_mha_vmap"): + out, (m, l) = jax.vmap( + functools.partial( + ragged_mqa, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + # out_dtype=out_dtype, + ), + in_axes=( + shard_axis, + shard_axis, + shard_axis, + *([None] * replicated_in_axes), + ), + out_axes=shard_axis, + )(q, k, v, start, end, *replicated_inputs) + return out, (m, l) def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): - bsz, _, _, head_dim = xq.shape - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if k_scaler is not None: - scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - if v_scaler is not None: - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - return output + bsz, _, _, head_dim = xq.shape + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if k_scaler is not None: + scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + if mask is not None: + # if mask.shape != (1,1,16,16): + # breakpoint() + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + if v_scaler is not None: + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) + return output class RaggedAttentionKernel: - def __init__(self, env, input_specs, output_specs, sharding_axis): - self.binded_ragged_mha = functools.partial( - ragged_mha, bk=env.block_size, shard_axis=sharding_axis - ) - self.binded_ragged_mha = shard_map( - ragged_mha, env.mesh, input_specs, output_specs, check_rep=False - ) - self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) - - def __call__( - self, + def __init__(self, env, input_specs, output_specs, sharding_axis): + self.binded_ragged_mha = functools.partial( + ragged_mha, bk=env.block_size, shard_axis=sharding_axis + ) + self.binded_ragged_mha = shard_map( + ragged_mha, env.mesh, input_specs, output_specs, check_rep=False + ) + self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) + + def __call__( + self, + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + k_scaler=None, + v_scaler=None, + ): + return self.binded_ragged_mha( xq, keys, values, @@ -314,17 +340,6 @@ def __call__( end, ragged_batch_index, ragged_block_index, - k_scaler=None, - v_scaler=None, - ): - return self.binded_ragged_mha( - xq, - keys, - values, - start, - end, - ragged_batch_index, - ragged_block_index, - k_scaler, - v_scaler, - ) + k_scaler, + v_scaler, + ) From 9aaa7a6a21e9a2af21b76f4d87423eab666a36ef Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 21:16:35 +0000 Subject: [PATCH 39/41] Ignores R0913: Too many arguments link error for ragged attention kernel. Fix other lint errors. --- .pylintrc | 2 +- jetstream_pt/attention_kernel.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.pylintrc b/.pylintrc index 66a6589e..84ddc0ae 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,2 +1,2 @@ [MESSAGES CONTROL] -disable=C0114,R0801,E1102,W0613,R1711,too-many-locals +disable=C0114,R0801,R0913,E1102,W0613,R1711,too-many-locals diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 0e835261..96bb4233 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -1,17 +1,17 @@ +import functools +import math + import jax import jax.numpy as jnp from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu from jax.experimental.shard_map import shard_map -import functools import torch +import torch.nn.functional as F import numpy as np -import math -import torch.nn.functional as F - DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) @@ -197,7 +197,7 @@ def scaler_index_map(b, i, *_): ], grid=(batch_size, seq_len // bk), ), - compiler_params=dict(dimension_semantics=("parallel", "arbitrary")), + compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, out_shape=[ q, jax.ShapeDtypeStruct( @@ -251,7 +251,6 @@ def ragged_mha( softmax denominator ([batch_size, num_heads, compute_dim, 1]). """ mask_value = DEFAULT_MASK_VALUE - seqlen = q.shape[-2] if k_scaler is None: replicated_in_axes = 4 replicated_inputs = (ragged_batch_index, ragged_block_index) @@ -285,6 +284,8 @@ def ragged_mha( def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): + """The vanilla attention kernel implementation.""" + bsz, _, _, head_dim = xq.shape with jax.named_scope("attn_mat1"): ## Attention start @@ -310,6 +311,7 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): class RaggedAttentionKernel: + """Ragged attention kernel.""" def __init__(self, env, input_specs, output_specs, sharding_axis): self.binded_ragged_mha = functools.partial( From 1286bb56d362b36947b4a43aa10116a550d8236c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 21:21:13 +0000 Subject: [PATCH 40/41] Ignore R0903: Too few public methods. Fix lint errors. --- .pylintrc | 2 +- jetstream_pt/config.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.pylintrc b/.pylintrc index 84ddc0ae..1df7809e 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,2 +1,2 @@ [MESSAGES CONTROL] -disable=C0114,R0801,R0913,E1102,W0613,R1711,too-many-locals +disable=C0114,R0801,R0903,R0913,E1102,W0613,R1711,too-many-locals diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index ae3dfe75..bdf5fe41 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -60,25 +60,26 @@ ) flags.DEFINE_bool( "profiling_prefill", - False, + False, "Whether to profile the prefill, " - "if set to false, profile generate function only", - required=False + "if set to false, profile generate function only", + required=False, ) flags.DEFINE_bool( "ragged_mha", False, - "Whether to enable Ragged multi head attention", - required=False + "Whether to enable Ragged multi head attention", + required=False, ) flags.DEFINE_integer( - "starting_position", - 512, + "starting_position", + 512, "The starting position of decoding, " - "for performance tuning and debugging only", - required=False + "for performance tuning and debugging only", + required=False, ) + def create_quantization_config_from_flags(): """Create Quantization Config from cmd flags""" config = QuantizationConfig() From f64fe6f64c873d6d8dcf4ca1feb104621f134b72 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 23 May 2024 21:27:49 +0000 Subject: [PATCH 41/41] Fix the rest of the lint errors. --- jetstream_pt/engine.py | 54 ++++++++--- jetstream_pt/environment.py | 2 + jetstream_pt/layers.py | 90 +++++++++++++++---- jetstream_pt/third_party/gemma/model.py | 40 +++++---- .../third_party/llama/model_exportable.py | 48 ++++++---- run_interactive.py | 2 +- run_interactive_multiple_host.py | 2 +- 7 files changed, 174 insertions(+), 64 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 832019d1..a709e90a 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -64,7 +64,7 @@ class DecodeState: ] # only present in quantized kv current_position: int lens: jax.Array # [batch_size, 1] - start: jax.Array # [batch_size, 1], the starting pos for each slot + start: jax.Array # [batch_size, 1], the starting pos for each slot input_pos: jax.Array # [batch_size, 1] input pos for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid @@ -128,7 +128,7 @@ def init_decode_state( caches, scalers, self.env.starting_position, - jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), # lens + jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), # lens jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # start pos jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # input pos jnp.full( @@ -168,7 +168,15 @@ def _call_model_generate( ] mask = jnp.expand_dims(mask, (1, 2)) - args = (tokens, input_pos, caches_obj, mask, start, ragged_batch_index, ragged_block_index) + args = ( + tokens, + input_pos, + caches_obj, + mask, + start, + ragged_batch_index, + ragged_block_index, + ) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: with torchjax.jax_mode: @@ -277,7 +285,9 @@ def _insert_no_wrap( cond = jnp.logical_and(x <= decode_state.current_position, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) - start = decode_state.start.at[slot].set(pos % self.env.cache_sequence_length) + start = decode_state.start.at[slot].set( + pos % self.env.cache_sequence_length + ) input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) if not self.env.quant_config.enable_kv_quantization: @@ -458,10 +468,10 @@ def insert( ) def precompute_ragged_block_indices(self, decode_state: DecodeState): - """Precompute the ragged attention block indices. Ragged attention iterates the grid - and relies on the computed grid index to skip the unnecessary blocks. The basic idea - is to use input_pos, which is the length of each slot to determine if we should - work on the next block of the slot or move to the next slot. """ + """Precompute the ragged attention block indices. Ragged attention iterates the grid + and relies on the computed grid index to skip the unnecessary blocks. The basic idea + is to use input_pos, which is the length of each slot to determine if we should + work on the next block of the slot or move to the next slot.""" start = decode_state.start end = (start + decode_state.input_pos) % self.env.cache_len batch_size = start.shape[0] @@ -477,7 +487,11 @@ def precompute_ragged_block_indices(self, decode_state: DecodeState): end = end.reshape((batch_size, 1)) am_last_batch = b == batch_size - 1 - last_good_block = jnp.where(start < end, jax.lax.div(end - 1, bk), jax.lax.div(self.env.cache_len -1, bk)) + last_good_block = jnp.where( + start < end, + jax.lax.div(end - 1, bk), + jax.lax.div(self.env.cache_len - 1, bk), + ) next_b = jnp.where(am_last_batch, b, b + 1) next_i = jnp.where(am_last_batch, last_good_block, 0) @@ -492,14 +506,22 @@ def true_comp(b, i, bk, start, end, next_b, next_i): # start > end, continue work on the block is there is no overlap with [end, start) def false_comp(b, i, bk, start, end): b_next = b - i_next = jnp.where(jnp.logical_and(i * bk >= end, (i + 1) * bk <= start), jax.lax.div(start, bk), i) + i_next = jnp.where( + jnp.logical_and(i * bk >= end, (i + 1) * bk <= start), + jax.lax.div(start, bk), + i, + ) return b_next, i_next true_comp_b, true_comp_i = true_comp(b, i, bk, start, end, next_b, next_i) false_comp_b, false_comp_i = false_comp(b, i, bk, start, end) - b_next = jnp.where(start < end, true_comp_b, jnp.where(start == end, next_b, false_comp_b)) - i_next = jnp.where(start < end, true_comp_i, jnp.where(start == end, next_i, false_comp_i)) + b_next = jnp.where( + start < end, true_comp_b, jnp.where(start == end, next_b, false_comp_b) + ) + i_next = jnp.where( + start < end, true_comp_i, jnp.where(start == end, next_i, false_comp_i) + ) return b_next, i_next def generate( @@ -511,8 +533,12 @@ def generate( # fill mask first mask = decode_state.mask.at[:, decode_state.current_position].set(0) - ragged_batch_index, ragged_block_index = self.precompute_ragged_block_indices(decode_state) - ragged_batch_index, ragged_block_index = ragged_batch_index.reshape((-1)), ragged_block_index.reshape((-1)) + ragged_batch_index, ragged_block_index = ( + self.precompute_ragged_block_indices(decode_state) + ) + ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( + (-1) + ), ragged_block_index.reshape((-1)) logits, new_caches, new_scales = self._call_model_generate( params, diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 9573e503..5ea8f3a3 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -95,6 +95,8 @@ class JetEngineEnvironmentData: # Starting position starting_position: int = 512 + + # pylint: disable-next=all class JetEngineEnvironment: diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 89936f84..c5e305b8 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -39,8 +39,10 @@ def _calc_cosine_dist(x, y): y = y.flatten().to(torch.float32) return (torch.dot(x, y) / (x.norm() * y.norm())).item() + import numpy as np + class Int8Embedding(torch.nn.Module): def __init__(self, num_embeddings, embedding_dims, device="cpu"): @@ -402,17 +404,28 @@ class AttentionKernel: def __init__(self, env): self.env = env self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention self.ragged_attention = ak.RaggedAttentionKernel( - env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), - sharding_axis=self.shard_axis + env, + input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), + output_specs=(qkv_pspec, (others_pspec, others_pspec)), + sharding_axis=self.shard_axis, ) - def __call__(self, xq, xk, xv, mask, cache, start=None, end=None, ragged_batch_index=None, ragged_block_index=None): + def __call__( + self, + xq, + xk, + xv, + mask, + cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, + ): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -431,10 +444,19 @@ def __call__(self, xq, xk, xv, mask, cache, start=None, end=None, ragged_batch_i keys, values = cache.update(xk, xv) keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) - + with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax(self.ragged_attention, xq, keys, values, start, end, ragged_batch_index, ragged_block_index) + output, _ = torch_xla2.interop.call_jax( + self.ragged_attention, + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + ) else: output = self.dense_attention(xq, keys, values, None, None, mask) @@ -451,17 +473,28 @@ class Int8KVAttentionKernel: def __init__(self, env): self.env = env self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention self.ragged_attention = ak.RaggedAttentionKernel( - env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), - sharding_axis=self.shard_axis + env, + input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), + output_specs=(qkv_pspec, (others_pspec, others_pspec)), + sharding_axis=self.shard_axis, ) - def __call__(self, xq, xk, xv, mask, cache, start=None, end=None, ragged_batch_index=None, ragged_block_index=None): + def __call__( + self, + xq, + xk, + xv, + mask, + cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, + ): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -484,9 +517,22 @@ def __call__(self, xq, xk, xv, mask, cache, start=None, end=None, ragged_batch_i with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax(self.ragged_attention, xq, keys, values, start, end, ragged_batch_index, ragged_block_index, k_scaler, v_scaler) + output, _ = torch_xla2.interop.call_jax( + self.ragged_attention, + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + k_scaler, + v_scaler, + ) else: - output= self.dense_attention(xq, keys, values, k_scaler, v_scaler, mask) + output = self.dense_attention( + xq, keys, values, k_scaler, v_scaler, mask + ) if not self.env.ragged_mha and seqlen == 1: output = output[:, :, 0:1, :] @@ -597,6 +643,16 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - output = self.attention_kernel(xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index).type_as(xq) + output = self.attention_kernel( + xq, + xk, + xv, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ).type_as(xq) output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 72607b1d..73d8e07e 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -135,10 +135,10 @@ def forward( freqs_cis, mask, cache, - start = None, - end = None, - ragged_batch_index = None, - ragged_block_index = None, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, ) -> torch.Tensor: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 @@ -168,7 +168,17 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - output = self.attention_kernel(xq, xk, xv, mask, cache, start, end, ragged_batch_index, ragged_block_index) + output = self.attention_kernel( + xq, + xk, + xv, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ) # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) @@ -333,18 +343,18 @@ def forward( input_pos: torch.Tensor, caches: List[Any], mask, - start = None, - ragged_batch_index = None, - ragged_block_index = None, + start=None, + ragged_batch_index=None, + ragged_block_index=None, ): """ - tokens: the input token for decoding - caches: kv caches - mask: causal mask to filter the attention results - start: the starting position for each slot - input_pos: the decoding position relative to the start, which is the length of the decoding results - ragged_batch_index: precomputed batch index for ragged attention - ragged_block_index: precomputed block index for ragged attention + tokens: the input token for decoding + caches: kv caches + mask: causal mask to filter the attention results + start: the starting position for each slot + input_pos: the decoding position relative to the start, which is the length of the decoding results + ragged_batch_index: precomputed batch index for ragged attention + ragged_block_index: precomputed block index for ragged attention """ with jax.named_scope("transformer_freq"): diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index ece6bcb3..2385839e 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -109,14 +109,21 @@ def forward( freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], cache, - start = None, - end = None, - ragged_batch_index = None, - ragged_block_index = None, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, ): with jax.named_scope("Attention"): attn = self.attention.forward( - self.attention_norm(x), freqs_cis, mask, cache, start, end, ragged_batch_index, ragged_block_index + self.attention_norm(x), + freqs_cis, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, ) with jax.named_scope("ffn_norm"): h = x + attn @@ -187,18 +194,18 @@ def forward( input_pos: torch.Tensor, caches: List[Any], mask, - start = None, - ragged_batch_index = None, - ragged_block_index = None, + start=None, + ragged_batch_index=None, + ragged_block_index=None, ): """ - tokens: the input token for decoding - caches: kv caches - mask: causal mask to filter the attention results - start: the starting position for each slot - input_pos: the decoding position relative to the start, which is the length of the decoding results - ragged_batch_index: precomputed batch index for ragged attention - ragged_block_index: precomputed block index for ragged attention + tokens: the input token for decoding + caches: kv caches + mask: causal mask to filter the attention results + start: the starting position for each slot + input_pos: the decoding position relative to the start, which is the length of the decoding results + ragged_batch_index: precomputed batch index for ragged attention + ragged_block_index: precomputed block index for ragged attention """ with jax.named_scope("transformer_tok"): @@ -216,7 +223,16 @@ def forward( end = None if start is None else (start + input_pos) % self.env.cache_len for layer, cache in zip(self.layers, caches): with jax.named_scope("TransformerBlock"): - h = layer(h, freqs_cis, mask, cache, start, end, ragged_batch_index, ragged_block_index) + h = layer( + h, + freqs_cis, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ) with jax.named_scope("transformer_norm"): h = self.norm(h) diff --git a/run_interactive.py b/run_interactive.py index ebe97c15..77b3a702 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -93,4 +93,4 @@ def main(argv): if __name__ == "__main__": os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - app.run(main) \ No newline at end of file + app.run(main) diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index f59674fc..d939b991 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -44,7 +44,7 @@ def create_engine(): max_cache_length=FLAGS.max_cache_length, sharding_config=FLAGS.sharding_config, shard_on_batch=FLAGS.shard_on_batch, - ragged_mha=FLAGS.ragged_mha + ragged_mha=FLAGS.ragged_mha, ) print("Initialize engine", time.perf_counter() - start)