diff --git a/deps/JetStream b/deps/JetStream index ec26ec24..8a1e3132 160000 --- a/deps/JetStream +++ b/deps/JetStream @@ -1 +1 @@ -Subproject commit ec26ec2427fad737f898bdec9a186f2acd49d6f1 +Subproject commit 8a1e31322e8e953909482b71f2689f82dbf4572f diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 354ed5d3..f860557e 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -78,6 +78,29 @@ "for performance tuning and debugging only", required=False, ) +flags.DEFINE_float( + "temperature", + 1.0, + "temperature parameter for scaling probability." + "Only invoked when sampling algorithm is set to" + "weighted or topk", +) +flags.DEFINE_string( + "sampling_algorithm", + "greedy", + "sampling algorithm to use. Options:" + "('greedy', 'weighted', 'neucleus', 'topk')", +) +flags.DEFINE_float( + "nucleus_topp", + 0.0, + "restricting to p probability mass before sampling", +) +flags.DEFINE_integer( + "topk", + 0, + "size of top k used when sampling next token", +) def create_quantization_config_from_flags(): @@ -140,6 +163,10 @@ def create_engine_from_config_flags(): shard_on_batch=FLAGS.shard_on_batch, ragged_mha=FLAGS.ragged_mha, starting_position=FLAGS.starting_position, + temperature=FLAGS.temperature, + sampling_algorithm=FLAGS.sampling_algorithm, + nucleus_topp=FLAGS.nucleus_topp, + topk=FLAGS.topk, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 68402722..ced821ec 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -28,6 +28,7 @@ import numpy as np from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils +from jetstream.engine import sampling_utils import torch_xla2 from torch.utils import _pytree as pytree @@ -85,6 +86,7 @@ def __init__( self.pt_model = pt_model self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 + self.rng = jax.random.PRNGKey(0) self.y_sharding = env.sharding_by_axis(1) self.x_sharding = env.sharding_by_axis(0) @@ -220,7 +222,14 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) return ( - jnp.argmax(logits[:, -1], axis=-1) + sampling_utils.sampling( + logits[:, -1], + self.rng, + self.env.sampling_algorithm, + self.env.topk, + self.env.nucleus_topp, + self.env.temperature, + ) .reshape(batch_size, -1) .astype(jnp.int32) ) @@ -248,9 +257,16 @@ def prefill( input_indexes, ) if len(logits.shape) == 3: # b, seqlen, num words - logits = logits[0] - - token = jnp.argmax(logits[true_length - 1]) + logits = logits[0] # seqlen, num words + + token = sampling_utils.sampling( + logits[true_length - 1], + self.rng, + self.env.sampling_algorithm, + self.env.topk, + self.env.nucleus_topp, + self.env.temperature, + ) # truncate to true_length didnt work need to be out side of jit # caches = [ @@ -762,6 +778,10 @@ def create_pytorch_engine( shard_on_batch=False, ragged_mha=False, starting_position=512, + temperature=None, + sampling_algorithm="greedy", + nucleus_topp=None, + topk=None, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -827,6 +847,10 @@ def create_pytorch_engine( shard_on_batch=shard_on_batch, ragged_mha=ragged_mha, starting_position=starting_position, + temperature=temperature, + sampling_algorithm=sampling_algorithm, + nucleus_topp=nucleus_topp, + topk=topk, ) if shard_on_batch and sharding_config: diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 5ea8f3a3..6f87147f 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -96,6 +96,19 @@ class JetEngineEnvironmentData: # Starting position starting_position: int = 512 + # Variables used in token sampling + # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") + sampling_algorithm: str = "greedy" + + # size of top k used when sampling next token + topk: int = 0 + + # restricting to p probability mass before sampling + nucleus_topp: float = 0.0 + + # temperature parameter for scaling probability + temperature: float = 1.0 + # pylint: disable-next=all class JetEngineEnvironment: diff --git a/tests/test_engine.py b/tests/test_engine.py index 286e9b31..57245c07 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -14,46 +14,92 @@ # pylint: disable=all - -# This model will output tokens with value of 2 -# and will update caches with value of 1.0 -# class Dummy(torch.nn.Module): - -# def __init__(self): -# super().__init__() -# self.params = None - -# def forward( -# self, -# tokens: torch.Tensor, -# input_pos: torch.Tensor, -# caches: List[Any], -# mask, -# ): -# batch_size, seqlen = tokens.shape -# for cache in caches: -# cache.update(torch.ones((batch_size, seqlen))) -# return torch.ones((batch_size, seqlen), dtype=torch.int32) * 2 - - -# class EngineTest(unittest.TestCase): - -# def _make_small_engine(self, quantize=False): -# env_data = JetEngineEnvironmentData() -# env_data.max_input_sequence_length = 128 -# env_data.max_input_sequence_length = 128 -# env_data.cache_sequence_length = 128 -# env_data.model_type = 'llama-2-tiny' -# if quantize: -# env_data.enable_kv_quantization = True -# env_data.enable_weight_quantization = True - -# env = JetEngineEnvironment(env_data) -# model = Dummy() -# model.params = env._model_arg # llama's model arg - -# engine = PyTorchEngine(model, env) -# return engine +import unittest +import jax +import jax.numpy as jnp + +from jetstream_pt.third_party.llama import model_exportable +from jetstream_pt.engine import PyTorchEngine +from tests import helpers + + +class EngineTest(unittest.TestCase): + + def setup(self): + env, model_arg = helpers.make_env_tiny(bf16_enable=True) + model_ours = model_exportable.Transformer(model_arg, env) + engine = PyTorchEngine(pt_model=model_ours, env=env) + engine.rng = jax.random.PRNGKey(0) + return engine + + def test_sampling_2D(self): + # test greedy + engine = self.setup() + self.assertEqual(engine.env.sampling_algorithm, "greedy") + logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]]) + token = engine._sampling(logits, batch_size=1) + self.assertEqual(token, jnp.array([[0]])) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test weighted + engine.env.sampling_algorithm = "weighted" + engine.env.temperature = 5.0 + token = engine._sampling(logits, batch_size=1) + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test topk + engine.env.sampling_algorithm = "topk" + engine.env.temperature = 5.0 + engine.env.topk = 4 + token = engine._sampling(logits, batch_size=1) + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test nucleus + engine.env.sampling_algorithm = "nucleus" + engine.env.temperature = 0.0 + engine.env.nucleus_topp = 0.8 + token = engine._sampling(logits, batch_size=1) + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + def test_sampling_3D(self): + # test greedy + engine = self.setup() + self.assertEqual(engine.env.sampling_algorithm, "greedy") + logits = jnp.array( + [ + [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], + [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], + ] + ) + token = engine._sampling(logits, batch_size=2) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test weighted + engine.env.sampling_algorithm = "weighted" + engine.env.temperature = 10.0 + token = engine._sampling(logits, batch_size=2) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test topk + engine.env.sampling_algorithm = "topk" + engine.env.temperature = 1.0 + engine.env.topk = 3 + token = engine._sampling(logits, batch_size=2) + self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test nucleus + engine.env.sampling_algorithm = "nucleus" + engine.env.temperature = 1.0 + engine.env.nucleus_topp = 0.8 + token = engine._sampling(logits, batch_size=2) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) # def test_insert(self): @@ -229,5 +275,5 @@ # # prefill -# if __name__ == '__main__': -# unittest.main() +if __name__ == "__main__": + unittest.main()