In [2]:
import torch

def load_torch_checkpoint(file_path):
    """
    Load a PyTorch checkpoint from a given file path.
    
    Args:
    file_path (str): Path to the checkpoint file.
    
    Returns:
    dict: The loaded checkpoint as a dictionary.
    """
    try:
        checkpoint = torch.load(file_path, map_location=torch.device('cpu'))
        print(f"Successfully loaded checkpoint from {file_path}")
        return checkpoint
    except Exception as e:
        print(f"Error loading checkpoint from {file_path}: {str(e)}")
        return None


In [3]:
ckpt_base =  '/home/yingqi/repo/hmt_pretrained/opt-350m'
ckpt1 = load_torch_checkpoint(f'{ckpt_base}/model_weights_0_lv_4_step2.pth')
ckpt2 = load_torch_checkpoint(f'{ckpt_base}/model_weights_700_lv_1_step2.pth')


Successfully loaded checkpoint from /home/yingqi/repo/hmt_pretrained/opt-350m/model_weights_0_lv_4_step2.pth
Successfully loaded checkpoint from /home/yingqi/repo/hmt_pretrained/opt-350m/model_weights_700_lv_1_step2.pth


In [4]:
print(ckpt1.keys())
print(ckpt2.keys())

odict_keys(['module.mem', 'module.memory_cell.memory', 'module.memory_cell.model.model.decoder.embed_tokens.weight', 'module.memory_cell.model.model.decoder.embed_positions.weight', 'module.memory_cell.model.model.decoder.project_out.weight', 'module.memory_cell.model.model.decoder.project_in.weight', 'module.memory_cell.model.model.decoder.layers.0.self_attn.k_proj.weight', 'module.memory_cell.model.model.decoder.layers.0.self_attn.k_proj.bias', 'module.memory_cell.model.model.decoder.layers.0.self_attn.v_proj.weight', 'module.memory_cell.model.model.decoder.layers.0.self_attn.v_proj.bias', 'module.memory_cell.model.model.decoder.layers.0.self_attn.q_proj.weight', 'module.memory_cell.model.model.decoder.layers.0.self_attn.q_proj.bias', 'module.memory_cell.model.model.decoder.layers.0.self_attn.out_proj.weight', 'module.memory_cell.model.model.decoder.layers.0.self_attn.out_proj.bias', 'module.memory_cell.model.model.decoder.layers.0.self_attn_layer_norm.weight', 'module.memory_cell.mo

In [5]:
def sum_model_weights(checkpoint):
    sums = []
    for key, value in checkpoint.items():
        if isinstance(value, torch.Tensor):
            sums.append(torch.sum(value).item())
    return sums

sums1 = sum_model_weights(ckpt1)
sums2 = sum_model_weights(ckpt2)

print("Sums of model weights for checkpoint 1:")
print(sums1)
print("\nSums of model weights for checkpoint 2:")
print(sums2)


Sums of model weights for checkpoint 1:
[0.640625, -2.84375, -32000.0, 37.0, 45.5, 18.125, -49.25, 0.3515625, 2.078125, -0.033447265625, -58.0, 1.203125, -6.59375, -0.0296630859375, 1024.0, -2.078125, 700.0, -121.0, 19.25, -0.0859375, 1024.0, 7.1875, -13.8125, 8.125, 14.875, 0.40234375, -37.25, -1.6015625, 1.6796875, -0.11376953125, 1024.0, -2.078125, 776.0, -99.5, -11.75, 0.30078125, 1024.0, 8.875, 26.125, 9.0625, 42.0, -0.09765625, 12.6875, 1.0078125, -6.8125, 0.349609375, 1024.0, -1.9609375, 980.0, -96.0, -7.78125, 0.3203125, 1024.0, 9.375, -19.75, 8.5, -94.5, -1.015625, -17.375, -0.2890625, -1.7265625, 0.10986328125, 1024.0, -0.470703125, 568.0, -101.5, 6.53125, 0.17578125, 1024.0, 6.71875, 32.0, 9.3125, 66.5, 0.1474609375, -1.5625, 1.4609375, 5.3125, 0.333984375, 1024.0, 0.1982421875, 332.0, -95.5, 7.6875, 0.6796875, 1024.0, 2.609375, -53.75, 11.5625, 10.8125, 0.59375, 11.125, -0.69140625, 21.625, 0.2099609375, 1024.0, -0.4921875, 466.0, -89.5, 1.765625, 0.0927734375, 1024.0, 0.42