File tree Expand file tree Collapse file tree 3 files changed +15
-0
lines changed Expand file tree Collapse file tree 3 files changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -178,6 +178,8 @@ def create_pytorch_ray_engine(
178178 is_disaggregated : bool = False ,
179179 num_hosts : int = 0 ,
180180 decode_pod_slice_name : str = None ,
181+ enable_jax_profiler : bool = False ,
182+ jax_profiler_port : int = 9999 ,
181183) -> Any :
182184
183185 # Return tuple as reponse: issues/107
@@ -218,6 +220,8 @@ def create_pytorch_ray_engine(
218220 quantize_kv = quantize_kv ,
219221 max_cache_length = max_cache_length ,
220222 sharding_config = sharding_config ,
223+ enable_jax_profiler = enable_jax_profiler ,
224+ jax_profiler_port = jax_profiler_port ,
221225 )
222226 engine_workers .append (engine_worker )
223227
Original file line number Diff line number Diff line change @@ -114,6 +114,8 @@ def __init__(
114114 quantize_kv = False ,
115115 max_cache_length = 1024 ,
116116 sharding_config = None ,
117+ enable_jax_profiler : bool = False ,
118+ jax_profiler_port : int = 9999 ,
117119 ):
118120
119121 jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
@@ -130,6 +132,10 @@ def __init__(
130132 f"---Jax device_count:{ device_count } , local_device_count{ local_device_count } "
131133 )
132134
135+ if enable_jax_profiler :
136+ jax .profiler .start_server (jax_profiler_port )
137+ print (f"Started JAX profiler server on port { jax_profiler_port } " )
138+
133139 checkpoint_format = ""
134140 checkpoint_path = ""
135141
Original file line number Diff line number Diff line change 3434flags .DEFINE_integer ("prometheus_port" , 0 , "" )
3535flags .DEFINE_integer ("tpu_chips" , 16 , "device tpu_chips" )
3636
37+ flags .DEFINE_bool ("enable_jax_profiler" , False , "enable jax profiler" )
38+ flags .DEFINE_integer ("jax_profiler_port" , 9999 , "port of JAX profiler server" )
39+
3740
3841def create_engine ():
3942 """create a pytorch engine"""
@@ -53,6 +56,8 @@ def create_engine():
5356 quantize_kv = FLAGS .quantize_kv_cache ,
5457 max_cache_length = FLAGS .max_cache_length ,
5558 sharding_config = FLAGS .sharding_config ,
59+ enable_jax_profiler = FLAGS .enable_jax_profiler ,
60+ jax_profiler_port = FLAGS .jax_profiler_port ,
5661 )
5762
5863 print ("Initialize engine" , time .perf_counter () - start )
You can’t perform that action at this time.
0 commit comments