Skip to content
2 changes: 1 addition & 1 deletion deps/JetStream
27 changes: 27 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,29 @@
"for performance tuning and debugging only",
required=False,
)
flags.DEFINE_float(
"temperature",
1.0,
"temperature parameter for scaling probability."
"Only invoked when sampling algorithm is set to"
"weighted or topk",
)
flags.DEFINE_string(
"sampling_algorithm",
"greedy",
"sampling algorithm to use. Options:"
"('greedy', 'weighted', 'neucleus', 'topk')",
)
flags.DEFINE_float(
"nucleus_topp",
0.0,
"restricting to p probability mass before sampling",
)
flags.DEFINE_integer(
"topk",
0,
"size of top k used when sampling next token",
)


def create_quantization_config_from_flags():
Expand Down Expand Up @@ -140,6 +163,10 @@ def create_engine_from_config_flags():
shard_on_batch=FLAGS.shard_on_batch,
ragged_mha=FLAGS.ragged_mha,
starting_position=FLAGS.starting_position,
temperature=FLAGS.temperature,
sampling_algorithm=FLAGS.sampling_algorithm,
nucleus_topp=FLAGS.nucleus_topp,
topk=FLAGS.topk,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down
32 changes: 28 additions & 4 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np

from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils
from jetstream.engine import sampling_utils
import torch_xla2
from torch.utils import _pytree as pytree

Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
self.pt_model = pt_model
self.env = env
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
self.rng = jax.random.PRNGKey(0)

self.y_sharding = env.sharding_by_axis(1)
self.x_sharding = env.sharding_by_axis(0)
Expand Down Expand Up @@ -220,7 +222,14 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
if len(logits.shape) == 2:
logits = jnp.expand_dims(logits, 0)
return (
jnp.argmax(logits[:, -1], axis=-1)
sampling_utils.sampling(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit for the sampling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests to test_engine.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding it, looks good to me!

logits[:, -1],
self.rng,
self.env.sampling_algorithm,
self.env.topk,
self.env.nucleus_topp,
self.env.temperature,
)
.reshape(batch_size, -1)
.astype(jnp.int32)
)
Expand Down Expand Up @@ -248,9 +257,16 @@ def prefill(
input_indexes,
)
if len(logits.shape) == 3: # b, seqlen, num words
logits = logits[0]

token = jnp.argmax(logits[true_length - 1])
logits = logits[0] # seqlen, num words

token = sampling_utils.sampling(
logits[true_length - 1],
self.rng,
self.env.sampling_algorithm,
self.env.topk,
self.env.nucleus_topp,
self.env.temperature,
)

# truncate to true_length didnt work need to be out side of jit
# caches = [
Expand Down Expand Up @@ -762,6 +778,10 @@ def create_pytorch_engine(
shard_on_batch=False,
ragged_mha=False,
starting_position=512,
temperature=None,
sampling_algorithm="greedy",
nucleus_topp=None,
topk=None,
) -> PyTorchEngine:
"""Returns: The pytorch engine."""

Expand Down Expand Up @@ -827,6 +847,10 @@ def create_pytorch_engine(
shard_on_batch=shard_on_batch,
ragged_mha=ragged_mha,
starting_position=starting_position,
temperature=temperature,
sampling_algorithm=sampling_algorithm,
nucleus_topp=nucleus_topp,
topk=topk,
)

if shard_on_batch and sharding_config:
Expand Down
13 changes: 13 additions & 0 deletions jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ class JetEngineEnvironmentData:
# Starting position
starting_position: int = 512

# Variables used in token sampling
# sampling algorithm to use ("greedy", "weighted", "neucleus", "topk")
sampling_algorithm: str = "greedy"

# size of top k used when sampling next token
topk: int = 0

# restricting to p probability mass before sampling
nucleus_topp: float = 0.0

# temperature parameter for scaling probability
temperature: float = 1.0


# pylint: disable-next=all
class JetEngineEnvironment:
Expand Down
130 changes: 88 additions & 42 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,92 @@

