Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
cf45d7f
Stable version of ragged attention.
wang2yn84 May 16, 2024
d2bb514
Converts the attention output types the same as q.
wang2yn84 May 16, 2024
8482117
Fixes the typo for the ragged attention.
wang2yn84 May 17, 2024
4585ab4
Provides the default value for partition_by_axis.
wang2yn84 May 17, 2024
1498ba9
Provides mesh to the shard_map.
wang2yn84 May 17, 2024
81bfaa6
Fixes typo.
wang2yn84 May 17, 2024
01d2eef
Fixes typo, should be start instead of start_pos.
wang2yn84 May 17, 2024
5603879
Should use "//" instead of "/" to get int results.
wang2yn84 May 17, 2024
2488297
Use block size // 2 as the starting current position for better initi…
wang2yn84 May 17, 2024
f04b20a
Updates the run_interactive script to use the correct result token pr…
wang2yn84 May 17, 2024
6aaf6d9
Fix typo, should use token_utils.process_result_token.
wang2yn84 May 17, 2024
cd84291
Fix typo.
wang2yn84 May 17, 2024
53240bc
Fixes the sampled tokens list.
wang2yn84 May 17, 2024
ed368b5
Use text_tokens_to_str to convert the output tokens.
wang2yn84 May 17, 2024
5264f11
Reshape the precomputed grid indices to 1D. Removes the
wang2yn84 May 17, 2024
a4241d9
Should check if X is None instead of if X
wang2yn84 May 17, 2024
00a8fa0
Fix the dense_attention not returning data.
wang2yn84 May 17, 2024
4a26aed
Reshape the kv scaler to 3 dim for ragged attention.
wang2yn84 May 17, 2024
7fdf340
Cannot stop the input_pos counter from increasing since we are using …
wang2yn84 May 20, 2024
0721646
Adds starting_position and profiling_prefill for better testing and b…
wang2yn84 May 20, 2024
930eaa0
Move flags in scripts to a common function (#92)
lsy323 May 20, 2024
97c6435
Stable version of ragged attention.
wang2yn84 May 16, 2024
6be5ec3
Fix the merge conflicts
wang2yn84 May 21, 2024
6ae0f9d
Fixes the missing pieces after merging conflicts. Adds couple of new …
wang2yn84 May 21, 2024
212aa8e
Integrates ragged attention to Gemma too.
wang2yn84 May 21, 2024
ddb32e0
Somehow have some local changes to run_interactive, reverting them to…
wang2yn84 May 21, 2024
fb68025
Set the default value for the newly added parameters.
wang2yn84 May 21, 2024
2def37c
Adds more descriptions to the ragged attention index precompuation fu…
wang2yn84 May 22, 2024
268c407
Merges the quantized ragged attention kernel with the non quantized v…
wang2yn84 May 22, 2024
8fa8fcb
Moves the attention calculation to attention.py for better code struc…
wang2yn84 May 22, 2024
98856f6
Fix run issues refactoring.
wang2yn84 May 22, 2024
ab38726
Fix the quantized version for ragged attention.
wang2yn84 May 22, 2024
a712862
Fix test_attention by adding default value for the newly added argume…
wang2yn84 May 23, 2024
c50aba6
Fixes unit tests, changes the Transformer model call argument order(i…
wang2yn84 May 23, 2024
d431803
Format attention_kernel.py
wang2yn84 May 23, 2024
9061fac
Add descrpitions to ragged attention outputs.
wang2yn84 May 23, 2024
a6059e9
Fix quantization tests by adding default value to quantization kernel…
wang2yn84 May 23, 2024
a73658d
Reformat attention_kernel.py. Format with black doesn't comply with t…
wang2yn84 May 23, 2024
9aaa7a6
Ignores R0913: Too many arguments link error for ragged attention ker…
wang2yn84 May 23, 2024
1286bb5
Ignore R0903: Too few public methods. Fix lint errors.
wang2yn84 May 23, 2024
f64fe6f
Fix the rest of the lint errors.
wang2yn84 May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[MESSAGES CONTROL]
disable=C0114,R0801,E1102,W0613,R1711,too-many-locals
disable=C0114,R0801,R0903,R0913,E1102,W0613,R1711,too-many-locals
347 changes: 347 additions & 0 deletions jetstream_pt/attention_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
import functools
import math

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.shard_map import shard_map

import torch
import torch.nn.functional as F

import numpy as np

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)


def ragged_flash_attention_kernel(
start_ref,
end_ref,
line_end_ref,
pre_b_ref,
pre_i_ref,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q, k, v and related are easy to read. What are the b, i, o, m, l, bk and pre? Can you add brief description to describe them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Done.

q_ref,
k_ref,
v_ref,
k_scaler_ref,
v_scaler_ref,
o_ref, # outputs
m_ref, # row max
l_ref, # propogation coefficient
bk: int,
mask_value: float,
normalize_var: bool,
quantized: bool,
):
"""Pallas kernel for flash attention."""
with jax.named_scope("attention_kernel"):
b, i = pl.program_id(0), pl.program_id(1)

@pl.when(i == 0)
def init():
with jax.named_scope("init"):
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)

