From 4fb75ec733748ccb26fb7504266618004659ac06 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 30 Jul 2024 02:25:00 +0000 Subject: [PATCH 1/5] add v5e-8 ray support --- install_everything.sh | 2 +- jetstream_pt/ray_engine.py | 5 +++-- run_interactive_multiple_host.py | 9 +++++++++ run_server_with_ray.py | 6 +++++- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/install_everything.sh b/install_everything.sh index 2fcf7576..f404ef6b 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -29,7 +29,7 @@ pip install tensorflow-text pip install tensorflow pip install huggingface_hub -pip install ray[default]==2.22.0 +pip install ray[default]==2.33.0 # torch cpu pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 8a091b3f..fd2dc4b8 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -215,6 +215,7 @@ def create_pytorch_ray_engine( sharding_config=None, is_disaggregated: bool = False, num_hosts: int = 0, + worker_chips: int = 0, decode_pod_slice_name: str = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, @@ -231,7 +232,7 @@ def create_pytorch_ray_engine( ray.init(ignore_reinit_error=True) pod_name = tpu.get_current_pod_name() num_hosts = ( - num_hosts if is_disaggregated else tpu.get_current_pod_worker_count() + num_hosts if num_hosts > 0 else tpu.get_current_pod_worker_count() ) print(f"pod_name:{pod_name}, number of host: {num_hosts}") assert ( @@ -242,7 +243,7 @@ def create_pytorch_ray_engine( ), f"num_hosts (current value {num_hosts}) should be a positive number" # pylint: disable-next=all engine_worker_with_tpu_resource = PyTorchRayWorker.options( - resources={"TPU": 4}, + resources={"TPU": worker_chips if worker_chips > 0 else 4}, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "tpu,cpu"}), ) engine_workers = [] diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 9cb364f7..a3b19e31 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -24,6 +24,13 @@ from jetstream_pt import ray_engine from jetstream_pt.config import FLAGS +_NUM_HOSTS = flags.DEFINE_integer( + "num_hosts", 0, "Number of TPU host", required=False +) + +_WORKER_CHIPS = flags.DEFINE_integer( + "worker_chips", 4, "Number of TPU chips per worker", required=False +) def create_engine(): """create a pytorch engine""" @@ -43,6 +50,8 @@ def create_engine(): quantize_kv=FLAGS.quantize_kv_cache, max_cache_length=FLAGS.max_cache_length, sharding_config=FLAGS.sharding_config, + num_hosts=_NUM_HOSTS.value, + worker_chips=_WORKER_CHIPS.value, ) print("Initialize engine", time.perf_counter() - start) diff --git a/run_server_with_ray.py b/run_server_with_ray.py index de3bdf21..5f38a087 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -43,7 +43,9 @@ "is_disaggregated", False, "Disaggregated serving if it's True" ) -flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False) +flags.DEFINE_integer("num_hosts", 0, "Number of TPU host", required=False) + +flags.DEFINE_integer("worker_chips", 4, "Number of TPU chips per worker", required=False) flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name") @@ -68,6 +70,8 @@ def create_engine(): sharding_config=FLAGS.sharding_config, enable_jax_profiler=FLAGS.enable_jax_profiler, jax_profiler_port=FLAGS.jax_profiler_port, + num_hosts=FLAGS.num_hosts, + worker_chips=FLAGS.worker_chips, ) print("Initialize engine", time.perf_counter() - start) From ba56ff81d5fb5087b31f48bac2196c4a53eb3edf Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 30 Jul 2024 02:26:33 +0000 Subject: [PATCH 2/5] format fix --- jetstream_pt/ray_engine.py | 6 ++---- run_interactive_multiple_host.py | 1 + run_server_with_ray.py | 4 +++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index fd2dc4b8..55120468 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -215,7 +215,7 @@ def create_pytorch_ray_engine( sharding_config=None, is_disaggregated: bool = False, num_hosts: int = 0, - worker_chips: int = 0, + worker_chips: int = 0, decode_pod_slice_name: str = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, @@ -231,9 +231,7 @@ def create_pytorch_ray_engine( ) ray.init(ignore_reinit_error=True) pod_name = tpu.get_current_pod_name() - num_hosts = ( - num_hosts if num_hosts > 0 else tpu.get_current_pod_worker_count() - ) + num_hosts = num_hosts if num_hosts > 0 else tpu.get_current_pod_worker_count() print(f"pod_name:{pod_name}, number of host: {num_hosts}") assert ( pod_name is not None diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index a3b19e31..9d4c9930 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -32,6 +32,7 @@ "worker_chips", 4, "Number of TPU chips per worker", required=False ) + def create_engine(): """create a pytorch engine""" jax.config.update("jax_default_prng_impl", "unsafe_rbg") diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 5f38a087..278e9647 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -45,7 +45,9 @@ flags.DEFINE_integer("num_hosts", 0, "Number of TPU host", required=False) -flags.DEFINE_integer("worker_chips", 4, "Number of TPU chips per worker", required=False) +flags.DEFINE_integer( + "worker_chips", 4, "Number of TPU chips per worker", required=False +) flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name") From 335820442a69a3c006e6b48f4c2d982604633046 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 30 Jul 2024 02:30:35 +0000 Subject: [PATCH 3/5] add readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e2feb9f9..ce95f676 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ Note: Get address ip and port information from ray head. Here is an example to run the server with ray for llama2 7B model: ```bash -python run_server_with_ray.py --tpu_chips=16 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" +python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips= 4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` # Run benchmark From 36a23cd41482876feb2126c2e9d7a31daed19bd8 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 30 Jul 2024 02:32:32 +0000 Subject: [PATCH 4/5] add readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ce95f676..5292e0c7 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ Note: Get address ip and port information from ray head. Here is an example to run the server with ray for llama2 7B model: ```bash -python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips= 4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" +python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` # Run benchmark From 36bcfaa2be9fa5b74a50faec758f41b88cb88fb1 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 30 Jul 2024 16:07:47 +0000 Subject: [PATCH 5/5] add chips assert --- jetstream_pt/ray_engine.py | 8 +++++++- run_interactive_multiple_host.py | 5 +++++ run_server_with_ray.py | 3 ++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 55120468..dbc1d489 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -216,6 +216,7 @@ def create_pytorch_ray_engine( is_disaggregated: bool = False, num_hosts: int = 0, worker_chips: int = 0, + tpu_chips: int = 0, decode_pod_slice_name: str = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, @@ -232,6 +233,7 @@ def create_pytorch_ray_engine( ray.init(ignore_reinit_error=True) pod_name = tpu.get_current_pod_name() num_hosts = num_hosts if num_hosts > 0 else tpu.get_current_pod_worker_count() + worker_chips = worker_chips if worker_chips > 0 else 4 print(f"pod_name:{pod_name}, number of host: {num_hosts}") assert ( pod_name is not None @@ -239,9 +241,13 @@ def create_pytorch_ray_engine( assert ( num_hosts > 0 ), f"num_hosts (current value {num_hosts}) should be a positive number" + assert ( + num_hosts * worker_chips == tpu_chips + ), f"num_hosts:{num_hosts} * worker_chips: {worker_chips} not equal to tpu_chips: {tpu_chips}" + # pylint: disable-next=all engine_worker_with_tpu_resource = PyTorchRayWorker.options( - resources={"TPU": worker_chips if worker_chips > 0 else 4}, + resources={"TPU": worker_chips}, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "tpu,cpu"}), ) engine_workers = [] diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 9d4c9930..24b27987 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -32,6 +32,10 @@ "worker_chips", 4, "Number of TPU chips per worker", required=False ) +_TPU_CHIPS = flags.DEFINE_integer( + "tpu_chips", 4, "All devices TPU chips", required=False +) + def create_engine(): """create a pytorch engine""" @@ -53,6 +57,7 @@ def create_engine(): sharding_config=FLAGS.sharding_config, num_hosts=_NUM_HOSTS.value, worker_chips=_WORKER_CHIPS.value, + tpu_chips=_TPU_CHIPS, ) print("Initialize engine", time.perf_counter() - start) diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 278e9647..03489e1a 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -34,7 +34,7 @@ "available servers", ) flags.DEFINE_integer("prometheus_port", 0, "") -flags.DEFINE_integer("tpu_chips", 16, "device tpu_chips") +flags.DEFINE_integer("tpu_chips", 16, "all devices tpu_chips") flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler") flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server") @@ -74,6 +74,7 @@ def create_engine(): jax_profiler_port=FLAGS.jax_profiler_port, num_hosts=FLAGS.num_hosts, worker_chips=FLAGS.worker_chips, + tpu_chips=FLAGS.tpu_chips, ) print("Initialize engine", time.perf_counter() - start)