In [2]:
import torch
import os
import re
from collections import OrderedDict

In [10]:
def average_checkpoints(checkpoint_root, last_n: int = 5, **kwargs):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
    """
    checkpoint_paths = []
    for i,_, files in os.walk(checkpoint_root):
        for file in files:
            if file.endswith('.pt'):
                ckpt_path = os.path.join(i, file)
                checkpoint_paths.append(ckpt_path)
    
    print(f"average_checkpoints: {checkpoint_paths}")
    state_dicts = []

    # Load state_dicts from checkpoints
    for path in checkpoint_paths:
        if os.path.isfile(path):
            # state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
            state_dicts.append(torch.load(path, map_location="cpu", weights_only=True))

        else:
            print(f"Checkpoint file {path} not found.")

    # Check if we have any state_dicts to average
    if len(state_dicts) < 1:
        print("No checkpoints found for averaging.")
        return

    # Average or sum weights
    avg_state_dict = OrderedDict()
    for key in state_dicts[0].keys():
        tensors = [state_dict[key].cpu() for state_dict in state_dicts]
        # Check the type of the tensor
        if str(tensors[0].dtype).startswith("torch.int"):
            # Perform sum for integer tensors
            summed_tensor = sum(tensors)
            avg_state_dict[key] = summed_tensor
        else:
            # Perform average for other types of tensors
            stacked_tensors = torch.stack(tensors)
            avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
    checkpoint_outpath = os.path.join(checkpoint_root, f"model.pt.avg{last_n}")
    torch.save({"state_dict": avg_state_dict}, checkpoint_outpath)
    return checkpoint_outpath

# 首先先把ckpt转换成pt格式，然后再进行平均

In [4]:
root = "/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/"
for i,_, files in os.walk(root):
    for file in files:
        if file.endswith('.ckpt'):
            ckpt_path = os.path.join(i, file)
            pt_path = ckpt_path.replace('.ckpt', '.pt')
            model = torch.load(ckpt_path, map_location='cpu')['state_dict']
            pt_model = {}
            for k, v in model.items():
                pt_model[k.replace('model.', '')] = v
            torch.save(pt_model, pt_path)
            print(f'Convert {ckpt_path} to {pt_path}')

/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=96-val_loss=7.5840-val_asr_acc=0.9475.ckpt


  model = torch.load(ckpt_path, map_location='cpu')['state_dict']
  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)
  @autocast(enabled = False)


Convert /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=96-val_loss=7.5840-val_asr_acc=0.9475.ckpt to /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=96-val_loss=7.5840-val_asr_acc=0.9475.pt
/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=98-val_loss=7.5216-val_asr_acc=0.9480.ckpt
Convert /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=98-val_loss=7.5216-val_asr_acc=0.9480.ckpt to /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=98-val_loss=7.5216-val_asr_acc=0.9480.pt
/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=92-val_loss=7.7163-val_asr_acc=0.9465.ckpt


  model = torch.load(ckpt_path, map_location='cpu')['state_dict']


Convert /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=92-val_loss=7.7163-val_asr_acc=0.9465.ckpt to /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=92-val_loss=7.7163-val_asr_acc=0.9465.pt
/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=95-val_loss=7.6161-val_asr_acc=0.9472.ckpt
Convert /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=95-val_loss=7.6161-val_asr_acc=0.9472.ckpt to /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=95-val_loss=7.6161-val_asr_acc=0.9472.pt
/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=93-val_loss=7.6823-val_asr_acc=0.9467.ckpt
Convert /ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=93-val_loss=7.6823-val_asr_acc=0.9467.ckpt to /ssd/

In [6]:
model_path = "/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=100-val_loss=7.4615-val_asr_acc=0.9484.pt"
model = torch.load(model_path, map_location='cpu')
model.keys()

  model = torch.load(model_path, map_location='cpu')


dict_keys(['encoder.embed.0.weight', 'encoder.encoders.0.self_attn.linear_q.weight', 'encoder.encoders.0.self_attn.linear_q.bias', 'encoder.encoders.0.self_attn.linear_k.weight', 'encoder.encoders.0.self_attn.linear_k.bias', 'encoder.encoders.0.self_attn.linear_v.weight', 'encoder.encoders.0.self_attn.linear_v.bias', 'encoder.encoders.0.self_attn.linear_out.weight', 'encoder.encoders.0.self_attn.linear_out.bias', 'encoder.encoders.0.feed_forward.w_1.weight', 'encoder.encoders.0.feed_forward.w_1.bias', 'encoder.encoders.0.feed_forward.w_2.weight', 'encoder.encoders.0.feed_forward.w_2.bias', 'encoder.encoders.0.norm1.weight', 'encoder.encoders.0.norm1.bias', 'encoder.encoders.0.norm2.weight', 'encoder.encoders.0.norm2.bias', 'encoder.encoders.1.self_attn.linear_q.weight', 'encoder.encoders.1.self_attn.linear_q.bias', 'encoder.encoders.1.self_attn.linear_k.weight', 'encoder.encoders.1.self_attn.linear_k.bias', 'encoder.encoders.1.self_attn.linear_v.weight', 'encoder.encoders.1.self_attn.l

In [11]:
average_checkpoints(root, last_n=10)

average_checkpoints: ['/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=100-val_loss=7.4615-val_asr_acc=0.9484.pt', '/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=99-val_loss=7.4913-val_asr_acc=0.9482.pt', '/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=101-val_loss=7.4322-val_asr_acc=0.9487.pt', '/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=92-val_loss=7.7163-val_asr_acc=0.9465.pt', '/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=97-val_loss=7.5525-val_asr_acc=0.9477.pt', '/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=95-val_loss=7.6161-val_asr_acc=0.9472.pt', '/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/epoch=98-val_loss=7.5216-val_asr_acc=0.9480.pt', '/s

'/ssd/zhuang/code/FunASR/examples/ChineseCorrection/tb_logs/stage_1/version_0/checkpoints/model.pt.avg10'