In [None]:
from pathlib import Path
from typing import List, Tuple
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.metrics import (
    recall_score,
    precision_score,
    f1_score,
    accuracy_score,
)

In [None]:
def metrics_gt_upsampled(
    gt_csv: str | Path,
    pred_csv: str | Path,
    repeat_factor: int = 3,        # 10 Hz ➜ 30 Hz
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Parameters
    ----------
    gt_csv   : 10 Hz ground-truth CSV (columns: time, trophalaxis)
    pred_csv : 30 Hz Keypoint-MoSeq CSV (column: syllable)
    repeat_factor : 10→30 Hz なら 3

    Returns
    -------
    df_metrics  : 各 syllable の指標 (DataFrame)
    summary_df  : 1 行だけの GT 割合サマリ (DataFrame)
    """
    gt_csv, pred_csv = Path(gt_csv), Path(pred_csv)
    gt   = pd.read_csv(gt_csv)
    pred = pd.read_csv(pred_csv)

    # ---- 10 Hz ➜ 30 Hz アップサンプル -------------------------------------
    offsets = np.arange(repeat_factor) / (repeat_factor * 10)  # 0, 0.0333…
    gt_up = pd.DataFrame(
        np.repeat(gt.values, repeat_factor, axis=0),
        columns=gt.columns,
    )
    gt_up["time"] = np.add.outer(gt["time"].values, offsets).ravel()[: len(gt_up)]

    # ---- 長さをそろえる ----------------------------------------------------
    n = min(len(gt_up), len(pred))
    gt_up = gt_up.iloc[:n].reset_index(drop=True)
    syll  = pred["syllable"].iloc[:n].reset_index(drop=True)
    wing  = gt_up["trophalaxis"]

    # ---- syllable ごとの指標 ----------------------------------------------
    rows = []
    for s in sorted(syll.unique()):
        mask = syll == s
        y    = mask.astype(int)
        rows.append(
            {
                "gt_file":  gt_csv.name,
                "pred_file": pred_csv.name,
                "syllable": int(s),
                "Recall":   recall_score(wing, y, zero_division=0),
                "Precision":precision_score(wing, y, zero_division=0),
                "F1":       f1_score(wing, y, zero_division=0),
                "Accuracy": accuracy_score(wing, y),
                "WingOverlap(%)": wing[mask].sum() / mask.sum() * 100
                                  if mask.sum() else 0.0,
            }
        )

    df_metrics = pd.DataFrame(rows).sort_values(
        ["gt_file", "pred_file", "syllable"]
    ).reset_index(drop=True)

    # ---- GT 全体の wing / no-wing 割合 ------------------------------------
    total = len(gt_up)
    wing_frames = wing.sum()
    summary_df = pd.DataFrame(
        {
            "gt_file":            [gt_csv.name],
            "pred_file":          [pred_csv.name],
            "trophalaxis(%)":  [wing_frames / total * 100],
            "no_wing(%)":         [(total - wing_frames) / total * 100],
            "total_frames":       [total],
        }
    )

    return df_metrics, summary_df

In [None]:
def load_aligned_vectors(
    gt_csv: str | Path,
    pred_csv: str | Path,
    repeat_factor: int = 3,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Ground-truth と予測を同じ長さにそろえた 1 次元配列を返す
    """
    gt, pred = pd.read_csv(gt_csv), pd.read_csv(pred_csv)

    offsets = np.arange(repeat_factor) / (repeat_factor * 10)
    gt_up = pd.DataFrame(np.repeat(gt.values, repeat_factor, axis=0),
                         columns=gt.columns)
    gt_up["time"] = np.add.outer(gt["time"].values, offsets).ravel()[: len(gt_up)]

    n = min(len(gt_up), len(pred))
    return (
        gt_up["trophalaxis"].iloc[:n].to_numpy(dtype=int),
        pred["syllable"].iloc[:n].to_numpy(dtype=int),
    )

In [None]:
def analyse_all(
    file_list: List[Tuple[str, str]],
    repeat_factor: int = 3,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Parameters
    ----------
    file_list : [(moseq_csv_filename, boris_csv_filename), ...]
                ※ディレクトリは関数内で付与
    Returns
    -------
    result_df  : 各動画 × syllable の指標
    summary_df : 各動画の wing / no-wing 割合
    global_df  : 全動画まとめた syllable 指標
    """
    result_frames, summary_frames = [], []
    all_wing, all_syll = [], []

    for moseq_name, boris_name, _ in tqdm(file_list, desc="scoring"):
        moseq_path = Path("../data/keypoint_moseq_data") / moseq_name
        boris_path = Path("../data/BORIS_data")          / boris_name

        # ---- per-video 指標 ------------------------------------------------
        metrics_df, summary_df = metrics_gt_upsampled(
            boris_path, moseq_path, repeat_factor
        )
        result_frames.append(metrics_df)
        summary_frames.append(summary_df)

        # ---- 全動画統合用ベクトル ----------------------------------------
        wing_vec, syll_vec = load_aligned_vectors(
            boris_path, moseq_path, repeat_factor
        )
        all_wing.append(wing_vec)
        all_syll.append(syll_vec)

    # ---- 動画ごと DataFrame を連結 ----------------------------------------
    result_df  = pd.concat(result_frames,  ignore_index=True)
    summary_df = pd.concat(summary_frames, ignore_index=True)

    # ---- 全動画統合指標 ----------------------------------------------------
    wing_all = np.concatenate(all_wing)
    syll_all = np.concatenate(all_syll)

    global_rows = []
    for s in np.unique(syll_all):
        mask = syll_all == s
        y    = mask.astype(int)
        global_rows.append(
            {
                "syllable": int(s),
                "Recall":   recall_score(wing_all, y, zero_division=0),
                "Precision":precision_score(wing_all, y, zero_division=0),
                "F1":       f1_score(wing_all, y, zero_division=0),
                "Accuracy": accuracy_score(wing_all, y),
                "WingOverlap(%)": wing_all[mask].sum() / mask.sum() * 100,
            }
        )

    global_df = pd.DataFrame(global_rows).sort_values("syllable").reset_index(drop=True)
    return result_df, summary_df, global_df

In [None]:
df_file_list = pd.read_csv('../data/file_list.csv')
file_list = df_file_list.values.tolist()
file_list = np.array(file_list)

In [None]:


result_df, summary_df, global_df = analyse_all(file_list)

# ------------ 結果表示 -------------
print("\n--- per-video syllable metrics (head) ---")
print(result_df.head())
print("\n--- per-video summary (head) ---")
print(summary_df.head())
print("\n=== Global syllable metrics across ALL videos ===")
print(global_df.head())

In [None]:
result_df.to_csv('../outputs/trophallaxis_compared_keypoint_moseq_data.csv')
summary_df.to_csv('../outputs/summary_trophallaxis_compared_keypoint_moseq_data.csv')
global_df.to_csv('../outputs/global_trophallaxis_compared_keypoint_moseq_data.csv')