@@ -59,7 +59,7 @@ def create_engine(devices):
5959  model  =  fetch_models .instantiate_model_from_repo_id (FLAGS .model_id , env )
6060  if  quant_config .enable_weight_quantization :
6161    quantize_model .quantize_model (model , quant_config )
62-     print (' ====== model ======='  )
62+     print (" ====== model ======="  )
6363    print (model )
6464
6565  weight_shardings  =  model .get_sharding_annotations ()
@@ -81,11 +81,7 @@ def list_model():
8181
8282def  serve ():
8383  """Run gRPC server.""" 
84-   if  FLAGS .model_id  ==  "" :
85-     print ("Please specify model_id with --model_id" )
86-     print ("valid model ids are:" )
87-     list_model ()
88-     sys .exit (1 )
84+   _check_model_id ()
8985  devices  =  server_lib .get_devices ()
9086  print (f"devices: { devices }  " )
9187
@@ -110,23 +106,27 @@ def serve():
110106  jetstream_server .wait_for_termination ()
111107
112108
113- def  interactive ():
114-   """Run interactive""" 
109+ def  _check_model_id ():
115110  if  FLAGS .model_id  ==  "" :
116111    print ("Please specify model_id with --model_id" )
117112    print ("valid model ids are:" )
118113    list_model ()
119114    sys .exit (1 )
115+ 
116+ 
117+ def  interactive ():
118+   """Run interactive""" 
119+   _check_model_id ()
120120  devices  =  server_lib .get_devices ()
121121  print (f"devices: { devices }  " )
122-   engine  =  create_engine (devices )
122+   pt_engine  =  create_engine (devices )
123123
124124  start  =  time .perf_counter ()
125-   params  =  engine .load_params ()
125+   params  =  pt_engine .load_params ()
126126  print ("Load params " , time .perf_counter () -  start )
127127
128-   metadata  =  engine .get_tokenizer ()
129-   tokenizer  =  engine .build_tokenizer (metadata )
128+   metadata  =  pt_engine .get_tokenizer ()
129+   tokenizer  =  pt_engine .build_tokenizer (metadata )
130130  max_output_length  =  1024 
131131
132132  profiling_output  =  FLAGS .profiling_output 
@@ -139,7 +139,7 @@ def interactive():
139139  if  profiling_prefill :
140140    jax .profiler .start_trace (profiling_output )
141141
142-   decode_state  =  engine .init_decode_state ()
142+   decode_state  =  pt_engine .init_decode_state ()
143143
144144  if  profiling_prefill :
145145    jax .profiler .stop_trace ()
@@ -167,11 +167,11 @@ def interactive():
167167    if  profiling_prefill :
168168      jax .profiler .start_trace (profiling_output )
169169
170-     prefill_result , _  =  engine .prefill (
170+     prefill_result , _  =  pt_engine .prefill (
171171        params = params , padded_tokens = tokens , true_length = true_length 
172172    )
173173    # pylint: disable-next=all 
174-     decode_state  =  engine .insert (prefill_result , decode_state , slot = slot )
174+     decode_state  =  pt_engine .insert (prefill_result , decode_state , slot = slot )
175175
176176    if  profiling_prefill :
177177      jax .profiler .stop_trace ()
@@ -183,7 +183,7 @@ def interactive():
183183      if  profiling_output :
184184        jax .profiler .start_trace (profiling_output )
185185
186-       decode_state , result_tokens  =  engine .generate (params , decode_state )
186+       decode_state , result_tokens  =  pt_engine .generate (params , decode_state )
187187      result_tokens  =  result_tokens .convert_to_numpy ()
188188
189189      if  profiling_output :
@@ -214,18 +214,13 @@ def main(argv):
214214
215215  if  argv [1 ] ==  "list" :
216216    list_model ()
217-     return 
218- 
219217  elif  argv [1 ] ==  "serve" :
220218    serve ()
221-     return 
222- 
223219  elif  argv [1 ] ==  "interactive" :
224220    interactive ()
225-     return 
226221  else :
227222    print (
228-       "Invalid arguments. please specify 'list', 'serve', or 'interactive'." 
223+          "Invalid arguments. please specify 'list', 'serve', or 'interactive'." 
229224    )
230225
231226
0 commit comments