diff --git a/.pylintrc b/.pylintrc index 66a6589e..1df7809e 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,2 +1,2 @@ [MESSAGES CONTROL] -disable=C0114,R0801,E1102,W0613,R1711,too-many-locals +disable=C0114,R0801,R0903,R0913,E1102,W0613,R1711,too-many-locals diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py new file mode 100644 index 00000000..96bb4233 --- /dev/null +++ b/jetstream_pt/attention_kernel.py @@ -0,0 +1,347 @@ +import functools +import math + +import jax +import jax.numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.shard_map import shard_map + +import torch +import torch.nn.functional as F + +import numpy as np + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + + +def ragged_flash_attention_kernel( + start_ref, + end_ref, + line_end_ref, + pre_b_ref, + pre_i_ref, + q_ref, + k_ref, + v_ref, + k_scaler_ref, + v_scaler_ref, + o_ref, # outputs + m_ref, # row max + l_ref, # propogation coefficient + bk: int, + mask_value: float, + normalize_var: bool, + quantized: bool, +): + """Pallas kernel for flash attention.""" + with jax.named_scope("attention_kernel"): + b, i = pl.program_id(0), pl.program_id(1) + + @pl.when(i == 0) + def init(): + with jax.named_scope("init"): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + length = line_end_ref[b] + start = start_ref[b] + end = end_ref[b] + + @pl.when(jnp.logical_and(i * bk < length, start != end)) + def run(): + with jax.named_scope("run_qk"): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + if normalize_var: + qk = qk / jnp.sqrt(k.shape[-1]) + if quantized: + qk = qk * k_scaler_ref[...] + with jax.named_scope("run_mask"): + start = start_ref[b] + end = end_ref[b] + iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) + mask_start_lt_end = jnp.logical_and( + i * bk + iota >= start, i * bk + iota < end + ).astype(jnp.int32) + mask_start_gt_end = jnp.logical_or( + i * bk + iota >= start, i * bk + iota < end + ).astype(jnp.int32) + # mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) + mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) + + qk = qk + jnp.where(mask, 0.0, mask_value) + + with jax.named_scope("run_softmax"): + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + l_curr = jax.lax.broadcast_in_dim( + s_curr.sum(axis=-1), l_prev.shape, (0,) + ) + if quantized: + s_curr = s_curr * v_scaler_ref[...] + o_curr_times_l_curr = jnp.dot(s_curr, v) + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) + / l_next_safe + ).astype(o_ref.dtype) + + +@functools.partial( + jax.jit, static_argnames=["bk", "mask_value", "normalize_var"] +) +def ragged_mqa( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + ragged_batch_index=None, + ragged_block_index=None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi query attention.""" + with jax.named_scope("ragged_mqa"): + batch_size, num_heads, head_dim = q.shape + seq_len = k.shape[1] + + def kv_index_map( + b, + i, + start_ref, + end_ref, + line_end_ref, + ragged_batch_index_ref, + ragged_block_index_ref, + ): + index = b * (seq_len // bk) + i + return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 + + def q_index_map( + b, + i, + start_ref, + end_ref, + line_end_ref, + ragged_batch_index_ref, + ragged_block_index_ref, + ): + index = b * (seq_len // bk) + i + return ragged_batch_index_ref[index], 0, 0 + + def scaler_index_map(b, i, *_): + return b, 0, i + + line_end = jnp.where(start < end, end, seq_len - 1) + + in_specs = [ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + ] + inputs = ( + start, + end, + line_end, + ragged_batch_index, + ragged_block_index, + q, + k, + v, + ) + quantized = False + if k_scaler is not None: + in_specs = in_specs + [ + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + pl.BlockSpec(scaler_index_map, (None, 1, bk)), + ] + inputs = inputs + (k_scaler, v_scaler) + quantized = True + + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=quantized, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + out_shape=[ + q, + jax.ShapeDtypeStruct( + (batch_size, num_heads, head_dim), jnp.float32 + ), + jax.ShapeDtypeStruct( + (batch_size, num_heads, head_dim), jnp.float32 + ), + ], + )(*inputs) + return out, (m[..., 0], l[..., 0]) + + +@functools.partial( + jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"] +) +def ragged_mha( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + ragged_batch_index: jax.Array, + ragged_block_index: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, + shard_axis: int = 1, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi head attention. + Args: + q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. + k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or + PartitionQuantizedTensor. + start: A i32[batch_size] jax.Array + end: A i32[batch_size] jax.Array + bk: An integer that is the sequence block size. + logit_cap: An optional float that caps logits via tanh. By default there is + no logit capping. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + out_dtype: An optional dtype for the output. If not provided, the output + dtype will be q's dtype. + Returns: + The output of attention([batch_size, num_heads, compute_dim, head_dim]), + along with the max logit ([batch_size, num_heads, compute_dim, 1]) and + softmax denominator ([batch_size, num_heads, compute_dim, 1]). + """ + mask_value = DEFAULT_MASK_VALUE + if k_scaler is None: + replicated_in_axes = 4 + replicated_inputs = (ragged_batch_index, ragged_block_index) + else: + replicated_in_axes = 6 + replicated_inputs = ( + jnp.squeeze(k_scaler, -1), + jnp.squeeze(v_scaler, -1), + ragged_batch_index, + ragged_block_index, + ) + + with jax.named_scope("ragged_mha_vmap"): + out, (m, l) = jax.vmap( + functools.partial( + ragged_mqa, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + # out_dtype=out_dtype, + ), + in_axes=( + shard_axis, + shard_axis, + shard_axis, + *([None] * replicated_in_axes), + ), + out_axes=shard_axis, + )(q, k, v, start, end, *replicated_inputs) + return out, (m, l) + + +def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): + """The vanilla attention kernel implementation.""" + + bsz, _, _, head_dim = xq.shape + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if k_scaler is not None: + scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + if mask is not None: + # if mask.shape != (1,1,16,16): + # breakpoint() + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + if v_scaler is not None: + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch.einsum("ikjm,ikml->ikjl", scores, values) + return output + + +class RaggedAttentionKernel: + """Ragged attention kernel.""" + + def __init__(self, env, input_specs, output_specs, sharding_axis): + self.binded_ragged_mha = functools.partial( + ragged_mha, bk=env.block_size, shard_axis=sharding_axis + ) + self.binded_ragged_mha = shard_map( + ragged_mha, env.mesh, input_specs, output_specs, check_rep=False + ) + self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) + + def __call__( + self, + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + k_scaler=None, + v_scaler=None, + ): + return self.binded_ragged_mha( + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + k_scaler, + v_scaler, + ) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index a4e391c3..bdf5fe41 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -58,6 +58,26 @@ lambda value: value in _VALID_QUANTIZATION_TYPE, f"quantize_type is invalid, supported quantization types are {_VALID_QUANTIZATION_TYPE}", ) +flags.DEFINE_bool( + "profiling_prefill", + False, + "Whether to profile the prefill, " + "if set to false, profile generate function only", + required=False, +) +flags.DEFINE_bool( + "ragged_mha", + False, + "Whether to enable Ragged multi head attention", + required=False, +) +flags.DEFINE_integer( + "starting_position", + 512, + "The starting position of decoding, " + "for performance tuning and debugging only", + required=False, +) def create_quantization_config_from_flags(): @@ -112,6 +132,8 @@ def create_engine_from_config_flags(): max_cache_length=FLAGS.max_cache_length, sharding_config=sharding_file_name, shard_on_batch=FLAGS.shard_on_batch, + ragged_mha=FLAGS.ragged_mha, + starting_position=FLAGS.starting_position, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index defa3e94..a709e90a 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -64,6 +64,7 @@ class DecodeState: ] # only present in quantized kv current_position: int lens: jax.Array # [batch_size, 1] + start: jax.Array # [batch_size, 1], the starting pos for each slot input_pos: jax.Array # [batch_size, 1] input pos for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid @@ -126,8 +127,9 @@ def init_decode_state( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, scalers, - self.env.max_input_sequence_length, - jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), + self.env.starting_position, + jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), # lens + jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # start pos jnp.zeros((self.env.batch_size,), dtype=jnp.int32), # input pos jnp.full( (self.env.batch_size, self.env.cache_sequence_length), @@ -145,7 +147,10 @@ def _call_model_generate( caches, cache_scales, mask, + start, input_pos, + ragged_batch_index, + ragged_block_index, ): if self.env.quant_config.enable_kv_quantization: caches_obj = [ @@ -163,7 +168,15 @@ def _call_model_generate( ] mask = jnp.expand_dims(mask, (1, 2)) - args = (tokens, input_pos, caches_obj, mask) + args = ( + tokens, + input_pos, + caches_obj, + mask, + start, + ragged_batch_index, + ragged_block_index, + ) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: with torchjax.jax_mode: @@ -272,6 +285,9 @@ def _insert_no_wrap( cond = jnp.logical_and(x <= decode_state.current_position, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) + start = decode_state.start.at[slot].set( + pos % self.env.cache_sequence_length + ) input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) if not self.env.quant_config.enable_kv_quantization: @@ -328,6 +344,7 @@ def insert(cache, scaler, new_entry): scales, decode_state.current_position, lens, + start, input_pos, mask, ) @@ -366,6 +383,7 @@ def _insert_wrap( mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) + start = decode_state.start.at[slot].set(start_insert) input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) old_caches = decode_state.caches @@ -420,6 +438,7 @@ def insert(cache, scaler, new_entry): scales, decode_state.current_position, lens, + start, input_pos, mask, ) @@ -448,6 +467,63 @@ def insert( slot, ) + def precompute_ragged_block_indices(self, decode_state: DecodeState): + """Precompute the ragged attention block indices. Ragged attention iterates the grid + and relies on the computed grid index to skip the unnecessary blocks. The basic idea + is to use input_pos, which is the length of each slot to determine if we should + work on the next block of the slot or move to the next slot.""" + start = decode_state.start + end = (start + decode_state.input_pos) % self.env.cache_len + batch_size = start.shape[0] + bk = self.env.block_size + # The batch index + b = jnp.arange(batch_size).reshape((batch_size, 1)) + num_bk = self.env.cache_len // self.env.block_size + # The block index + i = jnp.arange(num_bk).reshape((1, num_bk)) + i = jnp.broadcast_to(i, (batch_size, num_bk)) + + start = start.reshape((batch_size, 1)) + end = end.reshape((batch_size, 1)) + + am_last_batch = b == batch_size - 1 + last_good_block = jnp.where( + start < end, + jax.lax.div(end - 1, bk), + jax.lax.div(self.env.cache_len - 1, bk), + ) + + next_b = jnp.where(am_last_batch, b, b + 1) + next_i = jnp.where(am_last_batch, last_good_block, 0) + + # start < end, continue work on the block is there is overlap with the [start, end) + def true_comp(b, i, bk, start, end, next_b, next_i): + b_next = jnp.where(i * bk >= end, next_b, b) + i_next = jnp.where(i * bk >= end, next_i, i) + i_next = jnp.where((i + 1) * bk <= start, jax.lax.div(start, bk), i_next) + return b_next, i_next + + # start > end, continue work on the block is there is no overlap with [end, start) + def false_comp(b, i, bk, start, end): + b_next = b + i_next = jnp.where( + jnp.logical_and(i * bk >= end, (i + 1) * bk <= start), + jax.lax.div(start, bk), + i, + ) + return b_next, i_next + + true_comp_b, true_comp_i = true_comp(b, i, bk, start, end, next_b, next_i) + false_comp_b, false_comp_i = false_comp(b, i, bk, start, end) + + b_next = jnp.where( + start < end, true_comp_b, jnp.where(start == end, next_b, false_comp_b) + ) + i_next = jnp.where( + start < end, true_comp_i, jnp.where(start == end, next_i, false_comp_i) + ) + return b_next, i_next + def generate( self, params: Any, decode_state: DecodeState ) -> tuple[DecodeState, engine_api.ResultTokens]: @@ -457,6 +533,13 @@ def generate( # fill mask first mask = decode_state.mask.at[:, decode_state.current_position].set(0) + ragged_batch_index, ragged_block_index = ( + self.precompute_ragged_block_indices(decode_state) + ) + ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( + (-1) + ), ragged_block_index.reshape((-1)) + logits, new_caches, new_scales = self._call_model_generate( params, decode_state.tokens, @@ -464,8 +547,12 @@ def generate( decode_state.caches, decode_state.cache_scales, mask, + decode_state.start, decode_state.input_pos, + ragged_batch_index, + ragged_block_index, ) + next_token = self._sampling(logits, self.env.batch_size) lens = decode_state.lens + 1 data = jnp.concatenate( @@ -493,6 +580,7 @@ def generate( new_scales, (decode_state.current_position + 1) % self.env.cache_sequence_length, lens, + decode_state.start, decode_state.input_pos + 1, mask, ) @@ -619,6 +707,7 @@ def get_decode_state_sharding(self) -> DecodeState: self.replicated, self.replicated, self.replicated, + self.replicated, ) def get_prefix_sequence_ddim(self) -> Any: @@ -666,6 +755,8 @@ def create_pytorch_engine( max_cache_length=1024, sharding_config=None, shard_on_batch=False, + ragged_mha=False, + starting_position=512, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -724,6 +815,8 @@ def create_pytorch_engine( bf16_enable=bf16_enable, sharding_config_path=sharding_config, shard_on_batch=shard_on_batch, + ragged_mha=ragged_mha, + starting_position=starting_position, ) if shard_on_batch and sharding_config: @@ -756,6 +849,7 @@ def create_pytorch_engine( env_data.model_type = model_name + "-" + param_size env_data.num_layers = args.num_hidden_layers env = JetEngineEnvironment(env_data) + print(f"Enviroment variables: {vars(env)}") pt_model = gemma_model.GemmaModel(args, env) else: raise RuntimeError(f"Model with name {model_name} not found") diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index b4df9980..5ea8f3a3 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -87,6 +87,15 @@ class JetEngineEnvironmentData: # Whether to shard on batch dimension. i.e. data parallel. shard_on_batch: bool = False + # Whether to enable ragged multi head attention. + ragged_mha: bool = False + + # The block size for the ragged attention. + block_size: int = 512 + + # Starting position + starting_position: int = 512 + # pylint: disable-next=all class JetEngineEnvironment: @@ -95,19 +104,22 @@ def __init__(self, data: JetEngineEnvironmentData): self._data = data self.seq_len = self._data.max_input_sequence_length - + self.cache_len = self._data.cache_sequence_length + self.ragged_mha = self._data.ragged_mha + self.block_size = self._data.block_size + self.starting_position = self._data.starting_position P = jax.sharding.PartitionSpec num_of_partitions = jax.device_count() # make mesh etc. - self._mesh = jsharding.Mesh( + self.mesh = jsharding.Mesh( mesh_utils.create_device_mesh((num_of_partitions, 1)), axis_names=("x", "y"), ) - self.y_sharding = jsharding.NamedSharding(self._mesh, P(None, "x")) - self.x_sharding = jsharding.NamedSharding(self._mesh, P("x")) - self.replicated = jsharding.NamedSharding(self._mesh, P()) + self.y_sharding = jsharding.NamedSharding(self.mesh, P(None, "x")) + self.x_sharding = jsharding.NamedSharding(self.mesh, P("x")) + self.replicated = jsharding.NamedSharding(self.mesh, P()) if data.shard_on_batch: cache_sharding_axis = 0 @@ -144,17 +156,19 @@ def apply_sharding(self, tensor, *, axis: int | None): # pylint: disable-next=all tensor._elem = jax.lax.with_sharding_constraint(tensor._elem, sharding_spec) - def sharding_by_axis(self, axis): + def partition_by_axis(self, axis=None): """return sharding partition spc by axis, options are x, y, -1 or Noe""" if axis == -1 or axis is None: - return jsharding.NamedSharding(self._mesh, jax.sharding.PartitionSpec()) + return jax.sharding.PartitionSpec() sharding = [None] * (axis + 1) sharding[axis] = "x" - sharding_spec = jsharding.NamedSharding( - self._mesh, jax.sharding.PartitionSpec(*sharding) - ) + sharding_spec = jax.sharding.PartitionSpec(*sharding) return sharding_spec + def sharding_by_axis(self, axis): + """return sharding partition spc by axis, options are x, y, -1 or Noe""" + return jsharding.NamedSharding(self.mesh, self.partition_by_axis(axis)) + def make_caches_prefill(self): """Create kv caches for inference prefill""" caches = [] diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index a44c1f3e..c5e305b8 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -15,10 +15,10 @@ # pylint: disable-all """This version contains modification to make it easier to trace and support batch.""" -import math from typing import Optional, Tuple import jax +from . import attention_kernel as ak import jax.numpy as jnp import torch import torch.nn.functional as F @@ -31,6 +31,7 @@ quantize_tensor, ) from torch import nn +from . import attention_kernel as ak def _calc_cosine_dist(x, y): @@ -39,6 +40,9 @@ def _calc_cosine_dist(x, y): return (torch.dot(x, y) / (x.norm() * y.norm())).item() +import numpy as np + + class Int8Embedding(torch.nn.Module): def __init__(self, num_embeddings, embedding_dims, device="cpu"): @@ -399,8 +403,29 @@ class AttentionKernel: def __init__(self, env): self.env = env + self.shard_axis = 0 if self.env.shard_on_batch else 1 + qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + others_pspec = self.env.partition_by_axis() + self.dense_attention = ak.dense_attention + self.ragged_attention = ak.RaggedAttentionKernel( + env, + input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), + output_specs=(qkv_pspec, (others_pspec, others_pspec)), + sharding_axis=self.shard_axis, + ) - def __call__(self, xq, xk, xv, mask, cache): + def __call__( + self, + xq, + xk, + xv, + mask, + cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, + ): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -412,35 +437,34 @@ def __call__(self, xq, xk, xv, mask, cache): bsz, num_heads, seqlen, head_dim = xq.shape _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads - if seqlen == 1: + if not self.env.ragged_mha and seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) with jax.named_scope("attn_insert_cache"): keys, values = cache.update(xk, xv) keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - if seqlen == 1: + + with jax.named_scope("attn_qkv"): + if self.env.ragged_mha and seqlen == 1: + output, _ = torch_xla2.interop.call_jax( + self.ragged_attention, + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + ) + else: + output = self.dense_attention(xq, keys, values, None, None, mask) + + if not self.env.ragged_mha and seqlen == 1: output = output[:, :, 0:1, :] # For XLA matmul performance boost # output = torch.matmul(scores, values) - shard_axis = 0 if self.env.shard_on_batch else 1 - self.env.apply_sharding(output, axis=shard_axis) + self.env.apply_sharding(output, axis=self.shard_axis) return output @@ -448,8 +472,29 @@ class Int8KVAttentionKernel: def __init__(self, env): self.env = env + self.shard_axis = 0 if self.env.shard_on_batch else 1 + qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + others_pspec = self.env.partition_by_axis() + self.dense_attention = ak.dense_attention + self.ragged_attention = ak.RaggedAttentionKernel( + env, + input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), + output_specs=(qkv_pspec, (others_pspec, others_pspec)), + sharding_axis=self.shard_axis, + ) - def __call__(self, xq, xk, xv, mask, cache): + def __call__( + self, + xq, + xk, + xv, + mask, + cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, + ): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) @@ -461,37 +506,38 @@ def __call__(self, xq, xk, xv, mask, cache): bsz, num_heads, seqlen, head_dim = xq.shape _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads - if seqlen == 1: + + if not self.env.ragged_mha and seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) with jax.named_scope("attn_insert_cache"): keys, values, k_scaler, v_scaler = cache.update(xk, xv) keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = ( - torch.einsum("ikjl,ikml->ikjm", xq, keys) - / math.sqrt(head_dim) - * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - ) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - if seqlen == 1: + + with jax.named_scope("attn_qkv"): + if self.env.ragged_mha and seqlen == 1: + output, _ = torch_xla2.interop.call_jax( + self.ragged_attention, + xq, + keys, + values, + start, + end, + ragged_batch_index, + ragged_block_index, + k_scaler, + v_scaler, + ) + else: + output = self.dense_attention( + xq, keys, values, k_scaler, v_scaler, mask + ) + + if not self.env.ragged_mha and seqlen == 1: output = output[:, :, 0:1, :] - # output = torch.matmul(scores, values) - shard_axis = 0 if self.env.shard_on_batch else 1 - self.env.apply_sharding(output, axis=shard_axis) + + self.env.apply_sharding(output, axis=self.shard_axis) return output @@ -566,6 +612,10 @@ def forward( freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, ): with jax.named_scope("attn_linear_before_cache"): bsz, seqlen = x.shape[0], x.shape[-2] @@ -593,6 +643,16 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - output = self.attention_kernel(xq, xk, xv, mask, cache) + output = self.attention_kernel( + xq, + xk, + xv, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ).type_as(xq) output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 8cd87b13..73d8e07e 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -135,6 +135,10 @@ def forward( freqs_cis, mask, cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, ) -> torch.Tensor: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 @@ -164,7 +168,17 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - output = self.attention_kernel(xq, xk, xv, mask, cache) + output = self.attention_kernel( + xq, + xk, + xv, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ) # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) @@ -264,6 +278,10 @@ def forward( freqs_cis: torch.Tensor, cache: Any, mask: torch.Tensor, + start: torch.Tensor | None = None, + end: torch.Tensor | None = None, + ragged_batch_index: torch.Tensor | None = None, + ragged_block_index: torch.Tensor | None = None, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -273,6 +291,10 @@ def forward( freqs_cis=freqs_cis, mask=mask, cache=cache, + start=start, + end=end, + ragged_batch_index=ragged_batch_index, + ragged_block_index=ragged_block_index, ) hidden_states = residual + hidden_states @@ -321,7 +343,20 @@ def forward( input_pos: torch.Tensor, caches: List[Any], mask, + start=None, + ragged_batch_index=None, + ragged_block_index=None, ): + """ + tokens: the input token for decoding + caches: kv caches + mask: causal mask to filter the attention results + start: the starting position for each slot + input_pos: the decoding position relative to the start, which is the length of the decoding results + ragged_batch_index: precomputed batch index for ragged attention + ragged_block_index: precomputed block index for ragged attention + """ + with jax.named_scope("transformer_freq"): bsz, seqlen = tokens.shape freqs_cis = self.freqs_cis[input_pos] @@ -330,6 +365,8 @@ def forward( hidden_states = self.embedder(tokens) hidden_states = hidden_states * (self.config.hidden_size**0.5) + end = None if start is None else (start + input_pos) % self.env.cache_len + for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer( @@ -337,6 +374,10 @@ def forward( freqs_cis=freqs_cis, cache=caches[i], mask=mask, + start=start, + end=end, + ragged_batch_index=ragged_batch_index, + ragged_block_index=ragged_block_index, ) hidden_states = self.norm(hidden_states) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 7c692b22..2385839e 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -109,10 +109,21 @@ def forward( freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, ): with jax.named_scope("Attention"): attn = self.attention.forward( - self.attention_norm(x), freqs_cis, mask, cache + self.attention_norm(x), + freqs_cis, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, ) with jax.named_scope("ffn_norm"): h = x + attn @@ -183,7 +194,20 @@ def forward( input_pos: torch.Tensor, caches: List[Any], mask, + start=None, + ragged_batch_index=None, + ragged_block_index=None, ): + """ + tokens: the input token for decoding + caches: kv caches + mask: causal mask to filter the attention results + start: the starting position for each slot + input_pos: the decoding position relative to the start, which is the length of the decoding results + ragged_batch_index: precomputed batch index for ragged attention + ragged_block_index: precomputed block index for ragged attention + """ + with jax.named_scope("transformer_tok"): seqlen = tokens.shape[-1] h = self.tok_embeddings(tokens) @@ -196,9 +220,19 @@ def forward( assert len(caches) == len( self.layers ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" + end = None if start is None else (start + input_pos) % self.env.cache_len for layer, cache in zip(self.layers, caches): with jax.named_scope("TransformerBlock"): - h = layer(h, freqs_cis, mask, cache) + h = layer( + h, + freqs_cis, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ) with jax.named_scope("transformer_norm"): h = self.norm(h) diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 66695bf5..d939b991 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -44,6 +44,7 @@ def create_engine(): max_cache_length=FLAGS.max_cache_length, sharding_config=FLAGS.sharding_config, shard_on_batch=FLAGS.shard_on_batch, + ragged_mha=FLAGS.ragged_mha, ) print("Initialize engine", time.perf_counter() - start) @@ -53,7 +54,7 @@ def create_engine(): # pylint: disable-next=all def main(argv): - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() engine.load_params()