From 4e2aa6a45a8735b5c25749229df216aa1f6cb476 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 4 Jun 2024 23:00:06 +0000 Subject: [PATCH 1/2] Fix --- convert_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 4f1ade16..0994fda1 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -179,7 +179,7 @@ def _merge_llama_weights( f"{len(tensors)} shards (shape = {tensors[0].shape}) for {key})" ) state_dict_for_key = {} - for pattern, kind in llama_model.get_weight_sharding_type.items(): + for pattern, kind in llama_model.Transformer.get_weight_sharding_type().items(): if not key.endswith(pattern): continue with torch.no_grad(): From aac6aaa572c063bcaca2d17cd722385b46f51c6a Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 4 Jun 2024 23:12:25 +0000 Subject: [PATCH 2/2] Format --- convert_checkpoints.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 0994fda1..d3e05ee5 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -179,7 +179,10 @@ def _merge_llama_weights( f"{len(tensors)} shards (shape = {tensors[0].shape}) for {key})" ) state_dict_for_key = {} - for pattern, kind in llama_model.Transformer.get_weight_sharding_type().items(): + weight_sharding_type = ( + llama_model.Transformer.get_weight_sharding_type().items() + ) + for pattern, kind in weight_sharding_type: if not key.endswith(pattern): continue with torch.no_grad():