In [None]:
!nvidia-smi

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

import pandas as pd
import numpy as np
import librosa
import seaborn as sns
import os
import json
import IPython.display as ipd
import soundfile as sf
import torch
import h5py
import onnxruntime as ort
import openvino as ov
import re

from glob import glob
from tqdm import tqdm
from matplotlib import pyplot as plt
from itertools import chain
from os.path import join as pjoin
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
from copy import deepcopy
from pprint import pprint

# from code_base.utils import parallel_librosa_load, groupby_np_array, stack_and_max_by_samples, macro_f1_similarity, N_CLASSES_2021_2022, N_CLASSES_2021, comp_metric, N_CLASSES_XC_LIGIT_SHORTEN, N_CLASSES_XC_LIGIT_EVEN_SHORTEN
# from code_base.utils.constants import SAMPLE_RATE
from code_base.utils.onnx_utils import ONNXEnsemble, convert_to_onnx
from code_base.models import WaveCNNClasifier, WaveCNNAttenClasifier
from code_base.datasets import WaveDataset, WaveAllFileDataset
from code_base.utils.swa import avarage_weights, delete_prefix_from_chkp
from code_base.inefernce import BirdsInference
from code_base.utils import load_json, compose_submission_dataframe, groupby_np_array, stack_and_max_by_samples, write_json
from code_base.utils.metrics import score_numpy
%matplotlib inline


# Export Models

In [None]:
!ls -lt ../logdirs/ | head -20

In [None]:
bird2id_source = load_json("/home/vova/data/exps/birdclef_2024/class_mappings/bird2int_2024_PrevComp.json")
bird2id_target = load_json("/home/vova/data/exps/birdclef_2024/class_mappings/bird2int_2024.json")

id2bird_source = {v:k for k,v in bird2id_source.items()}
id2bird_target = {v:k for k,v in bird2id_target.items()}

REARRANGE_INDICES = np.array([
    bird2id_source[id2bird_target[i]] for i in range(len(id2bird_target))
]).astype(int)

MODEL_CLASS = WaveCNNAttenClasifier
TRAIN_PERIOD = 5
STRICT_LOAD = False

def prune_checkpoint_rule(inp_chkp):
    inp_chkp["head.attention.weight"] = inp_chkp["head.attention.weight"][REARRANGE_INDICES]
    inp_chkp["head.attention.bias"] = inp_chkp["head.attention.bias"][REARRANGE_INDICES]
    
    inp_chkp["head.fix_scale.weight"] = inp_chkp["head.fix_scale.weight"][REARRANGE_INDICES]
    inp_chkp["head.fix_scale.bias"] = inp_chkp["head.fix_scale.bias"][REARRANGE_INDICES]

    return inp_chkp

In [None]:
len(bird2id_source)

In [None]:
# EXP_NAME = "convnextv2_tiny_fcmae_ft_in22k_in1k_384_And_tf_efficientnetv2_b2_in1k_Pred075_Timewise025"
EXP_NAME = "maxvit_rmlp_pico_rw_256_sw_in1k_Exp_FullAtten_noamp_FixedAmp2Db_Amin1e6_64bs_5sec_MergedData_TimeFlip05_FormixupAlpha05NormedBinTgtEqW_balSamplWithRep_Radamlr3e4_CosBatchLR1e6_Epoch30_SpecAugV207_FocalBCELoss_Full_NoDuplsV2"
TRAIN_PERIOD = 5
print("Possible checkpoints:\n\n{}".format("\n".join(set([os.path.basename(el) for el in glob(f"../logdirs/{EXP_NAME}/checkpoints/*.ckpt") if "train" not in os.path.basename(el)]))))

In [None]:
conf_path = glob(f"../logdirs/{EXP_NAME}/code/*train_configs*.py")
assert len(conf_path) == 1
conf_path = conf_path[0]
!cat {conf_path}

