In [None]:
import os

import pandas as pd
import numpy as np
import seaborn as sns
import random

from matplotlib import pyplot as plt
from scipy.special import expit
from tqdm import tqdm

from code_base.utils.metrics import score_numpy

In [None]:
BIRDS = [
    '1139490', '1192948', '1194042', '126247', '1346504', '134933', '135045', '1462711', '1462737', '1564122', '21038', '21116',
    '21211', '22333', '22973', '22976', '24272', '24292', '24322', '41663', '41778', '41970', '42007', '42087', '42113', '46010',
    '47067', '476537', '476538', '48124', '50186', '517119', '523060', '528041', '52884', '548639', '555086', '555142', '566513',
    '64862', '65336', '65344', '65349', '65373', '65419', '65448', '65547', '65962', '66016', '66531', '66578', '66893', '67082',
    '67252', '714022', '715170', '787625', '81930', '868458', '963335', 
    'amakin1', 'amekes', 'ampkin1', 'anhing', 'babwar', 'bafibi1', 'banana', 'baymac', 'bbwduc', 'bicwre1', 'bkcdon', 'bkmtou1', 
    'blbgra1', 'blbwre1', 'blcant4', 'blchaw1', 'blcjay1', 'blctit1', 'blhpar1', 'blkvul', 'bobfly1', 'bobher1', 'brtpar1', 'bubcur1',
    'bubwre1', 'bucmot3', 'bugtan', 'butsal1', 'cargra1', 'cattyr', 'chbant1', 'chfmac1', 'cinbec1', 'cocher1', 'cocwoo1', 'colara1',
    'colcha1', 'compau', 'compot1', 'cotfly1', 'crbtan1', 'crcwoo1', 'crebob1', 'cregua1', 'creoro1', 'eardov1', 'fotfly', 'gohman1',
    'grasal4', 'grbhaw1', 'greani1', 'greegr', 'greibi1', 'grekis', 'grepot1', 'gretin1', 'grnkin', 'grysee1', 'gybmar', 'gycwor1', 
    'labter1', 'laufal1', 'leagre', 'linwoo1', 'littin1', 'mastit1', 'neocor', 'norscr1', 'olipic1', 'orcpar', 'palhor2', 'paltan1',
    'pavpig2', 'piepuf1', 'pirfly1', 'piwtyr1', 'plbwoo1', 'plctan1', 'plukit1', 'purgal2', 'ragmac1', 'rebbla1', 'recwoo1', 'rinkin1',
    'roahaw', 'rosspo1', 'royfly1', 'rtlhum', 'rubsee1', 'rufmot1', 'rugdov', 'rumfly1', 'ruther1', 'rutjac1', 'rutpuf1', 'saffin',
    'sahpar1', 'savhaw1', 'secfly1', 'shghum1', 'shtfly1', 'smbani', 'snoegr', 'sobtyr1', 'socfly1', 'solsan', 'soulap1', 'spbwoo1',
    'speowl1', 'spepar1', 'srwswa1', 'stbwoo2', 'strcuc1', 'strfly1', 'strher', 'strowl1', 'tbsfin1', 'thbeup1', 'thlsch3', 'trokin',
    'tropar', 'trsowl', 'turvul', 'verfly', 'watjac1', 'wbwwre1', 'whbant1', 'whbman1', 'whfant1', 'whmtyr1', 'whtdov', 'whttro1',
    'whwswa1', 'woosto', 'y00678', 'yebela1', 'yebfly1', 'yebsee1', 'yecspi2', 'yectyr1', 'yehbla2', 'yehcar1', 'yelori1', 'yeofly1',
    'yercac1', 'ywcpar', 
]

In [None]:
def convert_fernando_to_vova(input_df):
    output_df = input_df.copy()
    output_df["row_id"] = output_df.apply(lambda row: f"{row['filename']}_{(row['order'] + 1) * 5}", axis=1)
    output_df = output_df.drop(columns=["filename", "order"])
    output_df = output_df[["row_id"] + [col for col in output_df.columns[:-1]]]

    output_df.columns = ["row_id"] + [BIRDS[int(col[1:])] for col in output_df.columns[1:]]
    return output_df

# Load Data

In [None]:
ROOT = "/gpfs/space/projects/BetterMedicine/volodymyr1/exps/bird_clef_2025/kaggle_datasets/fernando_pseudo"

In [None]:
sorted(os.listdir(ROOT))

In [None]:
all_models = [
    el.split("_")[1].replace(".pkl", "") for el in os.listdir(ROOT) if el.startswith("res")
]
all_models

In [None]:
# sanity check 
for model in tqdm(all_models):
    probs = pd.read_pickle(
        os.path.join(ROOT, "res_" + model + ".pkl")
    )
    logits = pd.read_pickle(
        os.path.join(ROOT, "oof_unlabeled_" + model + "_f0.pkl")
    )
    if not np.allclose(expit(logits), probs.iloc[:,2:].values):
        print("Not close logits and probs for", model)

In [None]:
vova_pseudo = pd.read_csv(
    "../data/pseudo/four_ecas_from_GoodPretrains_879_869_867_866/v0.csv"
)

In [None]:
vova_pseudo

In [None]:
df_dict = dict()
for model_name in tqdm(all_models):
    df_dict[model_name] = convert_fernando_to_vova(pd.read_pickle(os.path.join(ROOT, "res_" + model_name + ".pkl")))
    df_dict[model_name] = df_dict[model_name][vova_pseudo.columns]

    df_dict[model_name] = df_dict[model_name].set_index("row_id").loc[vova_pseudo["row_id"]].reset_index()

# Compute metrics

In [None]:
df_dict.keys()

In [None]:
df_dict["eca.124"]

In [None]:
vova_pseudo_hard = vova_pseudo.copy()
vova_pseudo_hard.iloc[:,1:] = (vova_pseudo_hard.iloc[:,1:].values > 0.5).astype(int)

In [None]:
for model_name in all_models:
    print(f"Score for {model_name}:", score_numpy(
            vova_pseudo_hard.iloc[:,1:].values.astype(int), 
            df_dict[model_name].iloc[:,1:].values
    ))