From 5be8e7c40d461a5411a63f38e9ad549d80d3ab1e Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Mon, 3 Jun 2024 17:37:49 +0000 Subject: [PATCH 1/2] add jax profiler server --- deps/JetStream | 2 +- jetstream_pt/ray_engine.py | 4 ++++ jetstream_pt/ray_worker.py | 6 ++++++ run_server_with_ray.py | 5 +++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/deps/JetStream b/deps/JetStream index ec26ec24..e19a7906 160000 --- a/deps/JetStream +++ b/deps/JetStream @@ -1 +1 @@ -Subproject commit ec26ec2427fad737f898bdec9a186f2acd49d6f1 +Subproject commit e19a7906d8cdf1cae658a4c7c4f6f516aade49f9 diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 2d65ba15..13d11edc 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -178,6 +178,8 @@ def create_pytorch_ray_engine( is_disaggregated: bool = False, num_hosts: int = 0, decode_pod_slice_name: str = None, + enable_jax_profiler: bool = False, + jax_profiler_port: int = 9999, ) -> Any: # Return tuple as reponse: issues/107 @@ -218,6 +220,8 @@ def create_pytorch_ray_engine( quantize_kv=quantize_kv, max_cache_length=max_cache_length, sharding_config=sharding_config, + enable_jax_profiler=enable_jax_profiler, + jax_profiler_port=jax_profiler_port, ) engine_workers.append(engine_worker) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index b386bb35..7f31d676 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -114,6 +114,8 @@ def __init__( quantize_kv=False, max_cache_length=1024, sharding_config=None, + enable_jax_profiler: bool = False, + jax_profiler_port: int = 9999, ): jax.config.update("jax_default_prng_impl", "unsafe_rbg") @@ -130,6 +132,10 @@ def __init__( f"---Jax device_count:{device_count}, local_device_count{local_device_count} " ) + if enable_jax_profiler: + jax.profiler.start_server(jax_profiler_port) + print(f"Started JAX profiler server on port {jax_profiler_port}") + checkpoint_format = "" checkpoint_path = "" diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 5ec99f75..325bc108 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -34,6 +34,9 @@ flags.DEFINE_integer("prometheus_port", 0, "") flags.DEFINE_integer("tpu_chips", 16, "device tpu_chips") +flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler") +flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server") + def create_engine(): """create a pytorch engine""" @@ -53,6 +56,8 @@ def create_engine(): quantize_kv=FLAGS.quantize_kv_cache, max_cache_length=FLAGS.max_cache_length, sharding_config=FLAGS.sharding_config, + enable_jax_profiler=FLAGS.enable_jax_profiler, + jax_profiler_port=FLAGS.jax_profiler_port, ) print("Initialize engine", time.perf_counter() - start) From b93717371dbd408049a03b890e55bd2b8b31270a Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Mon, 3 Jun 2024 20:23:00 +0000 Subject: [PATCH 2/2] update jetstream --- deps/JetStream | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/JetStream b/deps/JetStream index e19a7906..ec26ec24 160000 --- a/deps/JetStream +++ b/deps/JetStream @@ -1 +1 @@ -Subproject commit e19a7906d8cdf1cae658a4c7c4f6f516aade49f9 +Subproject commit ec26ec2427fad737f898bdec9a186f2acd49d6f1