@@ -59,7 +59,7 @@ def create_engine(devices):
5959 model = fetch_models .instantiate_model_from_repo_id (FLAGS .model_id , env )
6060 if quant_config .enable_weight_quantization :
6161 quantize_model .quantize_model (model , quant_config )
62- print (' ====== model =======' )
62+ print (" ====== model =======" )
6363 print (model )
6464
6565 weight_shardings = model .get_sharding_annotations ()
@@ -81,11 +81,7 @@ def list_model():
8181
8282def serve ():
8383 """Run gRPC server."""
84- if FLAGS .model_id == "" :
85- print ("Please specify model_id with --model_id" )
86- print ("valid model ids are:" )
87- list_model ()
88- sys .exit (1 )
84+ _check_model_id ()
8985 devices = server_lib .get_devices ()
9086 print (f"devices: { devices } " )
9187
@@ -110,23 +106,27 @@ def serve():
110106 jetstream_server .wait_for_termination ()
111107
112108
113- def interactive ():
114- """Run interactive"""
109+ def _check_model_id ():
115110 if FLAGS .model_id == "" :
116111 print ("Please specify model_id with --model_id" )
117112 print ("valid model ids are:" )
118113 list_model ()
119114 sys .exit (1 )
115+
116+
117+ def interactive ():
118+ """Run interactive"""
119+ _check_model_id ()
120120 devices = server_lib .get_devices ()
121121 print (f"devices: { devices } " )
122- engine = create_engine (devices )
122+ pt_engine = create_engine (devices )
123123
124124 start = time .perf_counter ()
125- params = engine .load_params ()
125+ params = pt_engine .load_params ()
126126 print ("Load params " , time .perf_counter () - start )
127127
128- metadata = engine .get_tokenizer ()
129- tokenizer = engine .build_tokenizer (metadata )
128+ metadata = pt_engine .get_tokenizer ()
129+ tokenizer = pt_engine .build_tokenizer (metadata )
130130 max_output_length = 1024
131131
132132 profiling_output = FLAGS .profiling_output
@@ -139,7 +139,7 @@ def interactive():
139139 if profiling_prefill :
140140 jax .profiler .start_trace (profiling_output )
141141
142- decode_state = engine .init_decode_state ()
142+ decode_state = pt_engine .init_decode_state ()
143143
144144 if profiling_prefill :
145145 jax .profiler .stop_trace ()
@@ -167,11 +167,11 @@ def interactive():
167167 if profiling_prefill :
168168 jax .profiler .start_trace (profiling_output )
169169
170- prefill_result , _ = engine .prefill (
170+ prefill_result , _ = pt_engine .prefill (
171171 params = params , padded_tokens = tokens , true_length = true_length
172172 )
173173 # pylint: disable-next=all
174- decode_state = engine .insert (prefill_result , decode_state , slot = slot )
174+ decode_state = pt_engine .insert (prefill_result , decode_state , slot = slot )
175175
176176 if profiling_prefill :
177177 jax .profiler .stop_trace ()
@@ -183,7 +183,7 @@ def interactive():
183183 if profiling_output :
184184 jax .profiler .start_trace (profiling_output )
185185
186- decode_state , result_tokens = engine .generate (params , decode_state )
186+ decode_state , result_tokens = pt_engine .generate (params , decode_state )
187187 result_tokens = result_tokens .convert_to_numpy ()
188188
189189 if profiling_output :
@@ -214,18 +214,13 @@ def main(argv):
214214
215215 if argv [1 ] == "list" :
216216 list_model ()
217- return
218-
219217 elif argv [1 ] == "serve" :
220218 serve ()
221- return
222-
223219 elif argv [1 ] == "interactive" :
224220 interactive ()
225- return
226221 else :
227222 print (
228- "Invalid arguments. please specify 'list', 'serve', or 'interactive'."
223+ "Invalid arguments. please specify 'list', 'serve', or 'interactive'."
229224 )
230225
231226
0 commit comments