diff --git a/README.md b/README.md index e2feb9f..5292e0c 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ Note: Get address ip and port information from ray head. Here is an example to run the server with ray for llama2 7B model: ```bash -python run_server_with_ray.py --tpu_chips=16 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" +python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` # Run benchmark diff --git a/install_everything.sh b/install_everything.sh index 2fcf757..f404ef6 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -29,7 +29,7 @@ pip install tensorflow-text pip install tensorflow pip install huggingface_hub -pip install ray[default]==2.22.0 +pip install ray[default]==2.33.0 # torch cpu pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 8a091b3..dbc1d48 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -215,6 +215,8 @@ def create_pytorch_ray_engine( sharding_config=None, is_disaggregated: bool = False, num_hosts: int = 0, + worker_chips: int = 0, + tpu_chips: int = 0, decode_pod_slice_name: str = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, @@ -230,9 +232,8 @@ def create_pytorch_ray_engine( ) ray.init(ignore_reinit_error=True) pod_name = tpu.get_current_pod_name() - num_hosts = ( - num_hosts if is_disaggregated else tpu.get_current_pod_worker_count() - ) + num_hosts = num_hosts if num_hosts > 0 else tpu.get_current_pod_worker_count() + worker_chips = worker_chips if worker_chips > 0 else 4 print(f"pod_name:{pod_name}, number of host: {num_hosts}") assert ( pod_name is not None @@ -240,9 +241,13 @@ def create_pytorch_ray_engine( assert ( num_hosts > 0 ), f"num_hosts (current value {num_hosts}) should be a positive number" + assert ( + num_hosts * worker_chips == tpu_chips + ), f"num_hosts:{num_hosts} * worker_chips: {worker_chips} not equal to tpu_chips: {tpu_chips}" + # pylint: disable-next=all engine_worker_with_tpu_resource = PyTorchRayWorker.options( - resources={"TPU": 4}, + resources={"TPU": worker_chips}, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "tpu,cpu"}), ) engine_workers = [] diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 9cb364f..24b2798 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -24,6 +24,18 @@ from jetstream_pt import ray_engine from jetstream_pt.config import FLAGS +_NUM_HOSTS = flags.DEFINE_integer( + "num_hosts", 0, "Number of TPU host", required=False +) + +_WORKER_CHIPS = flags.DEFINE_integer( + "worker_chips", 4, "Number of TPU chips per worker", required=False +) + +_TPU_CHIPS = flags.DEFINE_integer( + "tpu_chips", 4, "All devices TPU chips", required=False +) + def create_engine(): """create a pytorch engine""" @@ -43,6 +55,9 @@ def create_engine(): quantize_kv=FLAGS.quantize_kv_cache, max_cache_length=FLAGS.max_cache_length, sharding_config=FLAGS.sharding_config, + num_hosts=_NUM_HOSTS.value, + worker_chips=_WORKER_CHIPS.value, + tpu_chips=_TPU_CHIPS, ) print("Initialize engine", time.perf_counter() - start) diff --git a/run_server_with_ray.py b/run_server_with_ray.py index de3bdf2..03489e1 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -34,7 +34,7 @@ "available servers", ) flags.DEFINE_integer("prometheus_port", 0, "") -flags.DEFINE_integer("tpu_chips", 16, "device tpu_chips") +flags.DEFINE_integer("tpu_chips", 16, "all devices tpu_chips") flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler") flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server") @@ -43,7 +43,11 @@ "is_disaggregated", False, "Disaggregated serving if it's True" ) -flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False) +flags.DEFINE_integer("num_hosts", 0, "Number of TPU host", required=False) + +flags.DEFINE_integer( + "worker_chips", 4, "Number of TPU chips per worker", required=False +) flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name") @@ -68,6 +72,9 @@ def create_engine(): sharding_config=FLAGS.sharding_config, enable_jax_profiler=FLAGS.enable_jax_profiler, jax_profiler_port=FLAGS.jax_profiler_port, + num_hosts=FLAGS.num_hosts, + worker_chips=FLAGS.worker_chips, + tpu_chips=FLAGS.tpu_chips, ) print("Initialize engine", time.perf_counter() - start)