From fb617dde46ba31ea52e0bdd1af00d86e7d5a93fc Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Wed, 1 May 2024 23:12:32 +0000 Subject: [PATCH 01/11] Support llama3 --- jetstream_pt/engine.py | 48 +++++++++--------- jetstream_pt/environment.py | 5 +- .../third_party/{llama2 => llama}/LICENSE | 0 .../third_party/{llama2 => llama}/__init__.py | 0 .../{llama2 => llama}/generation_original.py | 4 +- .../{llama2 => llama}/model_args.py | 35 +++++++++---- .../{llama2 => llama}/model_exportable.py | 4 +- .../{llama2 => llama}/model_original.py | 0 .../{llama2 => llama}/tokenizer.model | Bin .../{llama2 => llama}/tokenizer.py | 0 run_interactive.py | 32 +++++------- run_server.py | 2 + tests/test_engine.py | 2 +- tests/test_llama_e2e.py | 4 +- tests/test_model_impl.py | 6 +-- 15 files changed, 78 insertions(+), 64 deletions(-) rename jetstream_pt/third_party/{llama2 => llama}/LICENSE (100%) rename jetstream_pt/third_party/{llama2 => llama}/__init__.py (100%) rename jetstream_pt/third_party/{llama2 => llama}/generation_original.py (99%) rename jetstream_pt/third_party/{llama2 => llama}/model_args.py (73%) rename jetstream_pt/third_party/{llama2 => llama}/model_exportable.py (97%) rename jetstream_pt/third_party/{llama2 => llama}/model_original.py (100%) rename jetstream_pt/third_party/{llama2 => llama}/tokenizer.model (100%) rename jetstream_pt/third_party/{llama2 => llama}/tokenizer.py (100%) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 8c7881da..9c2daee9 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -24,9 +24,9 @@ import torch import numpy as np -from jetstream.engine import engine_api, tokenizer_pb2, token_utils +from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils import torch_xla2 -from jetstream_pt.third_party.llama2 import model_exportable, model_args +from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt import cache_manager from jetstream_pt import quantize @@ -34,8 +34,6 @@ from torch.utils import _pytree as pytree - - Mesh = jax.sharding.Mesh P = jax.sharding.PartitionSpec @@ -476,6 +474,12 @@ def generate( def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path) + def build_tokenizer(self, meta: tokenizer_pb2.TokenizerParameters) -> tokenizer_api.Tokenizer: + if 'llama-3' in self.env.model_type: + return token_utils.TikToken(meta) + else: + return token_utils.SentencePieceTokenizer(meta) + def join_prefixes( self, prefix1: engine_api.Prefix, @@ -592,13 +596,16 @@ def create_pytorch_engine( context_length: int = 1024, batch_size: int = 1, max_decode_length: int = 4096, - model_name = "llama", + model_name = "llama-2", quantize_weights = False, quantize_kv = False, max_cache_length = 1024, ) -> PyTorchEngine: """Returns: The pytorch engine.""" + supported_models = ['llama-2', 'llama-3'] + if model_name not in supported_models: + raise NotImplementedError('Model name should be one of {}'.format(','.join(supported_models))) # See issue b/309529778 if it's turned on. jax.config.update('jax_dynamic_shapes', False) # Pytorch exports has int64 constants. @@ -630,7 +637,7 @@ def create_pytorch_engine( tokenizer_path=tokenizer_path, checkpoint_path = checkpoint_path, checkpoint_format = checkpoint_format, - model_type = 'llama-2-' + param_size, + model_type = model_name + '-' + param_size, batch_size = batch_size, max_decode_length = max_decode_length, max_input_sequence_length = context_length, @@ -640,23 +647,18 @@ def create_pytorch_engine( bf16_enable = bf16_enable, ) env = JetEngineEnvironment(env_data) - - tokenizer = token_utils.load_vocab(tokenizer_path) - pt_model = None - shard_weights_fn = None - if model_name == "llama": - args = model_args.get_model_args(param_size, context_length, batch_size, tokenizer.vocab_size, bf16_enable) - args.device = 'meta' - args.quantize = quantize_weights - pt_model = model_exportable.Transformer(args, env) - - num_params_size = 0 - num_params = 0 - for k, v in pt_model.state_dict().items(): - num_params += 1 - num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) - print('Number of param Gbytes:', num_params_size / (1 << 30)) - print('Number of param: ', num_params) + args = model_args.get_model_args(model_name + '-' + param_size, context_length, batch_size, bf16_enable) + args.device = 'meta' + args.quantize = quantize_weights + pt_model = model_exportable.Transformer(args, env) + + num_params_size = 0 + num_params = 0 + for k, v in pt_model.state_dict().items(): + num_params += 1 + num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) + print('Number of param Gbytes:', num_params_size / (1 << 30)) + print('Number of param: ', num_params) return PyTorchEngine( pt_model=pt_model, diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index e24c15a9..3a2681df 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -19,7 +19,7 @@ import dataclasses from typing import Tuple, Dict -from jetstream_pt.third_party.llama2 import model_args +from jetstream_pt.third_party.llama import model_args from jetstream_pt import cache_manager import torch_xla2 @@ -65,10 +65,9 @@ def __init__(self, data: JetEngineEnvironmentData): self._data = data # Get 13b self._model_arg = model_args.get_model_args( - data.model_type.replace('llama-2-', ''), + data.model_type, context_length=data.max_input_sequence_length, batch_size=data.batch_size, - vocab_size=32000, # ? bf16_enable=data.bf16_enable, ) diff --git a/jetstream_pt/third_party/llama2/LICENSE b/jetstream_pt/third_party/llama/LICENSE similarity index 100% rename from jetstream_pt/third_party/llama2/LICENSE rename to jetstream_pt/third_party/llama/LICENSE diff --git a/jetstream_pt/third_party/llama2/__init__.py b/jetstream_pt/third_party/llama/__init__.py similarity index 100% rename from jetstream_pt/third_party/llama2/__init__.py rename to jetstream_pt/third_party/llama/__init__.py diff --git a/jetstream_pt/third_party/llama2/generation_original.py b/jetstream_pt/third_party/llama/generation_original.py similarity index 99% rename from jetstream_pt/third_party/llama2/generation_original.py rename to jetstream_pt/third_party/llama/generation_original.py index a78fa482..84a2a7ec 100644 --- a/jetstream_pt/third_party/llama2/generation_original.py +++ b/jetstream_pt/third_party/llama/generation_original.py @@ -5,8 +5,8 @@ from typing import List, Literal, Optional, Tuple, TypedDict import torch -from jetstream_pt.third_party.llama2 import model_original -from jetstream_pt.third_party.llama2.tokenizer import Tokenizer +from jetstream_pt.third_party.llama import model_original +from jetstream_pt.third_party.llama.tokenizer import Tokenizer Role = Literal["system", "user", "assistant"] diff --git a/jetstream_pt/third_party/llama2/model_args.py b/jetstream_pt/third_party/llama/model_args.py similarity index 73% rename from jetstream_pt/third_party/llama2/model_args.py rename to jetstream_pt/third_party/llama/model_args.py index 8d8baf04..500cd4b5 100755 --- a/jetstream_pt/third_party/llama2/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -34,44 +34,49 @@ class ModelArgs: device = 'cpu' quantize = False + rope_theta: float = 10000.0 + def get_arg( - param_size: str, + model_name: str, seqlen, batch_size, - vocab_size: int, bf16_enable: bool = False, ) -> ModelArgs: """Gets model args.""" data = {} - if param_size == "tiny": + if model_name == "llama-2-tiny": data = { "dim": 128, + "vocab_size": 32000, "multiple_of": 32, "n_heads": 8, "n_layers": 3, "norm_eps": 1e-05, } - elif param_size == "7b": + elif model_name == "llama-2-7b": data = { "dim": 4096, + "vocab_size": 32000, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, } - elif param_size == "13b": + elif model_name == "llama-2-13b": data = { "dim": 5120, + "vocab_size": 32000, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, } - elif param_size == "70b": + elif model_name == "llama-2-70b": data = { "dim": 8192, + "vocab_size": 32000, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, @@ -79,20 +84,30 @@ def get_arg( "n_layers": 80, "norm_eps": 1e-05, } + elif model_name == 'llama-3-8b': + data = { + "dim": 4096, + "vocab_size": 128256, + "multiple_of": 1024, + "ffn_dim_multiplier": 1.3, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + } return ModelArgs( max_seq_len=seqlen, max_batch_size=batch_size, - vocab_size=vocab_size, bf16_enable=bf16_enable, **data, ) -def get_model_args(param_size, context_length, batch_size, vocab_size, bf16_enable): +def get_model_args(model_name, context_length, batch_size, bf16_enable): model_args = get_arg( - param_size=param_size, + model_name=model_name, seqlen=context_length, batch_size=batch_size, - vocab_size=vocab_size, bf16_enable=bf16_enable, ) model_args.n_kv_heads = ( diff --git a/jetstream_pt/third_party/llama2/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py similarity index 97% rename from jetstream_pt/third_party/llama2/model_exportable.py rename to jetstream_pt/third_party/llama/model_exportable.py index c56501de..d6f9f5dd 100644 --- a/jetstream_pt/third_party/llama2/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -166,7 +166,9 @@ def __init__( ) # TODO what to do with this freqs_cis = precompute_freqs_cis( - self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + self.params.dim // self.params.n_heads, + self.params.max_seq_len * 2, + theta = self.params.rope_theta, ) self.register_buffer("freqs_cis", freqs_cis) diff --git a/jetstream_pt/third_party/llama2/model_original.py b/jetstream_pt/third_party/llama/model_original.py similarity index 100% rename from jetstream_pt/third_party/llama2/model_original.py rename to jetstream_pt/third_party/llama/model_original.py diff --git a/jetstream_pt/third_party/llama2/tokenizer.model b/jetstream_pt/third_party/llama/tokenizer.model similarity index 100% rename from jetstream_pt/third_party/llama2/tokenizer.model rename to jetstream_pt/third_party/llama/tokenizer.model diff --git a/jetstream_pt/third_party/llama2/tokenizer.py b/jetstream_pt/third_party/llama/tokenizer.py similarity index 100% rename from jetstream_pt/third_party/llama2/tokenizer.py rename to jetstream_pt/third_party/llama/tokenizer.py diff --git a/run_interactive.py b/run_interactive.py index 6ec9bd42..05d2a3a0 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -21,15 +21,14 @@ from jetstream.engine import token_utils from colorama import Fore, Style - +import numpy as np import os from jetstream_pt import engine as je import time - -logging.getLogger().setLevel(logging.ERROR) +# logging.getLogger().setLevel(logging.ERROR) FLAGS = flags.FLAGS @@ -64,7 +63,7 @@ _QUANTIZE_WEIGHTS = flags.DEFINE_bool('quantize_weights', False, 'weight quantization') _QUANTIZE_KV_CACHE = flags.DEFINE_bool('quantize_kv_cache', False, 'kv_cache_quantize') _MAX_CACHE_LENGTH = flags.DEFINE_integer('max_cache_length', 1024, 'kv_cache_quantize') - +_MODEL_NAME = flags.DEFINE_string('model', 'llama-2', 'name of the model. Supported options are llama-2 and llama-3') def create_engine(): jax.config.update('jax_default_prng_impl', 'unsafe_rbg') @@ -82,6 +81,7 @@ def create_engine(): param_size=_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, + model_name = _MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length = _MAX_CACHE_LENGTH.value, @@ -100,9 +100,7 @@ def main(argv): print('Load params ', time.perf_counter() - start) metadata = engine.get_tokenizer() - vocab = token_utils.load_vocab( - metadata.path, metadata.extra_ids) - stop_tokens = [vocab.eos_id, vocab.pad_id] + tokenizer = engine.build_tokenizer(metadata) max_output_length = 1024 if _PROFILING_OUTPUT.value: @@ -118,7 +116,7 @@ def main(argv): ] for prompt in prompts: slot = random.randint(0, _BATCH_SIZE.value) - tokens, true_length = token_utils.tokenize_and_pad(prompt, vocab, is_bos=True) + tokens, true_length = tokenizer.encode(prompt, is_bos=True) print(f"---- Input prompts are: {prompt}") print(f"---- Encoded tokens are: {tokens}") @@ -130,22 +128,18 @@ def main(argv): ) sampled_tokens_list = [] print(f"---- Streaming decode started on #slot{slot}.") + complete = np.zeros((1,), dtype=np.bool_) while True: decode_state, result_tokens = engine.generate( params, decode_state ) - - slot_data = result_tokens.get_result_at_slot(slot) - slot_tokens = slot_data.tokens - slot_lengths = slot_data.lengths - - token_id = slot_tokens[slot, 0].item() - if slot_lengths > max_output_length or token_id in stop_tokens: + result_tokens = result_tokens.convert_to_numpy() + output, complete = tokenizer.decode(slot, max_output_length, result_tokens, complete) + if complete[0]: break - - sampled_tokens_list.append(token_id) - output = token_utils.mix_decode(vocab, token_id) - print(Fore.GREEN + output, end="", flush=True) + token_id = output[0][0] + output_str = tokenizer.decode_str([token_id]) + print(Fore.GREEN + output_str, end="", flush=True) print(Style.RESET_ALL + "\n") print("---- Streaming decode finished.") diff --git a/run_server.py b/run_server.py index ff285e3a..fb4a3db7 100644 --- a/run_server.py +++ b/run_server.py @@ -70,6 +70,7 @@ 'The model size the server runs on.', required=False, ) +_MODEL_NAME = flags.DEFINE_string('model', 'llama-2', 'name of the model. Supported options are llama-2 and llama-3') _QUANTIZE_WEIGHTS = flags.DEFINE_bool('quantize_weights', False, 'weight quantization') _QUANTIZE_KV_CACHE = flags.DEFINE_bool('quantize_kv_cache', False, 'kv_cache_quantize') @@ -90,6 +91,7 @@ def main(argv: Sequence[str]): param_size=_PARAM_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, + model_name = _MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length = _MAX_CACHE_LENGTH.value, diff --git a/tests/test_engine.py b/tests/test_engine.py index 0666381a..be877c47 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -20,7 +20,7 @@ from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData from jetstream_pt.engine import PyTorchEngine, Prefix, DecodeState -from jetstream_pt.third_party.llama2 import model_exportable, model_original +from jetstream_pt.third_party.llama import model_exportable, model_original # This model will output tokens with value of 2 # and will update caches with value of 1.0 diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 1ed56df2..95d8db0e 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -20,8 +20,8 @@ import jax.numpy as jnp import numpy as np from jetstream_pt.engine import PyTorchEngine -from jetstream_pt.third_party.llama2 import model_exportable -from jetstream_pt.third_party.llama2.generation_original import LlamaOriginal +from jetstream_pt.third_party.llama import model_exportable +from jetstream_pt.third_party.llama.generation_original import LlamaOriginal from jetstream_pt import environment diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 2937d56c..46f21afc 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jetstream_pt.third_party.llama2 import model_exportable -from jetstream_pt.third_party.llama2 import model_original +from jetstream_pt.third_party.llama import model_exportable +from jetstream_pt.third_party.llama import model_original from jetstream_pt import layers -from jetstream_pt.third_party.llama2 import model_args +from jetstream_pt.third_party.llama import model_args from jetstream_pt import environment from jetstream_pt import cache_manager import torch From 78c95e2c8b5af77aa8a20dd6baed2e114d98865f Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 2 May 2024 00:35:54 +0000 Subject: [PATCH 02/11] Sync with main branch --- jetstream_pt/environment.py | 17 ++++++++--------- run_interactive.py | 9 +++------ 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 23ea7898..a52562c0 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -71,15 +71,14 @@ class JetEngineEnvironmentData: # pylint: disable-next=all class JetEngineEnvironment: - def __init__(self, data: JetEngineEnvironmentData): - self._data = data - # Get 13b - self._model_arg = model_args.get_model_args( - data.model_type, - context_length=data.max_input_sequence_length, - batch_size=data.batch_size, - bf16_enable=data.bf16_enable, - ) + def __init__(self, data: JetEngineEnvironmentData): + self._data = data + self._model_arg = model_args.get_model_args( + data.model_type, + context_length=data.max_input_sequence_length, + batch_size=data.batch_size, + bf16_enable=data.bf16_enable, + ) self.batch_size = self._data.batch_size self.seq_len = self._data.max_input_sequence_length diff --git a/run_interactive.py b/run_interactive.py index 3ca08968..db214450 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -148,12 +148,9 @@ def main(argv): if complete[0]: break token_id = output[0][0] - output_str = tokenizer.decode_str([token_id]) - print(Fore.GREEN + output_str, end="", flush=True) - sampled_tokens_list.append(token_id) - # output = token_utils.mix_decode(vocab, token_id) - # print(Fore.GREEN + output, end="", flush=True) + # output_str = tokenizer.decode_str([token_id]) + # print(Fore.GREEN + output_str, end="", flush=True) # print(Style.RESET_ALL + "\n") # print("---- Streaming decode finished.") @@ -161,7 +158,7 @@ def main(argv): print("---- All output tokens.") print(sampled_tokens_list) print("---- All output text.") - print(vocab.tokenizer.decode(sampled_tokens_list)) + print(tokenizer.decode_str(sampled_tokens_list)) if _PROFILING_OUTPUT.value: jax.profiler.stop_trace() From 7a32006f9c3ebf525cf7175cbf48bb14dab11551 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 2 May 2024 04:36:10 +0000 Subject: [PATCH 03/11] Fix CI --- benchmarks/run_offline.py | 7 +- jetstream_pt/engine.py | 60 +- jetstream_pt/ray_worker.py | 2 +- .../third_party/llama/generation_original.py | 672 ++++++++++-------- tests/test_llama_e2e.py | 14 +- 5 files changed, 409 insertions(+), 346 deletions(-) diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 8d97df14..97d6fb2a 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -21,7 +21,6 @@ import jax import jax.numpy as jnp -from jetstream.engine import token_utils from jetstream_pt import engine as je # pylint: disable-next=all from benchmarks import analyze_sharegpt @@ -97,12 +96,10 @@ def create_engine(): def run_prefill_time(engine, params, decode_state, seqlen): """Run prefill and measure time.""" metadata = engine.get_tokenizer() - vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) + tokenizer = engine.build_tokenizer(metadata) text = "This is a beautiful day" - tokens, true_length = token_utils.tokenize_and_pad( - text, vocab, is_bos=True, prefill_lengths=[seqlen] - ) + tokens, true_length = tokenizer.encode(text, is_bos=True, prefill_lengths=[]) for _ in range(3): prefill_result = engine.prefill( diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 39445fd8..740cc72e 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -529,11 +529,13 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: # pylint: disable-next=all return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path) - def build_tokenizer(self, meta: tokenizer_pb2.TokenizerParameters) -> tokenizer_api.Tokenizer: - if 'llama-3' in self.env.model_type: - return token_utils.TikToken(meta) - else: - return token_utils.SentencePieceTokenizer(meta) + def build_tokenizer( + self, metadata: tokenizer_pb2.TokenizerParameters + ) -> tokenizer_api.Tokenizer: + if "llama-3" in self.env.model_type: + return token_utils.TikToken(metadata) + + return token_utils.SentencePieceTokenizer(metadata) def join_prefixes( self, @@ -661,16 +663,18 @@ def create_pytorch_engine( context_length: int = 1024, batch_size: int = 1, max_decode_length: int = 4096, - model_name = "llama-2", - quantize_weights = False, - quantize_kv = False, - max_cache_length = 1024, + model_name="llama-2", + quantize_weights=False, + quantize_kv=False, + max_cache_length=1024, ) -> PyTorchEngine: """Returns: The pytorch engine.""" - supported_models = ['llama-2', 'llama-3'] + supported_models = ["llama-2", "llama-3"] if model_name not in supported_models: - raise NotImplementedError('Model name should be one of {}'.format(','.join(supported_models))) + raise NotImplementedError( + f"Model name should be one of{','.join(supported_models)}" + ) # See issue b/309529778 if it's turned on. jax.config.update("jax_dynamic_shapes", False) # Pytorch exports has int64 constants. @@ -703,30 +707,32 @@ def create_pytorch_engine( checkpoint_path = paths[0] env_data = JetEngineEnvironmentData( - tokenizer_path=tokenizer_path, - checkpoint_path = checkpoint_path, - checkpoint_format = checkpoint_format, - model_type = model_name + '-' + param_size, - batch_size = batch_size, - max_decode_length = max_decode_length, - max_input_sequence_length = context_length, - enable_weight_quantization = quantize_weights, - enable_kv_quantization = quantize_kv, - cache_sequence_length = max_cache_length, - bf16_enable = bf16_enable, + tokenizer_path=tokenizer_path, + checkpoint_path=checkpoint_path, + checkpoint_format=checkpoint_format, + model_type=model_name + "-" + param_size, + batch_size=batch_size, + max_decode_length=max_decode_length, + max_input_sequence_length=context_length, + enable_weight_quantization=quantize_weights, + enable_kv_quantization=quantize_kv, + cache_sequence_length=max_cache_length, + bf16_enable=bf16_enable, ) env = JetEngineEnvironment(env_data) - args = model_args.get_model_args(model_name + '-' + param_size, context_length, batch_size, bf16_enable) - args.device = 'meta' + args = model_args.get_model_args( + model_name + "-" + param_size, context_length, batch_size, bf16_enable + ) + args.device = "meta" args.quantize = quantize_weights pt_model = model_exportable.Transformer(args, env) num_params_size = 0 num_params = 0 - for k, v in pt_model.state_dict().items(): + for _, v in pt_model.state_dict().items(): num_params += 1 num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) - print('Number of param Gbytes:', num_params_size / (1 << 30)) - print('Number of param: ', num_params) + print("Number of param Gbytes:", num_params_size / (1 << 30)) + print("Number of param: ", num_params) return PyTorchEngine(pt_model=pt_model, env=env) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index edd782e4..c5fe0186 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -34,7 +34,7 @@ from jetstream.engine import engine_api, tokenizer_pb2, token_utils -from jetstream_pt.third_party.llama2 import model_exportable, model_args +from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt import cache_manager from jetstream_pt import quantize diff --git a/jetstream_pt/third_party/llama/generation_original.py b/jetstream_pt/third_party/llama/generation_original.py index e36db4a6..5d4e2810 100644 --- a/jetstream_pt/third_party/llama/generation_original.py +++ b/jetstream_pt/third_party/llama/generation_original.py @@ -1,4 +1,3 @@ - # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. @@ -6,27 +5,28 @@ from typing import List, Literal, Optional, Tuple, TypedDict import torch -from jetstream_pt.third_party.llama import model_original -from jetstream_pt.third_party.llama.tokenizer import Tokenizer +from jetstream_pt.third_party.llama2 import model_original +from flax import struct +from jetstream_pt.third_party.llama2.tokenizer import Tokenizer Role = Literal["system", "user", "assistant"] class Message(TypedDict): - role: Role - content: str + role: Role + content: str class CompletionPrediction(TypedDict, total=False): - generation: str - tokens: List[str] # not required - logprobs: List[float] # not required + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required class ChatPrediction(TypedDict, total=False): - generation: Message - tokens: List[str] # not required - logprobs: List[float] # not required + generation: Message + tokens: List[str] # not required + logprobs: List[float] # not required Dialog = List[Message] @@ -38,301 +38,361 @@ class ChatPrediction(TypedDict, total=False): UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." +@struct.dataclass +class DecodeStateOriginal: + prev_pos: int + cur_pos: int + tokens: torch.tensor + out_tokens: List[List[int]] + logits: torch.tensor + input_text_mask: torch.tensor + prompt_tokens: List[List[int]] + + class LlamaOriginal: - @staticmethod - def build( - tokenizer_path: str, - model_args: model_original.ModelArgs, - seed: int = 1, - ) -> "LlamaOriginal": - - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - - - # seed must be the same in all processes - torch.manual_seed(seed) - - tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words - model = model_original.Transformer(model_args) - - return LlamaOriginal(model, tokenizer) - - def __init__(self, model: model_original.Transformer, tokenizer: Tokenizer): - self.model = model - self.tokenizer = tokenizer - - - @torch.inference_mode() - def prefill( - self, - prompt_tokens: List[List[int]], - max_gen_len: int, - ) -> List[List[int]]: - """ - Do greedy search on CPU and return tokens only. - """ - - params = self.model.params - bsz = len(prompt_tokens) - assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) - - min_prompt_len = min(len(t) for t in prompt_tokens) - max_prompt_len = max(len(t) for t in prompt_tokens) - assert max_prompt_len <= params.max_seq_len - total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) - - pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cpu") - for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu") - - - prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cpu") - input_text_mask = tokens != pad_id - - cur_pos = min_prompt_len - logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - next_token = torch.argmax(logits[:, -1], dim=-1) - - next_token = next_token.reshape(-1) - # only replace token if prompt has already been generated - next_token = torch.where( - input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token - ) - tokens[:, cur_pos] = next_token - eos_reached |= (~input_text_mask[:, cur_pos]) & ( - next_token == self.tokenizer.eos_id - ) - prev_pos = cur_pos - - out_tokens, out_logprobs = [], [] - for i, toks in enumerate(tokens.tolist()): - # cut to max gen len - start = len(prompt_tokens[i]) - toks = toks[start : start + 1] - probs = None - # cut to eos tok if any - if self.tokenizer.eos_id in toks: - eos_idx = toks.index(self.tokenizer.eos_id) - toks = toks[:eos_idx] - out_tokens.append(toks) - out_logprobs.append(probs) - return out_tokens - - @torch.inference_mode() - def generate( - self, - prompt_tokens: List[List[int]], - max_gen_len: int, - ) -> List[List[int]]: - """ - Do greedy search on CPU and return tokens only. - """ - - params = self.model.params - bsz = len(prompt_tokens) - assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) - - min_prompt_len = min(len(t) for t in prompt_tokens) - max_prompt_len = max(len(t) for t in prompt_tokens) - assert max_prompt_len <= params.max_seq_len - total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) - - pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cpu") - for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu") - - - prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cpu") - input_text_mask = tokens != pad_id - - for cur_pos in range(min_prompt_len, total_len): - logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - next_token = torch.argmax(logits[:, -1], dim=-1) - - next_token = next_token.reshape(-1) - # only replace token if prompt has already been generated - next_token = torch.where( - input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token - ) - tokens[:, cur_pos] = next_token - eos_reached |= (~input_text_mask[:, cur_pos]) & ( - next_token == self.tokenizer.eos_id - ) - prev_pos = cur_pos - if all(eos_reached): - break - - out_tokens, out_logprobs = [], [] - for i, toks in enumerate(tokens.tolist()): - # cut to max gen len - start = len(prompt_tokens[i]) - toks = toks[start : len(prompt_tokens[i]) + max_gen_len] - probs = None - # cut to eos tok if any - if self.tokenizer.eos_id in toks: - eos_idx = toks.index(self.tokenizer.eos_id) - toks = toks[:eos_idx] - out_tokens.append(toks) - out_logprobs.append(probs) - return out_tokens - - def text_completion( - self, - prompts: List[str], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - echo: bool = False, - ) -> List[CompletionPrediction]: - """ - Perform text completion for a list of prompts using the language generation model. - - Args: - prompts (List[str]): List of text prompts for completion. - temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. - top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. - max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. - If not provided, it's set to the model's maximum sequence length minus 1. - logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. - echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. - - Returns: - List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. - - Note: - This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. - If logprobs is True, token log probabilities are computed for each generated token. - - """ - if max_gen_len is None: - max_gen_len = self.model.params.max_seq_len - 1 - prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - generation_tokens, generation_logprobs = self.generate( - prompt_tokens=prompt_tokens, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - echo=echo, - ) - return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] - - def chat_completion( - self, - dialogs: List[Dialog], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - ) -> List[ChatPrediction]: - """ - Generate assistant responses for a list of conversational dialogs using the language generation model. - - Args: - dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. - temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. - top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. - max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. - If not provided, it's set to the model's maximum sequence length minus 1. - logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. - - Returns: - List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. - - Raises: - AssertionError: If the last message in a dialog is not from the user. - AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order. - - Note: - This method generates assistant responses for the provided conversational dialogs. - It employs nucleus sampling to introduce controlled randomness in text generation. - If logprobs is True, token log probabilities are computed for each generated token. - - """ - if max_gen_len is None: - max_gen_len = self.model.params.max_seq_len - 1 - prompt_tokens = [] - unsafe_requests = [] - for dialog in dialogs: - unsafe_requests.append( - any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) - ) - if dialog[0]["role"] == "system": - dialog = [ - { - "role": dialog[1]["role"], - "content": B_SYS - + dialog[0]["content"] - + E_SYS - + dialog[1]["content"], - } - ] + dialog[2:] - assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( - [msg["role"] == "assistant" for msg in dialog[1::2]] - ), ( - "model only supports 'system', 'user' and 'assistant' roles, " - "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" - ) - dialog_tokens: List[int] = sum( - [ - self.tokenizer.encode( - f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", - bos=True, - eos=True, - ) - for prompt, answer in zip( - dialog[::2], - dialog[1::2], - ) - ], - [], - ) - assert ( - dialog[-1]["role"] == "user" - ), f"Last message must be from user, got {dialog[-1]['role']}" - dialog_tokens += self.tokenizer.encode( - f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", - bos=True, - eos=False, - ) - prompt_tokens.append(dialog_tokens) - - generation_tokens, generation_logprobs = self.generate( - prompt_tokens=prompt_tokens, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - ) - if logprobs: - return [ - { - "generation": { - "role": "assistant", - "content": self.tokenizer.decode(t) - if not unsafe - else UNSAFE_ERROR, - }, - "tokens": [self.tokenizer.decode(x) for x in t], - "logprobs": logprobs_i, - } - for t, logprobs_i, unsafe in zip( - generation_tokens, generation_logprobs, unsafe_requests - ) - ] - return [ + + @staticmethod + def build( + tokenizer_path: str, + model_args: model_original.ModelArgs, + seed: int = 1, + ) -> "LlamaOriginal": + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + + # seed must be the same in all processes + torch.manual_seed(seed) + + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + model = model_original.Transformer(model_args) + + return LlamaOriginal(model, tokenizer) + + def __init__(self, model: model_original.Transformer, tokenizer: Tokenizer): + self.model = model + self.tokenizer = tokenizer + + @torch.inference_mode() + def prefill( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + ) -> DecodeStateOriginal: + """ + Do greedy search on CPU and return tokens only. + """ + + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full( + (bsz, total_len), pad_id, dtype=torch.long, device="cpu" + ) + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu") + + prev_pos = 0 + input_text_mask = tokens != pad_id + + cur_pos = min_prompt_len + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + + prev_pos = cur_pos + + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = len(prompt_tokens[i]) + toks = toks[start : start + 1] + probs = None + out_tokens.append(toks) + out_logprobs.append(probs) + state = DecodeStateOriginal( + prev_pos=cur_pos, + cur_pos=cur_pos + 1, + tokens=tokens, + out_tokens=out_tokens, + logits=logits, + input_text_mask=input_text_mask, + prompt_tokens=prompt_tokens, + ) + + return state + + @torch.inference_mode() + def decode( + self, + decode_state: DecodeStateOriginal, + ) -> DecodeStateOriginal: + + prev_pos = decode_state.prev_pos + cur_pos = decode_state.cur_pos + tokens = decode_state.tokens + input_text_mask = decode_state.input_text_mask + prompt_tokens = decode_state.prompt_tokens + + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + prev_pos = cur_pos + + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + toks = toks[cur_pos : cur_pos + 1] + probs = None + out_tokens.append(toks) + out_logprobs.append(probs) + + state = DecodeStateOriginal( + prev_pos=cur_pos, + cur_pos=cur_pos + 1, + tokens=tokens, + out_tokens=out_tokens, + logits=logits, + input_text_mask=input_text_mask, + prompt_tokens=prompt_tokens, + ) + + return state + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + ) -> List[List[int]]: + """ + Do greedy search on CPU and return tokens only. + """ + + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full( + (bsz, total_len), pad_id, dtype=torch.long, device="cpu" + ) + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu") + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cpu") + input_text_mask = tokens != pad_id + + for cur_pos in range(min_prompt_len, total_len): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + next_token == self.tokenizer.eos_id + ) + prev_pos = cur_pos + if all(eos_reached): + break + + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + # cut to eos tok if any + if self.tokenizer.eos_id in toks: + eos_idx = toks.index(self.tokenizer.eos_id) + toks = toks[:eos_idx] + out_tokens.append(toks) + out_logprobs.append(probs) + return out_tokens + + def text_completion( + self, + prompts: List[str], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> List[CompletionPrediction]: + """ + Perform text completion for a list of prompts using the language generation model. + + Args: + prompts (List[str]): List of text prompts for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. + + Note: + This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [ + self.tokenizer.encode(x, bos=True, eos=False) for x in prompts + ] + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] + + def chat_completion( + self, + dialogs: List[Dialog], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> List[ChatPrediction]: + """ + Generate assistant responses for a list of conversational dialogs using the language generation model. + + Args: + dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + + Returns: + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. + + Raises: + AssertionError: If the last message in a dialog is not from the user. + AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order. + + Note: + This method generates assistant responses for the provided conversational dialogs. + It employs nucleus sampling to introduce controlled randomness in text generation. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [] + unsafe_requests = [] + for dialog in dialogs: + unsafe_requests.append( + any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) + ) + if dialog[0]["role"] == "system": + dialog = [ { - "generation": { - "role": "assistant", - "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, - } + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], } - for t, unsafe in zip(generation_tokens, unsafe_requests) - ] - + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system', 'user' and 'assistant' roles, " + "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" + ) + dialog_tokens: List[int] = sum( + [ + self.tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + bos=True, + eos=True, + ) + for prompt, answer in zip( + dialog[::2], + dialog[1::2], + ) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += self.tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + bos=True, + eos=False, + ) + prompt_tokens.append(dialog_tokens) + + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + ) + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t) + if not unsafe + else UNSAFE_ERROR, + }, + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i, unsafe in zip( + generation_tokens, generation_logprobs, unsafe_requests + ) + ] + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t) + if not unsafe + else UNSAFE_ERROR, + } + } + for t, unsafe in zip(generation_tokens, unsafe_requests) + ] \ No newline at end of file diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 317fbad1..fdced201 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -67,7 +67,7 @@ def test_original_llama2_seed(self): tokens = np.arange(10, dtype=np.int32) file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) output_tokens_multiple = [] for i in [1, 999, 99999]: @@ -104,7 +104,7 @@ def test_jetstream_llama2_seed(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) seed = 1 @@ -169,7 +169,7 @@ def _llama_e2e(self, env): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal @@ -255,7 +255,7 @@ def test_llama_e2e_two_addtional_tokens(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal @@ -329,7 +329,7 @@ def test_llama_e2e_four_addtional_tokens(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal @@ -399,7 +399,7 @@ def test_llama_with_original_prefill_decode_32(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal @@ -476,7 +476,7 @@ def test_llama_with_original_prefill_decode(self): file_dir = os.path.dirname(__file__) tokenizer_path = os.path.join( - file_dir, "../jetstream_pt/third_party/llama2/tokenizer.model" + file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) # orginal From 68bd3f8414f82bb5bee66f39438972feac4ff168 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 2 May 2024 05:46:11 +0000 Subject: [PATCH 04/11] fix linting --- jetstream_pt/engine.py | 4 ++-- jetstream_pt/ray_worker.py | 10 ++++------ jetstream_pt/third_party/llama/generation_original.py | 4 ++-- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 740cc72e..8d195253 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -530,7 +530,7 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path) def build_tokenizer( - self, metadata: tokenizer_pb2.TokenizerParameters + self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all ) -> tokenizer_api.Tokenizer: if "llama-3" in self.env.model_type: return token_utils.TikToken(metadata) @@ -674,7 +674,7 @@ def create_pytorch_engine( if model_name not in supported_models: raise NotImplementedError( f"Model name should be one of{','.join(supported_models)}" - ) + ) # See issue b/309529778 if it's turned on. jax.config.update("jax_dynamic_shapes", False) # Pytorch exports has int64 constants. diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index c5fe0186..cca2797a 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -32,7 +32,7 @@ from torch.utils import _pytree as pytree import torch_xla2 -from jetstream.engine import engine_api, tokenizer_pb2, token_utils +from jetstream.engine import engine_api, tokenizer_pb2 from jetstream_pt.third_party.llama import model_exportable, model_args @@ -99,7 +99,7 @@ def __init__( context_length: int = 1024, batch_size: int = 1, max_decode_length: int = 4096, - model_name="llama", + model_name="llama-2", quantize_weights=False, quantize_kv=False, max_cache_length=1024, @@ -159,14 +159,12 @@ def __init__( ) env = JetEngineEnvironment(env_data) - tokenizer = token_utils.load_vocab(tokenizer_path) pt_model = None - if model_name == "llama": + if 'llama' in model_name : args = model_args.get_model_args( - param_size, + model_name + '-' + param_size, context_length, batch_size, - tokenizer.vocab_size, bf16_enable, ) args.device = "meta" diff --git a/jetstream_pt/third_party/llama/generation_original.py b/jetstream_pt/third_party/llama/generation_original.py index 5d4e2810..0fa82d56 100644 --- a/jetstream_pt/third_party/llama/generation_original.py +++ b/jetstream_pt/third_party/llama/generation_original.py @@ -5,9 +5,9 @@ from typing import List, Literal, Optional, Tuple, TypedDict import torch -from jetstream_pt.third_party.llama2 import model_original +from jetstream_pt.third_party.llama import model_original from flax import struct -from jetstream_pt.third_party.llama2.tokenizer import Tokenizer +from jetstream_pt.third_party.llama.tokenizer import Tokenizer Role = Literal["system", "user", "assistant"] From e6e25ee50d0a5932855dc0de23005dc005b25fb1 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 2 May 2024 05:50:53 +0000 Subject: [PATCH 05/11] Fix pyink issues --- jetstream_pt/ray_worker.py | 4 +-- .../third_party/llama/generation_original.py | 2 +- jetstream_pt/third_party/llama/model_args.py | 31 ++++++++++--------- .../third_party/llama/model_exportable.py | 2 +- run_interactive.py | 19 +++++++----- run_server.py | 8 +++-- 6 files changed, 38 insertions(+), 28 deletions(-) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index cca2797a..2176d0bb 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -160,9 +160,9 @@ def __init__( env = JetEngineEnvironment(env_data) pt_model = None - if 'llama' in model_name : + if "llama" in model_name: args = model_args.get_model_args( - model_name + '-' + param_size, + model_name + "-" + param_size, context_length, batch_size, bf16_enable, diff --git a/jetstream_pt/third_party/llama/generation_original.py b/jetstream_pt/third_party/llama/generation_original.py index 0fa82d56..dd4339ee 100644 --- a/jetstream_pt/third_party/llama/generation_original.py +++ b/jetstream_pt/third_party/llama/generation_original.py @@ -395,4 +395,4 @@ def chat_completion( } } for t, unsafe in zip(generation_tokens, unsafe_requests) - ] \ No newline at end of file + ] diff --git a/jetstream_pt/third_party/llama/model_args.py b/jetstream_pt/third_party/llama/model_args.py index 500cd4b5..b9143384 100755 --- a/jetstream_pt/third_party/llama/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -31,7 +31,7 @@ class ModelArgs: bf16_enable: bool = False head_dim = -1 infer_length = 0 - device = 'cpu' + device = "cpu" quantize = False rope_theta: float = 10000.0 @@ -84,7 +84,7 @@ def get_arg( "n_layers": 80, "norm_eps": 1e-05, } - elif model_name == 'llama-3-8b': + elif model_name == "llama-3-8b": data = { "dim": 4096, "vocab_size": 128256, @@ -103,17 +103,18 @@ def get_arg( **data, ) + def get_model_args(model_name, context_length, batch_size, bf16_enable): - model_args = get_arg( - model_name=model_name, - seqlen=context_length, - batch_size=batch_size, - bf16_enable=bf16_enable, - ) - model_args.n_kv_heads = ( - model_args.n_heads - if model_args.n_kv_heads is None - else model_args.n_kv_heads - ) - model_args.head_dim = model_args.dim // model_args.n_heads - return model_args \ No newline at end of file + model_args = get_arg( + model_name=model_name, + seqlen=context_length, + batch_size=batch_size, + bf16_enable=bf16_enable, + ) + model_args.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + model_args.head_dim = model_args.dim // model_args.n_heads + return model_args diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 99337d5d..106e1f0b 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -159,7 +159,7 @@ def __init__( freqs_cis = precompute_freqs_cis( self.params.dim // self.params.n_heads, self.params.max_seq_len * 2, - theta = self.params.rope_theta, + theta=self.params.rope_theta, ) self.register_buffer("freqs_cis", freqs_cis) diff --git a/run_interactive.py b/run_interactive.py index db214450..0117ead2 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -69,7 +69,12 @@ _MAX_CACHE_LENGTH = flags.DEFINE_integer( "max_cache_length", 1024, "kv_cache_quantize" ) -_MODEL_NAME = flags.DEFINE_string('model', 'llama-2', 'name of the model. Supported options are llama-2 and llama-3') +_MODEL_NAME = flags.DEFINE_string( + "model", + "llama-2", + "name of the model. Supported options are llama-2 and llama-3", +) + def create_engine(): """create a pytorch engine""" @@ -86,7 +91,7 @@ def create_engine(): param_size=_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, - model_name = _MODEL_NAME.value, + model_name=_MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, @@ -125,7 +130,7 @@ def main(argv): "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: - slot = random.randint(0, _BATCH_SIZE.value) + slot = random.randint(0, _BATCH_SIZE.value) tokens, true_length = tokenizer.encode(prompt, is_bos=True) print(f"---- Input prompts are: {prompt}") print(f"---- Encoded tokens are: {tokens}") @@ -140,11 +145,11 @@ def main(argv): print(f"---- Streaming decode started on #slot{slot}.") complete = np.zeros((1,), dtype=np.bool_) while True: - decode_state, result_tokens = engine.generate( - params, decode_state - ) + decode_state, result_tokens = engine.generate(params, decode_state) result_tokens = result_tokens.convert_to_numpy() - output, complete = tokenizer.decode(slot, max_output_length, result_tokens, complete) + output, complete = tokenizer.decode( + slot, max_output_length, result_tokens, complete + ) if complete[0]: break token_id = output[0][0] diff --git a/run_server.py b/run_server.py index a9af8264..77d7f173 100644 --- a/run_server.py +++ b/run_server.py @@ -71,7 +71,11 @@ "The model size the server runs on.", required=False, ) -_MODEL_NAME = flags.DEFINE_string('model', 'llama-2', 'name of the model. Supported options are llama-2 and llama-3') +_MODEL_NAME = flags.DEFINE_string( + "model", + "llama-2", + "name of the model. Supported options are llama-2 and llama-3", +) _QUANTIZE_WEIGHTS = flags.DEFINE_bool( "quantize_weights", False, "weight quantization" @@ -99,7 +103,7 @@ def main(argv: Sequence[str]): param_size=_PARAM_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, - model_name = _MODEL_NAME.value, + model_name=_MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, From 8ba6120f7b26336f8747c5e200ee3061709055e0 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 2 May 2024 05:56:13 +0000 Subject: [PATCH 06/11] fix run_offline script --- benchmarks/run_offline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 97d6fb2a..fc62dba2 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -99,7 +99,7 @@ def run_prefill_time(engine, params, decode_state, seqlen): tokenizer = engine.build_tokenizer(metadata) text = "This is a beautiful day" - tokens, true_length = tokenizer.encode(text, is_bos=True, prefill_lengths=[]) + tokens, true_length = tokenizer.encode(text, is_bos=True, prefill_lengths=[seqlen]) for _ in range(3): prefill_result = engine.prefill( From ac11b4acb2cac6febae9a59e3d505942e7718bca Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 2 May 2024 05:56:45 +0000 Subject: [PATCH 07/11] Fix pyink --- benchmarks/run_offline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index fc62dba2..07128881 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -99,7 +99,9 @@ def run_prefill_time(engine, params, decode_state, seqlen): tokenizer = engine.build_tokenizer(metadata) text = "This is a beautiful day" - tokens, true_length = tokenizer.encode(text, is_bos=True, prefill_lengths=[seqlen]) + tokens, true_length = tokenizer.encode( + text, is_bos=True, prefill_lengths=[seqlen] + ) for _ in range(3): prefill_result = engine.prefill( From c5aaae59c4e8d78639b7d9ee04f7a7e894bd463e Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 2 May 2024 21:45:19 +0000 Subject: [PATCH 08/11] Fix after merging main --- jetstream_pt/engine.py | 6 +----- tests/test_llama_e2e.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index fa549918..7ff41cec 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -709,11 +709,7 @@ def create_pytorch_engine( if model_name.startswith("llama"): args = model_args.get_model_args( - param_size, - context_length, - batch_size, - tokenizer.vocab_size, - bf16_enable, + model_name + "-" + param_size, context_length, batch_size, bf16_enable ) args.device = "meta" args.quantize = quantize_weights diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 22d10459..10e698d6 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -77,7 +77,7 @@ def test_original_llama2_seed(self): file_dir, "../jetstream_pt/third_party/llama/tokenizer.model" ) output_tokens_multiple = [] - model_arg = model_args.get_model_args("tiny", 128, 1, 32000, True) + model_arg = model_args.get_model_args("llama-2-tiny", 128, 1, True) for i in [1, 999, 99999]: llama_original = LlamaOriginal.build(tokenizer_path, model_arg, i) prompt_tokens = [tokens] From 4748ffbe864413dca1879b5a7117f3cd17bcf712 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 3 May 2024 23:19:38 +0000 Subject: [PATCH 09/11] Update jetstream version in install_everything.sh --- install_everything.sh | 2 +- jetstream_pt/engine.py | 2 +- run_interactive.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/install_everything.sh b/install_everything.sh index aca732b3..f9838a45 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -13,7 +13,7 @@ # limitations under the License. TORCHXLA_TAG=jetstream-pytorch -JETSTREAM_TAG=v0.2.0 +JETSTREAM_TAG=v0.2.1 # Uninstall existing jax pip3 show jax && pip3 uninstall -y jax diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 7ff41cec..328f0fb3 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -709,7 +709,7 @@ def create_pytorch_engine( if model_name.startswith("llama"): args = model_args.get_model_args( - model_name + "-" + param_size, context_length, batch_size, bf16_enable + model_name + "-" + param_size, context_length, batch_size, bf16_enable ) args.device = "meta" args.quantize = quantize_weights diff --git a/run_interactive.py b/run_interactive.py index 0117ead2..2262b4f6 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -130,7 +130,7 @@ def main(argv): "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: - slot = random.randint(0, _BATCH_SIZE.value) + slot = random.randint(0, _BATCH_SIZE.value - 1) tokens, true_length = tokenizer.encode(prompt, is_bos=True) print(f"---- Input prompts are: {prompt}") print(f"---- Encoded tokens are: {tokens}") From 5d40ebf729e009fbbf596ceb7a8a5defda8cf423 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 3 May 2024 23:27:35 +0000 Subject: [PATCH 10/11] Fix unit tests --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 764bc4b0..ab302672 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,6 +1,6 @@ import torch import jax -from jetstream_pt.third_party.llama2 import model_args +from jetstream_pt.third_party.llama import model_args from jetstream_pt import environment From f5569562d2870455c2e3b77423fe775cd92ce2ba Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 3 May 2024 23:33:02 +0000 Subject: [PATCH 11/11] Fix test --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index ab302672..dd0c7c50 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -9,7 +9,7 @@ def make_env_tiny(bf16_enable=True): torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) jax.config.update("jax_traceback_filtering", "off") - config = model_args.get_model_args("tiny", 128, 1, 32000, True) + config = model_args.get_model_args("llama-2-tiny", 128, 1, True) environment_data = environment.JetEngineEnvironmentData() environment_data.max_input_sequence_length = 128 environment_data.max_input_sequence_length = 128