diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 8d97df14..07128881 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -21,7 +21,6 @@ import jax import jax.numpy as jnp -from jetstream.engine import token_utils from jetstream_pt import engine as je # pylint: disable-next=all from benchmarks import analyze_sharegpt @@ -97,11 +96,11 @@ def create_engine(): def run_prefill_time(engine, params, decode_state, seqlen): """Run prefill and measure time.""" metadata = engine.get_tokenizer() - vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) + tokenizer = engine.build_tokenizer(metadata) text = "This is a beautiful day" - tokens, true_length = token_utils.tokenize_and_pad( - text, vocab, is_bos=True, prefill_lengths=[seqlen] + tokens, true_length = tokenizer.encode( + text, is_bos=True, prefill_lengths=[seqlen] ) for _ in range(3): diff --git a/install_everything.sh b/install_everything.sh index aca732b3..f9838a45 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -13,7 +13,7 @@ # limitations under the License. TORCHXLA_TAG=jetstream-pytorch -JETSTREAM_TAG=v0.2.0 +JETSTREAM_TAG=v0.2.1 # Uninstall existing jax pip3 show jax && pip3 uninstall -y jax diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 16d0038a..328f0fb3 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -26,14 +26,14 @@ import torch import numpy as np -from jetstream.engine import engine_api, tokenizer_pb2, token_utils +from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils import torch_xla2 from torch.utils import _pytree as pytree from jetstream_pt import cache_manager from jetstream_pt import quantize from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData -from jetstream_pt.third_party.llama2 import model_exportable, model_args +from jetstream_pt.third_party.llama import model_exportable, model_args Mesh = jax.sharding.Mesh @@ -526,6 +526,14 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: # pylint: disable-next=all return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path) + def build_tokenizer( + self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all + ) -> tokenizer_api.Tokenizer: + if "llama-3" in self.env.model_type: + return token_utils.TikToken(metadata) + + return token_utils.SentencePieceTokenizer(metadata) + def join_prefixes( self, prefix1: engine_api.Prefix, @@ -652,13 +660,18 @@ def create_pytorch_engine( context_length: int = 1024, batch_size: int = 1, max_decode_length: int = 4096, - model_name="llama", + model_name="llama-2", quantize_weights=False, quantize_kv=False, max_cache_length=1024, ) -> PyTorchEngine: """Returns: The pytorch engine.""" + supported_models = ["llama-2", "llama-3"] + if model_name not in supported_models: + raise NotImplementedError( + f"Model name should be one of{','.join(supported_models)}" + ) # See issue b/309529778 if it's turned on. jax.config.update("jax_dynamic_shapes", False) # Pytorch exports has int64 constants. @@ -696,11 +709,7 @@ def create_pytorch_engine( if model_name.startswith("llama"): args = model_args.get_model_args( - param_size, - context_length, - batch_size, - tokenizer.vocab_size, - bf16_enable, + model_name + "-" + param_size, context_length, batch_size, bf16_enable ) args.device = "meta" args.quantize = quantize_weights diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 71ca873b..f289dd57 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -32,9 +32,9 @@ from torch.utils import _pytree as pytree import torch_xla2 -from jetstream.engine import engine_api, tokenizer_pb2, token_utils +from jetstream.engine import engine_api, tokenizer_pb2 -from jetstream_pt.third_party.llama2 import model_exportable, model_args +from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt import cache_manager from jetstream_pt import quantize @@ -99,7 +99,7 @@ def __init__( context_length: int = 1024, batch_size: int = 1, max_decode_length: int = 4096, - model_name="llama", + model_name="llama-2", quantize_weights=False, quantize_kv=False, max_cache_length=1024, @@ -159,14 +159,12 @@ def __init__( ) env = JetEngineEnvironment(env_data) - tokenizer = token_utils.load_vocab(tokenizer_path) pt_model = None - if model_name == "llama": + if "llama" in model_name: args = model_args.get_model_args( - param_size, + model_name + "-" + param_size, context_length, batch_size, - tokenizer.vocab_size, bf16_enable, ) args.device = "meta" diff --git a/jetstream_pt/third_party/llama2/LICENSE b/jetstream_pt/third_party/llama/LICENSE similarity index 100% rename from jetstream_pt/third_party/llama2/LICENSE rename to jetstream_pt/third_party/llama/LICENSE diff --git a/jetstream_pt/third_party/llama2/__init__.py b/jetstream_pt/third_party/llama/__init__.py similarity index 100% rename from jetstream_pt/third_party/llama2/__init__.py rename to jetstream_pt/third_party/llama/__init__.py diff --git a/jetstream_pt/third_party/llama2/generation_original.py b/jetstream_pt/third_party/llama/generation_original.py similarity index 99% rename from jetstream_pt/third_party/llama2/generation_original.py rename to jetstream_pt/third_party/llama/generation_original.py index 188d6b53..dd4339ee 100644 --- a/jetstream_pt/third_party/llama2/generation_original.py +++ b/jetstream_pt/third_party/llama/generation_original.py @@ -5,9 +5,9 @@ from typing import List, Literal, Optional, Tuple, TypedDict import torch -from jetstream_pt.third_party.llama2 import model_original +from jetstream_pt.third_party.llama import model_original from flax import struct -from jetstream_pt.third_party.llama2.tokenizer import Tokenizer +from jetstream_pt.third_party.llama.tokenizer import Tokenizer Role = Literal["system", "user", "assistant"] diff --git a/jetstream_pt/third_party/llama2/model_args.py b/jetstream_pt/third_party/llama/model_args.py similarity index 73% rename from jetstream_pt/third_party/llama2/model_args.py rename to jetstream_pt/third_party/llama/model_args.py index b4e51de9..b9143384 100755 --- a/jetstream_pt/third_party/llama2/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -34,44 +34,49 @@ class ModelArgs: device = "cpu" quantize = False + rope_theta: float = 10000.0 + def get_arg( - param_size: str, + model_name: str, seqlen, batch_size, - vocab_size: int, bf16_enable: bool = False, ) -> ModelArgs: """Gets model args.""" data = {} - if param_size == "tiny": + if model_name == "llama-2-tiny": data = { "dim": 128, + "vocab_size": 32000, "multiple_of": 32, "n_heads": 8, "n_layers": 3, "norm_eps": 1e-05, } - elif param_size == "7b": + elif model_name == "llama-2-7b": data = { "dim": 4096, + "vocab_size": 32000, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, } - elif param_size == "13b": + elif model_name == "llama-2-13b": data = { "dim": 5120, + "vocab_size": 32000, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, } - elif param_size == "70b": + elif model_name == "llama-2-70b": data = { "dim": 8192, + "vocab_size": 32000, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, @@ -79,23 +84,31 @@ def get_arg( "n_layers": 80, "norm_eps": 1e-05, } + elif model_name == "llama-3-8b": + data = { + "dim": 4096, + "vocab_size": 128256, + "multiple_of": 1024, + "ffn_dim_multiplier": 1.3, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + } return ModelArgs( max_seq_len=seqlen, max_batch_size=batch_size, - vocab_size=vocab_size, bf16_enable=bf16_enable, **data, ) -def get_model_args( - param_size, context_length, batch_size, vocab_size, bf16_enable -): +def get_model_args(model_name, context_length, batch_size, bf16_enable): model_args = get_arg( - param_size=param_size, + model_name=model_name, seqlen=context_length, batch_size=batch_size, - vocab_size=vocab_size, bf16_enable=bf16_enable, ) model_args.n_kv_heads = ( diff --git a/jetstream_pt/third_party/llama2/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py similarity index 97% rename from jetstream_pt/third_party/llama2/model_exportable.py rename to jetstream_pt/third_party/llama/model_exportable.py index c7909ab3..106e1f0b 100644 --- a/jetstream_pt/third_party/llama2/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -157,7 +157,9 @@ def __init__( ) # TODO what to do with this freqs_cis = precompute_freqs_cis( - self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + self.params.dim // self.params.n_heads, + self.params.max_seq_len * 2, + theta=self.params.rope_theta, ) self.register_buffer("freqs_cis", freqs_cis) diff --git a/jetstream_pt/third_party/llama2/model_original.py b/jetstream_pt/third_party/llama/model_original.py similarity index 100% rename from jetstream_pt/third_party/llama2/model_original.py rename to jetstream_pt/third_party/llama/model_original.py diff --git a/jetstream_pt/third_party/llama2/tokenizer.model b/jetstream_pt/third_party/llama/tokenizer.model similarity index 100% rename from jetstream_pt/third_party/llama2/tokenizer.model rename to jetstream_pt/third_party/llama/tokenizer.model diff --git a/jetstream_pt/third_party/llama2/tokenizer.py b/jetstream_pt/third_party/llama/tokenizer.py similarity index 100% rename from jetstream_pt/third_party/llama2/tokenizer.py rename to jetstream_pt/third_party/llama/tokenizer.py diff --git a/run_interactive.py b/run_interactive.py index df4aab89..f338beb0 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -24,6 +24,11 @@ import jax from jetstream.engine import token_utils +from colorama import Fore, Style +import numpy as np + +import os + from jetstream_pt import engine as je FLAGS = flags.FLAGS @@ -64,6 +69,11 @@ _MAX_CACHE_LENGTH = flags.DEFINE_integer( "max_cache_length", 1024, "kv_cache_quantize" ) +_MODEL_NAME = flags.DEFINE_string( + "model", + "llama-2", + "name of the model. Supported options are llama-2 and llama-3", +) def create_engine(): @@ -81,6 +91,7 @@ def create_engine(): param_size=_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, + model_name=_MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, @@ -100,8 +111,7 @@ def main(argv): print("Load params ", time.perf_counter() - start) metadata = engine.get_tokenizer() - vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) - stop_tokens = [vocab.eos_id, vocab.pad_id] + tokenizer = engine.build_tokenizer(metadata) max_output_length = 1024 if _PROFILING_OUTPUT.value: @@ -121,9 +131,8 @@ def main(argv): ] for prompt in prompts: slot = random.randint(0, _BATCH_SIZE.value - 1) - tokens, true_length = token_utils.tokenize_and_pad( - prompt, vocab, is_bos=True - ) + tokens, true_length = tokenizer.encode(prompt, is_bos=True) + print(f"---- Input prompts are: {prompt}") print(f"---- Encoded tokens are: {tokens}") @@ -135,21 +144,19 @@ def main(argv): decode_state = engine.insert(prefill_result, decode_state, slot=slot) sampled_tokens_list = [] print(f"---- Streaming decode started on #slot{slot}.") + complete = np.zeros((1,), dtype=np.bool_) while True: - # pylint: disable-next=all decode_state, result_tokens = engine.generate(params, decode_state) - - slot_data = result_tokens.get_result_at_slot(slot) - slot_tokens = slot_data.tokens - slot_lengths = slot_data.lengths - - token_id = slot_tokens[slot, 0].item() - if slot_lengths > max_output_length or token_id in stop_tokens: + result_tokens = result_tokens.convert_to_numpy() + output, complete = tokenizer.decode( + slot, max_output_length, result_tokens, complete + ) + if complete[0]: break - + token_id = output[0][0] sampled_tokens_list.append(token_id) - # output = token_utils.mix_decode(vocab, token_id) - # print(Fore.GREEN + output, end="", flush=True) + # output_str = tokenizer.decode_str([token_id]) + # print(Fore.GREEN + output_str, end="", flush=True) # print(Style.RESET_ALL + "\n") # print("---- Streaming decode finished.") @@ -157,7 +164,7 @@ def main(argv): print("---- All output tokens.") print(sampled_tokens_list) print("---- All output text.") - print(vocab.tokenizer.decode(sampled_tokens_list)) + print(tokenizer.decode_str(sampled_tokens_list)) if _PROFILING_OUTPUT.value: jax.profiler.stop_trace() diff --git a/run_server.py b/run_server.py index 39d8447c..77d7f173 100644 --- a/run_server.py +++ b/run_server.py @@ -71,6 +71,11 @@ "The model size the server runs on.", required=False, ) +_MODEL_NAME = flags.DEFINE_string( + "model", + "llama-2", + "name of the model. Supported options are llama-2 and llama-3", +) _QUANTIZE_WEIGHTS = flags.DEFINE_bool( "quantize_weights", False, "weight quantization" @@ -98,6 +103,7 @@ def main(argv: Sequence[str]): param_size=_PARAM_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, + model_name=_MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, diff --git a/tests/helpers.py b/tests/helpers.py index 764bc4b0..dd0c7c50 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,6 +1,6 @@ import torch import jax -from jetstream_pt.third_party.llama2 import model_args +from jetstream_pt.third_party.llama import model_args from jetstream_pt import environment @@ -9,7 +9,7 @@ def make_env_tiny(bf16_enable=True): torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) jax.config.update("jax_traceback_filtering", "off") - config = model_args.get_model_args("tiny", 128, 1, 32000, True) + config = model_args.get_model_args("llama-2-tiny", 128, 1, True) environment_data = environment.JetEngineEnvironmentData() environment_data.max_input_sequence_length = 128 environment_data.max_input_sequence_length = 128 diff --git a/tests/test_engine.py b/tests/test_engine.py index d6fec57e..a24e0d8e 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -21,7 +21,7 @@ from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData from jetstream_pt.engine import PyTorchEngine, Prefix, DecodeState -from jetstream_pt.third_party.llama2 import model_exportable, model_original +from jetstream_pt.third_party.llama import model_exportable, model_original # This model will output tokens with value of 2 # and will update caches with value of 1.0 diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 31a8c36f..10e698d6 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -26,8 +26,8 @@ from jetstream_pt.engine import PyTorchEngine -from jetstream_pt.third_party.llama2 import model_exportable, model_args -from jetstream_pt.third_party.llama2.generation_original import LlamaOriginal +from jetstream_pt.third_party.llama import model_exportable, model_args +from jetstream_pt.third_party.llama.generation_original import LlamaOriginal from jetstream_pt import environment @@ -74,10 +74,10 @@ def test_original_llama2_seed(self): tokens = np.arange(10, dtype=np.int32) file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) output_tokens_multiple = [] - model_arg = model_args.get_model_args("tiny", 128, 1, 32000, True) + model_arg = model_args.get_model_args("llama-2-tiny", 128, 1, True) for i in [1, 999, 99999]: llama_original = LlamaOriginal.build(tokenizer_path, model_arg, i) prompt_tokens = [tokens] @@ -111,7 +111,7 @@ def test_jetstream_llama2_seed(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) seed = 1 @@ -175,7 +175,7 @@ def _llama_e2e(self, env, model_arg): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal @@ -260,7 +260,7 @@ def test_llama_e2e_two_addtional_tokens(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal @@ -333,7 +333,7 @@ def test_llama_e2e_four_addtional_tokens(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal @@ -402,7 +402,7 @@ def test_llama_with_original_prefill_decode_32(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal @@ -478,7 +478,7 @@ def test_llama_with_original_prefill_decode(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 23b9d14d..e8c0f375 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -20,8 +20,8 @@ import torch_xla2 from . import helpers -from jetstream_pt.third_party.llama2 import model_exportable -from jetstream_pt.third_party.llama2 import model_original +from jetstream_pt.third_party.llama import model_exportable +from jetstream_pt.third_party.llama import model_original from jetstream_pt import layers from jetstream_pt import environment from jetstream_pt import cache_manager