From abf8535818db3f47745cdb93051d728a981d598c Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 17 May 2024 23:14:21 +0000 Subject: [PATCH 1/8] refactor flags --- README.md | 2 +- benchmarks/prefill_offline.py | 124 ++++++++++++---------- benchmarks/run_offline.py | 145 +++++++++++++------------- jetstream_pt/config.py | 101 +++++++++++++----- run_interactive.py | 108 +++---------------- run_interactive_multiple_host.py | 139 ++++++++++++------------ run_server.py | 174 ++++++++++++++++++------------- 7 files changed, 404 insertions(+), 389 deletions(-) diff --git a/README.md b/README.md index be6c8d65..9d3e21bf 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --m Here is an example to run the server with llama2 7B config. Note that the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`). ```bash -python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name --sharding_config="default_shardings/llama.yaml" +python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --sharding_config="default_shardings/llama.yaml" ``` Now you can fire gRPC to it. diff --git a/benchmarks/prefill_offline.py b/benchmarks/prefill_offline.py index 03bf4180..0686b037 100644 --- a/benchmarks/prefill_offline.py +++ b/benchmarks/prefill_offline.py @@ -12,57 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import os import time -import functools -import humanize -# pylint: disable-next=all -from absl import app -from absl import flags -import numpy as np +import humanize import jax - +import numpy as np +# pylint: disable-next=all +from absl import app, flags from jetstream_pt import engine as je - -FLAGS = flags.FLAGS - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", False, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False -) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_common_flags, + define_profiling_flags, ) -_SIZE = flags.DEFINE_string("size", "tiny", "size of model") - -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) +define_common_flags() +define_profiling_flags() + +# _TOKENIZER_PATH = flags.DEFINE_string( +# "tokenizer_path", +# "tokenizer.model", +# "The tokenizer model path", +# required=False, +# ) +# _CKPT_PATH = flags.DEFINE_string( +# "checkpoint_path", None, "Directory for .pth checkpoints", required=False +# ) +# _BF16_ENABLE = flags.DEFINE_bool( +# "bf16_enable", False, "Whether to enable bf16", required=False +# ) +# _CONTEXT_LENGTH = flags.DEFINE_integer( +# "context_length", 1024, "The context length", required=False +# ) +# _BATCH_SIZE = flags.DEFINE_integer( +# "batch_size", 32, "The batch size", required=False +# ) +# _PROFILING_OUTPUT = flags.DEFINE_string( +# "profiling_output", +# "", +# "The profiling output", +# required=False, +# ) + +# _SIZE = flags.DEFINE_string("size", "tiny", "size of model") + +# _QUANTIZE_WEIGHTS = flags.DEFINE_bool( +# "quantize_weights", False, "weight quantization" +# ) +# _QUANTIZE_KV_CACHE = flags.DEFINE_bool( +# "quantize_kv_cache", False, "kv_cache_quantize" +# ) +# _MAX_CACHE_LENGTH = flags.DEFINE_integer( +# "max_cache_length", 1024, "kv_cache_quantize" +# ) def create_engine(): @@ -73,16 +78,19 @@ def create_engine(): devices = jax.devices() start = time.perf_counter() engine = je.create_pytorch_engine( + model_name=FLAGS.model_name, devices=devices, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, ) print("Initialize engine", time.perf_counter() - start) @@ -163,14 +171,16 @@ def prefill_benchmark(tokens_list, engine, params, warmup): # pylint: disable-next=all def main(argv): - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() params = engine.load_params() print("Load params ", time.perf_counter() - start) - if _PROFILING_OUTPUT.value: - jax.profiler.start_trace(_PROFILING_OUTPUT.value) + profiling_output = FLAGS.profiling_output + if profiling_output: + jax.profiler.start_trace(profiling_output) + print_mem_usage() tokens_list = create_prefill_tokens() for _ in range(3): @@ -189,7 +199,7 @@ def main(argv): tokens_list=tokens_list, engine=engine, params=params, warmup=False ) - if _PROFILING_OUTPUT.value: + if profiling_output: jax.profiler.stop_trace() diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index df788591..357fe65f 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -15,66 +15,68 @@ import logging import os import time -# pylint: disable-next=all -from absl import app -from absl import flags import jax import jax.numpy as jnp - -from jetstream_pt import engine as je +# pylint: disable-next=all +from absl import app, flags # pylint: disable-next=all from benchmarks import analyze_sharegpt - - -logging.getLogger().setLevel(logging.ERROR) - - -FLAGS = flags.FLAGS - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", False, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False -) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, +from jetstream_pt import engine as je +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_common_flags, + define_profiling_flags, ) -_SIZE = flags.DEFINE_string("size", "tiny", "size of model") +logging.getLogger().setLevel(logging.ERROR) -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) -_MODEL_NAME = flags.DEFINE_string("model_name", "", "model_name") -_SHARDING_CONFIG = flags.DEFINE_string( - "sharding_config", "", "path to sharding config" -) -_SHAREGPT_PATH = flags.DEFINE_string( - "sharegpt_path", "", "path to sharegpt json file" -) +define_common_flags() +define_profiling_flags() +# FLAGS = flags.FLAGS + +# _TOKENIZER_PATH = flags.DEFINE_string( +# "tokenizer_path", +# "tokenizer.model", +# "The tokenizer model path", +# required=False, +# ) +# _CKPT_PATH = flags.DEFINE_string( +# "checkpoint_path", None, "Directory for .pth checkpoints", required=False +# ) +# _BF16_ENABLE = flags.DEFINE_bool( +# "bf16_enable", False, "Whether to enable bf16", required=False +# ) +# _CONTEXT_LENGTH = flags.DEFINE_integer( +# "context_length", 1024, "The context length", required=False +# ) +# _BATCH_SIZE = flags.DEFINE_integer( +# "batch_size", 32, "The batch size", required=False +# ) +# _PROFILING_OUTPUT = flags.DEFINE_string( +# "profiling_output", +# "", +# "The profiling output", +# required=False, +# ) + +# _SIZE = flags.DEFINE_string("size", "tiny", "size of model") + +# _QUANTIZE_WEIGHTS = flags.DEFINE_bool( +# "quantize_weights", False, "weight quantization" +# ) +# _QUANTIZE_KV_CACHE = flags.DEFINE_bool( +# "quantize_kv_cache", False, "kv_cache_quantize" +# ) +# _MAX_CACHE_LENGTH = flags.DEFINE_integer( +# "max_cache_length", 1024, "kv_cache_quantize" +# ) +# _MODEL_NAME = flags.DEFINE_string("model_name", "", "model_name") +# _SHARDING_CONFIG = flags.DEFINE_string( +# "sharding_config", "", "path to sharding config" +# ) +flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") def create_engine(): @@ -85,18 +87,19 @@ def create_engine(): devices = jax.devices() start = time.perf_counter() engine = je.create_pytorch_engine( + model_name=FLAGS.model_name, devices=devices, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - model_name=_MODEL_NAME.value, - sharding_config=_SHARDING_CONFIG.value, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, ) print("Initialize engine", time.perf_counter() - start) @@ -148,7 +151,7 @@ def run_prefill_time(engine, params, decode_state, seqlen): def main(argv): """Main function to run engine offline.""" - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() params = engine.load_params() @@ -156,8 +159,9 @@ def main(argv): prefill_times = {} - if _PROFILING_OUTPUT.value: - jax.profiler.start_trace(_PROFILING_OUTPUT.value) + profiling_output = FLAGS.profiling_output + if profiling_output: + jax.profiler.start_trace(profiling_output) decode_state = engine.init_decode_state() for batch, _ in MAXTEXT_PREFILL.items(): runtime, decode_state = run_prefill_time( @@ -186,18 +190,19 @@ def main(argv): dec_times.append(end - start) print(i, "decode time", (end - start)) - if _PROFILING_OUTPUT.value: + if profiling_output: jax.profiler.stop_trace() print("prefill ", prefill_times) print("decode", sum(dec_times) / 10) prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()} - decode_time_ms = sum(dec_times) * 1000 / 10 / _BATCH_SIZE.value + decode_time_ms = sum(dec_times) * 1000 / 10 / FLAGS.batch_size - if _SHAREGPT_PATH.value: + sharegpt_path = FLAGS.sharegpt_path + if sharegpt_path: analyze_sharegpt.do_simulation( - _SHAREGPT_PATH.value, prefill_times_ms, decode_time_ms + sharegpt_path, prefill_times_ms, decode_time_ms ) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 04638198..66d8c545 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -12,35 +12,80 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jetstream.core.config_lib import ServerConfig +import os +import time + +import jax +from absl import flags from jetstream_pt.engine import create_pytorch_engine +FLAGS = flags.FLAGS + + +def define_common_flags(): + """Add common config flags to global FLAG.""" + flags.DEFINE_string( + "tokenizer_path", + None, + "The tokenizer model path", + required=True, + ) + flags.DEFINE_string("model_name", None, "model type", required=False) + flags.DEFINE_string( + "checkpoint_path", None, "Directory for .pth checkpoints", required=False + ) + flags.DEFINE_bool( + "bf16_enable", True, "Whether to enable bf16", required=False + ) + flags.DEFINE_integer( + "context_length", 1024, "The context length", required=False + ) + flags.DEFINE_integer("batch_size", 32, "The batch size", required=False) + flags.DEFINE_string("size", "tiny", "size of model") + flags.DEFINE_bool("quantize_weights", False, "weight quantization") + flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize") + flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize") + flags.DEFINE_string("sharding_config", "", "config file for sharding") + flags.DEFINE_bool( + "shard_on_batch", + False, + "whether to shard on batch dimension" + "If set true, sharding_config will be ignored.", + ) + -# pylint: disable-next=all -def create_config( - devices, - tokenizer_path, - ckpt_path, - bf16_enable, - param_size, - context_length, - batch_size, - platform, -): - """Create a server config""" - - def func(): - return create_pytorch_engine( - devices=devices, - tokenizer_path=tokenizer_path, - ckpt_path=ckpt_path, - bf16_enable=bf16_enable, - param_size=param_size, - context_length=context_length, - batch_size=batch_size, - ) - - return ServerConfig( - interleaved_slices=(platform,), - interleaved_engine_create_fns=(func,), +def define_profiling_flags(): + """Add profiling related config flags to global FLAG.""" + flags.DEFINE_string( + "profiling_output", + "", + "The profiling output", + required=False, ) + + +def create_engine_from_config_flags(): + """create a pytorch engine from config flag""" + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + devices = jax.devices() + start = time.perf_counter() + engine = create_pytorch_engine( + model_name=FLAGS.model_name, + devices=devices, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, + ) + + print("Initialize engine", time.perf_counter() - start) + return engine diff --git a/run_interactive.py b/run_interactive.py index e6be9548..e209f41a 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -15,105 +15,28 @@ import os import random import time - from typing import List -from absl import app -from absl import flags -from colorama import Fore, Style import jax - -from jetstream.engine import token_utils -from colorama import Fore, Style import numpy as np - -import os - +from absl import app, flags +from colorama import Fore, Style +from jetstream.engine import token_utils from jetstream_pt import engine as je - -FLAGS = flags.FLAGS - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_MODEL_NAME = flags.DEFINE_string( - "model_name", None, "model type", required=False -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", False, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_common_flags, + define_profiling_flags, ) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False -) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, -) - -_SIZE = flags.DEFINE_string("size", "tiny", "size of model") - -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) -_SHARDING_CONFIG = flags.DEFINE_string( - "sharding_config", "", "config file for sharding" -) -_SHARD_ON_BATCH = flags.DEFINE_bool( - "shard_on_batch", - False, - "whether to shard on batch dimension." - "If set true, sharding_config will be ignored.", -) - - -def create_engine(): - """create a pytorch engine""" - jax.config.update("jax_default_prng_impl", "unsafe_rbg") - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - - devices = jax.devices() - start = time.perf_counter() - engine = je.create_pytorch_engine( - model_name=_MODEL_NAME.value, - devices=devices, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - sharding_config=_SHARDING_CONFIG.value, - shard_on_batch=_SHARD_ON_BATCH.value, - ) - - print("Initialize engine", time.perf_counter() - start) - return engine +define_common_flags() +define_profiling_flags() # pylint: disable-next=all def main(argv): - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() params = engine.load_params() @@ -123,8 +46,9 @@ def main(argv): tokenizer = engine.build_tokenizer(metadata) max_output_length = 1024 - if _PROFILING_OUTPUT.value: - jax.profiler.start_trace(_PROFILING_OUTPUT.value) + profiling_output = FLAGS.profiling_output + if profiling_output: + jax.profiler.start_trace(profiling_output) decode_state = engine.init_decode_state() prompts: List[str] = [ @@ -139,7 +63,7 @@ def main(argv): "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: - slot = random.randint(0, _BATCH_SIZE.value - 1) + slot = random.randint(0, FLAGS.batch_size - 1) tokens, true_length = tokenizer.encode(prompt) print(f"---- Input prompts are: {prompt}") @@ -174,7 +98,7 @@ def main(argv): print("---- All output text.") print(tokenizer.decode(sampled_tokens_list)) - if _PROFILING_OUTPUT.value: + if profiling_output: jax.profiler.stop_trace() diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 9de0c492..55be2455 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -15,63 +15,66 @@ import os import random import time - from typing import List -from absl import app -from absl import flags -from colorama import Fore, Style import jax - +from absl import app, flags +from colorama import Fore, Style from jetstream.engine import token_utils from jetstream_pt import ray_engine - -FLAGS = flags.FLAGS - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", False, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_common_flags, + define_profiling_flags, ) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, -) - -_SIZE = flags.DEFINE_string("size", "tiny", "size of model") -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) - -_MODEL_NAME = flags.DEFINE_string( - "model_name", None, "model type", required=False -) - -_SHARDING_CONFIG = flags.DEFINE_string( - "sharding_config", "", "config file for sharding" -) +define_common_flags() +define_profiling_flags() +# _TOKENIZER_PATH = flags.DEFINE_string( +# "tokenizer_path", +# "tokenizer.model", +# "The tokenizer model path", +# required=False, +# ) +# _CKPT_PATH = flags.DEFINE_string( +# "checkpoint_path", None, "Directory for .pth checkpoints", required=False +# ) +# _BF16_ENABLE = flags.DEFINE_bool( +# "bf16_enable", False, "Whether to enable bf16", required=False +# ) +# _CONTEXT_LENGTH = flags.DEFINE_integer( +# "context_length", 1024, "The context length", required=False +# ) +# _BATCH_SIZE = flags.DEFINE_integer( +# "batch_size", 32, "The batch size", required=False +# ) +# _PROFILING_OUTPUT = flags.DEFINE_string( +# "profiling_output", +# "", +# "The profiling output", +# required=False, +# ) + +# _SIZE = flags.DEFINE_string("size", "tiny", "size of model") + +# _QUANTIZE_WEIGHTS = flags.DEFINE_bool( +# "quantize_weights", False, "weight quantization" +# ) +# _QUANTIZE_KV_CACHE = flags.DEFINE_bool( +# "quantize_kv_cache", False, "kv_cache_quantize" +# ) +# _MAX_CACHE_LENGTH = flags.DEFINE_integer( +# "max_cache_length", 1024, "kv_cache_quantize" +# ) + +# _MODEL_NAME = flags.DEFINE_string( +# "model_name", None, "model type", required=False +# ) + +# _SHARDING_CONFIG = flags.DEFINE_string( +# "sharding_config", "", "config file for sharding" +# ) def create_engine(): @@ -81,17 +84,18 @@ def create_engine(): start = time.perf_counter() engine = ray_engine.create_pytorch_ray_engine( - model_name=_MODEL_NAME.value, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - sharding_config=_SHARDING_CONFIG.value, + model_name=FLAGS.model_name, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, ) print("Initialize engine", time.perf_counter() - start) @@ -101,7 +105,7 @@ def create_engine(): # pylint: disable-next=all def main(argv): - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() engine.load_params() @@ -112,8 +116,11 @@ def main(argv): stop_tokens = [vocab.eos_id, vocab.pad_id] max_output_length = 1024 - if _PROFILING_OUTPUT.value: - jax.profiler.start_trace(_PROFILING_OUTPUT.value) + # if _PROFILING_OUTPUT.value: + # jax.profiler.start_trace(_PROFILING_OUTPUT.value) + profiling_output = FLAGS.profiling_output + if profiling_output: + jax.profiler.start_trace(profiling_output) engine.init_decode_state() prompts: List[str] = [ @@ -128,7 +135,7 @@ def main(argv): "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: - slot = random.randint(0, _BATCH_SIZE.value - 1) + slot = random.randint(0, FLAGS.batch_size - 1) tokens, true_length = token_utils.tokenize_and_pad( prompt, vocab, is_bos=True, jax_padding=False ) @@ -161,7 +168,7 @@ def main(argv): print("---- All output text.") print(vocab.tokenizer.decode(sampled_tokens_list)) - if _PROFILING_OUTPUT.value: + if profiling_output: jax.profiler.stop_trace() diff --git a/run_server.py b/run_server.py index 1194da5c..b9bb1b7a 100644 --- a/run_server.py +++ b/run_server.py @@ -16,85 +16,109 @@ import os from typing import Sequence -from absl import app -from absl import flags - +import jetstream_pt +from absl import app, flags from jetstream.core import server_lib from jetstream.core.config_lib import ServerConfig +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_common_flags, +) -import jetstream_pt - +define_common_flags() -_PORT = flags.DEFINE_integer("port", 9000, "port to listen on") -_THREADS = flags.DEFINE_integer( - "threads", 64, "number of worker threads in thread pool" -) -_CONFIG = flags.DEFINE_string( +flags.DEFINE_integer("port", 9000, "port to listen on") +flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool") +flags.DEFINE_string( "config", "InterleavedCPUTestServer", "available servers", ) - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", True, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False -) -_PROFILING_OUTPUT = flags.DEFINE_string( +flags.DEFINE_string( "profiling_output", "", "The profiling output", required=False, ) -_PLATFORM = flags.DEFINE_string( +flags.DEFINE_string( "platform", "tpu=4", "The platform that the engine runs on", required=False, ) -_PARAM_SIZE = flags.DEFINE_string( - "param_size", - "7b", - "The model size the server runs on.", - required=False, -) -_MODEL_NAME = flags.DEFINE_string( - "model", - "llama-2", - "name of the model. Supported options are llama-2 and llama-3", -) +# _PORT = flags.DEFINE_integer("port", 9000, "port to listen on") +# _THREADS = flags.DEFINE_integer( +# "threads", 64, "number of worker threads in thread pool" +# ) +# _CONFIG = flags.DEFINE_string( +# "config", +# "InterleavedCPUTestServer", +# "available servers", +# ) +# _PROFILING_OUTPUT = flags.DEFINE_string( +# "profiling_output", +# "", +# "The profiling output", +# required=False, +# ) +# _PLATFORM = flags.DEFINE_string( +# "platform", +# "tpu=4", +# "The platform that the engine runs on", +# required=False, +# ) -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) -_SHARDING_CONFIG = flags.DEFINE_string( - "sharding_config", "", "config file for sharding" -) -_SHARD_ON_BATCH = flags.DEFINE_bool( - "shard_on_batch", - False, - "whether to shard on batch dimension" - "If set true, sharding_config will be ignored.", -) +# Common flags +# _TOKENIZER_PATH = flags.DEFINE_string( +# "tokenizer_path", +# "tokenizer.model", +# "The tokenizer model path", +# required=False, +# ) +# _CKPT_PATH = flags.DEFINE_string( +# "checkpoint_path", None, "Directory for .pth checkpoints", required=False +# ) +# _BF16_ENABLE = flags.DEFINE_bool( +# "bf16_enable", True, "Whether to enable bf16", required=False +# ) +# _CONTEXT_LENGTH = flags.DEFINE_integer( +# "context_length", 1024, "The context length", required=False +# ) +# _BATCH_SIZE = flags.DEFINE_integer( +# "batch_size", 32, "The batch size", required=False +# ) + +# _PARAM_SIZE = flags.DEFINE_string( +# "param_size", +# "7b", +# "The model size the server runs on.", +# required=False, +# ) +# _MODEL_NAME = flags.DEFINE_string( +# "model", +# "llama-2", +# "name of the model. Supported options are llama-2 and llama-3", +# ) + +# _QUANTIZE_WEIGHTS = flags.DEFINE_bool( +# "quantize_weights", False, "weight quantization" +# ) +# _QUANTIZE_KV_CACHE = flags.DEFINE_bool( +# "quantize_kv_cache", False, "kv_cache_quantize" +# ) +# _MAX_CACHE_LENGTH = flags.DEFINE_integer( +# "max_cache_length", 1024, "kv_cache_quantize" +# ) +# _SHARDING_CONFIG = flags.DEFINE_string( +# "sharding_config", "", "config file for sharding" +# ) +# _SHARD_ON_BATCH = flags.DEFINE_bool( +# "shard_on_batch", +# False, +# "whether to shard on batch dimension" +# "If set true, sharding_config will be ignored.", +# ) # pylint: disable-next=all @@ -106,22 +130,22 @@ def main(argv: Sequence[str]): print(f"devices: {devices}") sharding_config_path = _SHARDING_CONFIG.value engine = jetstream_pt.create_pytorch_engine( + model_name=FLAGS.model_name, devices=devices, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=_BF16_ENABLE.value, - param_size=_PARAM_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - model_name=_MODEL_NAME.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - sharding_config=sharding_config_path, - shard_on_batch=_SHARD_ON_BATCH.value, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, ) server_config = ServerConfig( - interleaved_slices=(_PLATFORM.value,), + interleaved_slices=(FLAGS.platform,), interleaved_engine_create_fns=(lambda a: engine,), ) print(f"server_config: {server_config}") @@ -129,8 +153,8 @@ def main(argv: Sequence[str]): # We separate credential from run so that we can unit test it with local credentials. # We would like to add grpc credentials for OSS. jetstream_server = server_lib.run( - threads=_THREADS.value, - port=_PORT.value, + threads=FLAGS.threads, + port=FLAGS.port, config=server_config, devices=devices, ) From fa831b6d6fb01c2d8f7b8d6f5c9a0545dacc0985 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 17 May 2024 23:31:42 +0000 Subject: [PATCH 2/8] clean up: --- .pylintrc | 2 +- benchmarks/prefill_offline.py | 37 --------------- benchmarks/run_offline.py | 42 ----------------- run_interactive.py | 1 + run_interactive_multiple_host.py | 46 ------------------ run_server.py | 80 +------------------------------- 6 files changed, 4 insertions(+), 204 deletions(-) diff --git a/.pylintrc b/.pylintrc index a03b49b3..66a6589e 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,2 +1,2 @@ [MESSAGES CONTROL] -disable=C0114,R0801,E1102,W0613,R1711 \ No newline at end of file +disable=C0114,R0801,E1102,W0613,R1711,too-many-locals diff --git a/benchmarks/prefill_offline.py b/benchmarks/prefill_offline.py index 0686b037..0a93147d 100644 --- a/benchmarks/prefill_offline.py +++ b/benchmarks/prefill_offline.py @@ -32,43 +32,6 @@ define_common_flags() define_profiling_flags() -# _TOKENIZER_PATH = flags.DEFINE_string( -# "tokenizer_path", -# "tokenizer.model", -# "The tokenizer model path", -# required=False, -# ) -# _CKPT_PATH = flags.DEFINE_string( -# "checkpoint_path", None, "Directory for .pth checkpoints", required=False -# ) -# _BF16_ENABLE = flags.DEFINE_bool( -# "bf16_enable", False, "Whether to enable bf16", required=False -# ) -# _CONTEXT_LENGTH = flags.DEFINE_integer( -# "context_length", 1024, "The context length", required=False -# ) -# _BATCH_SIZE = flags.DEFINE_integer( -# "batch_size", 32, "The batch size", required=False -# ) -# _PROFILING_OUTPUT = flags.DEFINE_string( -# "profiling_output", -# "", -# "The profiling output", -# required=False, -# ) - -# _SIZE = flags.DEFINE_string("size", "tiny", "size of model") - -# _QUANTIZE_WEIGHTS = flags.DEFINE_bool( -# "quantize_weights", False, "weight quantization" -# ) -# _QUANTIZE_KV_CACHE = flags.DEFINE_bool( -# "quantize_kv_cache", False, "kv_cache_quantize" -# ) -# _MAX_CACHE_LENGTH = flags.DEFINE_integer( -# "max_cache_length", 1024, "kv_cache_quantize" -# ) - def create_engine(): """create a pytorch engine""" diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 357fe65f..a6378044 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -34,48 +34,6 @@ define_common_flags() define_profiling_flags() -# FLAGS = flags.FLAGS - -# _TOKENIZER_PATH = flags.DEFINE_string( -# "tokenizer_path", -# "tokenizer.model", -# "The tokenizer model path", -# required=False, -# ) -# _CKPT_PATH = flags.DEFINE_string( -# "checkpoint_path", None, "Directory for .pth checkpoints", required=False -# ) -# _BF16_ENABLE = flags.DEFINE_bool( -# "bf16_enable", False, "Whether to enable bf16", required=False -# ) -# _CONTEXT_LENGTH = flags.DEFINE_integer( -# "context_length", 1024, "The context length", required=False -# ) -# _BATCH_SIZE = flags.DEFINE_integer( -# "batch_size", 32, "The batch size", required=False -# ) -# _PROFILING_OUTPUT = flags.DEFINE_string( -# "profiling_output", -# "", -# "The profiling output", -# required=False, -# ) - -# _SIZE = flags.DEFINE_string("size", "tiny", "size of model") - -# _QUANTIZE_WEIGHTS = flags.DEFINE_bool( -# "quantize_weights", False, "weight quantization" -# ) -# _QUANTIZE_KV_CACHE = flags.DEFINE_bool( -# "quantize_kv_cache", False, "kv_cache_quantize" -# ) -# _MAX_CACHE_LENGTH = flags.DEFINE_integer( -# "max_cache_length", 1024, "kv_cache_quantize" -# ) -# _MODEL_NAME = flags.DEFINE_string("model_name", "", "model_name") -# _SHARDING_CONFIG = flags.DEFINE_string( -# "sharding_config", "", "path to sharding config" -# ) flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") diff --git a/run_interactive.py b/run_interactive.py index e209f41a..24206620 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -33,6 +33,7 @@ define_common_flags() define_profiling_flags() + # pylint: disable-next=all def main(argv): diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 55be2455..f331a5c6 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -31,50 +31,6 @@ define_common_flags() define_profiling_flags() -# _TOKENIZER_PATH = flags.DEFINE_string( -# "tokenizer_path", -# "tokenizer.model", -# "The tokenizer model path", -# required=False, -# ) -# _CKPT_PATH = flags.DEFINE_string( -# "checkpoint_path", None, "Directory for .pth checkpoints", required=False -# ) -# _BF16_ENABLE = flags.DEFINE_bool( -# "bf16_enable", False, "Whether to enable bf16", required=False -# ) -# _CONTEXT_LENGTH = flags.DEFINE_integer( -# "context_length", 1024, "The context length", required=False -# ) -# _BATCH_SIZE = flags.DEFINE_integer( -# "batch_size", 32, "The batch size", required=False -# ) -# _PROFILING_OUTPUT = flags.DEFINE_string( -# "profiling_output", -# "", -# "The profiling output", -# required=False, -# ) - -# _SIZE = flags.DEFINE_string("size", "tiny", "size of model") - -# _QUANTIZE_WEIGHTS = flags.DEFINE_bool( -# "quantize_weights", False, "weight quantization" -# ) -# _QUANTIZE_KV_CACHE = flags.DEFINE_bool( -# "quantize_kv_cache", False, "kv_cache_quantize" -# ) -# _MAX_CACHE_LENGTH = flags.DEFINE_integer( -# "max_cache_length", 1024, "kv_cache_quantize" -# ) - -# _MODEL_NAME = flags.DEFINE_string( -# "model_name", None, "model type", required=False -# ) - -# _SHARDING_CONFIG = flags.DEFINE_string( -# "sharding_config", "", "config file for sharding" -# ) def create_engine(): @@ -116,8 +72,6 @@ def main(argv): stop_tokens = [vocab.eos_id, vocab.pad_id] max_output_length = 1024 - # if _PROFILING_OUTPUT.value: - # jax.profiler.start_trace(_PROFILING_OUTPUT.value) profiling_output = FLAGS.profiling_output if profiling_output: jax.profiler.start_trace(profiling_output) diff --git a/run_server.py b/run_server.py index b9bb1b7a..3d52d384 100644 --- a/run_server.py +++ b/run_server.py @@ -24,9 +24,11 @@ FLAGS, create_engine_from_config_flags, define_common_flags, + define_profiling_flags, ) define_common_flags() +define_profiling_flags() flags.DEFINE_integer("port", 9000, "port to listen on") flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool") @@ -35,90 +37,12 @@ "InterleavedCPUTestServer", "available servers", ) -flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, -) flags.DEFINE_string( "platform", "tpu=4", "The platform that the engine runs on", required=False, ) -# _PORT = flags.DEFINE_integer("port", 9000, "port to listen on") -# _THREADS = flags.DEFINE_integer( -# "threads", 64, "number of worker threads in thread pool" -# ) -# _CONFIG = flags.DEFINE_string( -# "config", -# "InterleavedCPUTestServer", -# "available servers", -# ) -# _PROFILING_OUTPUT = flags.DEFINE_string( -# "profiling_output", -# "", -# "The profiling output", -# required=False, -# ) -# _PLATFORM = flags.DEFINE_string( -# "platform", -# "tpu=4", -# "The platform that the engine runs on", -# required=False, -# ) - -# Common flags -# _TOKENIZER_PATH = flags.DEFINE_string( -# "tokenizer_path", -# "tokenizer.model", -# "The tokenizer model path", -# required=False, -# ) -# _CKPT_PATH = flags.DEFINE_string( -# "checkpoint_path", None, "Directory for .pth checkpoints", required=False -# ) -# _BF16_ENABLE = flags.DEFINE_bool( -# "bf16_enable", True, "Whether to enable bf16", required=False -# ) -# _CONTEXT_LENGTH = flags.DEFINE_integer( -# "context_length", 1024, "The context length", required=False -# ) -# _BATCH_SIZE = flags.DEFINE_integer( -# "batch_size", 32, "The batch size", required=False -# ) - -# _PARAM_SIZE = flags.DEFINE_string( -# "param_size", -# "7b", -# "The model size the server runs on.", -# required=False, -# ) -# _MODEL_NAME = flags.DEFINE_string( -# "model", -# "llama-2", -# "name of the model. Supported options are llama-2 and llama-3", -# ) - -# _QUANTIZE_WEIGHTS = flags.DEFINE_bool( -# "quantize_weights", False, "weight quantization" -# ) -# _QUANTIZE_KV_CACHE = flags.DEFINE_bool( -# "quantize_kv_cache", False, "kv_cache_quantize" -# ) -# _MAX_CACHE_LENGTH = flags.DEFINE_integer( -# "max_cache_length", 1024, "kv_cache_quantize" -# ) -# _SHARDING_CONFIG = flags.DEFINE_string( -# "sharding_config", "", "config file for sharding" -# ) -# _SHARD_ON_BATCH = flags.DEFINE_bool( -# "shard_on_batch", -# False, -# "whether to shard on batch dimension" -# "If set true, sharding_config will be ignored.", -# ) # pylint: disable-next=all From 75d7fc3ec919be0b7069dbdae4f6df14d35e4263 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 17 May 2024 23:45:30 +0000 Subject: [PATCH 3/8] fix run_server --- run_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/run_server.py b/run_server.py index 3d52d384..32bd91fd 100644 --- a/run_server.py +++ b/run_server.py @@ -52,7 +52,6 @@ def main(argv: Sequence[str]): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() print(f"devices: {devices}") - sharding_config_path = _SHARDING_CONFIG.value engine = jetstream_pt.create_pytorch_engine( model_name=FLAGS.model_name, devices=devices, From 2dd6dd2416fdf66bae30136e20b0071f9ba4a7fb Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 18 May 2024 00:14:29 +0000 Subject: [PATCH 4/8] move common flags to global --- benchmarks/run_offline.py | 2 -- jetstream_pt/config.py | 72 +++++++++++++++++++-------------------- run_server.py | 28 ++------------- 3 files changed, 38 insertions(+), 64 deletions(-) diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index a6378044..39e1584a 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -26,13 +26,11 @@ from jetstream_pt.config import ( FLAGS, create_engine_from_config_flags, - define_common_flags, define_profiling_flags, ) logging.getLogger().setLevel(logging.ERROR) -define_common_flags() define_profiling_flags() flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 66d8c545..ab1ec79f 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -21,47 +21,45 @@ FLAGS = flags.FLAGS - -def define_common_flags(): - """Add common config flags to global FLAG.""" - flags.DEFINE_string( - "tokenizer_path", - None, - "The tokenizer model path", - required=True, - ) - flags.DEFINE_string("model_name", None, "model type", required=False) - flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False - ) - flags.DEFINE_bool( - "bf16_enable", True, "Whether to enable bf16", required=False - ) - flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False - ) - flags.DEFINE_integer("batch_size", 32, "The batch size", required=False) - flags.DEFINE_string("size", "tiny", "size of model") - flags.DEFINE_bool("quantize_weights", False, "weight quantization") - flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize") - flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize") - flags.DEFINE_string("sharding_config", "", "config file for sharding") - flags.DEFINE_bool( - "shard_on_batch", - False, - "whether to shard on batch dimension" - "If set true, sharding_config will be ignored.", - ) +flags.DEFINE_string( + "tokenizer_path", + None, + "The tokenizer model path", + required=True, +) +flags.DEFINE_string("model_name", None, "model type", required=False) +flags.DEFINE_string( + "checkpoint_path", None, "Directory for .pth checkpoints", required=False +) +flags.DEFINE_bool( + "bf16_enable", True, "Whether to enable bf16", required=False +) +flags.DEFINE_integer( + "context_length", 1024, "The context length", required=False +) +flags.DEFINE_integer("batch_size", 32, "The batch size", required=False) +flags.DEFINE_string("size", "tiny", "size of model") +flags.DEFINE_bool("quantize_weights", False, "weight quantization") +flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize") +flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize") +flags.DEFINE_string("sharding_config", "", "config file for sharding") +flags.DEFINE_bool( + "shard_on_batch", + False, + "whether to shard on batch dimension" + "If set true, sharding_config will be ignored.", +) +flags.DEFINE_string( + "profiling_output", + "", + "The profiling output", + required=False, +) def define_profiling_flags(): """Add profiling related config flags to global FLAG.""" - flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, - ) + def create_engine_from_config_flags(): diff --git a/run_server.py b/run_server.py index 32bd91fd..ed55ecf6 100644 --- a/run_server.py +++ b/run_server.py @@ -16,6 +16,7 @@ import os from typing import Sequence +import jax import jetstream_pt from absl import app, flags from jetstream.core import server_lib @@ -23,11 +24,9 @@ from jetstream_pt.config import ( FLAGS, create_engine_from_config_flags, - define_common_flags, define_profiling_flags, ) -define_common_flags() define_profiling_flags() flags.DEFINE_integer("port", 9000, "port to listen on") @@ -37,13 +36,6 @@ "InterleavedCPUTestServer", "available servers", ) -flags.DEFINE_string( - "platform", - "tpu=4", - "The platform that the engine runs on", - required=False, -) - # pylint: disable-next=all def main(argv: Sequence[str]): @@ -52,23 +44,9 @@ def main(argv: Sequence[str]): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() print(f"devices: {devices}") - engine = jetstream_pt.create_pytorch_engine( - model_name=FLAGS.model_name, - devices=devices, - tokenizer_path=FLAGS.tokenizer_path, - ckpt_path=FLAGS.checkpoint_path, - bf16_enable=FLAGS.bf16_enable, - param_size=FLAGS.size, - context_length=FLAGS.context_length, - batch_size=FLAGS.batch_size, - quantize_weights=FLAGS.quantize_weights, - quantize_kv=FLAGS.quantize_kv_cache, - max_cache_length=FLAGS.max_cache_length, - sharding_config=FLAGS.sharding_config, - shard_on_batch=FLAGS.shard_on_batch, - ) + engine = create_engine_from_config_flags() server_config = ServerConfig( - interleaved_slices=(FLAGS.platform,), + interleaved_slices=(f"tpu={len(jax.devices())}",), interleaved_engine_create_fns=(lambda a: engine,), ) print(f"server_config: {server_config}") From 7e8a9c3f547394bcf6ee4dd6861c66da0cbae67e Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 18 May 2024 00:16:06 +0000 Subject: [PATCH 5/8] format --- benchmarks/prefill_offline.py | 2 -- jetstream_pt/config.py | 5 +---- run_server.py | 1 + 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/benchmarks/prefill_offline.py b/benchmarks/prefill_offline.py index 0a93147d..b922bd98 100644 --- a/benchmarks/prefill_offline.py +++ b/benchmarks/prefill_offline.py @@ -25,11 +25,9 @@ from jetstream_pt.config import ( FLAGS, create_engine_from_config_flags, - define_common_flags, define_profiling_flags, ) -define_common_flags() define_profiling_flags() diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index ab1ec79f..84ef036d 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -31,9 +31,7 @@ flags.DEFINE_string( "checkpoint_path", None, "Directory for .pth checkpoints", required=False ) -flags.DEFINE_bool( - "bf16_enable", True, "Whether to enable bf16", required=False -) +flags.DEFINE_bool("bf16_enable", True, "Whether to enable bf16", required=False) flags.DEFINE_integer( "context_length", 1024, "The context length", required=False ) @@ -59,7 +57,6 @@ def define_profiling_flags(): """Add profiling related config flags to global FLAG.""" - def create_engine_from_config_flags(): diff --git a/run_server.py b/run_server.py index ed55ecf6..4a83b1af 100644 --- a/run_server.py +++ b/run_server.py @@ -37,6 +37,7 @@ "available servers", ) + # pylint: disable-next=all def main(argv: Sequence[str]): del argv From 67af707e5c772f533131e163e23f94286c98d457 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 18 May 2024 00:18:43 +0000 Subject: [PATCH 6/8] update --- run_interactive_multiple_host.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index f331a5c6..9146df9d 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -25,11 +25,9 @@ from jetstream_pt.config import ( FLAGS, create_engine_from_config_flags, - define_common_flags, define_profiling_flags, ) -define_common_flags() define_profiling_flags() From 3e504e2a8aa34a29794976e136d8ff77af745761 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 18 May 2024 00:19:31 +0000 Subject: [PATCH 7/8] udpate readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9d3e21bf..27b17ba6 100644 --- a/README.md +++ b/README.md @@ -104,10 +104,10 @@ python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --m # Run the server -Here is an example to run the server with llama2 7B config. Note that the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`). +Here is an example to run the server with llama2 7B config. ```bash -python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --sharding_config="default_shardings/llama.yaml" +python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` Now you can fire gRPC to it. From 55c6c7fefdd8614c09be931af1ee201bfe39dfed Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 18 May 2024 00:20:10 +0000 Subject: [PATCH 8/8] update run_interactive --- run_interactive.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/run_interactive.py b/run_interactive.py index 24206620..74c2f67f 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -26,11 +26,9 @@ from jetstream_pt.config import ( FLAGS, create_engine_from_config_flags, - define_common_flags, define_profiling_flags, ) -define_common_flags() define_profiling_flags()