Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions jetstream_pt/attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def ragged_mha(
return out, (m, l)


def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""

bsz, _, _, head_dim = xq.shape
Expand All @@ -585,7 +585,37 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
return output


def flash_attention(
def reshape_heads(xq, keys):
"""Reshapes the query head for GQA"""
bq, hq, tq, dq = xq.shape
hkv = keys.shape[-3]
rep = hq // hkv
if rep > 1:
xq = xq.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq)
return xq, rep


def reshape_outputs(rep, o, m=None, d=None):
"""Reshapes back the attention output for GQA"""
bq, hqo, tqo, dq = o.shape
tq = tqo // rep
hq = hqo * rep
o = o.reshape(bq, hqo, rep, tq, dq).reshape(bq, hq, tq, dq)
if m is not None and d is not None:
m = m.reshape(bq, hqo, rep, tq, 1).reshape(bq, hq, tq, 1)
d = d.reshape(bq, hqo, rep, tq, 1).reshape(bq, hq, tq, 1)
return o, (m, d)


def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""
xq, rep = reshape_heads(xq, keys)
output = _dense_attention(xq, keys, values, k_scaler, v_scaler, mask)
output, _ = reshape_outputs(rep, output)
return output


def _flash_attention(
xq,
keys,
values,
Expand Down Expand Up @@ -637,6 +667,24 @@ def flash_attention(
return o, (logits_max, denominator)


def flash_attention(
xq,
keys,
values,
layer,
k_scaler=None,
v_scaler=None,
mask=None,
normalize_var=True,
):
"""Flash attention kernel."""
xq, rep = reshape_heads(xq, keys)
o, (logits_max, denominator) = _flash_attention(
xq, keys, values, k_scaler, v_scaler, mask
)
return reshape_outputs(rep, o, logits_max, denominator)


class RaggedAttentionKernel:
"""Ragged attention kernel."""

Expand Down
6 changes: 5 additions & 1 deletion jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
for _ in self.pt_model.layers
]
mask = jnp.full(
(1, 1, tokens.shape[1], tokens.shape[1]),
(1, self.env.n_reps, tokens.shape[1], tokens.shape[1]),
float("-inf"),
dtype=self.default_dtype,
)
mask = jnp.triu(mask, k=1)
mask = mask.reshape(1, 1, -1, tokens.shape[1])
start = jnp.zeros((tokens.shape[0],), dtype=jnp.int32)
args = (tokens, input_indexes, caches, mask, start)

Expand Down Expand Up @@ -970,6 +971,7 @@ def create_pytorch_engine(
)
env_data.model_type = model_name + "-" + param_size
env_data.num_layers = args.n_layers
env_data.n_reps = args.n_heads // args.n_kv_heads
env = JetEngineEnvironment(env_data)
pt_model = llama_model.Transformer(args, env)
elif model_name == "gemma":
Expand All @@ -982,6 +984,7 @@ def create_pytorch_engine(
)
env_data.model_type = model_name + "-" + param_size
env_data.num_layers = args.num_hidden_layers
env_data.n_reps = args.num_attention_heads // args.num_key_value_heads
env = JetEngineEnvironment(env_data)
print(f"Enviroment variables: {vars(env)}")
pt_model = gemma_model.GemmaModel(args, env)
Expand All @@ -995,6 +998,7 @@ def create_pytorch_engine(
args.dim // args.n_head,
)
env_data.num_layers = args.n_layer
env_data.n_reps = args.n_heads // args.n_local_heads
env = JetEngineEnvironment(env_data)
pt_model = mixtral_model.Transformer(args, env)
else:
Expand Down
11 changes: 10 additions & 1 deletion jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,13 @@ class JetEngineEnvironmentData:

testing_seed: int = 0

# The ratio between query heads and kv heads
n_reps: int = 0


# pylint: disable-next=all
class JetEngineEnvironment:

# pylint: disable-next=all
def __init__(self, data: JetEngineEnvironmentData):
self._data = data

Expand All @@ -144,6 +147,12 @@ def __init__(self, data: JetEngineEnvironmentData):
self.flash_attention = True
self.generate_cache_stacked = True
self.new_cache_stacked = True
else:
self.lazy_cache_update = False
self.ragged_mha = False
self.flash_attention = False
self.generate_cache_stacked = False
self.new_cache_stacked = False

if self.testing:
self.lazy_cache_update = self._data.lazy_cache_update
Expand Down
41 changes: 2 additions & 39 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,29 +365,6 @@ def apply_rotary_emb(
return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)."""

*_, bs, n_kv_heads, slen, head_dim = x.shape
stacked = x.ndim == 5

if n_rep == 1:
return x

if stacked:
layer = x.shape[0]
return (
x[:, :, :, None, :, :]
.expand(layer, bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(layer, bs, n_kv_heads * n_rep, slen, head_dim)
)
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)


class AttentionKernel:

def __init__(self, env, layer_id):
Expand Down Expand Up @@ -473,7 +450,7 @@ def attend(xq, keys, values, local_mask=None):
ragged_batch_index,
ragged_block_index,
)
elif self.env.flash_attention:
elif self.env.flash_attention and seqlen == 1:
with torch_xla2.default_env():
local_output, (local_max, local_denom) = self.flash_attention(
xq, keys, values, self.layer_id, mask=local_mask
Expand Down Expand Up @@ -505,10 +482,6 @@ def attend(xq, keys, values, local_mask=None):

with jax.named_scope("attn_insert_cache"):
orig_keys, orig_values = cache.update(xk, xv, self.layer_id)
# We are not using ragged attention for prefill yet.
if not self.env.ragged_mha or seqlen > 1:
orig_keys = repeat_kv(orig_keys, n_rep)
orig_values = repeat_kv(orig_values, n_rep)

# print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}")
with jax.named_scope("attn_qkv"):
Expand All @@ -522,9 +495,6 @@ def attend(xq, keys, values, local_mask=None):

# For flash attention, existing output contains the existing kv cache generated logits
with jax.named_scope("attn_new_qkv"):
if not self.env.ragged_mha or seqlen > 1:
xk = repeat_kv(xk, n_rep)
xv = repeat_kv(xv, n_rep)
new_output, (new_max, new_denom) = attend(xq, xk, xv, None)

with jax.named_scope("attn_global"):
Expand Down Expand Up @@ -633,7 +603,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
k_scaler,
v_scaler,
)
elif self.env.flash_attention:
elif self.env.flash_attention and seqlen == 1:
with torch_xla2.default_env():
local_output, (local_max, local_denom) = self.flash_attention(
xq,
Expand Down Expand Up @@ -676,10 +646,6 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
new_k_scaler,
new_v_scaler,
) = cache.update(xk, xv, self.layer_id)
# We are not using ragged attention for prefill yet.
if not self.env.ragged_mha or seqlen > 1:
orig_keys = repeat_kv(orig_keys, n_rep)
orig_values = repeat_kv(orig_values, n_rep)
with jax.named_scope("attn_qkv"):
existing_output, (existing_max, existing_denom) = attend(
xq, orig_keys, orig_values, k_scaler, v_scaler, mask
Expand All @@ -692,9 +658,6 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
# For flash attention, existing output contains the existing kv cache generated logits
with jax.named_scope("attn_new_qkv"):
# At this point, flash attention or ragged attention must have been enabled
if not self.env.ragged_mha or seqlen > 1:
new_key = repeat_kv(new_key, n_rep)
new_value = repeat_kv(new_value, n_rep)
new_output, (new_max, new_denom) = attend(
xq, new_key, new_value, new_k_scaler, new_v_scaler, None
)
Expand Down
2 changes: 2 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None):
environment_data.cache_sequence_length,
config.dim // config.n_heads,
)
environment_data.n_reps = config.n_heads // config.n_kv_heads
environment_data.testing = True
env_data_update_fn(environment_data)
env = environment.JetEngineEnvironment(environment_data)
Expand Down Expand Up @@ -54,6 +55,7 @@ def make_mixtral_env(bf16_enable=True):
environment_data.cache_sequence_length,
config.dim // config.n_head,
)
environment_data.n_reps = config.n_head // config.n_local_heads
env = environment.JetEngineEnvironment(environment_data)
env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
return env, config
Expand Down
44 changes: 27 additions & 17 deletions tests/test_model_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,21 @@ def setUp(self):
jax.config.update("jax_enable_x64", False)
torch.set_default_dtype(torch.float32)

def _prefill_mask(self, seqlen, start_pos):
def _prefill_mask(self, seqlen, start_pos, env):
mask = torch.full((seqlen, seqlen), float("-inf"))

mask = torch.triu(mask, diagonal=1)

# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([torch.zeros((seqlen, start_pos)), mask])
return mask
orig_mask = torch.hstack([torch.zeros((seqlen, start_pos)), mask])

mask = mask.repeat((env.n_reps, 1))
our_mask = torch.hstack(
[torch.zeros((seqlen * env.n_reps, start_pos)), mask]
)
return orig_mask, our_mask

def _make_freqs_cis(self, model_arg, seqlen, start_pos):
freqs_cis = model_original.precompute_freqs_cis(
Expand Down Expand Up @@ -117,8 +121,8 @@ def test_attention(self):
) # (batch, seqlen, embedding dim)
start_pos = 0
freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos)
mask = self._prefill_mask(seqlen, start_pos)
inputs_orig = (x, start_pos, freqs_cis, mask)
orig_mask, our_mask = self._prefill_mask(seqlen, start_pos, env)
inputs_orig = (x, start_pos, freqs_cis, orig_mask)

expected_out = attention_orig(*inputs_orig)

Expand All @@ -127,7 +131,7 @@ def test_attention(self):
input_ours = (
x,
freqs_cis,
mask,
our_mask,
cache,
)

Expand Down Expand Up @@ -236,11 +240,17 @@ def load_hook(state_dict, prefix, *args):
) # (batch, seqlen, embedding dim)
start_pos = 0
freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos)
mask = self._prefill_mask(seqlen, start_pos)
orig_mask, our_mask = self._prefill_mask(seqlen, start_pos, env)
kv_write_indexes = torch.arange(0, seqlen)
cache_k = torch.zeros((batch, seqlen, num_kv_heads, head_dim))
cache_v = torch.zeros((batch, seqlen, num_kv_heads, head_dim))
inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask)
inputs_orig = (
x,
freqs_cis,
kv_write_indexes,
(cache_k, cache_v),
orig_mask,
)

expected_out = attention_orig(*inputs_orig)

Expand All @@ -249,7 +259,7 @@ def load_hook(state_dict, prefix, *args):
input_ours = (
x,
freqs_cis,
mask,
our_mask,
cache,
)

Expand Down Expand Up @@ -284,8 +294,8 @@ def test_transformer_block(self):
) # (batch, seqlen, embedding dim)
start_pos = 0
freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos)
mask = self._prefill_mask(seqlen, start_pos)
inputs_orig = (x, start_pos, freqs_cis, mask)
orig_mask, our_mask = self._prefill_mask(seqlen, start_pos, env)
inputs_orig = (x, start_pos, freqs_cis, orig_mask)

expected_out = block_orig(*inputs_orig)

Expand All @@ -294,7 +304,7 @@ def test_transformer_block(self):
input_ours = (
x,
freqs_cis,
mask,
our_mask,
cache,
)

Expand Down Expand Up @@ -356,7 +366,7 @@ def test_transformer(self):
seqlen = 32
x = torch.randint(0, 32000, (1, seqlen)) # (batch, seqlen, embedding dim)
start_pos = 0
mask = self._prefill_mask(seqlen, start_pos)
_, our_mask = self._prefill_mask(seqlen, start_pos, env)
inputs_orig = (x, start_pos)

expected_out = model_orig(*inputs_orig)
Expand All @@ -367,7 +377,7 @@ def test_transformer(self):
x,
input_pos,
caches,
mask,
our_mask,
)

result_torch = helpers.call_xla_model(model_ours, state_dict, input_ours)
Expand Down Expand Up @@ -417,7 +427,7 @@ def test_mixtral_transformer(self):
seqlen = 32
x = torch.randint(0, 32000, (1, seqlen)) # (batch, seqlen, embedding dim)
start_pos = 0
mask = self._prefill_mask(seqlen, start_pos)
_, our_mask = self._prefill_mask(seqlen, start_pos, env)
input_pos = torch.arange(0, seqlen)
inputs_orig = (x, input_pos)

Expand All @@ -430,7 +440,7 @@ def test_mixtral_transformer(self):
x,
input_pos,
caches,
mask,
our_mask,
)
result_torch = helpers.call_xla_model(model_ours, new_dict, input_ours)

Expand Down
1 change: 1 addition & 0 deletions tests/test_run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def reset_flags(self):

def setup(self):
"""Setup."""
# pylint: disable-next=all
from run_server import flags

f = flags.FLAGS
Expand Down