diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 38edc89..234e9fe 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -558,33 +558,6 @@ def ragged_mha( return out, (m, l) -def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): - """The vanilla attention kernel implementation.""" - - bsz, _, _, head_dim = xq.shape - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if k_scaler is not None: - scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - if v_scaler is not None: - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - return output - - def reshape_heads(xq, keys): """Reshapes the query head for GQA""" bq, hq, tq, dq = xq.shape @@ -607,6 +580,29 @@ def reshape_outputs(rep, o, m=None, d=None): return o, (m, d) +def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): + """The vanilla attention kernel implementation.""" + + bsz, _, _, head_dim = xq.shape + with jax.named_scope("attn_mat1"): + ## Attention start + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if k_scaler is not None: + scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + if v_scaler is not None: + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + + with jax.named_scope("attn_mat2"): + output = torch.einsum( + "ikjm,ikml->ikjl", scores, values + ) # (bs, n_local_heads, seqlen, head_dim) + return output + + def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): """The vanilla attention kernel implementation.""" xq, rep = reshape_heads(xq, keys) @@ -680,7 +676,14 @@ def flash_attention( """Flash attention kernel.""" xq, rep = reshape_heads(xq, keys) o, (logits_max, denominator) = _flash_attention( - xq, keys, values, k_scaler, v_scaler, mask + xq=xq, + keys=keys, + values=values, + layer=layer, + k_scaler=k_scaler, + v_scaler=v_scaler, + mask=mask, + normalize_var=normalize_var, ) return reshape_outputs(rep, o, logits_max, denominator) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index d66909d..1ef2fe5 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -433,12 +433,11 @@ def attend(xq, keys, values, local_mask=None): # When GQA is enabled, it not necessary to expand if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1: true_len = 2 - # xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) xq = torch.nn.functional.pad( xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) - if self.env.ragged_mha and seqlen == 1: + if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( impl, xq, @@ -449,15 +448,28 @@ def attend(xq, keys, values, local_mask=None): end, ragged_batch_index, ragged_block_index, + None, # k_scaler + None, # v_scaler ) elif self.env.flash_attention and seqlen == 1: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention( - xq, keys, values, self.layer_id, mask=local_mask + xq=xq, + keys=keys, + values=values, + layer=self.layer_id, + k_scaler=None, + v_scaler=None, + mask=local_mask, ) else: local_output = self.dense_attention( - xq, keys, values, None, None, local_mask + xq=xq, + keys=keys, + values=values, + k_scaler=None, + v_scaler=None, + mask=local_mask, ) local_max = None local_denom = None @@ -474,9 +486,6 @@ def attend(xq, keys, values, local_mask=None): if local_denom is not None: local_denom = local_denom[:, :, 0:seqlen, :] - # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") - # if local_max is not None and local_denom is not None: - # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") self.env.apply_sharding(local_output, axis=self.q_shard_axis) return local_output, (local_max, local_denom) @@ -486,7 +495,7 @@ def attend(xq, keys, values, local_mask=None): # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend( - xq, orig_keys, orig_values, mask + xq=xq, keys=orig_keys, values=orig_values, local_mask=mask ) # Updating cache during each step still has very large impact on latency. # For non flash attention or prefill, existing output contains everything @@ -495,23 +504,20 @@ def attend(xq, keys, values, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): - new_output, (new_max, new_denom) = attend(xq, xk, xv, None) + new_output, (new_max, new_denom) = attend( + xq=xq, keys=xk, values=xv, local_mask=None + ) with jax.named_scope("attn_global"): - # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") - # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") - - global_sum = existing_denom * torch.exp( - existing_max - ) + new_denom * torch.exp(new_max) - existing_output = ( - existing_output - * existing_denom - * torch.exp(existing_max) - / global_sum - ) - new_output = new_output * new_denom * torch.exp(new_max) / global_sum - attn_out = existing_output + new_output + global_max = torch.max(existing_max, new_max) + alpha = torch.exp(existing_max - global_max) + beta = torch.exp(new_max - global_max) + global_denom = alpha * existing_denom + beta * new_denom + # global_denom = torch.where(global_denom == 0.0, 1.0, global_denom) + attn_out = ( + existing_denom * alpha * existing_output + + beta * new_output * new_denom + ) / global_denom return attn_out @@ -588,8 +594,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) - # We are not using ragged attention for prefill yet. - if self.env.ragged_mha and seqlen == 1: + if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( impl, xq, @@ -606,17 +611,22 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): elif self.env.flash_attention and seqlen == 1: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention( - xq, - keys, - values, - self.layer_id, - k_scaler, - v_scaler, + xq=xq, + keys=keys, + values=values, + layer=self.layer_id, + k_scaler=k_scaler, + v_scaler=v_scaler, mask=local_mask, ) else: local_output = self.dense_attention( - xq, keys, values, k_scaler, v_scaler, local_mask + xq=xq, + keys=keys, + values=values, + k_scaler=k_scaler, + v_scaler=v_scaler, + mask=local_mask, ) local_max = None local_denom = None @@ -648,7 +658,12 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): ) = cache.update(xk, xv, self.layer_id) with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend( - xq, orig_keys, orig_values, k_scaler, v_scaler, mask + xq=xq, + keys=orig_keys, + values=orig_values, + k_scaler=k_scaler, + v_scaler=v_scaler, + local_mask=mask, ) # For non flash attention or prefill, existing output contains everything @@ -663,18 +678,15 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): ) with jax.named_scope("attn_global"): - global_sum = existing_denom * torch.exp( - existing_max - ) + new_denom * torch.exp(new_max) - existing_output = ( - existing_output - * existing_denom - * torch.exp(existing_max) - / global_sum - ) - new_output = new_output * new_denom * torch.exp(new_max) / global_sum - attn_out = existing_output + new_output - + global_max = torch.max(existing_max, new_max) + alpha = torch.exp(existing_max - global_max) + beta = torch.exp(new_max - global_max) + global_denom = alpha * existing_denom + beta * new_denom + # global_denom = torch.where(global_denom == 0.0, 1.0, global_denom) + attn_out = ( + existing_denom * alpha * existing_output + + beta * new_output * new_denom + ) / global_denom return attn_out @@ -800,16 +812,16 @@ def forward( # if cache is not None and cache.cache_k is not None: # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( - xq, - xk, - xv, - mask, + xq=xq, + xk=xk, + xv=xv, + mask=mask, # cache[self.layer_id], - cache, - start, - end, - ragged_batch_index, - ragged_block_index, + cache=cache, + start=start, + end=end, + ragged_batch_index=ragged_batch_index, + ragged_block_index=ragged_block_index, ).type_as(xq) # print(f"output {output.shape}") output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 0b76c86..ff9ff84 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -18,6 +18,8 @@ import torch import torch_xla2 +from absl.testing import parameterized + from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.third_party.llama import model_original from jetstream_pt.third_party.gemma import model_original as gemma_orig @@ -32,7 +34,7 @@ from . import helpers -class ModelComponentTest(unittest.TestCase): +class ModelComponentTest(parameterized.TestCase): """Test diff between original model and xla model for transformer, transformer block, attention and other component in model""" @@ -75,7 +77,7 @@ def _generate_mask(self, cache_length, pos, seqlen, ring_buffer=True): if ring_buffer: cond = jnp.logical_and(x <= pos, x >= pos - seqlen) else: - # Left aligned buffer we postpone the cache update + # Left aligned buffer we postpone the cache update therefore mask out pos cond = jnp.logical_and(x < pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) return torchjax.to_torch(res) @@ -98,10 +100,33 @@ def _make_one_cache_for_generate(self, env, pos): ) return cache_decode + @parameterized.named_parameters( + ("ring_buffer", "ring"), + ("non_ring_buffer_flash_attention", "flash"), + ("non_ring_buffer_ragged_attention", "ragged"), + ) # pylint: disable-next=all - def test_attention(self): + def test_attention(self, attn_type): torch.manual_seed(0) env, model_arg = helpers.make_env_tiny(False) + if attn_type == "ring": + env.lazy_cache_update = False + env.ragged_mha = False + env.flash_attention = False + self.generate_cache_stacked = False + env.ring_buffer = True + elif attn_type == "flash": + env.lazy_cache_update = True + env.ragged_mha = True + env.flash_attention = True + self.generate_cache_stacked = True + env.ring_buffer = False + elif attn_type == "flash": + env.lazy_cache_update = True + env.ragged_mha = False + env.flash_attention = True + self.generate_cache_stacked = True + env.ring_buffer = False attention_orig = model_original.Attention(model_arg) attention_ours = layers.Attention( @@ -167,10 +192,14 @@ def test_attention(self): ) expected_out = attention_orig(*inputs_orig2) cache_decode.input_pos = [pos] # next position to update - mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) + mask = self._generate_mask( + env.cache_sequence_length, pos, seqlen, env.ring_buffer + ) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) - input_ours2 = (x2, freqs_cis, mask, cache_decode) + start = torch.tensor([0] * batch, dtype=torch.int) + end = torch.tensor([pos] * batch, dtype=torch.int) + input_ours2 = (x2, freqs_cis, mask, cache_decode, start, end) result_torch = helpers.call_xla_model( attention_ours, attention_orig.state_dict(), input_ours2 )