In [None]:
MODELS = [
    {
        "model_config": dict(
            backbone="maxvit_rmlp_pico_rw_256.sw_in1k",
            mel_spec_paramms={
                "sample_rate": 32000,
                "n_mels": 128,
                "f_min": 20,
                "n_fft": 2048,
                "hop_length": 512,
                "normalized": True,
            },
            spec_resize=(256,256),
            atten_smoothing_config={
                "dropout": 0.1,
                "num_layers": 1,
                "n_steps": 64,
            },
            head_type="AttHeadSimplified",
            head_config={
                "p": 0.5,
                "num_class": 188,
                "omit_pooling": True,
                "output_type": "clipwise_pred_long",
            },
            exportable=True,
            fixed_amplitude_to_db=True,
            amin=1e-6
        ),
        "exp_name": EXP_NAME,
        "fold": None,
        "chkp_name":"last.ckpt",
        "swa_checkpoint_regex": None, #r'(?P<key>\w+)=(?P<value>[\d.]+)(?=\.ckpt|$)',
        "use_sigmoid": True,
        "swa_sort_rule": lambda x: -float(x["valid_roc_auc"]),
        "delete_prefix": "model.",
        "n_swa_models": 3,
        "model_output_key": None,
        # "prune_checkpoint_func": prune_checkpoint_rule
    },
]

In [None]:
def create_model_and_upload_chkp(
    model_class,
    model_config,
    model_device,
    model_chkp_root,
    model_chkp_basename=None,
    model_chkp_regex=None,
    delete_prefix=None,
    swa_sort_rule=None,
    n_swa_to_take=3,
    prune_checkpoint_func=None
):
    if model_chkp_basename is None:
        basenames = os.listdir(model_chkp_root)
        checkpoints = []
        for el in basenames:
            matches = re.findall(model_chkp_regex, el)
            if not matches:
                continue
            parsed_dict = {key: value for key, value in matches}
            parsed_dict["name"] = el
            checkpoints.append(parsed_dict)
        print("SWA checkpoints")
        pprint(checkpoints)
        checkpoints = sorted(checkpoints, key=swa_sort_rule)
        checkpoints = checkpoints[:n_swa_to_take]
        print("SWA sorted checkpoints")
        pprint(checkpoints)
        if len(checkpoints) > 1:
            checkpoints = [
                torch.load(os.path.join(model_chkp_root, el["name"]), map_location="cpu")["state_dict"] for el in checkpoints
            ]
            t_chkp = avarage_weights(
                nn_weights=checkpoints,
                delete_prefix=delete_prefix
            )
        else:
            chkp_path = os.path.join(model_chkp_root, checkpoints[0]["name"])
            print("vanilla model")
            print("Loading", chkp_path)
            t_chkp = torch.load(
                chkp_path, 
                map_location="cpu"
            )["state_dict"]
            if delete_prefix is not None:
                t_chkp = delete_prefix_from_chkp(t_chkp, delete_prefix)
    else:
        chkp_path = os.path.join(model_chkp_root, model_chkp_basename)
        print("vanilla model")
        print("Loading", chkp_path)
        t_chkp = torch.load(
            chkp_path, 
            map_location="cpu"
        )["state_dict"]
        if delete_prefix is not None:
            t_chkp = delete_prefix_from_chkp(t_chkp, delete_prefix)

    if prune_checkpoint_func is not None:
        t_chkp = prune_checkpoint_func(t_chkp)
    t_model = model_class(**model_config, device=model_device) 
    print("Missing keys: ", set(t_model.state_dict().keys()) - set(t_chkp))
    print("Extra keys: ",  set(t_chkp) - set(t_model.state_dict().keys()))
    t_model.load_state_dict(t_chkp, strict=False)
    t_model.eval()
    return t_model

In [None]:
model = [create_model_and_upload_chkp(
    model_class=MODEL_CLASS,
    model_config=model_config['model_config'],
    model_device="cpu",
    model_chkp_root=f"../logdirs/{model_config['exp_name']}/checkpoints",
    model_chkp_basename=model_config["chkp_name"] if model_config["swa_checkpoint_regex"] is None else None,
    model_chkp_regex=model_config.get("swa_checkpoint_regex"),
    swa_sort_rule=model_config.get("swa_sort_rule"),
    n_swa_to_take=model_config.get("n_swa_models", 3),
    delete_prefix=model_config.get("delete_prefix"),
    prune_checkpoint_func=model_config.get("prune_checkpoint_func")
) for model_config in MODELS]

In [None]:

