Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 70 additions & 85 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@

import grpc
import jax
import jax.numpy as jnp

from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
Expand Down Expand Up @@ -523,72 +522,83 @@ def _process_prefill_content(
tokenizer: tokenizer_api.Tokenizer,
is_bos: bool,
max_prefill_length: int,
chunked_prefill: bool = False,
chunk_size: Optional[int] = None,
) -> (
Tuple[(jax.Array | np.ndarray), int, jax.Array]
| Tuple[
list[jax.Array | np.ndarray],
list[int],
list[jax.Array],
]
):
assert (chunked_prefill and chunk_size is not None) or (
not chunked_prefill
), "Set chunk_size when chunked_prefill is True to use chunked prefill"

) -> Tuple[jax.Array | np.ndarray, int]:
content = request.prefill_content
if isinstance(content, str):
# If it's text input, tokenize and pad the input.
tokens, true_length = tokenizer.encode(
return tokenizer.encode(
content,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
jax_padding=self._jax_padding,
)
positions = jnp.expand_dims(
jnp.arange(0, len(tokens), dtype=jnp.int32), 0
)

if chunked_prefill and chunk_size is not None:
# tokenizer.encode handle the is_bos already,
# set is_bos to False while chunking
return token_utils.chunk_and_pad_tokens(
tokens[:true_length],
tokenizer.bos_id,
tokenizer.pad_id,
is_bos=False,
max_prefill_length=max_prefill_length,
chunk_size=chunk_size,
jax_padding=self._jax_padding,
)
return tokens, true_length, positions

else:
if chunked_prefill and chunk_size is not None:
return token_utils.chunk_and_pad_tokens(
content,
tokenizer.bos_id,
tokenizer.pad_id,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
chunk_size=chunk_size,
jax_padding=self._jax_padding,
)

# If it's token input, pad the input.
tokens, true_length = token_utils.pad_tokens(
return token_utils.pad_tokens(
content,
tokenizer.bos_id,
tokenizer.pad_id,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
jax_padding=self._jax_padding,
)
positions = jnp.expand_dims(
jnp.arange(0, len(tokens), dtype=jnp.int32), 0

def _do_chunked_prefill(
self,
prefill_engine: engine_api.Engine,
prefill_params: Any,
tokenizer: tokenizer_api.Tokenizer,
tokens: jax.Array | np.ndarray,
) -> Tuple[engine_api.Prefix, engine_api.ResultTokens]:
"""Do chunked prefill.

Should not use without enabling use_chunked_prefill config.
"""

assert prefill_engine.use_chunked_prefill

prefill_result = None
first_token = None

existing_prefix = None
for start_pos in range(
0,
len(tokens),
prefill_engine.prefill_chunk_size,
):
input_token = tokens[
start_pos : min(
len(tokens), start_pos + prefill_engine.prefill_chunk_size
)
]
padded_input_token, input_true_length = token_utils.pad_tokens(
input_token,
tokenizer.bos_id,
tokenizer.pad_id,
is_bos=False,
max_prefill_length=prefill_engine.max_prefill_length,
jax_padding=self._jax_padding,
)
prefill_result, first_token = prefill_engine.prefill(
params=prefill_params,
existing_prefix=existing_prefix,
padded_tokens=padded_input_token,
true_length=input_true_length,
)
existing_prefix = engine_api.ExistingPrefix(
cache=prefill_result["cache"],
common_prefix_tokens=tokens[
0 : min(
len(tokens), start_pos + prefill_engine.prefill_chunk_size
)
],
)
return tokens, true_length, positions

# Should assign in the loop
assert prefill_result is not None
assert first_token is not None

return prefill_result, first_token

def _prefill_thread(self, idx: int):
"""Thread which runs in the background performing prefills."""
Expand Down Expand Up @@ -616,12 +626,11 @@ def _prefill_thread(self, idx: int):
f" is_bos: {is_bos}",
)
# Tokenize and padding the text or token input.
padded_tokens, true_length, _ = self._process_prefill_content(
padded_tokens, true_length = self._process_prefill_content(
request,
tokenizer,
is_bos,
prefill_engine.max_prefill_length,
False,
)

# Compute new kv cache for the prefill_content.
Expand All @@ -636,49 +645,25 @@ def _prefill_thread(self, idx: int):
else:
# if chunked_prefill is used,
if prefill_engine.use_chunked_prefill:
padded_chunked_tokens, true_lengths_of_chunks, positions_chunks = (
self._process_prefill_content(
request,
tokenizer,
is_bos,
prefill_engine.max_prefill_length,
prefill_engine.use_chunked_prefill,
prefill_engine.prefill_chunk_size,
)
prefill_result, first_token = self._do_chunked_prefill(
prefill_engine,
prefill_params,
tokenizer,
padded_tokens[:true_length],
)
prefill_result = None
for chunk_num, _ in enumerate(padded_chunked_tokens):
cache_so_far = (
{} if prefill_result is None else prefill_result["cache"] # pylint: disable=unsubscriptable-object
)
prefill_result, first_token = prefill_engine.prefill(
params=prefill_params | {"cache": cache_so_far},
padded_tokens=padded_chunked_tokens[chunk_num],
true_length=true_lengths_of_chunks[chunk_num],
positions=positions_chunks[chunk_num],
previous_chunk=prefill_result,
complete_prompt_true_length=true_length,
complete_padded_prompt=padded_tokens,
)
# true_length_array is arrays of 1 true lengths so far
t_l_array = jnp.expand_dims(
jnp.arange(
0,
chunk_num * prefill_engine.prefill_chunk_size
+ true_lengths_of_chunks[chunk_num],
),
0,
)
prefill_result["true_length_array"] = t_l_array
else:
# Compute new kv cache for the prefill_content.
prefill_result, first_token = prefill_engine.prefill(
params=prefill_params,
padded_tokens=padded_tokens,
true_length=true_length,
)

request.complete = np.zeros(
(prefill_engine.samples_per_slot,), np.bool_
)

request.prefill_result = prefill_result
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)

# put first token to detokenize queue
my_detokenize_backlog = self._detokenize_backlogs[idx]
Expand Down
12 changes: 7 additions & 5 deletions jetstream/engine/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
PRNGKeyType = Any


@struct.dataclass
class ExistingPrefix:
cache: Any
common_prefix_tokens: jax.Array


@struct.dataclass
class SlotData:
"""Class to store slot data."""
Expand Down Expand Up @@ -157,14 +163,10 @@ def prefill(
self,
*,
params: Params,
existing_prefix: Optional[Prefix] = None,
existing_prefix: Optional[ExistingPrefix] = None,
padded_tokens: jax.Array,
true_length: int,
sampler: Optional[Callable[[Any], Any]] = None,
complete_prompt_true_length: Optional[int] = None,
complete_padded_prompt: Optional[jax.Array] = None,
positions: Optional[jax.Array] = None,
previous_chunk: Optional[Any] = None,
request_id: Optional[uuid.UUID] = None,
) -> Tuple[Prefix, ResultTokens]:
"""Computes a kv-cache for a set of tokens conditional on existing cache.
Expand Down
Loading