2727flags .DEFINE_integer ("max_output_length" , 1024 , "The batch size" )
2828flags .DEFINE_integer ("port" , 9000 , "port to listen on" )
2929flags .DEFINE_integer ("threads" , 64 , "number of worker threads in thread pool" )
30- flags .DEFINE_string ("benchmark_save_offline_result_to_file" , "" , "if set, then save the result to the given file name" )
30+ flags .DEFINE_string (
31+ "benchmark_save_offline_result_to_file" ,
32+ "" ,
33+ "if set, then save the result to the given file name" ,
34+ )
3135
3236
3337def shard_weights (env , weights , weight_shardings ):
@@ -115,21 +119,22 @@ def _check_model_id():
115119 list_model ()
116120 sys .exit (1 )
117121
118- def _run_prefill_time (engine , params , decode_state , seqlen , profiler_started ):
122+
123+ def _run_prefill_time (pt_engine , params , decode_state , seqlen , profiler_started ):
119124 """Run prefill and measure time."""
120- metadata = engine .get_tokenizer ()
121- tokenizer = engine .build_tokenizer (metadata )
125+ metadata = pt_engine .get_tokenizer ()
126+ tokenizer = pt_engine .build_tokenizer (metadata )
122127
123128 text = "This is a beautiful day"
124129 tokens , true_length = tokenizer .encode (
125130 text , is_bos = True , prefill_lengths = [seqlen ]
126131 )
127132
128133 for _ in range (3 ):
129- prefill_result , _ = engine .prefill (
134+ prefill_result , _ = pt_engine .prefill (
130135 params = params , padded_tokens = tokens , true_length = true_length
131136 )
132- decode_state = engine .insert (
137+ decode_state = pt_engine .insert (
133138 prefill_result , decode_state , slot = jnp .int32 (1 )
134139 )
135140
@@ -140,10 +145,10 @@ def _run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
140145 jax .profiler .start_trace (FLAGS .profiling_output )
141146 profiler_started = True
142147
143- prefill_result , _ = engine .prefill (
148+ prefill_result , _ = pt_engine .prefill (
144149 params = params , padded_tokens = tokens , true_length = true_length
145150 )
146- decode_state = engine .insert (
151+ decode_state = pt_engine .insert (
147152 prefill_result , decode_state , slot = jnp .int32 (i )
148153 )
149154 jax .block_until_ready (decode_state )
@@ -244,25 +249,28 @@ def interactive():
244249 print ("---- All output text." )
245250 print (tokenizer .decode (sampled_tokens_list ))
246251
252+
247253def _save_benchmark_to_file (filename , prefill_times_ms , decode_time_ms ):
248- lines = [
249- " # Offline benchmark numbers" ,
250- " ## Model: " + FLAGS .model_id ,
251- " ## Batch size: {}" .format (FLAGS .override_batch_size ),
252- " ## Quantize: {}" .format (FLAGS .quantize_weights ),
253- " | | time (ms) |" ,
254- " |-------|-----------|" ,
255- ] + [
256- "| Prefill {} | {} |" .format (x , y ) for x , y in prefill_times_ms .items ()
257- ] + [
258- "| Decode | {} |" .format (decode_time_ms )
259- ]
260- with open (filename , 'w' ) as f :
261- f .write ('\n ' .join (lines ))
254+ lines = (
255+ [
256+ " # Offline benchmark numbers" ,
257+ " ## Model: " + FLAGS .model_id ,
258+ f" ## Batch size: { FLAGS .override_batch_size } " ,
259+ f" ## Quantize: { FLAGS .quantize_weights } " ,
260+ " | | time (ms) |" ,
261+ " |-------|-----------|" ,
262+ ]
263+ + [
264+ f"| Prefill { x } | { y } |"
265+ for x , y in prefill_times_ms .items ()
266+ ]
267+ + [f"| Decode | { decode_time_ms } |" ]
268+ )
269+ with open (filename , "w" , encoding = 'utf-8' ) as f :
270+ f .write ("\n " .join (lines ))
262271 f .flush ()
263272
264273
265-
266274def benchmark_offline ():
267275 """function to run engine offline."""
268276 _check_model_id ()
@@ -280,7 +288,7 @@ def benchmark_offline():
280288 profiler_started = False
281289 # 16 .. 1024
282290 for exp in range (4 , 11 ):
283- batch = 2 ** exp
291+ batch = 2 ** exp
284292 runtime , decode_state , profiler_started = _run_prefill_time (
285293 pt_engine , params , decode_state , batch , profiler_started
286294 )
@@ -333,13 +341,12 @@ def benchmark_offline():
333341
334342 if FLAGS .benchmark_save_offline_result_to_file :
335343 _save_benchmark_to_file (
336- FLAGS .benchmark_save_offline_result_to_file ,
337- prefill_times_ms ,
338- decode_time_ms
344+ FLAGS .benchmark_save_offline_result_to_file ,
345+ prefill_times_ms ,
346+ decode_time_ms ,
339347 )
340348
341349
342-
343350def main ():
344351 """Main function."""
345352
0 commit comments