From eaf393caabc3b5aa3a4520640c57c35495792757 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 30 Aug 2024 00:44:01 +0000 Subject: [PATCH 1/2] Update Jetstream, add optional sampler args. --- deps/JetStream | 2 +- jetstream_pt/engine.py | 29 ++++++++++++++++++----------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/deps/JetStream b/deps/JetStream index 69ce8a2..93de590 160000 --- a/deps/JetStream +++ b/deps/JetStream @@ -1 +1 @@ -Subproject commit 69ce8a2646ac32bea9194019078248b49e69728e +Subproject commit 93de5901a19d5271ceea7de107406b6a40f52c0c diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index f375cb5..85e3301 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -14,7 +14,7 @@ """Implement Jet Engine API.""" -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Callable import threading import functools import os @@ -256,6 +256,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], true_length: int, + sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: if isinstance(padded_tokens, jax.Array): batched_token = padded_tokens.reshape(1, -1) @@ -273,14 +274,17 @@ def prefill( ) if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words - token = sampling_utils.sampling( - logits[true_length - 1], - self.rng, - self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - self.env.temperature, - ) + if sampler: + token = sampler(logits[true_length - 1]) + else: + token = sampling_utils.sampling( + logits[true_length - 1], + self.rng, + self.env.sampling_algorithm, + self.env.topk, + self.env.nucleus_topp, + self.env.temperature, + ) token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( [ @@ -610,7 +614,7 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState + self, params: Any, decode_state: DecodeState, sampler = None ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] pos = decode_state.current_position @@ -653,7 +657,10 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - next_token = self._sampling(logits, self.env.batch_size) + if sampler: + next_token = sampler(logits[:, -1]) + else: + next_token = self._sampling(logits, self.env.batch_size) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 lens = decode_state.lens + 1 From 4b5ce4019ac72337bf8d1012622d18cc2bc6a943 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 30 Aug 2024 18:58:16 +0000 Subject: [PATCH 2/2] Ray lint --- jetstream_pt/engine.py | 2 +- jetstream_pt/ray_engine.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 85e3301..2a8ad31 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -614,7 +614,7 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState, sampler = None + self, params: Any, decode_state: DecodeState, sampler=None ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] pos = decode_state.current_position diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index a064a44..8cdcd3f 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -70,6 +70,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], true_length: int, + sampler=None, ) -> Tuple[Prefix, engine_api.ResultTokens]: if self.is_disaggregated: return self.prefill_impl( @@ -144,7 +145,10 @@ def insert( return None def generate( - self, params: Any, decode_state: DecodeState + self, + params: Any, + decode_state: DecodeState, + sampler=None, ) -> tuple[None, engine_api.ResultTokens]: if self.is_disaggregated: return self.generate_impl(params=params, decode_state=decode_state)