From a78f80647853c1b2ffbf7084bb25155966a411d4 Mon Sep 17 00:00:00 2001 From: Yuyan Peng Date: Mon, 31 Mar 2025 23:52:03 +0000 Subject: [PATCH] Refactor chunked prefill --- jetstream/core/orchestrator.py | 155 +++++++++++++++------------------ jetstream/engine/engine_api.py | 12 +-- 2 files changed, 77 insertions(+), 90 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 9b057baf..01ee2f97 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -91,7 +91,6 @@ import grpc import jax -import jax.numpy as jnp from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc @@ -523,61 +522,19 @@ def _process_prefill_content( tokenizer: tokenizer_api.Tokenizer, is_bos: bool, max_prefill_length: int, - chunked_prefill: bool = False, - chunk_size: Optional[int] = None, - ) -> ( - Tuple[(jax.Array | np.ndarray), int, jax.Array] - | Tuple[ - list[jax.Array | np.ndarray], - list[int], - list[jax.Array], - ] - ): - assert (chunked_prefill and chunk_size is not None) or ( - not chunked_prefill - ), "Set chunk_size when chunked_prefill is True to use chunked prefill" - + ) -> Tuple[jax.Array | np.ndarray, int]: content = request.prefill_content if isinstance(content, str): # If it's text input, tokenize and pad the input. - tokens, true_length = tokenizer.encode( + return tokenizer.encode( content, is_bos=is_bos, max_prefill_length=max_prefill_length, jax_padding=self._jax_padding, ) - positions = jnp.expand_dims( - jnp.arange(0, len(tokens), dtype=jnp.int32), 0 - ) - - if chunked_prefill and chunk_size is not None: - # tokenizer.encode handle the is_bos already, - # set is_bos to False while chunking - return token_utils.chunk_and_pad_tokens( - tokens[:true_length], - tokenizer.bos_id, - tokenizer.pad_id, - is_bos=False, - max_prefill_length=max_prefill_length, - chunk_size=chunk_size, - jax_padding=self._jax_padding, - ) - return tokens, true_length, positions - else: - if chunked_prefill and chunk_size is not None: - return token_utils.chunk_and_pad_tokens( - content, - tokenizer.bos_id, - tokenizer.pad_id, - is_bos=is_bos, - max_prefill_length=max_prefill_length, - chunk_size=chunk_size, - jax_padding=self._jax_padding, - ) - # If it's token input, pad the input. - tokens, true_length = token_utils.pad_tokens( + return token_utils.pad_tokens( content, tokenizer.bos_id, tokenizer.pad_id, @@ -585,10 +542,63 @@ def _process_prefill_content( max_prefill_length=max_prefill_length, jax_padding=self._jax_padding, ) - positions = jnp.expand_dims( - jnp.arange(0, len(tokens), dtype=jnp.int32), 0 + + def _do_chunked_prefill( + self, + prefill_engine: engine_api.Engine, + prefill_params: Any, + tokenizer: tokenizer_api.Tokenizer, + tokens: jax.Array | np.ndarray, + ) -> Tuple[engine_api.Prefix, engine_api.ResultTokens]: + """Do chunked prefill. + + Should not use without enabling use_chunked_prefill config. + """ + + assert prefill_engine.use_chunked_prefill + + prefill_result = None + first_token = None + + existing_prefix = None + for start_pos in range( + 0, + len(tokens), + prefill_engine.prefill_chunk_size, + ): + input_token = tokens[ + start_pos : min( + len(tokens), start_pos + prefill_engine.prefill_chunk_size + ) + ] + padded_input_token, input_true_length = token_utils.pad_tokens( + input_token, + tokenizer.bos_id, + tokenizer.pad_id, + is_bos=False, + max_prefill_length=prefill_engine.max_prefill_length, + jax_padding=self._jax_padding, + ) + prefill_result, first_token = prefill_engine.prefill( + params=prefill_params, + existing_prefix=existing_prefix, + padded_tokens=padded_input_token, + true_length=input_true_length, + ) + existing_prefix = engine_api.ExistingPrefix( + cache=prefill_result["cache"], + common_prefix_tokens=tokens[ + 0 : min( + len(tokens), start_pos + prefill_engine.prefill_chunk_size + ) + ], ) - return tokens, true_length, positions + + # Should assign in the loop + assert prefill_result is not None + assert first_token is not None + + return prefill_result, first_token def _prefill_thread(self, idx: int): """Thread which runs in the background performing prefills.""" @@ -616,12 +626,11 @@ def _prefill_thread(self, idx: int): f" is_bos: {is_bos}", ) # Tokenize and padding the text or token input. - padded_tokens, true_length, _ = self._process_prefill_content( + padded_tokens, true_length = self._process_prefill_content( request, tokenizer, is_bos, prefill_engine.max_prefill_length, - False, ) # Compute new kv cache for the prefill_content. @@ -636,40 +645,12 @@ def _prefill_thread(self, idx: int): else: # if chunked_prefill is used, if prefill_engine.use_chunked_prefill: - padded_chunked_tokens, true_lengths_of_chunks, positions_chunks = ( - self._process_prefill_content( - request, - tokenizer, - is_bos, - prefill_engine.max_prefill_length, - prefill_engine.use_chunked_prefill, - prefill_engine.prefill_chunk_size, - ) + prefill_result, first_token = self._do_chunked_prefill( + prefill_engine, + prefill_params, + tokenizer, + padded_tokens[:true_length], ) - prefill_result = None - for chunk_num, _ in enumerate(padded_chunked_tokens): - cache_so_far = ( - {} if prefill_result is None else prefill_result["cache"] # pylint: disable=unsubscriptable-object - ) - prefill_result, first_token = prefill_engine.prefill( - params=prefill_params | {"cache": cache_so_far}, - padded_tokens=padded_chunked_tokens[chunk_num], - true_length=true_lengths_of_chunks[chunk_num], - positions=positions_chunks[chunk_num], - previous_chunk=prefill_result, - complete_prompt_true_length=true_length, - complete_padded_prompt=padded_tokens, - ) - # true_length_array is arrays of 1 true lengths so far - t_l_array = jnp.expand_dims( - jnp.arange( - 0, - chunk_num * prefill_engine.prefill_chunk_size - + true_lengths_of_chunks[chunk_num], - ), - 0, - ) - prefill_result["true_length_array"] = t_l_array else: # Compute new kv cache for the prefill_content. prefill_result, first_token = prefill_engine.prefill( @@ -677,8 +658,12 @@ def _prefill_thread(self, idx: int): padded_tokens=padded_tokens, true_length=true_length, ) + + request.complete = np.zeros( + (prefill_engine.samples_per_slot,), np.bool_ + ) + request.prefill_result = prefill_result - request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) # put first token to detokenize queue my_detokenize_backlog = self._detokenize_backlogs[idx] diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index 1c85b256..5e5bb044 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -47,6 +47,12 @@ PRNGKeyType = Any +@struct.dataclass +class ExistingPrefix: + cache: Any + common_prefix_tokens: jax.Array + + @struct.dataclass class SlotData: """Class to store slot data.""" @@ -157,14 +163,10 @@ def prefill( self, *, params: Params, - existing_prefix: Optional[Prefix] = None, + existing_prefix: Optional[ExistingPrefix] = None, padded_tokens: jax.Array, true_length: int, sampler: Optional[Callable[[Any], Any]] = None, - complete_prompt_true_length: Optional[int] = None, - complete_padded_prompt: Optional[jax.Array] = None, - positions: Optional[jax.Array] = None, - previous_chunk: Optional[Any] = None, request_id: Optional[uuid.UUID] = None, ) -> Tuple[Prefix, ResultTokens]: """Computes a kv-cache for a set of tokens conditional on existing cache.