Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/JetStream
29 changes: 18 additions & 11 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Implement Jet Engine API."""

from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union, Callable
import threading
import functools
import os
Expand Down Expand Up @@ -256,6 +256,7 @@ def prefill(
existing_prefix: Optional[Prefix] = None,
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array],
true_length: int,
sampler: Optional[Callable[[Any], Any]] = None,
) -> Tuple[Prefix, engine_api.ResultTokens]:
if isinstance(padded_tokens, jax.Array):
batched_token = padded_tokens.reshape(1, -1)
Expand All @@ -273,14 +274,17 @@ def prefill(
)
if len(logits.shape) == 3: # b, seqlen, num words
logits = logits[0] # seqlen, num words
token = sampling_utils.sampling(
logits[true_length - 1],
self.rng,
self.env.sampling_algorithm,
self.env.topk,
self.env.nucleus_topp,
self.env.temperature,
)
if sampler:
token = sampler(logits[true_length - 1])
else:
token = sampling_utils.sampling(
logits[true_length - 1],
self.rng,
self.env.sampling_algorithm,
self.env.topk,
self.env.nucleus_topp,
self.env.temperature,
)
token_out = jnp.reshape(token, (1, 1))
data = jnp.concatenate(
[
Expand Down Expand Up @@ -610,7 +614,7 @@ def false_comp(b, i, bk, start, end):
return b_next, i_next

def generate(
self, params: Any, decode_state: DecodeState
self, params: Any, decode_state: DecodeState, sampler=None
) -> tuple[DecodeState, engine_api.ResultTokens]:
# seq_len = padded_tokens.shape[0]
pos = decode_state.current_position
Expand Down Expand Up @@ -653,7 +657,10 @@ def update_mask():
# fill mask later, now use flash attention
mask = update_mask()

next_token = self._sampling(logits, self.env.batch_size)
if sampler:
next_token = sampler(logits[:, -1])
else:
next_token = self._sampling(logits, self.env.batch_size)
if self.env.ring_buffer:
input_pos = decode_state.input_pos + 1
lens = decode_state.lens + 1
Expand Down
6 changes: 5 additions & 1 deletion jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def prefill(
existing_prefix: Optional[Prefix] = None,
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
true_length: int,
sampler=None,
) -> Tuple[Prefix, engine_api.ResultTokens]:
if self.is_disaggregated:
return self.prefill_impl(
Expand Down Expand Up @@ -144,7 +145,10 @@ def insert(
return None

def generate(
self, params: Any, decode_state: DecodeState
self,
params: Any,
decode_state: DecodeState,
sampler=None,
) -> tuple[None, engine_api.ResultTokens]:
if self.is_disaggregated:
return self.generate_impl(params=params, decode_state=decode_state)
Expand Down