-
Notifications
You must be signed in to change notification settings - Fork 18
Integrates ragged attention to JetStream Pytorch #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
cf45d7f
Stable version of ragged attention.
wang2yn84 d2bb514
Converts the attention output types the same as q.
wang2yn84 8482117
Fixes the typo for the ragged attention.
wang2yn84 4585ab4
Provides the default value for partition_by_axis.
wang2yn84 1498ba9
Provides mesh to the shard_map.
wang2yn84 81bfaa6
Fixes typo.
wang2yn84 01d2eef
Fixes typo, should be start instead of start_pos.
wang2yn84 5603879
Should use "//" instead of "/" to get int results.
wang2yn84 2488297
Use block size // 2 as the starting current position for better initi…
wang2yn84 f04b20a
Updates the run_interactive script to use the correct result token pr…
wang2yn84 6aaf6d9
Fix typo, should use token_utils.process_result_token.
wang2yn84 cd84291
Fix typo.
wang2yn84 53240bc
Fixes the sampled tokens list.
wang2yn84 ed368b5
Use text_tokens_to_str to convert the output tokens.
wang2yn84 5264f11
Reshape the precomputed grid indices to 1D. Removes the
wang2yn84 a4241d9
Should check if X is None instead of if X
wang2yn84 00a8fa0
Fix the dense_attention not returning data.
wang2yn84 4a26aed
Reshape the kv scaler to 3 dim for ragged attention.
wang2yn84 7fdf340
Cannot stop the input_pos counter from increasing since we are using …
wang2yn84 0721646
Adds starting_position and profiling_prefill for better testing and b…
wang2yn84 930eaa0
Move flags in scripts to a common function (#92)
lsy323 97c6435
Stable version of ragged attention.
wang2yn84 6be5ec3
Fix the merge conflicts
wang2yn84 6ae0f9d
Fixes the missing pieces after merging conflicts. Adds couple of new …
wang2yn84 212aa8e
Integrates ragged attention to Gemma too.
wang2yn84 ddb32e0
Somehow have some local changes to run_interactive, reverting them to…
wang2yn84 fb68025
Set the default value for the newly added parameters.
wang2yn84 2def37c
Adds more descriptions to the ragged attention index precompuation fu…
wang2yn84 268c407
Merges the quantized ragged attention kernel with the non quantized v…
wang2yn84 8fa8fcb
Moves the attention calculation to attention.py for better code struc…
wang2yn84 98856f6
Fix run issues refactoring.
wang2yn84 ab38726
Fix the quantized version for ragged attention.
wang2yn84 a712862
Fix test_attention by adding default value for the newly added argume…
wang2yn84 c50aba6
Fixes unit tests, changes the Transformer model call argument order(i…
wang2yn84 d431803
Format attention_kernel.py
wang2yn84 9061fac
Add descrpitions to ragged attention outputs.
wang2yn84 a6059e9
Fix quantization tests by adding default value to quantization kernel…
wang2yn84 a73658d
Reformat attention_kernel.py. Format with black doesn't comply with t…
wang2yn84 9aaa7a6
Ignores R0913: Too many arguments link error for ragged attention ker…
wang2yn84 1286bb5
Ignore R0903: Too few public methods. Fix lint errors.
wang2yn84 f64fe6f
Fix the rest of the lint errors.
wang2yn84 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| [MESSAGES CONTROL] | ||
| disable=C0114,R0801,E1102,W0613,R1711,too-many-locals | ||
| disable=C0114,R0801,R0903,R0913,E1102,W0613,R1711,too-many-locals |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,347 @@ | ||
| import functools | ||
| import math | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax.experimental import pallas as pl | ||
| from jax.experimental.pallas import tpu as pltpu | ||
| from jax.experimental.shard_map import shard_map | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| import numpy as np | ||
|
|
||
| DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) | ||
|
|
||
|
|
||
| def ragged_flash_attention_kernel( | ||
| start_ref, | ||
| end_ref, | ||
| line_end_ref, | ||
| pre_b_ref, | ||
| pre_i_ref, | ||
| q_ref, | ||
| k_ref, | ||
| v_ref, | ||
| k_scaler_ref, | ||
| v_scaler_ref, | ||
| o_ref, # outputs | ||
| m_ref, # row max | ||
| l_ref, # propogation coefficient | ||
| bk: int, | ||
| mask_value: float, | ||
| normalize_var: bool, | ||
| quantized: bool, | ||
| ): | ||
| """Pallas kernel for flash attention.""" | ||
| with jax.named_scope("attention_kernel"): | ||
| b, i = pl.program_id(0), pl.program_id(1) | ||
|
|
||
| @pl.when(i == 0) | ||
| def init(): | ||
| with jax.named_scope("init"): | ||
| m_ref[...] = jnp.full_like(m_ref, -jnp.inf) | ||
| l_ref[...] = jnp.zeros_like(l_ref) | ||
| o_ref[...] = jnp.zeros_like(o_ref) | ||
|
|
||
| length = line_end_ref[b] | ||
| start = start_ref[b] | ||
| end = end_ref[b] | ||
|
|
||
| @pl.when(jnp.logical_and(i * bk < length, start != end)) | ||
| def run(): | ||
| with jax.named_scope("run_qk"): | ||
| q = q_ref[...].astype(jnp.float32) | ||
| k = k_ref[...].astype(jnp.float32) | ||
| v = v_ref[...].astype(jnp.float32) | ||
| m_prev, l_prev = m_ref[...], l_ref[...] | ||
|
|
||
| qk = jax.lax.dot_general( | ||
| q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 | ||
| ) | ||
| if normalize_var: | ||
| qk = qk / jnp.sqrt(k.shape[-1]) | ||
| if quantized: | ||
| qk = qk * k_scaler_ref[...] | ||
| with jax.named_scope("run_mask"): | ||
| start = start_ref[b] | ||
| end = end_ref[b] | ||
| iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) | ||
| mask_start_lt_end = jnp.logical_and( | ||
| i * bk + iota >= start, i * bk + iota < end | ||
| ).astype(jnp.int32) | ||
| mask_start_gt_end = jnp.logical_or( | ||
| i * bk + iota >= start, i * bk + iota < end | ||
| ).astype(jnp.int32) | ||
| # mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) | ||
| mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) | ||
|
|
||
| qk = qk + jnp.where(mask, 0.0, mask_value) | ||
|
|
||
| with jax.named_scope("run_softmax"): | ||
| m_curr = qk.max(axis=-1) | ||
|
|
||
| s_curr = jnp.exp(qk - m_curr[..., None]) | ||
|
|
||
| l_curr = jax.lax.broadcast_in_dim( | ||
| s_curr.sum(axis=-1), l_prev.shape, (0,) | ||
| ) | ||
| if quantized: | ||
| s_curr = s_curr * v_scaler_ref[...] | ||
| o_curr_times_l_curr = jnp.dot(s_curr, v) | ||
| m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) | ||
| m_next = jnp.maximum(m_prev, m_curr) | ||
| alpha = jnp.exp(m_prev - m_next) | ||
| beta = jnp.exp(m_curr - m_next) | ||
| l_next = alpha * l_prev + beta * l_curr | ||
| l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) | ||
|
|
||
| m_ref[...], l_ref[...] = m_next, l_next_safe | ||
| o_ref[...] = ( | ||
| (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) | ||
| / l_next_safe | ||
| ).astype(o_ref.dtype) | ||
|
|
||
|
|
||
| @functools.partial( | ||
| jax.jit, static_argnames=["bk", "mask_value", "normalize_var"] | ||
| ) | ||
| def ragged_mqa( | ||
| q: jax.Array, | ||
| k: jax.Array, | ||
| v: jax.Array, | ||
| start: jax.Array, | ||
| end: jax.Array, | ||
| k_scaler: jax.Array | None = None, | ||
| v_scaler: jax.Array | None = None, | ||
| ragged_batch_index=None, | ||
| ragged_block_index=None, | ||
| bk: int = 512, | ||
| mask_value: float = DEFAULT_MASK_VALUE, | ||
| normalize_var: bool = True, | ||
| ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: | ||
| """Ragged multi query attention.""" | ||
| with jax.named_scope("ragged_mqa"): | ||
| batch_size, num_heads, head_dim = q.shape | ||
| seq_len = k.shape[1] | ||
|
|
||
| def kv_index_map( | ||
| b, | ||
| i, | ||
| start_ref, | ||
| end_ref, | ||
| line_end_ref, | ||
| ragged_batch_index_ref, | ||
| ragged_block_index_ref, | ||
| ): | ||
| index = b * (seq_len // bk) + i | ||
| return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 | ||
|
|
||
| def q_index_map( | ||
| b, | ||
| i, | ||
| start_ref, | ||
| end_ref, | ||
| line_end_ref, | ||
| ragged_batch_index_ref, | ||
| ragged_block_index_ref, | ||
| ): | ||
| index = b * (seq_len // bk) + i | ||
| return ragged_batch_index_ref[index], 0, 0 | ||
|
|
||
| def scaler_index_map(b, i, *_): | ||
| return b, 0, i | ||
|
|
||
| line_end = jnp.where(start < end, end, seq_len - 1) | ||
|
|
||
| in_specs = [ | ||
| pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), | ||
| pl.BlockSpec(kv_index_map, (None, bk, head_dim)), | ||
| pl.BlockSpec(kv_index_map, (None, bk, head_dim)), | ||
| ] | ||
| inputs = ( | ||
| start, | ||
| end, | ||
| line_end, | ||
| ragged_batch_index, | ||
| ragged_block_index, | ||
| q, | ||
| k, | ||
| v, | ||
| ) | ||
| quantized = False | ||
| if k_scaler is not None: | ||
| in_specs = in_specs + [ | ||
| pl.BlockSpec(scaler_index_map, (None, 1, bk)), | ||
| pl.BlockSpec(scaler_index_map, (None, 1, bk)), | ||
| ] | ||
| inputs = inputs + (k_scaler, v_scaler) | ||
| quantized = True | ||
|
|
||
| out, m, l = pl.pallas_call( | ||
| functools.partial( | ||
| ragged_flash_attention_kernel, | ||
| bk=bk, | ||
| mask_value=mask_value, | ||
| normalize_var=normalize_var, | ||
| quantized=quantized, | ||
| ), | ||
| grid_spec=pltpu.PrefetchScalarGridSpec( | ||
| num_scalar_prefetch=5, | ||
| in_specs=in_specs, | ||
| out_specs=[ | ||
| pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), | ||
| pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), | ||
| pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), | ||
| ], | ||
| grid=(batch_size, seq_len // bk), | ||
| ), | ||
| compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, | ||
| out_shape=[ | ||
| q, | ||
| jax.ShapeDtypeStruct( | ||
| (batch_size, num_heads, head_dim), jnp.float32 | ||
| ), | ||
| jax.ShapeDtypeStruct( | ||
| (batch_size, num_heads, head_dim), jnp.float32 | ||
| ), | ||
| ], | ||
| )(*inputs) | ||
| return out, (m[..., 0], l[..., 0]) | ||
|
|
||
|
|
||
| @functools.partial( | ||
| jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"] | ||
| ) | ||
| def ragged_mha( | ||
| q: jax.Array, | ||
| k: jax.Array, | ||
| v: jax.Array, | ||
| start: jax.Array, | ||
| end: jax.Array, | ||
| ragged_batch_index: jax.Array, | ||
| ragged_block_index: jax.Array, | ||
| k_scaler: jax.Array | None = None, | ||
| v_scaler: jax.Array | None = None, | ||
| bk: int = 512, | ||
| mask_value: float = DEFAULT_MASK_VALUE, | ||
| normalize_var: bool = True, | ||
| shard_axis: int = 1, | ||
| ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: | ||
| """Ragged multi head attention. | ||
| Args: | ||
| q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. | ||
| k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or | ||
| PartitionQuantizedTensor. | ||
| v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or | ||
| PartitionQuantizedTensor. | ||
| start: A i32[batch_size] jax.Array | ||
| end: A i32[batch_size] jax.Array | ||
| bk: An integer that is the sequence block size. | ||
| logit_cap: An optional float that caps logits via tanh. By default there is | ||
| no logit capping. | ||
| mask_value: The value used for padding in attention. By default it is a very | ||
| negative floating point number. | ||
| out_dtype: An optional dtype for the output. If not provided, the output | ||
| dtype will be q's dtype. | ||
| Returns: | ||
| The output of attention([batch_size, num_heads, compute_dim, head_dim]), | ||
| along with the max logit ([batch_size, num_heads, compute_dim, 1]) and | ||
| softmax denominator ([batch_size, num_heads, compute_dim, 1]). | ||
| """ | ||
| mask_value = DEFAULT_MASK_VALUE | ||
| if k_scaler is None: | ||
| replicated_in_axes = 4 | ||
| replicated_inputs = (ragged_batch_index, ragged_block_index) | ||
| else: | ||
| replicated_in_axes = 6 | ||
| replicated_inputs = ( | ||
| jnp.squeeze(k_scaler, -1), | ||
| jnp.squeeze(v_scaler, -1), | ||
| ragged_batch_index, | ||
| ragged_block_index, | ||
| ) | ||
|
|
||
| with jax.named_scope("ragged_mha_vmap"): | ||
| out, (m, l) = jax.vmap( | ||
| functools.partial( | ||
| ragged_mqa, | ||
| bk=bk, | ||
| mask_value=mask_value, | ||
| normalize_var=normalize_var, | ||
| # out_dtype=out_dtype, | ||
| ), | ||
| in_axes=( | ||
| shard_axis, | ||
| shard_axis, | ||
| shard_axis, | ||
| *([None] * replicated_in_axes), | ||
| ), | ||
| out_axes=shard_axis, | ||
| )(q, k, v, start, end, *replicated_inputs) | ||
| return out, (m, l) | ||
|
|
||
|
|
||
| def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): | ||
| """The vanilla attention kernel implementation.""" | ||
|
|
||
| bsz, _, _, head_dim = xq.shape | ||
| with jax.named_scope("attn_mat1"): | ||
| ## Attention start | ||
| # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) | ||
| scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) | ||
| if k_scaler is not None: | ||
| scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) | ||
| if mask is not None: | ||
| # if mask.shape != (1,1,16,16): | ||
| # breakpoint() | ||
| scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) | ||
| with jax.named_scope("attn_soft"): | ||
| scores = F.softmax(scores.float(), dim=-1).type_as(xq) | ||
| if v_scaler is not None: | ||
| scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) | ||
|
|
||
| with jax.named_scope("attn_mat2"): | ||
| # output = torch.einsum( | ||
| # "ikjm,ikml->ikjl", scores, values | ||
| # ) # (bs, n_local_heads, seqlen, head_dim) | ||
| output = torch.einsum("ikjm,ikml->ikjl", scores, values) | ||
| return output | ||
|
|
||
|
|
||
| class RaggedAttentionKernel: | ||
| """Ragged attention kernel.""" | ||
|
|
||
| def __init__(self, env, input_specs, output_specs, sharding_axis): | ||
| self.binded_ragged_mha = functools.partial( | ||
| ragged_mha, bk=env.block_size, shard_axis=sharding_axis | ||
| ) | ||
| self.binded_ragged_mha = shard_map( | ||
| ragged_mha, env.mesh, input_specs, output_specs, check_rep=False | ||
| ) | ||
| self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) | ||
|
|
||
| def __call__( | ||
| self, | ||
| xq, | ||
| keys, | ||
| values, | ||
| start, | ||
| end, | ||
| ragged_batch_index, | ||
| ragged_block_index, | ||
| k_scaler=None, | ||
| v_scaler=None, | ||
| ): | ||
| return self.binded_ragged_mha( | ||
| xq, | ||
| keys, | ||
| values, | ||
| start, | ||
| end, | ||
| ragged_batch_index, | ||
| ragged_block_index, | ||
| k_scaler, | ||
| v_scaler, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q, k, v and related are easy to read. What are the b, i, o, m, l, bk and pre? Can you add brief description to describe them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added. Done.