Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def create_pytorch_ray_engine(
is_disaggregated: bool = False,
num_hosts: int = 0,
decode_pod_slice_name: str = None,
enable_jax_profiler: bool = False,
jax_profiler_port: int = 9999,
) -> Any:

# Return tuple as reponse: issues/107
Expand Down Expand Up @@ -218,6 +220,8 @@ def create_pytorch_ray_engine(
quantize_kv=quantize_kv,
max_cache_length=max_cache_length,
sharding_config=sharding_config,
enable_jax_profiler=enable_jax_profiler,
jax_profiler_port=jax_profiler_port,
)
engine_workers.append(engine_worker)

Expand Down
6 changes: 6 additions & 0 deletions jetstream_pt/ray_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
enable_jax_profiler: bool = False,
jax_profiler_port: int = 9999,
):

jax.config.update("jax_default_prng_impl", "unsafe_rbg")
Expand All @@ -130,6 +132,10 @@ def __init__(
f"---Jax device_count:{device_count}, local_device_count{local_device_count} "
)

if enable_jax_profiler:
jax.profiler.start_server(jax_profiler_port)
print(f"Started JAX profiler server on port {jax_profiler_port}")

checkpoint_format = ""
checkpoint_path = ""

Expand Down
5 changes: 5 additions & 0 deletions run_server_with_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
flags.DEFINE_integer("prometheus_port", 0, "")
flags.DEFINE_integer("tpu_chips", 16, "device tpu_chips")

flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")


def create_engine():
"""create a pytorch engine"""
Expand All @@ -53,6 +56,8 @@ def create_engine():
quantize_kv=FLAGS.quantize_kv_cache,
max_cache_length=FLAGS.max_cache_length,
sharding_config=FLAGS.sharding_config,
enable_jax_profiler=FLAGS.enable_jax_profiler,
jax_profiler_port=FLAGS.jax_profiler_port,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down