# pylint: disable=all


# This model will output tokens with value of 2
# and will update caches with value of 1.0
# class Dummy(torch.nn.Module):

# def __init__(self):
# super().__init__()
# self.params = None

# def forward(
# self,
# tokens: torch.Tensor,
# input_pos: torch.Tensor,
# caches: List[Any],
# mask,
# ):
# batch_size, seqlen = tokens.shape
# for cache in caches:
# cache.update(torch.ones((batch_size, seqlen)))
# return torch.ones((batch_size, seqlen), dtype=torch.int32) * 2


# class EngineTest(unittest.TestCase):

# def _make_small_engine(self, quantize=False):
# env_data = JetEngineEnvironmentData()
# env_data.max_input_sequence_length = 128
# env_data.max_input_sequence_length = 128
# env_data.cache_sequence_length = 128
# env_data.model_type = 'llama-2-tiny'
# if quantize:
# env_data.enable_kv_quantization = True
# env_data.enable_weight_quantization = True

# env = JetEngineEnvironment(env_data)
# model = Dummy()
# model.params = env._model_arg # llama's model arg

# engine = PyTorchEngine(model, env)
# return engine
import unittest
import jax
import jax.numpy as jnp

from jetstream_pt.third_party.llama import model_exportable
from jetstream_pt.engine import PyTorchEngine
from tests import helpers


class EngineTest(unittest.TestCase):

def setup(self):
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
model_ours = model_exportable.Transformer(model_arg, env)
engine = PyTorchEngine(pt_model=model_ours, env=env)
engine.rng = jax.random.PRNGKey(0)
return engine

def test_sampling_2D(self):
# test greedy
engine = self.setup()
self.assertEqual(engine.env.sampling_algorithm, "greedy")
logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]])
token = engine._sampling(logits, batch_size=1)
self.assertEqual(token, jnp.array([[0]]))
self.assertTrue(jnp.isdtype(token, jnp.int32))

# test weighted
engine.env.sampling_algorithm = "weighted"
engine.env.temperature = 5.0
token = engine._sampling(logits, batch_size=1)
self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
self.assertTrue(jnp.isdtype(token, jnp.int32))

# test topk
engine.env.sampling_algorithm = "topk"
engine.env.temperature = 5.0
engine.env.topk = 4
token = engine._sampling(logits, batch_size=1)
self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
self.assertTrue(jnp.isdtype(token, jnp.int32))

# test nucleus
engine.env.sampling_algorithm = "nucleus"
engine.env.temperature = 0.0
engine.env.nucleus_topp = 0.8
token = engine._sampling(logits, batch_size=1)
self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
self.assertTrue(jnp.isdtype(token, jnp.int32))

def test_sampling_3D(self):
# test greedy
engine = self.setup()
self.assertEqual(engine.env.sampling_algorithm, "greedy")
logits = jnp.array(
[
[[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]],
[[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]],
]
)
token = engine._sampling(logits, batch_size=2)
self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]])))
self.assertTrue(jnp.isdtype(token, jnp.int32))

# test weighted
engine.env.sampling_algorithm = "weighted"
engine.env.temperature = 10.0
token = engine._sampling(logits, batch_size=2)
self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]])))
self.assertTrue(jnp.isdtype(token, jnp.int32))

# test topk
engine.env.sampling_algorithm = "topk"
engine.env.temperature = 1.0
engine.env.topk = 3
token = engine._sampling(logits, batch_size=2)
self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]])))
self.assertTrue(jnp.isdtype(token, jnp.int32))

# test nucleus
engine.env.sampling_algorithm = "nucleus"
engine.env.temperature = 1.0
engine.env.nucleus_topp = 0.8
token = engine._sampling(logits, batch_size=2)
self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]])))
self.assertTrue(jnp.isdtype(token, jnp.int32))


# def test_insert(self):
Expand Down Expand Up @@ -229,5 +275,5 @@
# # prefill


# if __name__ == '__main__':
# unittest.main()
if __name__ == "__main__":
unittest.main()