In [None]:
# compare_ic.py
import os, re, pickle, math
from pathlib import Path
import numpy as np
import pandas as pd

FITDIR = Path("./PKL_DATA/fitdata")   # 如路径不同，改这里
OUTDIR = FITDIR                        # 输出 CSV 与该目录同级

# 匹配类似: fitresults_ALL_SUBJECTS_lossarbi.pkl
# group 只吃到“第一个下划线”为止；model 允许带下划线
FNAME_RE = re.compile(r"^fitresults_(?P<group>[^_]+)_(?P<model>.+)\.pkl$")


def _akaike_weights(delta):
    # delta: 一组 ΔAIC（或 ΔBIC）
    weights = np.exp(-0.5 * delta)
    s = weights.sum()
    return weights / s if s > 0 else np.ones_like(weights) / len(weights)

def load_all_results(fitdir=FITDIR):
    """
    返回 dict[group] -> DataFrame
    DataFrame 列含: subj, model, nll, k, aic, bic
    """
    per_group_rows = {}
    for pkl in fitdir.glob("fitresults_*.pkl"):
        m = FNAME_RE.match(pkl.name)
        if not m:
            continue
        group = m.group("group")
        model = m.group("model")

        with open(pkl, "rb") as f:
            lst = pickle.load(f)   # list[ dict ], 每个被试一个 dict

        rows = []
        for i, res in enumerate(lst):
            # 兼容：有的版本 fl() 里没把 subj 名字塞进 result
            subj = res.get("subj", f"subj_{i:03d}")
            rows.append({
                "subj": subj,
                "model": model,
                "nll": float(res["negloglike"]),
                "k": int(res["n_param"]),
                "aic": float(res["aic"]),
                "bic": float(res["bic"]),
            })

        df_new = pd.DataFrame(rows)
        if group not in per_group_rows:
            per_group_rows[group] = df_new
        else:
            per_group_rows[group] = pd.concat([per_group_rows[group], df_new], ignore_index=True)

    # 统一排序
    for g in per_group_rows:
        per_group_rows[g] = per_group_rows[g].sort_values(["subj","model"]).reset_index(drop=True)

    return per_group_rows

def add_deltas_and_weights(df):
    """
    对单个 group 的 DF 添加:
      ΔAIC, ΔBIC, wAIC, wBIC 以及 winner 列
    """
    def _per_subj(block):
        block = block.copy()
        # ΔAIC/ΔBIC
        block["dAIC"] = block["aic"] - block["aic"].min()
        block["dBIC"] = block["bic"] - block["bic"].min()
        # 权重
        block["wAIC"] = _akaike_weights(block["dAIC"].to_numpy())
        block["wBIC"] = _akaike_weights(block["dBIC"].to_numpy())
        # 获胜模型
        block["winner_AIC"] = block.loc[block["aic"].idxmin(), "model"]
        block["winner_BIC"] = block.loc[block["bic"].idxmin(), "model"]
        return block

    return df.groupby("subj", group_keys=False).apply(_per_subj)

def summarize_winners(df):
    """
    返回两个 Series：各模型在 AIC / BIC 下当选次数
    """
    aic_winners = df.groupby(["subj"]).apply(lambda g: g.loc[g["aic"].idxmin(), "model"]).value_counts()
    bic_winners = df.groupby(["subj"]).apply(lambda g: g.loc[g["bic"].idxmin(), "model"]).value_counts()
    return aic_winners, bic_winners

if __name__ == "__main__":
    groups = load_all_results(FITDIR)

    for gname, df in groups.items():
        df2 = add_deltas_and_weights(df)

        # 明细表（每个 subject × 模型一行）
        detail_path = OUTDIR / f"IC_detail_{gname}.csv"
        df2.to_csv(detail_path, index=False)

        # 每个 subject 的最优模型统计
        aic_w, bic_w = summarize_winners(df2)
        summary = pd.DataFrame({
            "AIC_wins": aic_w,
            "BIC_wins": bic_w
        }).fillna(0).astype(int).sort_values(["AIC_wins","BIC_wins"], ascending=False)

        summary_path = OUTDIR / f"IC_summary_{gname}.csv"
        summary.to_csv(summary_path)

        # 也打印到屏幕
        print(f"\n=== Group: {gname} ===")
        print("Top models by AIC/BIC (counts):")
        print(summary)
        print(f"\nSaved detail -> {detail_path}")
        print(f"Saved summary -> {summary_path}")
