diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 6d571d2..38edc89 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -558,7 +558,7 @@ def ragged_mha( return out, (m, l) -def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): +def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): """The vanilla attention kernel implementation.""" bsz, _, _, head_dim = xq.shape @@ -585,7 +585,37 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): return output -def flash_attention( +def reshape_heads(xq, keys): + """Reshapes the query head for GQA""" + bq, hq, tq, dq = xq.shape + hkv = keys.shape[-3] + rep = hq // hkv + if rep > 1: + xq = xq.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) + return xq, rep + + +def reshape_outputs(rep, o, m=None, d=None): + """Reshapes back the attention output for GQA""" + bq, hqo, tqo, dq = o.shape + tq = tqo // rep + hq = hqo * rep + o = o.reshape(bq, hqo, rep, tq, dq).reshape(bq, hq, tq, dq) + if m is not None and d is not None: + m = m.reshape(bq, hqo, rep, tq, 1).reshape(bq, hq, tq, 1) + d = d.reshape(bq, hqo, rep, tq, 1).reshape(bq, hq, tq, 1) + return o, (m, d) + + +def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): + """The vanilla attention kernel implementation.""" + xq, rep = reshape_heads(xq, keys) + output = _dense_attention(xq, keys, values, k_scaler, v_scaler, mask) + output, _ = reshape_outputs(rep, output) + return output + + +def _flash_attention( xq, keys, values, @@ -637,6 +667,24 @@ def flash_attention( return o, (logits_max, denominator) +def flash_attention( + xq, + keys, + values, + layer, + k_scaler=None, + v_scaler=None, + mask=None, + normalize_var=True, +): + """Flash attention kernel.""" + xq, rep = reshape_heads(xq, keys) + o, (logits_max, denominator) = _flash_attention( + xq, keys, values, k_scaler, v_scaler, mask + ) + return reshape_outputs(rep, o, logits_max, denominator) + + class RaggedAttentionKernel: """Ragged attention kernel.""" diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 79dfb94..c859fe8 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -216,11 +216,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes): for _ in self.pt_model.layers ] mask = jnp.full( - (1, 1, tokens.shape[1], tokens.shape[1]), + (1, self.env.n_reps, tokens.shape[1], tokens.shape[1]), float("-inf"), dtype=self.default_dtype, ) mask = jnp.triu(mask, k=1) + mask = mask.reshape(1, 1, -1, tokens.shape[1]) start = jnp.zeros((tokens.shape[0],), dtype=jnp.int32) args = (tokens, input_indexes, caches, mask, start) @@ -970,6 +971,7 @@ def create_pytorch_engine( ) env_data.model_type = model_name + "-" + param_size env_data.num_layers = args.n_layers + env_data.n_reps = args.n_heads // args.n_kv_heads env = JetEngineEnvironment(env_data) pt_model = llama_model.Transformer(args, env) elif model_name == "gemma": @@ -982,6 +984,7 @@ def create_pytorch_engine( ) env_data.model_type = model_name + "-" + param_size env_data.num_layers = args.num_hidden_layers + env_data.n_reps = args.num_attention_heads // args.num_key_value_heads env = JetEngineEnvironment(env_data) print(f"Enviroment variables: {vars(env)}") pt_model = gemma_model.GemmaModel(args, env) @@ -995,6 +998,7 @@ def create_pytorch_engine( args.dim // args.n_head, ) env_data.num_layers = args.n_layer + env_data.n_reps = args.n_heads // args.n_local_heads env = JetEngineEnvironment(env_data) pt_model = mixtral_model.Transformer(args, env) else: diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 84289d9..fad4472 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -121,10 +121,13 @@ class JetEngineEnvironmentData: testing_seed: int = 0 + # The ratio between query heads and kv heads + n_reps: int = 0 + # pylint: disable-next=all class JetEngineEnvironment: - + # pylint: disable-next=all def __init__(self, data: JetEngineEnvironmentData): self._data = data @@ -144,6 +147,12 @@ def __init__(self, data: JetEngineEnvironmentData): self.flash_attention = True self.generate_cache_stacked = True self.new_cache_stacked = True + else: + self.lazy_cache_update = False + self.ragged_mha = False + self.flash_attention = False + self.generate_cache_stacked = False + self.new_cache_stacked = False if self.testing: self.lazy_cache_update = self._data.lazy_cache_update diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index d484df9..d66909d 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -365,29 +365,6 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep).""" - - *_, bs, n_kv_heads, slen, head_dim = x.shape - stacked = x.ndim == 5 - - if n_rep == 1: - return x - - if stacked: - layer = x.shape[0] - return ( - x[:, :, :, None, :, :] - .expand(layer, bs, n_kv_heads, n_rep, slen, head_dim) - .reshape(layer, bs, n_kv_heads * n_rep, slen, head_dim) - ) - return ( - x[:, :, None, :, :] - .expand(bs, n_kv_heads, n_rep, slen, head_dim) - .reshape(bs, n_kv_heads * n_rep, slen, head_dim) - ) - - class AttentionKernel: def __init__(self, env, layer_id): @@ -473,7 +450,7 @@ def attend(xq, keys, values, local_mask=None): ragged_batch_index, ragged_block_index, ) - elif self.env.flash_attention: + elif self.env.flash_attention and seqlen == 1: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention( xq, keys, values, self.layer_id, mask=local_mask @@ -505,10 +482,6 @@ def attend(xq, keys, values, local_mask=None): with jax.named_scope("attn_insert_cache"): orig_keys, orig_values = cache.update(xk, xv, self.layer_id) - # We are not using ragged attention for prefill yet. - if not self.env.ragged_mha or seqlen > 1: - orig_keys = repeat_kv(orig_keys, n_rep) - orig_values = repeat_kv(orig_values, n_rep) # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): @@ -522,9 +495,6 @@ def attend(xq, keys, values, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): - if not self.env.ragged_mha or seqlen > 1: - xk = repeat_kv(xk, n_rep) - xv = repeat_kv(xv, n_rep) new_output, (new_max, new_denom) = attend(xq, xk, xv, None) with jax.named_scope("attn_global"): @@ -633,7 +603,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): k_scaler, v_scaler, ) - elif self.env.flash_attention: + elif self.env.flash_attention and seqlen == 1: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention( xq, @@ -676,10 +646,6 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): new_k_scaler, new_v_scaler, ) = cache.update(xk, xv, self.layer_id) - # We are not using ragged attention for prefill yet. - if not self.env.ragged_mha or seqlen > 1: - orig_keys = repeat_kv(orig_keys, n_rep) - orig_values = repeat_kv(orig_values, n_rep) with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend( xq, orig_keys, orig_values, k_scaler, v_scaler, mask @@ -692,9 +658,6 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): # At this point, flash attention or ragged attention must have been enabled - if not self.env.ragged_mha or seqlen > 1: - new_key = repeat_kv(new_key, n_rep) - new_value = repeat_kv(new_value, n_rep) new_output, (new_max, new_denom) = attend( xq, new_key, new_value, new_k_scaler, new_v_scaler, None ) diff --git a/tests/helpers.py b/tests/helpers.py index 3c5cb4e..09d718a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -27,6 +27,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): environment_data.cache_sequence_length, config.dim // config.n_heads, ) + environment_data.n_reps = config.n_heads // config.n_kv_heads environment_data.testing = True env_data_update_fn(environment_data) env = environment.JetEngineEnvironment(environment_data) @@ -54,6 +55,7 @@ def make_mixtral_env(bf16_enable=True): environment_data.cache_sequence_length, config.dim // config.n_head, ) + environment_data.n_reps = config.n_head // config.n_local_heads env = environment.JetEngineEnvironment(environment_data) env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu return env, config diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index efbaa09..0b76c86 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -42,17 +42,21 @@ def setUp(self): jax.config.update("jax_enable_x64", False) torch.set_default_dtype(torch.float32) - def _prefill_mask(self, seqlen, start_pos): + def _prefill_mask(self, seqlen, start_pos, env): mask = torch.full((seqlen, seqlen), float("-inf")) mask = torch.triu(mask, diagonal=1) - # When performing key-value caching, we compute the attention scores # only for the new sequence. Thus, the matrix of scores is of size # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for # j > cache_len + i, since row i corresponds to token cache_len + i. - mask = torch.hstack([torch.zeros((seqlen, start_pos)), mask]) - return mask + orig_mask = torch.hstack([torch.zeros((seqlen, start_pos)), mask]) + + mask = mask.repeat((env.n_reps, 1)) + our_mask = torch.hstack( + [torch.zeros((seqlen * env.n_reps, start_pos)), mask] + ) + return orig_mask, our_mask def _make_freqs_cis(self, model_arg, seqlen, start_pos): freqs_cis = model_original.precompute_freqs_cis( @@ -117,8 +121,8 @@ def test_attention(self): ) # (batch, seqlen, embedding dim) start_pos = 0 freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) - mask = self._prefill_mask(seqlen, start_pos) - inputs_orig = (x, start_pos, freqs_cis, mask) + orig_mask, our_mask = self._prefill_mask(seqlen, start_pos, env) + inputs_orig = (x, start_pos, freqs_cis, orig_mask) expected_out = attention_orig(*inputs_orig) @@ -127,7 +131,7 @@ def test_attention(self): input_ours = ( x, freqs_cis, - mask, + our_mask, cache, ) @@ -236,11 +240,17 @@ def load_hook(state_dict, prefix, *args): ) # (batch, seqlen, embedding dim) start_pos = 0 freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) - mask = self._prefill_mask(seqlen, start_pos) + orig_mask, our_mask = self._prefill_mask(seqlen, start_pos, env) kv_write_indexes = torch.arange(0, seqlen) cache_k = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) cache_v = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) - inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask) + inputs_orig = ( + x, + freqs_cis, + kv_write_indexes, + (cache_k, cache_v), + orig_mask, + ) expected_out = attention_orig(*inputs_orig) @@ -249,7 +259,7 @@ def load_hook(state_dict, prefix, *args): input_ours = ( x, freqs_cis, - mask, + our_mask, cache, ) @@ -284,8 +294,8 @@ def test_transformer_block(self): ) # (batch, seqlen, embedding dim) start_pos = 0 freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) - mask = self._prefill_mask(seqlen, start_pos) - inputs_orig = (x, start_pos, freqs_cis, mask) + orig_mask, our_mask = self._prefill_mask(seqlen, start_pos, env) + inputs_orig = (x, start_pos, freqs_cis, orig_mask) expected_out = block_orig(*inputs_orig) @@ -294,7 +304,7 @@ def test_transformer_block(self): input_ours = ( x, freqs_cis, - mask, + our_mask, cache, ) @@ -356,7 +366,7 @@ def test_transformer(self): seqlen = 32 x = torch.randint(0, 32000, (1, seqlen)) # (batch, seqlen, embedding dim) start_pos = 0 - mask = self._prefill_mask(seqlen, start_pos) + _, our_mask = self._prefill_mask(seqlen, start_pos, env) inputs_orig = (x, start_pos) expected_out = model_orig(*inputs_orig) @@ -367,7 +377,7 @@ def test_transformer(self): x, input_pos, caches, - mask, + our_mask, ) result_torch = helpers.call_xla_model(model_ours, state_dict, input_ours) @@ -417,7 +427,7 @@ def test_mixtral_transformer(self): seqlen = 32 x = torch.randint(0, 32000, (1, seqlen)) # (batch, seqlen, embedding dim) start_pos = 0 - mask = self._prefill_mask(seqlen, start_pos) + _, our_mask = self._prefill_mask(seqlen, start_pos, env) input_pos = torch.arange(0, seqlen) inputs_orig = (x, input_pos) @@ -430,7 +440,7 @@ def test_mixtral_transformer(self): x, input_pos, caches, - mask, + our_mask, ) result_torch = helpers.call_xla_model(model_ours, new_dict, input_ours) diff --git a/tests/test_run_server.py b/tests/test_run_server.py index 73022a7..a69a09f 100644 --- a/tests/test_run_server.py +++ b/tests/test_run_server.py @@ -45,6 +45,7 @@ def reset_flags(self): def setup(self): """Setup.""" + # pylint: disable-next=all from run_server import flags f = flags.FLAGS