55# import torch_xla2 first!
66import torch_xla2 # pylint: disable
77import jax
8+ from jax import numpy as jnp
89from absl import app , flags
910from jetstream .engine import token_utils
1011from jetstream .core import server_lib
2627flags .DEFINE_integer ("max_output_length" , 1024 , "The batch size" )
2728flags .DEFINE_integer ("port" , 9000 , "port to listen on" )
2829flags .DEFINE_integer ("threads" , 64 , "number of worker threads in thread pool" )
30+ flags .DEFINE_string (
31+ "benchmark_save_offline_result_to_file" ,
32+ "" ,
33+ "if set, then save the result to the given file name" ,
34+ )
2935
3036
3137def shard_weights (env , weights , weight_shardings ):
@@ -114,6 +120,45 @@ def _check_model_id():
114120 sys .exit (1 )
115121
116122
123+ def _run_prefill_time (
124+ pt_engine , params , decode_state , seqlen , profiler_started
125+ ):
126+ """Run prefill and measure time."""
127+ metadata = pt_engine .get_tokenizer ()
128+ tokenizer = pt_engine .build_tokenizer (metadata )
129+
130+ text = "This is a beautiful day"
131+ tokens , true_length = tokenizer .encode (
132+ text , is_bos = True , prefill_lengths = [seqlen ]
133+ )
134+
135+ for _ in range (3 ):
136+ prefill_result , _ = pt_engine .prefill (
137+ params = params , padded_tokens = tokens , true_length = true_length
138+ )
139+ decode_state = pt_engine .insert (
140+ prefill_result , decode_state , slot = jnp .int32 (1 )
141+ )
142+
143+ nums = 5
144+ start = time .perf_counter ()
145+ for i in range (nums ):
146+ if i == nums - 1 and FLAGS .profiling_prefill and not profiler_started :
147+ jax .profiler .start_trace (FLAGS .profiling_output )
148+ profiler_started = True
149+
150+ prefill_result , _ = pt_engine .prefill (
151+ params = params , padded_tokens = tokens , true_length = true_length
152+ )
153+ decode_state = pt_engine .insert (
154+ prefill_result , decode_state , slot = jnp .int32 (i )
155+ )
156+ jax .block_until_ready (decode_state )
157+
158+ end = time .perf_counter ()
159+ return (end - start ) / nums , decode_state , profiler_started
160+
161+
117162def interactive ():
118163 """Run interactive"""
119164 _check_model_id ()
@@ -207,6 +252,100 @@ def interactive():
207252 print (tokenizer .decode (sampled_tokens_list ))
208253
209254
255+ def _save_benchmark_to_file (filename , prefill_times_ms , decode_time_ms ):
256+ lines = (
257+ [
258+ " # Offline benchmark numbers" ,
259+ " ## Model: " + FLAGS .model_id ,
260+ f" ## Batch size: { FLAGS .override_batch_size } " ,
261+ f" ## Quantize: { FLAGS .quantize_weights } " ,
262+ " | | time (ms) |" ,
263+ " |-------|-----------|" ,
264+ ]
265+ + [f"| Prefill { x } | { y } |" for x , y in prefill_times_ms .items ()]
266+ + [f"| Decode | { decode_time_ms } |" ]
267+ )
268+ with open (filename , "w" , encoding = "utf-8" ) as f :
269+ f .write ("\n " .join (lines ))
270+ f .flush ()
271+
272+
273+ def benchmark_offline ():
274+ """function to run engine offline."""
275+ _check_model_id ()
276+ devices = server_lib .get_devices ()
277+ print (f"devices: { devices } " )
278+ pt_engine = create_engine (devices )
279+
280+ start = time .perf_counter ()
281+ params = pt_engine .load_params ()
282+ print ("Load params " , time .perf_counter () - start )
283+
284+ prefill_times = {}
285+
286+ decode_state = pt_engine .init_decode_state ()
287+ profiler_started = False
288+ # 16 .. 1024
289+ for exp in range (4 , 11 ):
290+ batch = 2 ** exp
291+ runtime , decode_state , profiler_started = _run_prefill_time (
292+ pt_engine , params , decode_state , batch , profiler_started
293+ )
294+ prefill_times [batch ] = runtime
295+
296+ sampled_tokens_list = []
297+
298+ for i in range (3 ): # warm up
299+ # pylint: disable-next=all
300+ decode_state , sampled_tokens = pt_engine .generate (
301+ params = params , decode_state = decode_state
302+ )
303+ sampled_tokens_list .append (sampled_tokens )
304+
305+ profiling_output = FLAGS .profiling_output
306+ print ("======= decode starting ===" )
307+
308+ dec_times = []
309+ for i in range (10 ):
310+ if profiling_output and i == 7 and not profiler_started :
311+ jax .profiler .start_trace (profiling_output )
312+ profiler_started = True
313+ start = time .perf_counter ()
314+ # pylint: disable-next=all
315+ decode_state , sampled_tokens = pt_engine .generate (params , decode_state )
316+ jax .block_until_ready (decode_state )
317+ sampled_tokens_list .append (sampled_tokens )
318+ end = time .perf_counter ()
319+ dec_times .append (end - start )
320+ print (i , "decode time" , (end - start ))
321+
322+ if profiler_started :
323+ jax .profiler .stop_trace ()
324+
325+ print ("prefill " , prefill_times )
326+ avg_decode_times = sum (dec_times [2 :]) / len (dec_times [2 :])
327+ print ("decode" , avg_decode_times )
328+
329+ prefill_times_ms = {k : v * 1000 for k , v in prefill_times .items ()}
330+ decode_time_ms = sum (dec_times [2 :]) * 1000 / 8
331+
332+ largest_prefill = max (prefill_times .items ())
333+ print ("MAX tokens:" , FLAGS .batch_size / avg_decode_times )
334+
335+ time2 = (FLAGS .batch_size * FLAGS .max_decode_length ) / (
336+ FLAGS .batch_size * largest_prefill [1 ]
337+ + FLAGS .max_decode_length * avg_decode_times
338+ )
339+ print ("MAX tokens 2:" , time2 )
340+
341+ if FLAGS .benchmark_save_offline_result_to_file :
342+ _save_benchmark_to_file (
343+ FLAGS .benchmark_save_offline_result_to_file ,
344+ prefill_times_ms ,
345+ decode_time_ms ,
346+ )
347+
348+
210349def main ():
211350 """Main function."""
212351
@@ -221,6 +360,8 @@ def main_real(argv):
221360 serve ()
222361 elif argv [1 ] == "interactive" :
223362 interactive ()
363+ elif argv [1 ] == "benchmark_offline" :
364+ benchmark_offline ()
224365 else :
225366 print (
226367 "Invalid arguments. please specify 'list', 'serve', or 'interactive'."
0 commit comments