length = line_end_ref[b]
start = start_ref[b]
end = end_ref[b]

@pl.when(jnp.logical_and(i * bk < length, start != end))
def run():
with jax.named_scope("run_qk"):
q = q_ref[...].astype(jnp.float32)
k = k_ref[...].astype(jnp.float32)
v = v_ref[...].astype(jnp.float32)
m_prev, l_prev = m_ref[...], l_ref[...]

qk = jax.lax.dot_general(
q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32
)
if normalize_var:
qk = qk / jnp.sqrt(k.shape[-1])
if quantized:
qk = qk * k_scaler_ref[...]
with jax.named_scope("run_mask"):
start = start_ref[b]
end = end_ref[b]
iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1)
mask_start_lt_end = jnp.logical_and(
i * bk + iota >= start, i * bk + iota < end
).astype(jnp.int32)
mask_start_gt_end = jnp.logical_or(
i * bk + iota >= start, i * bk + iota < end
).astype(jnp.int32)
# mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end)
mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end)

qk = qk + jnp.where(mask, 0.0, mask_value)

with jax.named_scope("run_softmax"):
m_curr = qk.max(axis=-1)

s_curr = jnp.exp(qk - m_curr[..., None])

l_curr = jax.lax.broadcast_in_dim(
s_curr.sum(axis=-1), l_prev.shape, (0,)
)
if quantized:
s_curr = s_curr * v_scaler_ref[...]
o_curr_times_l_curr = jnp.dot(s_curr, v)
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)

m_ref[...], l_ref[...] = m_next, l_next_safe
o_ref[...] = (
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr)
/ l_next_safe
).astype(o_ref.dtype)


