-
Notifications
You must be signed in to change notification settings - Fork 17
Closed
Description
I get this error
Loading checkpoint files from /home/yeandy/llama/llama-2-13b.
Loading checkpoints takes 9.128946957000153 seconds
Starting to merge weights.
Merging weights across 2 shards (shape = torch.Size([32000, 2560])) for tok_embeddings.weight)
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/yeandy/jetstream-pytorch/convert_checkpoints.py", line 421, in <module>
app.run(main)
File "/home/yeandy/.env/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/yeandy/.env/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/yeandy/jetstream-pytorch/convert_checkpoints.py", line 386, in main
state_dict, params = _get_llama_state_dict(_INPUT_CHECKPOINT_DIR.value)
File "/home/yeandy/jetstream-pytorch/convert_checkpoints.py", line 321, in _get_llama_state_dict
state_dict = _merge_llama_weights(
File "/home/yeandy/jetstream-pytorch/convert_checkpoints.py", line 182, in _merge_llama_weights
for pattern, kind in llama_model.get_weight_sharding_type.items():
AttributeError: module 'jetstream_pt.third_party.llama.model_exportable' has no attribute 'get_weight_sharding_type'
Due to wrong call here https://github.com/google/jetstream-pytorch/blob/main/convert_checkpoints.py#L182
I believe it should be
llama_model.Transformer.get_weight_sharding_type().items()
instead of
llama_model.get_weight_sharding_type.items()
Metadata
Metadata
Assignees
Labels
No labels