From 4197ae62ef60be4252094b1e1c77266f89abb868 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 24 May 2024 15:31:22 -0700 Subject: [PATCH] Update run_interactive.py with finer control of profiler. --- run_interactive.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/run_interactive.py b/run_interactive.py index 77b3a702..ccddc9c3 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -40,7 +40,8 @@ def main(argv): max_output_length = 1024 profiling_output = FLAGS.profiling_output - if profiling_output: + profiling_prefill = FLAGS.profiling_prefill + if profiling_output and profiling_prefill: jax.profiler.start_trace(profiling_output) decode_state = engine.init_decode_state() @@ -68,7 +69,11 @@ def main(argv): print(f"---- Streaming decode started on #slot{slot}.") complete = np.zeros((1,), dtype=np.bool_) while True: + if profiling_output and not profiling_prefill: + jax.profiler.start_trace(profiling_output) decode_state, result_tokens = engine.generate(params, decode_state) + if profiling_output and not profiling_prefill: + jax.profiler.stop_trace() result_tokens = result_tokens.convert_to_numpy() output, complete = token_utils.process_result_tokens( tokenizer=tokenizer, @@ -87,7 +92,7 @@ def main(argv): print("---- All output text.") print(tokenizer.decode(sampled_tokens_list)) - if profiling_output: + if profiling_output and profiling_prefill: jax.profiler.stop_trace()