@functools.partial(
jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]
)
def ragged_mqa(
q: jax.Array,
k: jax.Array,
v: jax.Array,
start: jax.Array,
end: jax.Array,
k_scaler: jax.Array | None = None,
v_scaler: jax.Array | None = None,
ragged_batch_index=None,
ragged_block_index=None,
bk: int = 512,
mask_value: float = DEFAULT_MASK_VALUE,
normalize_var: bool = True,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi query attention."""
with jax.named_scope("ragged_mqa"):
batch_size, num_heads, head_dim = q.shape
seq_len = k.shape[1]

def kv_index_map(
b,
i,
start_ref,
end_ref,
line_end_ref,
ragged_batch_index_ref,
ragged_block_index_ref,
):
index = b * (seq_len // bk) + i
return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0

def q_index_map(
b,
i,
start_ref,
end_ref,
line_end_ref,
ragged_batch_index_ref,
ragged_block_index_ref,
):
index = b * (seq_len // bk) + i
return ragged_batch_index_ref[index], 0, 0

def scaler_index_map(b, i, *_):
return b, 0, i

line_end = jnp.where(start < end, end, seq_len - 1)

in_specs = [
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
pl.BlockSpec(kv_index_map, (None, bk, head_dim)),
pl.BlockSpec(kv_index_map, (None, bk, head_dim)),
]
inputs = (
start,
end,
line_end,
ragged_batch_index,
ragged_block_index,
q,
k,
v,
)
quantized = False
if k_scaler is not None:
in_specs = in_specs + [
pl.BlockSpec(scaler_index_map, (None, 1, bk)),
pl.BlockSpec(scaler_index_map, (None, 1, bk)),
]
inputs = inputs + (k_scaler, v_scaler)
quantized = True

out, m, l = pl.pallas_call(
functools.partial(
ragged_flash_attention_kernel,
bk=bk,
mask_value=mask_value,
normalize_var=normalize_var,
quantized=quantized,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=5,
in_specs=in_specs,
out_specs=[
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
],
grid=(batch_size, seq_len // bk),
),
compiler_params={"dimension_semantics": ("parallel", "arbitrary")},
out_shape=[
q,
jax.ShapeDtypeStruct(
(batch_size, num_heads, head_dim), jnp.float32
),
jax.ShapeDtypeStruct(
(batch_size, num_heads, head_dim), jnp.float32
),
],
)(*inputs)
return out, (m[..., 0], l[..., 0])


@functools.partial(
jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"]
)
def ragged_mha(
q: jax.Array,
k: jax.Array,
v: jax.Array,
start: jax.Array,
end: jax.Array,
ragged_batch_index: jax.Array,
ragged_block_index: jax.Array,
k_scaler: jax.Array | None = None,
v_scaler: jax.Array | None = None,
bk: int = 512,
mask_value: float = DEFAULT_MASK_VALUE,
normalize_var: bool = True,
shard_axis: int = 1,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi head attention.
Args:
q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array.
k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or
PartitionQuantizedTensor.
v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or
PartitionQuantizedTensor.
start: A i32[batch_size] jax.Array
end: A i32[batch_size] jax.Array
bk: An integer that is the sequence block size.
logit_cap: An optional float that caps logits via tanh. By default there is
no logit capping.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
out_dtype: An optional dtype for the output. If not provided, the output
dtype will be q's dtype.
Returns:
The output of attention([batch_size, num_heads, compute_dim, head_dim]),
along with the max logit ([batch_size, num_heads, compute_dim, 1]) and
softmax denominator ([batch_size, num_heads, compute_dim, 1]).
"""
mask_value = DEFAULT_MASK_VALUE
if k_scaler is None:
replicated_in_axes = 4
replicated_inputs = (ragged_batch_index, ragged_block_index)
else:
replicated_in_axes = 6
replicated_inputs = (
jnp.squeeze(k_scaler, -1),
jnp.squeeze(v_scaler, -1),
ragged_batch_index,
ragged_block_index,
)

with jax.named_scope("ragged_mha_vmap"):
out, (m, l) = jax.vmap(
functools.partial(
ragged_mqa,
bk=bk,
mask_value=mask_value,
normalize_var=normalize_var,
# out_dtype=out_dtype,
),
in_axes=(
shard_axis,
shard_axis,
shard_axis,
*([None] * replicated_in_axes),
),
out_axes=shard_axis,
)(q, k, v, start, end, *replicated_inputs)
return out, (m, l)


def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""

bsz, _, _, head_dim = xq.shape
with jax.named_scope("attn_mat1"):
## Attention start
# scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim)
scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim)
if k_scaler is not None:
scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2]))
if mask is not None:
# if mask.shape != (1,1,16,16):
# breakpoint()
scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen)
with jax.named_scope("attn_soft"):
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
if v_scaler is not None:
scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2]))

with jax.named_scope("attn_mat2"):
# output = torch.einsum(
# "ikjm,ikml->ikjl", scores, values
# ) # (bs, n_local_heads, seqlen, head_dim)
output = torch.einsum("ikjm,ikml->ikjl", scores, values)
return output


class RaggedAttentionKernel:
"""Ragged attention kernel."""

def __init__(self, env, input_specs, output_specs, sharding_axis):
self.binded_ragged_mha = functools.partial(
ragged_mha, bk=env.block_size, shard_axis=sharding_axis
)
self.binded_ragged_mha = shard_map(
ragged_mha, env.mesh, input_specs, output_specs, check_rep=False
)
self.binded_ragged_mha = jax.jit(self.binded_ragged_mha)

def __call__(
self,
xq,
keys,
values,
start,
end,
ragged_batch_index,
ragged_block_index,
k_scaler=None,
v_scaler=None,
):
return self.binded_ragged_mha(
xq,
keys,
values,
start,
end,
ragged_batch_index,
ragged_block_index,
k_scaler,
v_scaler,
)
22 changes: 22 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@
lambda value: value in _VALID_QUANTIZATION_TYPE,
f"quantize_type is invalid, supported quantization types are {_VALID_QUANTIZATION_TYPE}",
)
flags.DEFINE_bool(
"profiling_prefill",
False,
"Whether to profile the prefill, "
"if set to false, profile generate function only",
required=False,
)
flags.DEFINE_bool(
"ragged_mha",
False,
"Whether to enable Ragged multi head attention",
required=False,
)
flags.DEFINE_integer(
"starting_position",
512,
"The starting position of decoding, "
"for performance tuning and debugging only",
required=False,
)


def create_quantization_config_from_flags():
Expand Down Expand Up @@ -112,6 +132,8 @@ def create_engine_from_config_flags():
max_cache_length=FLAGS.max_cache_length,
sharding_config=sharding_file_name,
shard_on_batch=FLAGS.shard_on_batch,
ragged_mha=FLAGS.ragged_mha,
starting_position=FLAGS.starting_position,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down
Loading