# model = [create_model_and_upload_chkp(
#     model_class=MODEL_CLASS,
#     model_config=MODELS[0]['model_config'],
#     model_device="cpu",
#     model_chkp_root=f"../logdirs/{MODELS[0]['exp_name']}/fold_{m_i}/checkpoints",
#     # model_chkp_root=f"../logdirs/{CONFIG['exp_name']}/checkpoints",
#     model_chkp_basename=MODELS[0]["chkp_name"] if MODELS[0]["swa_checkpoint_regex"] is None else None,
#     model_chkp_regex=MODELS[0].get("swa_checkpoint_regex"),
#     swa_sort_rule=MODELS[0].get("swa_sort_rule"),
#     n_swa_to_take=MODELS[0].get("n_swa_models", 3),
#     delete_prefix=MODELS[0].get("delete_prefix"),
#     prune_checkpoint_func=MODELS[0].get("prune_checkpoint_func")
# ) for m_i in range(5)]


In [None]:
exportable_ensem = ONNXEnsemble(
    model_class=MODEL_CLASS,
    configs=[deepcopy(config['model_config']) for config in MODELS],
    # final_activation="softmax"
    # avarage_type="gaus"
    # weights=[0.75,0.4,0.5]
)
# exportable_ensem = ONNXEnsemble(
#     model_class=MODEL_CLASS,
#     configs=[deepcopy(MODELS[0]['model_config']) for _ in range(5)],
#     # final_activation="softmax"
#     # avarage_type="gaus"
#     # weights=[0.75,0.4,0.5]
# )
assert len(exportable_ensem.models) == len(model)
for model_id in range(len(model)):
    exportable_ensem.models[model_id].load_state_dict(model[model_id].state_dict())
exportable_ensem.eval()
convert_to_onnx(
    model_to_convert=exportable_ensem,
    sample_input=torch.randn(5, TRAIN_PERIOD * 32_000),
    base_path=f"../logdirs/{EXP_NAME}/onnx_ensem_logits",
    use_fp16=True,
    use_openvino=True,
    opset_version=14
    # base_path="test_onnx"
)

In [None]:
# !rm ../logdirs/convnext_small_fb_in22k_ft_in1k_384__convnextv2_tiny_fcmae_ft_in22k_in1k_384__eca_nfnet_l0_noval_v27_075Clipwise025TimeMax_GausMean -rf

In [None]:
# !rm ../logdirs/convnext_small_fb_in22k_ft_in1k_384__convnextv2_tiny_fcmae_ft_in22k_in1k_384__eca_nfnet_l0_noval_v27_075Clipwise025TimeMax_GausMean -rf

In [None]:
!ls -lt ../logdirs/ | head

# Check

In [None]:
os.listdir(f"../logdirs/{EXP_NAME}/")

In [None]:
core = ov.Core()
model = core.read_model(model=f"../logdirs/{EXP_NAME}/onnx_ensem_openvino_fp16/model_simpl.xml")
compiled_model = core.compile_model(model=model, device_name="CPU")

In [None]:
sample = np.random.randn(5, 32000*5).astype(np.float16)

In [None]:
compiled_model.output(0)

In [None]:
%%time
output = compiled_model([sample])[compiled_model.output(0)]

In [None]:
output.shape

In [None]:
output.dtype

In [None]:
sample = sample.astype(np.float16)

In [None]:
onnx_model = ort.InferenceSession(
    f"../logdirs/{EXP_NAME}/onnx_ensem_fp16/model_simpl.onnx"
)

In [None]:
wave = torch.from_numpy(np.random.randn(32000*5)).half()

In [None]:
wave.numpy()

In [None]:
%%time
out = onnx_model.run(
    None,
    {"input": sample}
)

In [None]:
out[0].shape

In [None]:
if CONFIG["run_test"]:
    inference_class_onnx = BirdsInference(
        device="cpu",
        verbose_tqdm=True,
        use_sigmoid=CONFIG["use_sigmoid"],
        model_output_key=CONFIG["model_output_key"],
    )
    onnx_model = ort.InferenceSession(
        f"../logdirs/{CONFIG['exp_name']}/onnx_ensem_2first_folds/model_simpl.onnx"
    )
    test_preds_onnx, test_preds_long_onnx, test_dfidx_onnx, test_end_onnx = inference_class_onnx.predict_test_loader(
        nn_models=onnx_model,
        data_loader=loader_test,
        is_onnx_model=True
    )
    test_pred_df_onnx = compose_submission_dataframe(
        probs=test_preds_onnx,
        dfidxs=test_dfidx_onnx,
        end_seconds=test_end_onnx,
        filenames=loader_test.dataset.df[loader_test.dataset.name_col].copy(),
        bird2id=bird2id
    )