@@ -278,7 +278,7 @@ def _load_orig_llama_weight(input_ckpt_dir: epath.Path):
278278
279279def _load_hf_llama_weight (input_ckpt_dir : epath .Path ):
280280 print (f"Loading checkpoint files from { input_ckpt_dir } ." )
281- safetensors_files = input_ckpt_dir .glob ("*.safetensors" )
281+ safetensors_files = list ( input_ckpt_dir .glob ("*.safetensors" ) )
282282 if len (list (safetensors_files )) == 0 :
283283 raise ValueError (
284284 f"No *.safetensors found in the input dir { input_ckpt_dir } "
@@ -419,14 +419,23 @@ def _get_llama_state_dict(input_ckpt_dir):
419419 return state_dict , params
420420
421421
422+ def fix_json (text ):
423+ text = text .replace ("'" , '"' )
424+ lines = text .split ("\n " )
425+ lines [- 3 ] = lines [- 3 ].replace ("," , "" )
426+ return "\n " .join (lines )
427+
428+
422429def _get_gemma_state_dict (input_ckpt_dir ):
423430 ckpt_file = list (input_ckpt_dir .glob ("*.ckpt" ))
424431 assert len (ckpt_file ) == 1 , "only expect 1 ckpt file for Gemma model."
425432 ckpt_file = ckpt_file [0 ]
426433 state_dict = torch .load (str (ckpt_file ), map_location = torch .device ("cpu" ))[
427434 "model_state_dict"
428435 ]
429- model_config = json .loads ((input_ckpt_dir / "config.json" ).read_text ())
436+ config_text = fix_json ((input_ckpt_dir / "config.json" ).read_text ())
437+ print ("gemma config is" , config_text )
438+ model_config = json .loads (config_text )
430439 for key in list (state_dict .keys ()):
431440 if state_dict [key ].dtype .is_complex and _OUTPUT_SAFETENSORS .value :
432441 assert (
0 commit comments