Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions run_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def main(argv):
max_output_length = 1024

profiling_output = FLAGS.profiling_output
if profiling_output:
profiling_prefill = FLAGS.profiling_prefill
Copy link
Collaborator

@lsy323 lsy323 May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current logic only handles:

  1. profile both output and prefill
  2. profile output but not prefill

If profile prefill but not output, the current logic won't profile at all. Not sure if we want to handle this case as well. We can add in future PR if there is such use case as well.

if profiling_output and profiling_prefill:
jax.profiler.start_trace(profiling_output)

decode_state = engine.init_decode_state()
Expand Down Expand Up @@ -68,7 +69,11 @@ def main(argv):
print(f"---- Streaming decode started on #slot{slot}.")
complete = np.zeros((1,), dtype=np.bool_)
while True:
if profiling_output and not profiling_prefill:
jax.profiler.start_trace(profiling_output)
decode_state, result_tokens = engine.generate(params, decode_state)
if profiling_output and not profiling_prefill:
jax.profiler.stop_trace()
result_tokens = result_tokens.convert_to_numpy()
output, complete = token_utils.process_result_tokens(
tokenizer=tokenizer,
Expand All @@ -87,7 +92,7 @@ def main(argv):
print("---- All output text.")
print(tokenizer.decode(sampled_tokens_list))

if profiling_output:
if profiling_output and profiling_prefill:
jax.profiler.stop_trace()


Expand Down