55# import torch_xla2 first!
66import torch_xla2 # pylint: disable
77import jax
8+ from jax import numpy as jnp
89from absl import app , flags
910from jetstream .engine import token_utils
1011from jetstream .core import server_lib
2627flags .DEFINE_integer ("max_output_length" , 1024 , "The batch size" )
2728flags .DEFINE_integer ("port" , 9000 , "port to listen on" )
2829flags .DEFINE_integer ("threads" , 64 , "number of worker threads in thread pool" )
30+ flags .DEFINE_string ("benchmark_save_offline_result_to_file" , "" , "if set, then save the result to the given file name" )
2931
3032
3133def shard_weights (env , weights , weight_shardings ):
@@ -113,6 +115,42 @@ def _check_model_id():
113115 list_model ()
114116 sys .exit (1 )
115117
118+ def _run_prefill_time (engine , params , decode_state , seqlen , profiler_started ):
119+ """Run prefill and measure time."""
120+ metadata = engine .get_tokenizer ()
121+ tokenizer = engine .build_tokenizer (metadata )
122+
123+ text = "This is a beautiful day"
124+ tokens , true_length = tokenizer .encode (
125+ text , is_bos = True , prefill_lengths = [seqlen ]
126+ )
127+
128+ for _ in range (3 ):
129+ prefill_result , _ = engine .prefill (
130+ params = params , padded_tokens = tokens , true_length = true_length
131+ )
132+ decode_state = engine .insert (
133+ prefill_result , decode_state , slot = jnp .int32 (1 )
134+ )
135+
136+ nums = 5
137+ start = time .perf_counter ()
138+ for i in range (nums ):
139+ if i == nums - 1 and FLAGS .profiling_prefill and not profiler_started :
140+ jax .profiler .start_trace (FLAGS .profiling_output )
141+ profiler_started = True
142+
143+ prefill_result , _ = engine .prefill (
144+ params = params , padded_tokens = tokens , true_length = true_length
145+ )
146+ decode_state = engine .insert (
147+ prefill_result , decode_state , slot = jnp .int32 (i )
148+ )
149+ jax .block_until_ready (decode_state )
150+
151+ end = time .perf_counter ()
152+ return (end - start ) / nums , decode_state , profiler_started
153+
116154
117155def interactive ():
118156 """Run interactive"""
@@ -206,6 +244,101 @@ def interactive():
206244 print ("---- All output text." )
207245 print (tokenizer .decode (sampled_tokens_list ))
208246
247+ def _save_benchmark_to_file (filename , prefill_times_ms , decode_time_ms ):
248+ lines = [
249+ " # Offline benchmark numbers" ,
250+ " ## Model: " + FLAGS .model_id ,
251+ " ## Batch size: {}" .format (FLAGS .override_batch_size ),
252+ " ## Quantize: {}" .format (FLAGS .quantize_weights ),
253+ " | | time (ms) |" ,
254+ " |-------|-----------|" ,
255+ ] + [
256+ "| Prefill {} | {} |" .format (x , y ) for x , y in prefill_times_ms .items ()
257+ ] + [
258+ "| Decode | {} |" .format (decode_time_ms )
259+ ]
260+ with open (filename , 'w' ) as f :
261+ f .write ('\n ' .join (lines ))
262+ f .flush ()
263+
264+
265+
266+ def benchmark_offline ():
267+ """function to run engine offline."""
268+ _check_model_id ()
269+ devices = server_lib .get_devices ()
270+ print (f"devices: { devices } " )
271+ pt_engine = create_engine (devices )
272+
273+ start = time .perf_counter ()
274+ params = pt_engine .load_params ()
275+ print ("Load params " , time .perf_counter () - start )
276+
277+ prefill_times = {}
278+
279+ decode_state = pt_engine .init_decode_state ()
280+ profiler_started = False
281+ # 16 .. 1024
282+ for exp in range (4 , 11 ):
283+ batch = 2 ** exp
284+ runtime , decode_state , profiler_started = _run_prefill_time (
285+ pt_engine , params , decode_state , batch , profiler_started
286+ )
287+ prefill_times [batch ] = runtime
288+
289+ sampled_tokens_list = []
290+
291+ for i in range (3 ): # warm up
292+ # pylint: disable-next=all
293+ decode_state , sampled_tokens = pt_engine .generate (
294+ params = params , decode_state = decode_state
295+ )
296+ sampled_tokens_list .append (sampled_tokens )
297+
298+ profiling_output = FLAGS .profiling_output
299+ print ("======= decode starting ===" )
300+
301+ dec_times = []
302+ for i in range (10 ):
303+ if profiling_output and i == 7 and not profiler_started :
304+ jax .profiler .start_trace (profiling_output )
305+ profiler_started = True
306+ start = time .perf_counter ()
307+ # pylint: disable-next=all
308+ decode_state , sampled_tokens = pt_engine .generate (params , decode_state )
309+ jax .block_until_ready (decode_state )
310+ sampled_tokens_list .append (sampled_tokens )
311+ end = time .perf_counter ()
312+ dec_times .append (end - start )
313+ print (i , "decode time" , (end - start ))
314+
315+ if profiler_started :
316+ jax .profiler .stop_trace ()
317+
318+ print ("prefill " , prefill_times )
319+ avg_decode_times = sum (dec_times [2 :]) / len (dec_times [2 :])
320+ print ("decode" , avg_decode_times )
321+
322+ prefill_times_ms = {k : v * 1000 for k , v in prefill_times .items ()}
323+ decode_time_ms = sum (dec_times [2 :]) * 1000 / 8
324+
325+ largest_prefill = max (prefill_times .items ())
326+ print ("MAX tokens:" , FLAGS .batch_size / avg_decode_times )
327+
328+ time2 = (FLAGS .batch_size * FLAGS .max_decode_length ) / (
329+ FLAGS .batch_size * largest_prefill [1 ]
330+ + FLAGS .max_decode_length * avg_decode_times
331+ )
332+ print ("MAX tokens 2:" , time2 )
333+
334+ if FLAGS .benchmark_save_offline_result_to_file :
335+ _save_benchmark_to_file (
336+ FLAGS .benchmark_save_offline_result_to_file ,
337+ prefill_times_ms ,
338+ decode_time_ms
339+ )
340+
341+
209342
210343def main ():
211344 """Main function."""
@@ -221,6 +354,8 @@ def main_real(argv):
221354 serve ()
222355 elif argv [1 ] == "interactive" :
223356 interactive ()
357+ elif argv [1 ] == "benchmark_offline" :
358+ benchmark_offline ()
224359 else :
225360 print (
226361 "Invalid arguments. please specify 'list', 'serve', or 'interactive'."
0 commit comments