In [1]:
# from funasr.train_utils.average_nbest_models import average_checkpoints

# average_checkpoints(
#     "./outputs_1", 12, use_deepspeed=False
# )
# average_checkpoints(
#     "./outputs_1", 5, use_deepspeed=False
# )

In [None]:
import logging
from pathlib import Path
from typing import Optional
from typing import Sequence
from typing import Union,List
import warnings
import os
from io import BytesIO

import torch
from typing import Collection
import os
import torch
import re
from collections import OrderedDict
from functools import cmp_to_key



@torch.no_grad()
def average_checkpoints(checkpoint_paths: List[str]):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
    """
    output_dir = "outputs_1"
    print(f"average_checkpoints: {checkpoint_paths}")
    state_dicts = []

    # Load state_dicts from checkpoints
    for path in checkpoint_paths:
        if os.path.isfile(path):
            state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
        else:
            print(f"Checkpoint file {path} not found.")

    # Check if we have any state_dicts to average
    if len(state_dicts) < 1:
        print("No checkpoints found for averaging.")
        return

    # Average or sum weights
    avg_state_dict = OrderedDict()
    for key in state_dicts[0].keys():
        tensors = [state_dict[key].cpu() for state_dict in state_dicts]
        # Check the type of the tensor
        if str(tensors[0].dtype).startswith("torch.int"):
            # Perform sum for integer tensors
            summed_tensor = sum(tensors)
            avg_state_dict[key] = summed_tensor
        else:
            # Perform average for other types of tensors
            stacked_tensors = torch.stack(tensors)
            avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
    checkpoint_outpath = os.path.join(output_dir, f"model_emsemble.pt")
    torch.save({"state_dict": avg_state_dict}, checkpoint_outpath)
    return checkpoint_outpath

In [5]:
checkpoint_paths = ["archive(1)/outputs0/model.pt.avg10",
                   "archive(1)/outputs1/model.pt.avg10",
                    "archive(1)/outputs2/model.pt.avg10",
                    "archive(1)/outputs3/model.pt.avg10",
                    "archive(1)/outputs4/model.pt.avg10",
                   ]
average_checkpoints(checkpoint_paths)

average_checkpoints: ['archive(1)/outputs0/model.pt.avg10', 'archive(1)/outputs1/model.pt.avg10', 'archive(1)/outputs2/model.pt.avg10', 'archive(1)/outputs3/model.pt.avg10', 'archive(1)/outputs4/model.pt.avg10']


'outputs_1/model_emsemble.pt'