diff --git a/deps/JetStream b/deps/JetStream index 69ce8a2..93de590 160000 --- a/deps/JetStream +++ b/deps/JetStream @@ -1 +1 @@ -Subproject commit 69ce8a2646ac32bea9194019078248b49e69728e +Subproject commit 93de5901a19d5271ceea7de107406b6a40f52c0c diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index f375cb5..2a8ad31 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -14,7 +14,7 @@ """Implement Jet Engine API.""" -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Callable import threading import functools import os @@ -256,6 +256,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], true_length: int, + sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: if isinstance(padded_tokens, jax.Array): batched_token = padded_tokens.reshape(1, -1) @@ -273,14 +274,17 @@ def prefill( ) if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words - token = sampling_utils.sampling( - logits[true_length - 1], - self.rng, - self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - self.env.temperature, - ) + if sampler: + token = sampler(logits[true_length - 1]) + else: + token = sampling_utils.sampling( + logits[true_length - 1], + self.rng, + self.env.sampling_algorithm, + self.env.topk, + self.env.nucleus_topp, + self.env.temperature, + ) token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( [ @@ -610,7 +614,7 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState + self, params: Any, decode_state: DecodeState, sampler=None ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] pos = decode_state.current_position @@ -653,7 +657,10 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - next_token = self._sampling(logits, self.env.batch_size) + if sampler: + next_token = sampler(logits[:, -1]) + else: + next_token = self._sampling(logits, self.env.batch_size) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 lens = decode_state.lens + 1 diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index a064a44..8cdcd3f 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -70,6 +70,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], true_length: int, + sampler=None, ) -> Tuple[Prefix, engine_api.ResultTokens]: if self.is_disaggregated: return self.prefill_impl( @@ -144,7 +145,10 @@ def insert( return None def generate( - self, params: Any, decode_state: DecodeState + self, + params: Any, + decode_state: DecodeState, + sampler=None, ) -> tuple[None, engine_api.ResultTokens]: if self.is_disaggregated: return self.generate_impl(params=params, decode_state=decode_state)