diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 234e9fe..4689ef5 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -1,18 +1,22 @@ +from collections.abc import Callable import functools import math +from typing import Any import jax import jax.numpy as jnp from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention from jax.experimental.shard_map import shard_map - +import numpy as np import torch import torch.nn.functional as F +from jetstream_pt import torchjax -import numpy as np DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) +P = jax.sharding.PartitionSpec def ragged_flash_attention_kernel( @@ -735,3 +739,52 @@ def __call__( k_scaler, v_scaler, ) + + +def shard_kv_heads( + paged_attention_impl: Callable[..., Any], + mesh: jax.sharding.Mesh, + kv_head_mesh_axis_name: str, +): + """Shard map on kv head.""" + in_specs = ( + P(None, kv_head_mesh_axis_name, None), # q + P(kv_head_mesh_axis_name, None, None, None), # k + P(kv_head_mesh_axis_name, None, None, None), # v + P(), # lengths + P(), # page_indices + ) + + out_specs = P(None, kv_head_mesh_axis_name, None) # q + + return jax.jit( + shard_map( + paged_attention_impl, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + ) + + +def call_paged_attention(env, xq, keys, values, seq_lens, page_indices): + """Paged attention kernel.""" + xq, keys, values, seq_lens, page_indices = torchjax.from_torch( + (xq, keys, values, seq_lens, page_indices) + ) + paged_attention_impl = functools.partial( + paged_attention, + pages_per_compute_block=env.block_size // env.paged_attention_page_size, + # mask_value=float("-inf") + ) + sharded_paged_attention_impl = shard_kv_heads( + paged_attention_impl, + env.mesh, + kv_head_mesh_axis_name="x", + ) + output = sharded_paged_attention_impl( + xq, keys, values, seq_lens, page_indices + ) + + return torchjax.to_torch(output) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 9a44a47..7687859 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -19,6 +19,7 @@ import torch_xla2 from jetstream_pt import torchjax +from jetstream_pt.page_attention_manager import PageAttentionManager # pylint: disable-next=all @@ -663,6 +664,7 @@ def __init__( self, cache_k: torch.Tensor, # previous cache cache_v: torch.Tensor, # previous cache + page_attention_manager: PageAttentionManager, page_token_indices: torch.Tensor, # page and token indices for the cache sharding, env=None, @@ -670,11 +672,13 @@ def __init__( super().__init__() self.cache_k = cache_k self.cache_v = cache_v + self.page_attention_manager = page_attention_manager self.page_token_indices = page_token_indices self.sharding = sharding self.env = env + self.stacked = False - def update(self, key, value): + def update(self, key, value, layer_id=0): """Update kv cache""" keyj, valuej, page_token_indicesj = torchjax.from_torch( (key, value, self.page_token_indices) @@ -683,12 +687,14 @@ def update(self, key, value): def _update(cache, x): x = x.squeeze(2).transpose((1, 0, 2)) x = x[:, page_token_indicesj[2], :] - head, _, page_size, dim = cache.shape + head, _, paged_attention_page_size, dim = cache.shape selected_cache = cache[:, page_token_indicesj[0], :, :] selected_cache = selected_cache.reshape((head, -1, dim)) selected_cache = selected_cache.at[:, page_token_indicesj[1], :].set(x) - selected_cache = selected_cache.reshape((head, -1, page_size, dim)) + selected_cache = selected_cache.reshape( + (head, -1, paged_attention_page_size, dim) + ) cache = cache.at[:, page_token_indicesj[0], :, :].set(selected_cache) return cache @@ -696,19 +702,23 @@ def _update(cache, x): # pylint: disable-next=all self.cache_k._elem = _update(self.cache_k._elem, keyj) # pylint: disable-next=all - self.cache_k._elem = _update(self.cache_v._elem, valuej) + self.cache_v._elem = _update(self.cache_v._elem, valuej) return self.cache_k, self.cache_v def state(self): """Get kv cache state""" # pylint: disable-next=all - return self.cache_k.jax(), self.cache_v.jax() + return torchjax.from_torch((self.cache_k, self.cache_v)) + + def finalize(self): + """Do nothing now""" + return @classmethod - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" - default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 + default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 k = jnp.zeros(shape, device=device, dtype=default_dtype) v = jnp.zeros(shape, device=device, dtype=default_dtype) k, v = torchjax.to_torch((k, v)) - return cls(k, v, None, device, env=env) + return cls(k, v, None, None, device, env=env) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 70b530f..52738b3 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -139,6 +139,18 @@ "size of top k used when sampling next token", ) +flags.DEFINE_integer( + "paged_attention_total_num_pages", + 0, + "total number of pages per layer for page attention", +) + +flags.DEFINE_integer( + "paged_attention_page_size", + 64, + "page size per page", +) + def create_quantization_config_from_flags(): """Create Quantization Config from cmd flags""" @@ -213,6 +225,8 @@ def create_engine_from_config_flags(): generate_cache_stacked=FLAGS.generate_cache_stacked, new_cache_stacked=FLAGS.new_cache_stacked, lazy_cache_update=FLAGS.lazy_cache_update, + paged_attention_total_num_pages=FLAGS.paged_attention_total_num_pages, + paged_attention_page_size=FLAGS.paged_attention_page_size, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 2a8ad31..2c7e38e 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -38,6 +38,7 @@ from jetstream_pt import torchjax from jetstream_pt.hf_tokenizer import HFTokenizerAdapter from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig +from jetstream_pt.page_attention_manager import PageAttentionManager from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model @@ -103,6 +104,7 @@ def __init__( jax.config.update("jax_enable_x64", False) + self.prefill_cache_sharding = self.env.prefill_cache_sharding self.prefill = jax.jit( self.prefill, out_shardings=(self.get_prefix_destination_sharding(), None), @@ -113,10 +115,42 @@ def __init__( out_shardings=self.get_decode_state_sharding(), ) self.generate = jax.jit( - self.generate, + self.generate_impl, donate_argnums=(1,), out_shardings=(self.get_decode_state_sharding(), None), ) + + if self.env.page_attention: + max_pages_per_sequence = ( + self.env._data.cache_sequence_length + // self.env._data.paged_attention_page_size + ) + assert ( + self.env._data.cache_sequence_length + % self.env._data.paged_attention_page_size + == 0 + ), f"cache_sequence_length {self.env._data.cache_sequence_length} should divide paged_attention_page_size {self.env._data.paged_attention_page_size}" + + self.page_attention_manager = PageAttentionManager( + batch_size=self.env.batch_size, + paged_attention_total_num_pages=self.env._data.paged_attention_total_num_pages, + paged_attention_page_size=self.env._data.paged_attention_page_size, + max_pages_per_sequence=max_pages_per_sequence, + ) + + self._insert_page_attention_jit = jax.jit( + self._insert_page_attention, + donate_argnums=(0, 1), + out_shardings=self.get_decode_state_sharding(), + ) + self.insert = self.insert_page_attention_with_reservation + self.generate_jit = jax.jit( + self.generate_impl, + donate_argnums=(1,), + out_shardings=(self.get_decode_state_sharding(), None), + ) + + self.generate = self.generate_page_attention # self._insert_wrap = jax.jit(self._insert_wrap, donate_argnums=(0, 1), # out_shardings=self.get_decode_state_sharding()) @@ -162,6 +196,7 @@ def _call_model_generate( input_pos, ragged_batch_index, ragged_block_index, + page_token_indices, ): if self.env.quant_config.enable_kv_quantization: caches_obj = [ @@ -172,6 +207,18 @@ def _call_model_generate( list(zip(caches, cache_scales)) ) ] + elif self.env.page_attention: + caches_obj = [ + cache_manager.PageKVCacheGenerate( + k, + v, + self.page_attention_manager, + page_token_indices, + self.cache_sharding, + env=self.env, + ) + for k, v in torchjax.to_torch(caches) + ] else: caches_obj = [ cache_manager.KVCacheGenerate( @@ -533,6 +580,50 @@ def insert(cache, scaler, new_entry): mask, ) + def _insert_page_attention( + self, + prefix: Prefix, + decode_state: DecodeState, + slot: int, + num_pages: int, + update_indexes: jax.Array, + tep_kv: jax.Array, + ): + caches = self.page_attention_manager.insert_prefill_cache( + prefill_caches=prefix.caches, + decode_caches=decode_state.caches, + update_indexes=update_indexes, + tep_kv=tep_kv, + sharding=self.cache_sharding, + ) + + current_pos = prefix.seq_len + + pos = current_pos - prefix.seq_len + tokens = decode_state.tokens.at[slot].set(prefix.token) + + x = jnp.arange(0, self.env.cache_sequence_length) + cond = jnp.logical_and(x < current_pos, x >= pos) + mask_insert = jnp.where(cond, 0, float("-inf")) + mask = decode_state.mask.at[slot].set(mask_insert) + start = decode_state.start.at[slot].set( + pos % self.env.cache_sequence_length + ) + + input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) + scales = None + lens = decode_state.lens.at[slot].set(1) + return DecodeState( + tokens, + caches, + scales, + decode_state.current_position, + lens, + start, + input_pos, + mask, + ) + def insert( self, prefix: Prefix, @@ -556,6 +647,29 @@ def insert( else: return self._insert_no_wrap(prefix, decode_state, slot) + def insert_page_attention_with_reservation( + self, + prefix: Prefix, + decode_state: DecodeState, + slot: int, + ) -> DecodeState: + num_pages, update_indexes = ( + self.page_attention_manager.reserve_pages_insert(slot, prefix.seq_len) + ) + _, kv_heads, _, dim = prefix.caches[0][0].shape + tep_kv = jnp.zeros( + ( + kv_heads, + num_pages * self.page_attention_manager.paged_attention_page_size, + dim, + ), + dtype=self.default_dtype, + device=self.prefill_cache_sharding, + ) + return self._insert_page_attention_jit( + prefix, decode_state, slot, num_pages, update_indexes, tep_kv + ) + def precompute_ragged_block_indices(self, decode_state: DecodeState): """Precompute the ragged attention block indices. Ragged attention iterates the grid and relies on the computed grid index to skip the unnecessary blocks. The basic idea @@ -615,8 +729,32 @@ def false_comp(b, i, bk, start, end): def generate( self, params: Any, decode_state: DecodeState, sampler=None + ) -> tuple[DecodeState, engine_api.ResultTokens]: + return (None, None) + + def generate_page_attention( + self, params: Any, decode_state: DecodeState + ) -> tuple[DecodeState, engine_api.ResultTokens]: + self.page_attention_manager.fill_new_pages(decode_state.input_pos) + page_token_indices = self.page_attention_manager.get_page_token_indices( + decode_state.input_pos + ) + new_decode_state, result_tokens = self.generate_jit( + params, decode_state, page_token_indices=page_token_indices + ) + # new_decode_state, result_tokens = self.generate_impl(params, decode_state, page_token_indices) + return new_decode_state, result_tokens + + def generate_impl( + self, + params: Any, + decode_state: DecodeState, + sampler=None, + page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] + if self.env.page_attention: + page_token_indices = torchjax.to_torch(page_token_indices) pos = decode_state.current_position if self.env.ring_buffer: input_indexes = jnp.full((1,), pos) @@ -651,6 +789,7 @@ def update_mask(): decode_state.input_pos, ragged_batch_index, ragged_block_index, + page_token_indices, ) if self.env.lazy_cache_update: @@ -815,7 +954,11 @@ def get_prefix_destination_sharding(self) -> Prefix: """Returns the shardings necessary to transfer data between engines.""" return Prefix( self.replicated, - self.replicated if self.env.shard_on_batch else self.cache_sharding, + self.replicated + if self.env.shard_on_batch + else self.prefill_cache_sharding + if self.env.page_attention + else self.cache_sharding, self.replicated, ) @@ -888,6 +1031,8 @@ def create_pytorch_engine( generate_cache_stacked=False, new_cache_stacked=False, lazy_cache_update=False, + paged_attention_total_num_pages=0, + paged_attention_page_size=64, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -962,6 +1107,8 @@ def create_pytorch_engine( generate_cache_stacked=generate_cache_stacked, new_cache_stacked=new_cache_stacked, lazy_cache_update=lazy_cache_update, + paged_attention_total_num_pages=paged_attention_total_num_pages, + paged_attention_page_size=paged_attention_page_size, ) if shard_on_batch and sharding_config: @@ -974,10 +1121,19 @@ def create_pytorch_engine( ) args.device = "meta" env_data.cache_shape = ( - batch_size, - args.n_kv_heads, - max_cache_length, - args.dim // args.n_heads, + ( + batch_size, + args.n_kv_heads, + max_cache_length, + args.dim // args.n_heads, + ) + if env_data.paged_attention_total_num_pages == 0 + else ( + args.n_kv_heads, + env_data.paged_attention_total_num_pages, + env_data.paged_attention_page_size, + args.head_dim, + ) ) env_data.model_type = model_name + "-" + param_size env_data.num_layers = args.n_layers diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index ff7feaa..4917705 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -65,6 +65,13 @@ class JetEngineEnvironmentData: "head_dim", ) + prefill_attention_kv_axis_names: Tuple[str, ...] = ( + "batch", + "num_attn_heads", + "sequence_length", + "head_dim", + ) + # Shape of cache len(cache_shape) == len(attention_kv_axis_names) cache_shape: Tuple[int, ...] = () @@ -99,6 +106,12 @@ class JetEngineEnvironmentData: flash_attention: bool = False + # total number of pages per layer + paged_attention_total_num_pages: int = 0 + + # page size per page + paged_attention_page_size: int = 64 + generate_cache_stacked: bool = False new_cache_stacked: bool = False @@ -140,6 +153,7 @@ def __init__(self, data: JetEngineEnvironmentData): self.testing = self._data.testing self.testing_seed = self._data.testing_seed self.ring_buffer = self._data.ring_buffer + self.page_attention = self.paged_attention_total_num_pages > 0 # If not None, then use this tokenizer without # trying to create new ones. @@ -165,6 +179,13 @@ def __init__(self, data: JetEngineEnvironmentData): self.generate_cache_stacked = self._data.generate_cache_stacked self.new_cache_stacked = self._data.new_cache_stacked + if self.page_attention: + self.lazy_cache_update = False + self.ragged_mha = False + self.flash_attention = False + self.generate_cache_stacked = False + self.new_cache_stacked = False + self.default_type = jnp.bfloat16 if self._data.bf16_enable else jnp.float32 if self.generate_cache_stacked: @@ -193,6 +214,13 @@ def __init__(self, data: JetEngineEnvironmentData): "sequence_length", "head_dim", ) + elif self.page_attention: + self.attention_kv_axis_names = ( + "num_attn_heads", # kv_heads + "paged_attention_total_num_pages", + "pages_size", + "head_dim", + ) if data.shard_on_batch: self.kv_cache_shard_axis = "batch" else: @@ -201,6 +229,9 @@ def __init__(self, data: JetEngineEnvironmentData): self.cache_sharding_axis = self.attention_kv_axis_names.index( self.kv_cache_shard_axis ) + self.prefill_cache_sharding_axis = ( + self.prefill_attention_kv_axis_names.index(self.kv_cache_shard_axis) + ) if self.cache_shape[self.cache_sharding_axis] == 1: # cannot shard on an axis that is 1 @@ -208,6 +239,9 @@ def __init__(self, data: JetEngineEnvironmentData): self.cache_sharding_axis = len(self.cache_shape) - 1 self.cache_sharding = self.sharding_by_axis(self.cache_sharding_axis) + self.prefill_cache_sharding = self.sharding_by_axis( + self.prefill_cache_sharding_axis + ) self._load_sharding_config() def _load_sharding_config(self): @@ -263,6 +297,12 @@ def make_caches_generate(self): self.cache_shape, self.cache_sharding, env=self ) ) + elif self.page_attention: + caches.append( + cache_manager.PageKVCacheGenerate.empty( + self.cache_shape, self.cache_sharding, env=self + ) + ) else: caches.append( cache_manager.KVCacheGenerate.empty( diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index c41fe76..0f08dde 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -396,6 +396,7 @@ def __init__(self, env, layer_id): others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention self.flash_attention = ak.flash_attention + self.page_attention = ak.call_paged_attention self.ragged_attention_orig = ak.RaggedAttentionKernel( env, input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), @@ -445,12 +446,18 @@ def attend(xq, keys, values, local_mask=None): true_len = seqlen # When GQA is enabled, it not necessary to expand - if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1: + if ( + not (self.env.ragged_mha and n_rep > 1) + and seqlen == 1 + and not self.env.page_attention + ): true_len = 2 xq = torch.nn.functional.pad( xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) + local_max = None + local_denom = None if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( impl, @@ -476,6 +483,15 @@ def attend(xq, keys, values, local_mask=None): v_scaler=None, mask=local_mask, ) + elif self.env.page_attention and seqlen == 1: + local_output = self.page_attention( + self.env, + torch.squeeze(xq, 2), + keys, + values, + cache.page_attention_manager.lengths, + cache.page_attention_manager.page_indices, + ) else: local_output = self.dense_attention( xq=xq, @@ -485,8 +501,6 @@ def attend(xq, keys, values, local_mask=None): v_scaler=None, mask=local_mask, ) - local_max = None - local_denom = None local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) if local_max is not None: diff --git a/jetstream_pt/page_attention_manager.py b/jetstream_pt/page_attention_manager.py index ab2d2c9..eb17765 100644 --- a/jetstream_pt/page_attention_manager.py +++ b/jetstream_pt/page_attention_manager.py @@ -20,39 +20,50 @@ class PageAttentionManager: def __init__( self, batch_size: int, - total_num_pages: int, - page_size: int, + paged_attention_total_num_pages: int, + paged_attention_page_size: int, max_pages_per_sequence: int, ): self.unused_pages = queue.Queue() self.batch_size = batch_size self.page_indices = jnp.full( - (batch_size, max_pages_per_sequence), -1, dtype=jnp.int32 + (batch_size, max_pages_per_sequence), + paged_attention_total_num_pages - 1, + dtype=jnp.int32, ) self.lengths = jnp.zeros(batch_size, dtype=jnp.int32) - self.page_size = page_size + self.paged_attention_page_size = paged_attention_page_size self.max_pages_per_sequence = max_pages_per_sequence - for i in range(total_num_pages): + for i in range(paged_attention_total_num_pages): self.unused_pages.put(i, block=False) # pylint: disable-next=all - def reserve_pages_insert(self, slot: int, seq_len: int) -> Tuple[int, list]: + def reserve_pages_insert( + self, slot: int, seq_len: int + ) -> Tuple[int, jax.Array]: self.lengths = self.lengths.at[slot].set(seq_len) - num_pages = seq_len // self.page_size - if seq_len % self.page_size != 0: - num_pages = num_pages + 1 + num_pages = ( + seq_len // self.paged_attention_page_size + if seq_len % self.paged_attention_page_size == 0 + else seq_len // self.paged_attention_page_size + 1 + ) indices = [self.unused_pages.get(block=False) for _ in range(num_pages)] self.page_indices = self.page_indices.at[slot, :num_pages].set(indices) - return num_pages + return num_pages, self.page_indices[slot, :num_pages] # pylint: disable-next=all def reserve_pages_decode(self, slot: int, seq_len: int): - if seq_len > 0 and seq_len % self.page_size == 0: + if seq_len > 0 and seq_len % self.paged_attention_page_size == 0: index = self.unused_pages.get(block=False) - num_pages = seq_len // self.page_size + num_pages = seq_len // self.paged_attention_page_size self.page_indices = self.page_indices.at[slot, num_pages].set(index) + # pylint: disable-next=all + def fill_new_pages(self, lens: jax.Array): + for slot in range(self.batch_size): + self.reserve_pages_decode(slot, lens[slot]) + # pylint: disable-next=all def prefill_cache_padding( self, @@ -61,65 +72,65 @@ def prefill_cache_padding( num_pages: int, ) -> List[Tuple[jax.Array, jax.Array]]: - pad_width = num_pages * self.page_size - seq_len + pad_width = num_pages * self.paged_attention_page_size - seq_len if pad_width == 0: return caches - caches = [ + return [ (self.pad_sequences(k, pad_width), self.pad_sequences(v, pad_width)) for k, v in caches ] - return caches def insert_prefill_cache( self, prefill_caches: List[Tuple[jax.Array, jax.Array]], decode_caches: List[Tuple[jax.Array, jax.Array]], - slot: int, - seq_len: int, + update_indexes: jax.Array, + tep_kv: jax.Array, sharding: jsharding.Sharding, ) -> List[Tuple[jax.Array, jax.Array]]: - """Insert prefill caches to decode caches slot. + """Insert prefill caches to decode caches. Args: prefill_caches: List of Tuple K, V. For each K, V: [batch_size, num_heads, seq_len, head_dim] jax.Array. decode_caches: List of Tuple K, V. For each K, V: - [num_heads, total_num_pages, page_size, head_dim] jax.Array. - slot: Slot of batch size in decode. - seq_len: Prefill tokens seqeunce length. + [num_heads, paged_attention_total_num_pages, paged_attention_page_size, head_dim] jax.Array. + update_indexes: Page indexes for insertion. + tep_kv: List of Tuple K, V. For each K, V: + kv_heads, num_pages * .paged_attention_page_size, dim. sharding: Decode cache sharding. Returns: Decode cache. List of Tuple K, V. For each K, V: - [num_heads, total_num_pages, page_size, head_dim] jax.Array. + [num_heads, paged_attention_total_num_pages, paged_attention_page_size, head_dim] jax.Array. """ - - num_pages = self.reserve_pages_insert(slot, seq_len) - padded_caches = self.prefill_cache_padding( - prefill_caches, seq_len, num_pages - ) # Reduce cache batch deminsion # [kv_heads, seq_len, dim] squeezed_caches = [ (jnp.squeeze(k, axis=0), jnp.squeeze(v, axis=0)) - for k, v in padded_caches + for k, v in prefill_caches ] - kv_heads, _, dim = squeezed_caches[0][0].shape - # [kv_heads, num_pages, page_size, dim] - paged_caches = [ + tmp_caches = [ ( - jnp.reshape(k, (kv_heads, -1, self.page_size, dim)), - jnp.reshape(v, (kv_heads, -1, self.page_size, dim)), + tep_kv.at[:, : k.shape[1], :].set(k), + tep_kv.at[:, : v.shape[1], :].set(v), ) for k, v in squeezed_caches ] - update_indexes = self.page_indices[slot, :num_pages] + kv_heads, _, dim = tmp_caches[0][0].shape + # [kv_heads, num_pages, paged_attention_page_size, dim] + paged_caches = [ + ( + jnp.reshape(k, (kv_heads, -1, self.paged_attention_page_size, dim)), + jnp.reshape(v, (kv_heads, -1, self.paged_attention_page_size, dim)), + ) + for k, v in tmp_caches + ] @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, new_entry): - new_entry = new_entry.squeeze(0) res = cache.at[:, update_indexes, :, :].set(new_entry) res = jax.lax.with_sharding_constraint(res, sharding) return res @@ -133,41 +144,61 @@ def insert(cache, new_entry): # pylint: disable-next=all def get_page_token_indices(self, lens: jax.Array) -> jax.Array: - - assert lens.shape == ( - self.batch_size, - 1, - ), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}" + # assert lens.shape == ( + # self.batch_size, + # 1, + # ), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}" update_page_indices = [] token_scale_indices = [] batch_slots = [] - new_lens = [] offset = 0 + for slot in range(self.batch_size): - seq_len = lens[slot][0] - num_pages = seq_len // self.page_size + 1 - token_pos = seq_len % self.page_size - page_index = self.page_indices[slot, num_pages - 1] - if page_index < 0: + seq_len = lens[slot] + if seq_len == 0: continue + num_pages = seq_len // self.paged_attention_page_size + 1 + token_pos = seq_len % self.paged_attention_page_size + page_index = self.page_indices[slot, num_pages - 1] + update_page_indices.append(page_index) token_scale_indices.append(offset + token_pos) batch_slots.append(slot) - new_lens.append(seq_len + 1) - offset += self.page_size + offset += self.paged_attention_page_size + self.lengths = jnp.where(lens == 0, 0, lens + 1) + update_page_indices = jnp.asarray(update_page_indices) + token_scale_indices = jnp.asarray(token_scale_indices) + batch_slots = jnp.asarray(batch_slots) return jnp.stack( ( - jnp.asarray(update_page_indices), - jnp.asarray(token_scale_indices), - jnp.asarray(batch_slots), - jnp.asarray(new_lens), + update_page_indices, + token_scale_indices, + batch_slots, ) ) # pylint: disable-next=all - def fill_new_pages(self, lens: jax.Array): - for slot in range(self.batch_size): - self.reserve_pages_decode(slot, lens[slot]) + def get_compress_kv_cache( + self, + decode_caches: List[Tuple[jax.Array, jax.Array]], + slot: int, + ) -> List[Tuple[jax.Array, jax.Array]]: + lens = self.lengths[slot] + indices = self.page_indices[slot] + return [ + ( + self._compress_cache(k, lens, indices), + self._compress_cache(v, lens, indices), + ) + for k, v in decode_caches + ] + + def _compress_cache(self, cache: jax.Array, lens: int, indices: jax.Array): + head, _, _, dim = cache.shape + selected_cache = cache[:, indices, :, :] + selected_cache = selected_cache.reshape((head, -1, dim)) + selected_cache = selected_cache[:, 0:lens, :] + return selected_cache # pylint: disable-next=all def pad_sequences(self, array, pad_width=10): @@ -188,5 +219,5 @@ def free_pages_resource(self, slot): break self.unused_pages.put(index, block=False) - self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([-1])) + self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([0])) return None diff --git a/tests/helpers.py b/tests/helpers.py index 09d718a..ac0ea5f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -74,3 +74,36 @@ def call_xla_model(model, weights, args): result = torch.func.functional_call(model, xla_weights, xla_inputs) result_torch = torch_xla2.tensor.j2t(result.jax()) return result_torch + + +# pylint: disable-next=all +def make_page_attention_env_tiny( + bf16_enable=True, env_data_update_fn=lambda _: None +): + torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 + torch.set_default_dtype(torch_dtype) + jax.config.update("jax_dynamic_shapes", False) + jax.config.update("jax_traceback_filtering", "off") + config = model_args.get_model_args("llama-2-tiny", 128, 1, True) + environment_data = environment.JetEngineEnvironmentData() + environment_data.paged_attention_page_size = 32 + environment_data.paged_attention_total_num_pages = 16 + environment_data.block_size = 64 + environment_data.max_input_sequence_length = 128 + environment_data.max_input_sequence_length = 128 + environment_data.cache_sequence_length = 128 + environment_data.bf16_enable = bf16_enable + environment_data.model_type = "llama-2-tiny" + environment_data.batch_size = 1 + environment_data.num_layers = config.n_layers + environment_data.cache_shape = ( + config.n_kv_heads, + environment_data.paged_attention_total_num_pages, + environment_data.paged_attention_page_size, + config.dim // config.n_heads, + ) + environment_data.testing = True + env_data_update_fn(environment_data) + env = environment.JetEngineEnvironment(environment_data) + env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu + return env, config diff --git a/tests/test_attention_kernal.py b/tests/test_attention_kernal.py new file mode 100644 index 0000000..835fd41 --- /dev/null +++ b/tests/test_attention_kernal.py @@ -0,0 +1,492 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from absl.testing import parameterized +from collections.abc import Callable +import math +import functools +from typing import Any + +import torch +import jax +import jax.numpy as jnp +import numpy as np +from jetstream_pt.third_party.llama import model_args +from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention +from jax.experimental import shard_map +import jetstream_pt.attention_kernel as ak +from jetstream_pt import torchjax +from jetstream_pt import environment + + +P = jax.sharding.PartitionSpec +mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=("x",)) + + +class PageAttentionTest(parameterized.TestCase): + + def _make_env(self, bf16_enable=True): + torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 + torch.set_default_dtype(torch_dtype) + jax.config.update("jax_dynamic_shapes", False) + jax.config.update("jax_traceback_filtering", "off") + jax.config.update("jax_platform_name", "cpu") + jax.config.update("jax_enable_x64", False) + replicated = jax.sharding.NamedSharding(mesh, P()) + config = model_args.get_model_args("tiny", 128, 1, True) + environment_data = environment.JetEngineEnvironmentData() + environment_data.max_input_sequence_length = 128 + environment_data.max_input_sequence_length = 128 + environment_data.cache_sequence_length = 128 + environment_data.bf16_enable = bf16_enable + environment_data.model_type = "llama-2-tiny" + environment_data.batch_size = 3 + environment_data.num_layers = config.n_layers + environment_data.cache_shape = ( + 1, + config.n_kv_heads, + environment_data.cache_sequence_length, + config.dim // config.n_heads, + ) + env = environment.JetEngineEnvironment(environment_data) + env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu + env.sharding = replicated + return env, config + + # def test_dense_vs_page_attention(self): + # self._make_env() + # page_attention_output = test_sharded_multi_page_grouped_query_attention() + # output = test_dense_attention() + # print(f"output : {output[0, 0, 0:10]}") + # print(f"output : {output[0, 1, 0:10]}") + # print(f"page_attention_output : {page_attention_output[0, 0, 0:10]}") + # print(f"page_attention_output : {page_attention_output[0, 1, 0:10]}") + # self.assertTrue(jnp.array_equal(page_attention_output, output)) + + def test_jax_dense_vs_torch_dense(self): + self._make_env() + torch_output = test_torch_dense_attention() + output = test_dense_attention() + # print(f"output : {output[0, 1, 0:10]}") + # print(f"page_attention_output : {torch_output[0, 1, 0:10]}") + self.assertTrue(jnp.allclose(torch_output, torch_output, atol=1e-4)) + + # def test_torch_dense_attention_with_saved_data(self): + # self._make_env() + # _torch_dense_attention_with_saved_data() + + # def test_dense_attention_with_saved_data(self): + # self._make_env() + # _dense_attention_with_saved_data() + + # def test_sharded_multi_page_grouped_query_attention_with_saved_data(self): + # self._make_env() + # _sharded_multi_page_grouped_query_attention_with_saved_data() + + +multi_page_grouped_query_attention_fully_pipelined = paged_attention + + +def dense_attention(xq, keys, values, mask=None): + + bsz, _, _, head_dim = xq.shape + + scores = jnp.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + + scores = jax.nn.softmax(scores, axis=-1) + + output = jnp.einsum("ikjm,ikml->ikjl", scores, values) + return output + + +def shard_kv_heads( + paged_attention_impl: Callable[..., Any], + mesh: jax.sharding.Mesh, + kv_head_mesh_axis_name: str, +): + """Shards GQA PagedAttention along KV heads.""" + in_specs = ( + P(None, kv_head_mesh_axis_name, None), # q + P(kv_head_mesh_axis_name, None, None, None), # k + P(kv_head_mesh_axis_name, None, None, None), # v + P(), # lengths + P(), # page_indices + ) + + out_specs = P(None, kv_head_mesh_axis_name, None) # q + + return jax.jit( + shard_map.shard_map( + paged_attention_impl, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + ) + + +def get_data(step: int = 0): + n_heads = 8 + paged_attention_total_num_pages = 256 + paged_attention_page_size = 64 + xq = get_xq(1, n_heads, step) + keys = get_keys( + n_heads, paged_attention_total_num_pages, paged_attention_page_size, step + ) + values = get_values( + n_heads, paged_attention_total_num_pages, paged_attention_page_size, step + ) + seq_lens = get_seq_lens(step) + page_indices = get_page_indices() + return xq, keys, values, seq_lens, page_indices + + +def get_dense_data(xq, keys, values, seq_lens, page_indices, n_heads=8): + paged_attention_total_num_pages = 256 + paged_attention_page_size = 64 + batch_size = 1 + dim = 128 + + xq = jnp.expand_dims(xq, 2) + + seq_len = seq_lens[0] + num_pages = ( + seq_len + paged_attention_page_size - 1 + ) // paged_attention_page_size + page_indices = page_indices[0, 0:num_pages] + + keys = keys[:, page_indices, :, :] + keys = keys.reshape( + batch_size, n_heads, num_pages * paged_attention_page_size, dim + ) + values = values[:, page_indices, :, :] + values = values.reshape( + batch_size, n_heads, num_pages * paged_attention_page_size, dim + ) + + mask = jnp.full( + (batch_size, num_pages * paged_attention_page_size), + float("-inf"), + dtype=jnp.float32, + ) + batch = jnp.arange(batch_size) + mask = mask.at[batch, 0:seq_len].set(0) + return xq, keys, values, mask + + +def get_xq(batch, head, step): + # xq[0:1,0:1,0:2] + xq1 = [[[0.976562, 0.335938]]] + xq2 = [[[0.494141, 1.66406]]] + xq = xq1 if step == 0 else xq2 + xq = jnp.asarray(xq, dtype=jnp.float32) + xq = jnp.tile(xq, (batch, head, 64)) + return xq + + +def get_keys(head, pages, paged_attention_page_size, step): + # keys[0, 0, 0:9, 0:2] + key1 = [ + [ + [ + [-0.419922, -0.194336], + [-0.0270996, -0.211914], + [-0.234375, -0.178711], + [0.279297, 0.699219], + [0.0114746, -0.138672], + [-0.886719, -0.296875], + [0.106934, -0.269531], + [0.133789, -0.363281], + [0.953125, 0.165039], + ] + ] + ] + + # keys[0, 0, 0:10, 0:2] + key2 = [ + [ + [-0.419922, -0.194336], + [-0.0270996, -0.211914], + [-0.234375, -0.178711], + [0.279297, 0.699219], + [0.0114746, -0.138672], + [-0.886719, -0.296875], + [0.106934, -0.269531], + [0.133789, -0.363281], + [0.953125, 0.165039], + [-0.0247803, 0.197266], + ] + ] + + key = key1 if step == 0 else key2 + key = jnp.asarray(key) + key = jnp.tile(key, (1, 1, 1, 64)) + + r = jnp.zeros( + (head, pages, paged_attention_page_size, 128), dtype=jnp.float32 + ) + r = r.at[:, 0:1, 0 : key.shape[2], :].set(key) + return r + + +def get_values(head, pages, paged_attention_page_size, step): + # values[0:1, 0:1, :9, 0:2] + v1 = [ + [ + [ + [-0.000770569, -0.019043], + [-0.00793457, -0.00564575], + [0.00234985, -0.0113525], + [0.00311279, -0.00210571], + [0.012085, 0.00242615], + [-0.00665283, -0.00382996], + [-0.000762939, -0.00738525], + [-0.00811768, 0.00646973], + [-0.00352478, -0.00128174], + ] + ] + ] + + # values[0:1, 0:1, :10, 0:2] + v2 = [ + [-0.000770569, -0.019043], + [-0.00793457, -0.00564575], + [0.00234985, -0.0113525], + [0.00311279, -0.00210571], + [0.012085, 0.00242615], + [-0.00665283, -0.00382996], + [-0.000762939, -0.00738525], + [-0.00811768, 0.00646973], + [-0.00352478, -0.00128174], + [0.0014801, -0.00915527], + ] + v = v1 if step == 0 else v2 + v = jnp.asarray(v) + v = jnp.tile(v, (1, 1, 1, 64)) + + r = jnp.zeros( + (head, pages, paged_attention_page_size, 128), dtype=jnp.float32 + ) + r = r.at[:, 0:1, 0 : v.shape[2], :].set(v) + return r + + +def get_seq_lens(step): + lens = [9] if step == 0 else [10] + return jnp.asarray(lens, dtype=jnp.int32) + + +def get_page_indices(): + # (1, 32) + indices = [0] * 32 + return jnp.asarray(indices, dtype=jnp.int32).reshape(-1, 32) + + +def get_output(): + # (1, 1, 2) + output = [[[-0.00352478, -0.001297]]] + return jnp.asarray(output) + + +def test_sharded_multi_page_grouped_query_attention(): + xq, keys, values, seq_lens, page_indices = get_data(0) + + paged_attention_page_size = 64 + block_size = 512 + + print(f"mesh shape:{mesh.shape}") + q_pspec = jax.sharding.NamedSharding(mesh, P(None, "x", None)) + kv_pspec = jax.sharding.NamedSharding(mesh, P("x", None, None, None)) + q_sharded = jax.device_put(xq, q_pspec) + k_pages_sharded = jax.device_put(keys, kv_pspec) + v_pages_sharded = jax.device_put(values, kv_pspec) + + paged_attention_impl = functools.partial( + multi_page_grouped_query_attention_fully_pipelined, + pages_per_compute_block=block_size // paged_attention_page_size, + ) + sharded_paged_attention_impl = shard_kv_heads( + paged_attention_impl, + mesh, + kv_head_mesh_axis_name="x", + ) + + def run(): + o_sharded = sharded_paged_attention_impl( + q_sharded, + k_pages_sharded, + v_pages_sharded, + seq_lens, + page_indices, + ) + return o_sharded + + with mesh: + return run() # warm up + + +def test_dense_attention(): + xq, keys, values, seq_lens, page_indices = get_data(0) + xq, keys, values, mask = get_dense_data( + xq, keys, values, seq_lens, page_indices + ) + output = dense_attention(xq, keys, values, mask=mask) + return output.squeeze(2) + + +def test_torch_dense_attention(): + xq, keys, values, seq_lens, page_indices = get_data(0) + xq, keys, values, mask = get_dense_data( + xq, keys, values, seq_lens, page_indices + ) + + xq, keys, values, mask = torchjax.to_torch((xq, keys, values, mask)) + + output = ak.dense_attention(xq, keys, values, mask=mask) + output = torchjax.from_torch(output) + return output.squeeze(2) + + +def _torch_dense_attention_with_saved_data(): + loaded_data = jnp.load("/home/**/data/test/paged_attention1.npy.npz") + xq = loaded_data["xq"] + keys = loaded_data["keys"] + values = loaded_data["values"] + seq_lens = loaded_data["seq_lens"] + page_indices = loaded_data["page_indices"] + output = loaded_data["output"] + print(f"output result: {output[0, 0, 0:10]}") + print(f"output result1: {output[0, 0, 0:10]}") + + q_pspec = jax.sharding.NamedSharding(mesh, P(None, "x", None)) + kv_pspec = jax.sharding.NamedSharding(mesh, P("x", None, None, None)) + replicated = jax.sharding.NamedSharding(mesh, P()) + xq = jax.device_put(xq, q_pspec) + keys = jax.device_put(keys, kv_pspec) + values = jax.device_put(values, kv_pspec) + seq_lens = jax.device_put(seq_lens, replicated) + page_indices = jax.device_put(page_indices, replicated) + xq, keys, values, mask = get_dense_data( + xq, keys, values, seq_lens, page_indices, n_heads=32 + ) + + xq, keys, values, mask = torchjax.to_torch((xq, keys, values, mask)) + + output = ak.dense_attention(xq, keys, values, mask=mask) + return output.squeeze(2) + + +def _dense_attention_with_saved_data(): + loaded_data = jnp.load("/home/**/data/test/paged_attention1.npy.npz") + xq = loaded_data["xq"] + keys = loaded_data["keys"] + values = loaded_data["values"] + seq_lens = loaded_data["seq_lens"] + page_indices = loaded_data["page_indices"] + output = loaded_data["output"] + print(f"output result: {output[0, 0, 0:10]}") + print(f"output result1: {output[0, 1, 0:10]}") + q_pspec = jax.sharding.NamedSharding(mesh, P(None, "x", None)) + kv_pspec = jax.sharding.NamedSharding(mesh, P("x", None, None, None)) + replicated = jax.sharding.NamedSharding(mesh, P()) + xq = jax.device_put(xq, q_pspec) + keys = jax.device_put(keys, kv_pspec) + values = jax.device_put(values, kv_pspec) + seq_lens = jax.device_put(seq_lens, replicated) + page_indices = jax.device_put(page_indices, replicated) + xq, keys, values, mask = get_dense_data( + xq, keys, values, seq_lens, page_indices, n_heads=32 + ) + + output = dense_attention(xq, keys, values, mask=mask) + return output.squeeze(2) + + +def _sharded_multi_page_grouped_query_attention_with_saved_data(): + loaded_data = jnp.load("/home/**/data/test/paged_attention1.npy.npz") + xq = loaded_data["xq"] + keys = loaded_data["keys"] + values = loaded_data["values"] + seq_lens = loaded_data["seq_lens"] + page_indices = loaded_data["page_indices"] + output = loaded_data["output"] + print(f"output : {output[0, 0, 0:10]}") + print(f"output : {output[0, 0, 0:10]}") + + paged_attention_page_size = 64 + block_size = 512 + + print(f"mesh shape:{mesh.shape}") + q_pspec = jax.sharding.NamedSharding(mesh, P(None, "x", None)) + kv_pspec = jax.sharding.NamedSharding(mesh, P("x", None, None, None)) + replicated = jax.sharding.NamedSharding(mesh, P()) + q_sharded = jax.device_put(xq, q_pspec) + k_pages_sharded = jax.device_put(keys, kv_pspec) + v_pages_sharded = jax.device_put(values, kv_pspec) + seq_lens = jax.device_put(seq_lens, replicated) + page_indices = jax.device_put(page_indices, replicated) + + paged_attention_impl = functools.partial( + multi_page_grouped_query_attention_fully_pipelined, + pages_per_compute_block=block_size // paged_attention_page_size, + ) + sharded_paged_attention_impl = shard_kv_heads( + paged_attention_impl, + mesh, + kv_head_mesh_axis_name="x", + ) + + def run(): + o_sharded = sharded_paged_attention_impl( + q_sharded, + k_pages_sharded, + v_pages_sharded, + seq_lens, + page_indices, + ) + return o_sharded + + with mesh: + result = run() # warm up + print(f"output result: {result[0, 0, 0:10]}") + print(f"output result1: {result[0, 1, 0:10]}") + print(f"array equal: {jnp.array_equal(result, output)}") + return result + + +def test_compare_attention_saved_data(): + p_loaded_data = jnp.load("/home/fanhai/data/test/paged_attention1.npy.npz") + + p_output = p_loaded_data["output"] + + print(f"p_output : {p_output[0, 0, 0:10]}") + print(f"p_output : {p_output[0, 1, 0:10]}") + + loaded_data = jnp.load("/home/fanhai/data/test/dense.npy.npz") + xq = loaded_data["xq"] + keys = loaded_data["keys"] + values = loaded_data["values"] + output = loaded_data["output"] + output = output[:, :, 0, :] + print(f"output : {output[0, 0, 0:10]}") + print(f"output : {output[0, 1, 0:10]}") + print(f"array equal: {jnp.array_equal(p_output, output)}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_kv_cache_manager.py b/tests/test_kv_cache_manager.py index 85df2a4..74c7bce 100644 --- a/tests/test_kv_cache_manager.py +++ b/tests/test_kv_cache_manager.py @@ -1,6 +1,7 @@ import unittest import jax +import numpy as np import jax.numpy as jnp import torch @@ -11,6 +12,8 @@ from jetstream_pt import torchjax from absl.testing import parameterized +P = jax.sharding.PartitionSpec + class PageAttentnioTest(parameterized.TestCase): @@ -37,6 +40,9 @@ def _make_env(self, bf16_enable=True): ) env = environment.JetEngineEnvironment(environment_data) env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu + mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=("x",)) + replicated = jax.sharding.NamedSharding(mesh, P()) + env.sharding = replicated return env, config def test_page_attention_update(self): @@ -46,14 +52,15 @@ def test_page_attention_update(self): env, _ = self._make_env() pam = PageAttentionManager( - batch_size=5, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + batch_size=5, + paged_attention_total_num_pages=20, + paged_attention_page_size=4, + max_pages_per_sequence=4, ) shape = (1, 20, 4, 2) decode_caches = [] decode_caches.append( - PageKVCacheGenerate.empty( - shape=shape, device=None, bf16_enable=True, env=env - ) + PageKVCacheGenerate.empty(shape=shape, device=None, env=env) ) decode_caches = [c.state() for c in decode_caches] @@ -68,23 +75,32 @@ def _insert_prefill(seq_len, dim, slot): prefill_chache.update(k, v, 0) prefill_caches = [prefill_chache] prefill_caches = [c.state() for c in prefill_caches] - - return pam.insert_prefill_cache( - prefill_caches, decode_caches, slot, seq_len, env.cache_sharding + num_pages, update_indexes = pam.reserve_pages_insert(slot, seq_len) + _, kv_heads, _, dim = prefill_caches[0][0].shape + tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16) + + caches = pam.insert_prefill_cache( + prefill_caches=prefill_caches, + decode_caches=decode_caches, + update_indexes=update_indexes, + tep_kv=tep_kv, + sharding=env.sharding, ) + return caches + decode_caches = _insert_prefill(3, 2, 0) decode_caches = _insert_prefill(8, 2, 1) decode_caches = _insert_prefill(13, 2, 3) - lens = jnp.asarray([3, 8, 0, 13, 0]).reshape(5, 1) + lens = jnp.asarray([3, 8, 0, 13, 0]) pam.fill_new_pages(lens) page_token_indices = pam.get_page_token_indices(lens) page_token_indices = torchjax.to_torch(page_token_indices) caches_obj = [ PageKVCacheGenerate( - k, v, page_token_indices, self.cache_sharding, env=env + k, v, pam, page_token_indices, self.cache_sharding, env=env ) for k, v in torchjax.to_torch(decode_caches) ] diff --git a/tests/test_page_attention.py b/tests/test_page_attention.py index 4880fc8..294cb8d 100644 --- a/tests/test_page_attention.py +++ b/tests/test_page_attention.py @@ -1,6 +1,7 @@ import unittest import jax +import numpy as np import jax.numpy as jnp import torch @@ -10,6 +11,8 @@ from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill from absl.testing import parameterized +P = jax.sharding.PartitionSpec + class PageAttentionTest(parameterized.TestCase): @@ -19,6 +22,9 @@ def _make_env(self, bf16_enable=True): jax.config.update("jax_dynamic_shapes", False) jax.config.update("jax_traceback_filtering", "off") jax.config.update("jax_platform_name", "cpu") + jax.config.update("jax_enable_x64", False) + mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=("x",)) + replicated = jax.sharding.NamedSharding(mesh, P()) config = model_args.get_model_args("tiny", 128, 1, True) environment_data = environment.JetEngineEnvironmentData() environment_data.max_input_sequence_length = 128 @@ -36,6 +42,7 @@ def _make_env(self, bf16_enable=True): ) env = environment.JetEngineEnvironment(environment_data) env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu + env.sharding = replicated return env, config def test_prefill_insert(self): @@ -43,14 +50,15 @@ def test_prefill_insert(self): env, _ = self._make_env() pam = PageAttentionManager( - batch_size=3, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + batch_size=3, + paged_attention_total_num_pages=20, + paged_attention_page_size=4, + max_pages_per_sequence=4, ) shape = (1, 6, 4, 2) decode_caches = [] decode_caches.append( - PageKVCacheGenerate.empty( - shape=shape, device=None, bf16_enable=True, env=env - ) + PageKVCacheGenerate.empty(shape=shape, device=None, env=env) ) decode_caches = [c.state() for c in decode_caches] @@ -61,8 +69,30 @@ def test_prefill_insert(self): prefill_caches = [prefill_chache] prefill_caches = [c.state() for c in prefill_caches] - pam.insert_prefill_cache( - prefill_caches, decode_caches, 1, 3, env.x_sharding + num_pages, update_indexes = pam.reserve_pages_insert(0, 3) + _, kv_heads, _, dim = prefill_caches[0][0].shape + tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16) + + caches = pam.insert_prefill_cache( + prefill_caches=prefill_caches, + decode_caches=decode_caches, + update_indexes=update_indexes, + tep_kv=tep_kv, + sharding=env.sharding, + ) + expected_kv = jnp.arange(6).reshape(3, 2) + padding = jnp.asarray([[0, 0]]) + expected_kv = jnp.concatenate((expected_kv, padding)) + + self.assertTrue( + jnp.array_equal( + caches[0][0][0, 0, 0:4, 0:2], expected_kv.astype(jnp.bfloat16) + ) + ) + self.assertTrue( + jnp.array_equal( + caches[0][1][0, 0, 0:4, 0:2], expected_kv.astype(jnp.bfloat16) + ) ) def test_prefill_insert_multiple_pages(self): @@ -73,14 +103,15 @@ def test_prefill_insert_multiple_pages(self): env, _ = self._make_env() pam = PageAttentionManager( - batch_size=3, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + batch_size=3, + paged_attention_total_num_pages=20, + paged_attention_page_size=4, + max_pages_per_sequence=4, ) shape = (1, 20, 4, 2) decode_caches = [] decode_caches.append( - PageKVCacheGenerate.empty( - shape=shape, device=None, bf16_enable=True, env=env - ) + PageKVCacheGenerate.empty(shape=shape, device=None, env=env) ) decode_caches = [c.state() for c in decode_caches] @@ -93,9 +124,18 @@ def test_prefill_insert_multiple_pages(self): prefill_caches = [prefill_chache] prefill_caches = [c.state() for c in prefill_caches] + num_pages, update_indexes = pam.reserve_pages_insert(0, 6) + _, kv_heads, _, dim = prefill_caches[0][0].shape + tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16) + decode_caches = pam.insert_prefill_cache( - prefill_caches, decode_caches, 1, 6, env.cache_sharding + prefill_caches=prefill_caches, + decode_caches=decode_caches, + update_indexes=update_indexes, + tep_kv=tep_kv, + sharding=env.sharding, ) + self.assertEqual(len(decode_caches), 1) expected = jnp.arange(16).at[12:16].set([0, 0, 0, 0]).reshape(1, 2, 4, 2) @@ -109,7 +149,10 @@ def test_reserve_pages_decode(self): env, _ = self._make_env() pam = PageAttentionManager( - batch_size=3, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + batch_size=3, + paged_attention_total_num_pages=20, + paged_attention_page_size=4, + max_pages_per_sequence=4, ) slot = 1 seq_len = 8 @@ -122,13 +165,13 @@ def test_reserve_pages_decode(self): lens = jnp.asarray([0, seq_len, 0]) pam.fill_new_pages(lens) - expected_slot_page_indices = jnp.asarray([0, 1, 2]) - slot_page_indices = pam.page_indices[slot][0:3] + expected_slot_page_indices = jnp.asarray([0, 1, 2, 19]) + slot_page_indices = pam.page_indices[slot] self.assertTrue( jnp.array_equal(slot_page_indices, expected_slot_page_indices) ) - expected_0_page_indices = jnp.asarray([-1, -1, -1, -1]) + expected_0_page_indices = jnp.asarray([19, 19, 19, 19]) zer0_page_indices = pam.page_indices[0][0:4] self.assertTrue(jnp.array_equal(zer0_page_indices, expected_0_page_indices)) @@ -136,13 +179,16 @@ def test_get_page_token_indices(self): env, _ = self._make_env() pam = PageAttentionManager( - batch_size=5, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + batch_size=5, + paged_attention_total_num_pages=20, + paged_attention_page_size=4, + max_pages_per_sequence=4, ) pam.reserve_pages_insert(1, 8) pam.reserve_pages_insert(3, 13) pam.reserve_pages_insert(0, 3) - lens = jnp.asarray([3, 8, 0, 13, 0]).reshape(5, 1) + lens = jnp.asarray([3, 8, 0, 13, 0]) pam.fill_new_pages(lens) page_token_indices = pam.get_page_token_indices(lens)