You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def load_from_torch_shard_ckpt(model, ckpt_dir):
"""
Load sharded checkpoints directly from huggingface dir.
"""
with open(os.path.join(ckpt_dir, 'pytorch_model.bin.index.json')) as fp:
ckpt_index = json.load(fp)
total_size = ckpt_index['metadata']['total_size']
weight_map = ckpt_index['weight_map']
file_weight_map = {}
for key, value in weight_map.items():
# key: param name; value: filename.
if value not in file_weight_map:
file_weight_map[value] = []
file_weight_map[value].append(key)
load_from_map(model, ckpt_dir, file_weight_map)
The text was updated successfully, but these errors were encountered:
def load_from_torch_shard_ckpt(model, ckpt_dir):
"""
Load sharded checkpoints directly from huggingface dir.
"""
with open(os.path.join(ckpt_dir, 'pytorch_model.bin.index.json')) as fp:
ckpt_index = json.load(fp)
The text was updated successfully, but these errors were encountered: