In [None]:
!nvidia-smi

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

import pandas as pd
import numpy as np
import librosa
import seaborn as sns
import os
import json
import IPython.display as ipd
import soundfile as sf
import torch
import h5py
import onnxruntime as ort
import openvino as ov

from glob import glob
from tqdm import tqdm
from matplotlib import pyplot as plt
from itertools import chain
from os.path import join as pjoin
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
from copy import deepcopy

# from code_base.utils import parallel_librosa_load, groupby_np_array, stack_and_max_by_samples, macro_f1_similarity, N_CLASSES_2021_2022, N_CLASSES_2021, comp_metric, N_CLASSES_XC_LIGIT_SHORTEN, N_CLASSES_XC_LIGIT_EVEN_SHORTEN
# from code_base.utils.constants import SAMPLE_RATE
from code_base.utils.onnx_utils import ONNXEnsemble, convert_to_onnx
from code_base.models import WaveCNNClasifier, WaveCNNAttenClasifier
from code_base.datasets import WaveDataset, WaveAllFileDataset
from code_base.utils.swa import avarage_weights, delete_prefix_from_chkp
from code_base.inefernce import BirdsInference
from code_base.utils import load_json, compose_submission_dataframe, groupby_np_array, stack_and_max_by_samples, write_json
from code_base.utils.metrics import score_numpy
%matplotlib inline


# Export Models

In [None]:
!ls -lt ../logdirs/ | head -20

In [None]:
bird2id_source = load_json("/home/vova/data/exps/birdclef_2024/class_mappings/bird2int_2024_PrevComp.json")
bird2id_target = load_json("/home/vova/data/exps/birdclef_2024/class_mappings/bird2int_2024.json")

id2bird_source = {v:k for k,v in bird2id_source.items()}
id2bird_target = {v:k for k,v in bird2id_target.items()}

REARRANGE_INDICES = np.array([
    bird2id_source[id2bird_target[i]] for i in range(len(id2bird_target))
]).astype(int)

MODEL_CLASS = WaveCNNAttenClasifier
TRAIN_PERIOD = 5
STRICT_LOAD = False

def prune_checkpoint_rule(inp_chkp):
    inp_chkp["head.attention.weight"] = inp_chkp["head.attention.weight"][REARRANGE_INDICES]
    inp_chkp["head.attention.bias"] = inp_chkp["head.attention.bias"][REARRANGE_INDICES]
    
    inp_chkp["head.fix_scale.weight"] = inp_chkp["head.fix_scale.weight"][REARRANGE_INDICES]
    inp_chkp["head.fix_scale.bias"] = inp_chkp["head.fix_scale.bias"][REARRANGE_INDICES]

    return inp_chkp

In [None]:
# EXP_NAME = "eca_nfnet_l0_Exp_noamp_64bs_5sec_PrevCompXCScoredDataNoSecLab_BackGroundSoundScapeP05_mixupP05_RandomFiltering_balancedSampler_Adamlr1e3_TailCosBatchLR1e6_Epoch40_SpecAugV1_FocalLoss_DPR02_Full_NoDuplsV1"
# TRAIN_PERIOD = 5
# print("Possible checkpoints:\n\n{}".format("\n".join(set([os.path.basename(el) for el in glob(f"../logdirs/{EXP_NAME}/checkpoints/*.ckpt") if "train" not in os.path.basename(el)]))))

In [None]:
# conf_path = glob(f"../logdirs/{EXP_NAME}/code/*train_configs*.py")
# assert len(conf_path) == 1
# conf_path = conf_path[0]
# !cat {conf_path}

In [None]:
MODELS = [
    {
        "model_config": dict(
            backbone="convnextv2_tiny.fcmae_ft_in22k_in1k_384",
            mel_spec_paramms={
                "sample_rate": 32000,
                "n_mels": 128,
                "f_min": 20,
                "n_fft": 2048,
                "hop_length": 512,
                "normalized": True,
            },
            head_config={
                "p": 0.5,
                "num_class": 188,
                "train_period": TRAIN_PERIOD,
                "infer_period": TRAIN_PERIOD,
                "output_type": "clipwise_pred_long",
                # "output_type": "clipwise_logits_long"
            },
            exportable=True,
        ),
        "exp_name": "convnextv2_tiny_fcmae_ft_in22k_in1k_384_Exp_noamp_64bs_5sec_PrevCompXCScoredDataNoSecLab_BackGroundSoundScapeP05_mixupP05_RandomFiltering_balancedSampler_Adamlr1e4_TailCosBatchLR1e6_Epoch30_FocalLoss_Full_NoDuplsV1",
        "fold": None,
        "chkp_name":"last.ckpt",
        "swa_checkpoint_regex": None, #r'(?P<key>\w+)=(?P<value>[\d.]+)(?=\.ckpt|$)',
        "use_sigmoid": True,
        "swa_sort_rule": lambda x: -float(x["valid_tresh01_dice3d"]),
        "delete_prefix": "model.",
        "n_swa_models": 3,
        "model_output_key": None,
        # "prune_checkpoint_func": prune_checkpoint_rule
    },
    {
        "model_config": dict(
            backbone="convnextv2_tiny.fcmae_ft_in22k_in1k_384",
            mel_spec_paramms={
                "sample_rate": 32000,
                "n_mels": 128,
                "f_min": 20,
                "n_fft": 2048,
                "hop_length": 512,
                "normalized": True,
            },
            spec_augment_config={
                "freq_mask": {
                    "mask_max_length": 10,
                    "mask_max_masks": 3,
                    "p": 0.3,
                    "inplace": True,
                },
                "time_mask": {
                    "mask_max_length": 20,
                    "mask_max_masks": 3,
                    "p": 0.3,
                    "inplace": True,
                },
            },
            head_config={
                "p": 0.5,
                "num_class": 188,
                "train_period": TRAIN_PERIOD,
                "infer_period": TRAIN_PERIOD,
                "output_type": "clipwise_pred_long",
                # "output_type": "clipwise_logits_long"
            },
            exportable=True,
        ),
        "exp_name": "convnext_small_fb_in22k_ft_in1k_384_Exp_noamp_64bs_5sec_PrevCompXCScoredDataNoSecLab_BackGroundSoundScapeP05_mixupP05_RandomFiltering_balancedSampler_Adamlr1e4_TailCosBatchLR1e6_Epoch40_SpecAugV1_FocalLoss_Full_NoDuplsV1",
        "fold": None,
        "chkp_name":"last.ckpt",
        "swa_checkpoint_regex": None, #r'(?P<key>\w+)=(?P<value>[\d.]+)(?=\.ckpt|$)',
        "use_sigmoid": True,
        "swa_sort_rule": lambda x: -float(x["valid_tresh01_dice3d"]),
        "delete_prefix": "model.",
        "n_swa_models": 3,
        "model_output_key": None,
        # "prune_checkpoint_func": prune_checkpoint_rule
    },
    {
        "model_config": dict(
            backbone="eca_nfnet_l0",
            add_backbone_config={"drop_path_rate": 0.2},
            mel_spec_paramms={
                "sample_rate": 32000,
                "n_mels": 128,
                "f_min": 20,
                "n_fft": 2048,
                "hop_length": 512,
                "normalized": True,
            },
            spec_augment_config={
                "freq_mask": {
                    "mask_max_length": 10,
                    "mask_max_masks": 3,
                    "p": 0.3,
                    "inplace": True,
                },
                "time_mask": {
                    "mask_max_length": 20,
                    "mask_max_masks": 3,
                    "p": 0.3,
                    "inplace": True,
                },
            },
            head_config={
                "p": 0.5,
                "num_class": 188,
                "train_period": TRAIN_PERIOD,
                "infer_period": TRAIN_PERIOD,
                "output_type": "clipwise_pred_long",
                # "output_type": "clipwise_logits_long"
            },
            exportable=True,
        ),
        "exp_name": "eca_nfnet_l0_Exp_noamp_64bs_5sec_PrevCompXCScoredDataNoSecLab_BackGroundSoundScapeP05_mixupP05_RandomFiltering_balancedSampler_Adamlr1e3_TailCosBatchLR1e6_Epoch40_SpecAugV1_FocalLoss_DPR02_Full_NoDuplsV1",
        "fold": None,
        "chkp_name":"last.ckpt",
        "swa_checkpoint_regex": None, #r'(?P<key>\w+)=(?P<value>[\d.]+)(?=\.ckpt|$)',
        "use_sigmoid": True,
        "swa_sort_rule": lambda x: -float(x["valid_tresh01_dice3d"]),
        "delete_prefix": "model.",
        "n_swa_models": 3,
        "model_output_key": None,
        # "prune_checkpoint_func": prune_checkpoint_rule
    },
]

INFERENCE_CONFIG = {
    # Inference Class
    "use_sigmoid": False,
    # Data
    "test_data_root":"/home/vova/data/exps/birdclef_2024/birdclef_2024/unlabeled_soundscapes/*.ogg",
    "label_map_data_path": "/home/vova/data/exps/birdclef_2024/class_mappings/bird2int_2024.json",
    "scored_birds_path":"/home/vova/data/exps/birdclef_2024/scored_birds/sb_2024.json", 
    "lookback":None,
    "lookahead":None,
    "segment_len":5,
    "step": None,
    "late_normalize": True,

    "model_output_key": None,
}

In [None]:
def create_model_and_upload_chkp(
    model_class,
    model_config,
    model_device,
    model_chkp_root,
    model_chkp_basename=None,
    model_chkp_regex=None,
    delete_prefix=None,
    swa_sort_rule=None,
    n_swa_to_take=3,
    prune_checkpoint_func=None
):
    if model_chkp_basename is None:
        basenames = os.listdir(model_chkp_root)
        checkpoints = []
        for el in basenames:
            matches = re.findall(model_chkp_regex, el)
            if not matches:
                continue
            parsed_dict = {key: value for key, value in matches}
            parsed_dict["name"] = el
            checkpoints.append(parsed_dict)
        print("SWA checkpoints")
        pprint(checkpoints)
        checkpoints = sorted(checkpoints, key=swa_sort_rule)
        checkpoints = checkpoints[:n_swa_to_take]
        print("SWA sorted checkpoints")
        pprint(checkpoints)
        if len(checkpoints) > 1:
            checkpoints = [
                torch.load(os.path.join(model_chkp_root, el["name"]), map_location="cpu")["state_dict"] for el in checkpoints
            ]
            t_chkp = avarage_weights(
                nn_weights=checkpoints,
                delete_prefix=delete_prefix
            )
        else:
            chkp_path = os.path.join(model_chkp_root, checkpoints[0]["name"])
            print("vanilla model")
            print("Loading", chkp_path)
            t_chkp = torch.load(
                chkp_path, 
                map_location="cpu"
            )["state_dict"]
            if delete_prefix is not None:
                t_chkp = delete_prefix_from_chkp(t_chkp, delete_prefix)
    else:
        chkp_path = os.path.join(model_chkp_root, model_chkp_basename)
        print("vanilla model")
        print("Loading", chkp_path)
        t_chkp = torch.load(
            chkp_path, 
            map_location="cpu"
        )["state_dict"]
        if delete_prefix is not None:
            t_chkp = delete_prefix_from_chkp(t_chkp, delete_prefix)

    if prune_checkpoint_func is not None:
        t_chkp = prune_checkpoint_func(t_chkp)
    t_model = model_class(**model_config, device=model_device) 
    print("Missing keys: ", set(t_model.state_dict().keys()) - set(t_chkp))
    print("Extra keys: ",  set(t_chkp) - set(t_model.state_dict().keys()))
    t_model.load_state_dict(t_chkp, strict=False)
    t_model.eval()
    return t_model

In [None]:
model = [create_model_and_upload_chkp(
    model_class=MODEL_CLASS,
    model_config=model_config['model_config'],
    model_device="cuda",
    model_chkp_root=f"../logdirs/{model_config['exp_name']}/checkpoints",
    model_chkp_basename=model_config["chkp_name"] if model_config["swa_checkpoint_regex"] is None else None,
    model_chkp_regex=model_config.get("swa_checkpoint_regex"),
    swa_sort_rule=model_config.get("swa_sort_rule"),
    n_swa_to_take=model_config.get("n_swa_models", 3),
    delete_prefix=model_config.get("delete_prefix"),
    prune_checkpoint_func=model_config.get("prune_checkpoint_func")
) for model_config in MODELS]

In [None]:
len(model)

# Prepare Data

In [None]:
bird2id = load_json(INFERENCE_CONFIG["label_map_data_path"])

test_au_pathes = glob(INFERENCE_CONFIG["test_data_root"])#[:100]

test_df = pd.DataFrame({
    "filename": test_au_pathes,
    "duration_s": [librosa.get_duration(filename=el) for el in tqdm(test_au_pathes)]
})

In [None]:
ds_config_test = {
   "root": "",
   "label_str2int_mapping_path": INFERENCE_CONFIG["label_map_data_path"],
   "n_cores": 64,
   "use_audio_cache": True,
   "test_mode": True,
   "segment_len": INFERENCE_CONFIG["segment_len"],
   "lookback":INFERENCE_CONFIG["lookback"],
   "lookahead":INFERENCE_CONFIG["lookahead"],
    "sample_id": None,
    "late_normalize": INFERENCE_CONFIG["late_normalize"],
    "step": INFERENCE_CONFIG["step"],
    "validate_sr": 32_000,
    "verbose": False
}
loader_config = {
    "batch_size": 64,
    "drop_last": False,
    "shuffle": False,
    "num_workers": 0,
}

In [None]:
ds_test = WaveAllFileDataset(df=test_df, **ds_config_test)
loader_test = torch.utils.data.DataLoader(
    ds_test,
    **loader_config,
)

# Ineference Class

In [None]:
inference_class = BirdsInference(
    device="cuda",
    verbose_tqdm=True,
    use_sigmoid=INFERENCE_CONFIG["use_sigmoid"],
    model_output_key=INFERENCE_CONFIG["model_output_key"],
)

# Prediction

In [None]:
test_preds, test_dfidx, test_end = inference_class.predict_test_loader(
    nn_models=model,
    data_loader=loader_test
)
test_pred_df = compose_submission_dataframe(
    probs=test_preds,
    dfidxs=test_dfidx,
    end_seconds=test_end,
    filenames=loader_test.dataset.df[loader_test.dataset.name_col].copy(),
    bird2id=bird2id
)

In [None]:
plt.title("Most 'Probable' class probability distribution")
plt.hist(test_preds.max(axis=1), bins=30)
plt.show()

print(
    "Max Prob: ", test_preds.max(), 
    "Min Prob: ", test_preds.min(),
    "Median Prob: ", np.median(test_preds)
)

In [None]:
test_pred_df

In [None]:
test_pred_df.to_csv(
    "../pseudo/v1/prod_df.csv", index=False
)

# Prepare Ready2Use DF

In [None]:
CLASSES = test_pred_df.columns[1:].to_list()

In [None]:
def get_aranged_classes(probs, ordered_classes, tresh):
    probs = probs.values
    accepted_classes_idx = np.where(probs > tresh)[0]
    if len(accepted_classes_idx) == 0:
        return []
    elif len(accepted_classes_idx) == 1:
        return [ordered_classes[accepted_classes_idx[0]]]
    else:
        accepted_classes_names_and_probs = [
            (ordered_classes[idx], probs[idx]) for idx in accepted_classes_idx
        ]
        accepted_classes_names_and_probs = sorted(
            accepted_classes_names_and_probs,
            key=lambda x: -x[1]
        )
        return [el[0] for el in accepted_classes_names_and_probs]

## V1.1

In [None]:
test_pred_df["all_labels"] = test_pred_df[CLASSES].apply(lambda row: get_aranged_classes(row, ordered_classes=CLASSES, tresh=0.7), axis=1)

In [None]:
test_pred_df["primary_label"] = test_pred_df["all_labels"].apply(lambda x: x[0] if len(x) > 0 else None)
test_pred_df["secondary_labels"] = test_pred_df["all_labels"].apply(lambda x: x[1:])

In [None]:
test_pred_df["filename"] = test_pred_df["row_id"].apply(lambda x: x.split("_")[0] + ".ogg")

In [None]:
test_pred_df

In [None]:
positive_test_pred_df = test_pred_df[~test_pred_df["primary_label"].isna()].reset_index(drop=True)

In [None]:
positive_test_pred_df = positive_test_pred_df[[
    "row_id", "all_labels", "primary_label", "secondary_labels", "filename"
]]

In [None]:
positive_test_pred_df["stratify_col"] = positive_test_pred_df["primary_label"] + "_soundscapes"

In [None]:
positive_test_pred_df

In [None]:
positive_test_pred_df.to_csv(
    "../pseudo/v1/positive_filtered_07.csv",
    index=False
)

In [None]:
positive_test_pred_df["stratify_col"] = positive_test_pred_df["stratify_col"].apply(lambda x: x.split("_")[0])

In [None]:
positive_test_pred_df.to_csv(
    "../pseudo/v1/positive_filtered_07_stratifyV2.csv",
    index=False
)