diff --git a/.pylintrc b/.pylintrc index a03b49b3..66a6589e 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,2 +1,2 @@ [MESSAGES CONTROL] -disable=C0114,R0801,E1102,W0613,R1711 \ No newline at end of file +disable=C0114,R0801,E1102,W0613,R1711,too-many-locals diff --git a/README.md b/README.md index be6c8d65..27b17ba6 100644 --- a/README.md +++ b/README.md @@ -104,10 +104,10 @@ python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --m # Run the server -Here is an example to run the server with llama2 7B config. Note that the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`). +Here is an example to run the server with llama2 7B config. ```bash -python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name --sharding_config="default_shardings/llama.yaml" +python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` Now you can fire gRPC to it. diff --git a/benchmarks/prefill_offline.py b/benchmarks/prefill_offline.py index 03bf4180..b922bd98 100644 --- a/benchmarks/prefill_offline.py +++ b/benchmarks/prefill_offline.py @@ -12,57 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import os import time -import functools -import humanize -# pylint: disable-next=all -from absl import app -from absl import flags -import numpy as np +import humanize import jax - +import numpy as np +# pylint: disable-next=all +from absl import app, flags from jetstream_pt import engine as je - -FLAGS = flags.FLAGS - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", False, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False -) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_profiling_flags, ) -_SIZE = flags.DEFINE_string("size", "tiny", "size of model") - -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) +define_profiling_flags() def create_engine(): @@ -73,16 +39,19 @@ def create_engine(): devices = jax.devices() start = time.perf_counter() engine = je.create_pytorch_engine( + model_name=FLAGS.model_name, devices=devices, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, ) print("Initialize engine", time.perf_counter() - start) @@ -163,14 +132,16 @@ def prefill_benchmark(tokens_list, engine, params, warmup): # pylint: disable-next=all def main(argv): - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() params = engine.load_params() print("Load params ", time.perf_counter() - start) - if _PROFILING_OUTPUT.value: - jax.profiler.start_trace(_PROFILING_OUTPUT.value) + profiling_output = FLAGS.profiling_output + if profiling_output: + jax.profiler.start_trace(profiling_output) + print_mem_usage() tokens_list = create_prefill_tokens() for _ in range(3): @@ -189,7 +160,7 @@ def main(argv): tokens_list=tokens_list, engine=engine, params=params, warmup=False ) - if _PROFILING_OUTPUT.value: + if profiling_output: jax.profiler.stop_trace() diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index df788591..39e1584a 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -15,66 +15,24 @@ import logging import os import time -# pylint: disable-next=all -from absl import app -from absl import flags import jax import jax.numpy as jnp - -from jetstream_pt import engine as je +# pylint: disable-next=all +from absl import app, flags # pylint: disable-next=all from benchmarks import analyze_sharegpt - - -logging.getLogger().setLevel(logging.ERROR) - - -FLAGS = flags.FLAGS - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", False, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False -) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, +from jetstream_pt import engine as je +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_profiling_flags, ) -_SIZE = flags.DEFINE_string("size", "tiny", "size of model") +logging.getLogger().setLevel(logging.ERROR) -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) -_MODEL_NAME = flags.DEFINE_string("model_name", "", "model_name") -_SHARDING_CONFIG = flags.DEFINE_string( - "sharding_config", "", "path to sharding config" -) -_SHAREGPT_PATH = flags.DEFINE_string( - "sharegpt_path", "", "path to sharegpt json file" -) +define_profiling_flags() +flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") def create_engine(): @@ -85,18 +43,19 @@ def create_engine(): devices = jax.devices() start = time.perf_counter() engine = je.create_pytorch_engine( + model_name=FLAGS.model_name, devices=devices, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - model_name=_MODEL_NAME.value, - sharding_config=_SHARDING_CONFIG.value, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, ) print("Initialize engine", time.perf_counter() - start) @@ -148,7 +107,7 @@ def run_prefill_time(engine, params, decode_state, seqlen): def main(argv): """Main function to run engine offline.""" - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() params = engine.load_params() @@ -156,8 +115,9 @@ def main(argv): prefill_times = {} - if _PROFILING_OUTPUT.value: - jax.profiler.start_trace(_PROFILING_OUTPUT.value) + profiling_output = FLAGS.profiling_output + if profiling_output: + jax.profiler.start_trace(profiling_output) decode_state = engine.init_decode_state() for batch, _ in MAXTEXT_PREFILL.items(): runtime, decode_state = run_prefill_time( @@ -186,18 +146,19 @@ def main(argv): dec_times.append(end - start) print(i, "decode time", (end - start)) - if _PROFILING_OUTPUT.value: + if profiling_output: jax.profiler.stop_trace() print("prefill ", prefill_times) print("decode", sum(dec_times) / 10) prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()} - decode_time_ms = sum(dec_times) * 1000 / 10 / _BATCH_SIZE.value + decode_time_ms = sum(dec_times) * 1000 / 10 / FLAGS.batch_size - if _SHAREGPT_PATH.value: + sharegpt_path = FLAGS.sharegpt_path + if sharegpt_path: analyze_sharegpt.do_simulation( - _SHAREGPT_PATH.value, prefill_times_ms, decode_time_ms + sharegpt_path, prefill_times_ms, decode_time_ms ) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 04638198..84ef036d 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -12,35 +12,75 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jetstream.core.config_lib import ServerConfig +import os +import time + +import jax +from absl import flags from jetstream_pt.engine import create_pytorch_engine +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "tokenizer_path", + None, + "The tokenizer model path", + required=True, +) +flags.DEFINE_string("model_name", None, "model type", required=False) +flags.DEFINE_string( + "checkpoint_path", None, "Directory for .pth checkpoints", required=False +) +flags.DEFINE_bool("bf16_enable", True, "Whether to enable bf16", required=False) +flags.DEFINE_integer( + "context_length", 1024, "The context length", required=False +) +flags.DEFINE_integer("batch_size", 32, "The batch size", required=False) +flags.DEFINE_string("size", "tiny", "size of model") +flags.DEFINE_bool("quantize_weights", False, "weight quantization") +flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize") +flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize") +flags.DEFINE_string("sharding_config", "", "config file for sharding") +flags.DEFINE_bool( + "shard_on_batch", + False, + "whether to shard on batch dimension" + "If set true, sharding_config will be ignored.", +) +flags.DEFINE_string( + "profiling_output", + "", + "The profiling output", + required=False, +) + + +def define_profiling_flags(): + """Add profiling related config flags to global FLAG.""" -# pylint: disable-next=all -def create_config( - devices, - tokenizer_path, - ckpt_path, - bf16_enable, - param_size, - context_length, - batch_size, - platform, -): - """Create a server config""" - - def func(): - return create_pytorch_engine( - devices=devices, - tokenizer_path=tokenizer_path, - ckpt_path=ckpt_path, - bf16_enable=bf16_enable, - param_size=param_size, - context_length=context_length, - batch_size=batch_size, - ) - - return ServerConfig( - interleaved_slices=(platform,), - interleaved_engine_create_fns=(func,), + +def create_engine_from_config_flags(): + """create a pytorch engine from config flag""" + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + devices = jax.devices() + start = time.perf_counter() + engine = create_pytorch_engine( + model_name=FLAGS.model_name, + devices=devices, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, ) + + print("Initialize engine", time.perf_counter() - start) + return engine diff --git a/run_interactive.py b/run_interactive.py index e6be9548..74c2f67f 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -15,105 +15,27 @@ import os import random import time - from typing import List -from absl import app -from absl import flags -from colorama import Fore, Style import jax - -from jetstream.engine import token_utils -from colorama import Fore, Style import numpy as np - -import os - +from absl import app, flags +from colorama import Fore, Style +from jetstream.engine import token_utils from jetstream_pt import engine as je - -FLAGS = flags.FLAGS - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_MODEL_NAME = flags.DEFINE_string( - "model_name", None, "model type", required=False -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", False, "Whether to enable bf16", required=False +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_profiling_flags, ) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False -) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, -) - -_SIZE = flags.DEFINE_string("size", "tiny", "size of model") - -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) -_SHARDING_CONFIG = flags.DEFINE_string( - "sharding_config", "", "config file for sharding" -) -_SHARD_ON_BATCH = flags.DEFINE_bool( - "shard_on_batch", - False, - "whether to shard on batch dimension." - "If set true, sharding_config will be ignored.", -) - -def create_engine(): - """create a pytorch engine""" - jax.config.update("jax_default_prng_impl", "unsafe_rbg") - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - - devices = jax.devices() - start = time.perf_counter() - engine = je.create_pytorch_engine( - model_name=_MODEL_NAME.value, - devices=devices, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - sharding_config=_SHARDING_CONFIG.value, - shard_on_batch=_SHARD_ON_BATCH.value, - ) - - print("Initialize engine", time.perf_counter() - start) - return engine +define_profiling_flags() # pylint: disable-next=all def main(argv): - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() params = engine.load_params() @@ -123,8 +45,9 @@ def main(argv): tokenizer = engine.build_tokenizer(metadata) max_output_length = 1024 - if _PROFILING_OUTPUT.value: - jax.profiler.start_trace(_PROFILING_OUTPUT.value) + profiling_output = FLAGS.profiling_output + if profiling_output: + jax.profiler.start_trace(profiling_output) decode_state = engine.init_decode_state() prompts: List[str] = [ @@ -139,7 +62,7 @@ def main(argv): "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: - slot = random.randint(0, _BATCH_SIZE.value - 1) + slot = random.randint(0, FLAGS.batch_size - 1) tokens, true_length = tokenizer.encode(prompt) print(f"---- Input prompts are: {prompt}") @@ -174,7 +97,7 @@ def main(argv): print("---- All output text.") print(tokenizer.decode(sampled_tokens_list)) - if _PROFILING_OUTPUT.value: + if profiling_output: jax.profiler.stop_trace() diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 9de0c492..9146df9d 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -15,63 +15,20 @@ import os import random import time - from typing import List -from absl import app -from absl import flags -from colorama import Fore, Style import jax - +from absl import app, flags +from colorama import Fore, Style from jetstream.engine import token_utils from jetstream_pt import ray_engine - -FLAGS = flags.FLAGS - -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", False, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_profiling_flags, ) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, -) - -_SIZE = flags.DEFINE_string("size", "tiny", "size of model") -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) - -_MODEL_NAME = flags.DEFINE_string( - "model_name", None, "model type", required=False -) - -_SHARDING_CONFIG = flags.DEFINE_string( - "sharding_config", "", "config file for sharding" -) +define_profiling_flags() def create_engine(): @@ -81,17 +38,18 @@ def create_engine(): start = time.perf_counter() engine = ray_engine.create_pytorch_ray_engine( - model_name=_MODEL_NAME.value, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - sharding_config=_SHARDING_CONFIG.value, + model_name=FLAGS.model_name, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + shard_on_batch=FLAGS.shard_on_batch, ) print("Initialize engine", time.perf_counter() - start) @@ -101,7 +59,7 @@ def create_engine(): # pylint: disable-next=all def main(argv): - engine = create_engine() + engine = create_engine_from_config_flags() start = time.perf_counter() engine.load_params() @@ -112,8 +70,9 @@ def main(argv): stop_tokens = [vocab.eos_id, vocab.pad_id] max_output_length = 1024 - if _PROFILING_OUTPUT.value: - jax.profiler.start_trace(_PROFILING_OUTPUT.value) + profiling_output = FLAGS.profiling_output + if profiling_output: + jax.profiler.start_trace(profiling_output) engine.init_decode_state() prompts: List[str] = [ @@ -128,7 +87,7 @@ def main(argv): "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: - slot = random.randint(0, _BATCH_SIZE.value - 1) + slot = random.randint(0, FLAGS.batch_size - 1) tokens, true_length = token_utils.tokenize_and_pad( prompt, vocab, is_bos=True, jax_padding=False ) @@ -161,7 +120,7 @@ def main(argv): print("---- All output text.") print(vocab.tokenizer.decode(sampled_tokens_list)) - if _PROFILING_OUTPUT.value: + if profiling_output: jax.profiler.stop_trace() diff --git a/run_server.py b/run_server.py index 1194da5c..4a83b1af 100644 --- a/run_server.py +++ b/run_server.py @@ -16,86 +16,27 @@ import os from typing import Sequence -from absl import app -from absl import flags - +import jax +import jetstream_pt +from absl import app, flags from jetstream.core import server_lib from jetstream.core.config_lib import ServerConfig +from jetstream_pt.config import ( + FLAGS, + create_engine_from_config_flags, + define_profiling_flags, +) -import jetstream_pt - +define_profiling_flags() -_PORT = flags.DEFINE_integer("port", 9000, "port to listen on") -_THREADS = flags.DEFINE_integer( - "threads", 64, "number of worker threads in thread pool" -) -_CONFIG = flags.DEFINE_string( +flags.DEFINE_integer("port", 9000, "port to listen on") +flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool") +flags.DEFINE_string( "config", "InterleavedCPUTestServer", "available servers", ) -_TOKENIZER_PATH = flags.DEFINE_string( - "tokenizer_path", - "tokenizer.model", - "The tokenizer model path", - required=False, -) -_CKPT_PATH = flags.DEFINE_string( - "checkpoint_path", None, "Directory for .pth checkpoints", required=False -) -_BF16_ENABLE = flags.DEFINE_bool( - "bf16_enable", True, "Whether to enable bf16", required=False -) -_CONTEXT_LENGTH = flags.DEFINE_integer( - "context_length", 1024, "The context length", required=False -) -_BATCH_SIZE = flags.DEFINE_integer( - "batch_size", 32, "The batch size", required=False -) -_PROFILING_OUTPUT = flags.DEFINE_string( - "profiling_output", - "", - "The profiling output", - required=False, -) -_PLATFORM = flags.DEFINE_string( - "platform", - "tpu=4", - "The platform that the engine runs on", - required=False, -) -_PARAM_SIZE = flags.DEFINE_string( - "param_size", - "7b", - "The model size the server runs on.", - required=False, -) -_MODEL_NAME = flags.DEFINE_string( - "model", - "llama-2", - "name of the model. Supported options are llama-2 and llama-3", -) - -_QUANTIZE_WEIGHTS = flags.DEFINE_bool( - "quantize_weights", False, "weight quantization" -) -_QUANTIZE_KV_CACHE = flags.DEFINE_bool( - "quantize_kv_cache", False, "kv_cache_quantize" -) -_MAX_CACHE_LENGTH = flags.DEFINE_integer( - "max_cache_length", 1024, "kv_cache_quantize" -) -_SHARDING_CONFIG = flags.DEFINE_string( - "sharding_config", "", "config file for sharding" -) -_SHARD_ON_BATCH = flags.DEFINE_bool( - "shard_on_batch", - False, - "whether to shard on batch dimension" - "If set true, sharding_config will be ignored.", -) - # pylint: disable-next=all def main(argv: Sequence[str]): @@ -104,24 +45,9 @@ def main(argv: Sequence[str]): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() print(f"devices: {devices}") - sharding_config_path = _SHARDING_CONFIG.value - engine = jetstream_pt.create_pytorch_engine( - devices=devices, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=_BF16_ENABLE.value, - param_size=_PARAM_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - model_name=_MODEL_NAME.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - sharding_config=sharding_config_path, - shard_on_batch=_SHARD_ON_BATCH.value, - ) + engine = create_engine_from_config_flags() server_config = ServerConfig( - interleaved_slices=(_PLATFORM.value,), + interleaved_slices=(f"tpu={len(jax.devices())}",), interleaved_engine_create_fns=(lambda a: engine,), ) print(f"server_config: {server_config}") @@ -129,8 +55,8 @@ def main(argv: Sequence[str]): # We separate credential from run so that we can unit test it with local credentials. # We would like to add grpc credentials for OSS. jetstream_server = server_lib.run( - threads=_THREADS.value, - port=_PORT.value, + threads=FLAGS.threads, + port=FLAGS.port, config=server_config, devices=devices, )