In [None]:
from __future__ import annotations

import os, re, time
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple

import numpy as np
import pandas as pd

import torch
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt

# ---------------------------
# Project imports (reuse code)
# ---------------------------
def _find_project_root(start: Path) -> Path:
    cur = start.resolve()
    for _ in range(8):
        if (cur / "src").exists():
            return cur
        cur = cur.parent
    raise RuntimeError("Cannot find project root containing 'src' folder.")

PROJECT_ROOT = _find_project_root(Path.cwd())
if str(PROJECT_ROOT) not in os.sys.path:
    os.sys.path.insert(0, str(PROJECT_ROOT))
print("PROJECT_ROOT =", PROJECT_ROOT)

from src.config import TrainConfig
from src.dataset_utils import make_data_loaders
from src.model_utils import get_model
from src.io_utils import load_checkpoint

# ---------------------------
# USER CONFIG
# ---------------------------
BASE_OUTPUT_DIR = Path("outputs/cifar10_iter_ema_voting_sweep")  # <-- chỉnh theo folder outputs của bạn
DATA_DIR = Path("data")                                         # <-- chỉnh nếu khác

# EXPERIMENT_PLAN: List[Dict[str, Any]] = [
#     {"noise_ratios": [0.8], "alphas": [0.3, 0.2, 0.6, 0.8], "modes": ["ema_hard", "vote_match_noisy", "vote_relabel"]},
#     {"noise_ratios": [0.6], "alphas": [0.3, 0.2, 0.6, 0.8], "modes": ["ema_hard", "vote_match_noisy", "vote_relabel"]},
#     {"noise_ratios": [0.4], "alphas": [0.3, 0.2, 0.6, 0.8], "modes": ["ema_hard", "vote_match_noisy", "vote_relabel"]},
#     {"noise_ratios": [0.2], "alphas": [0.3, 0.2, 0.6, 0.8], "modes": ["ema_hard", "vote_match_noisy", "vote_relabel"]},
# ]

EXPERIMENT_PLAN  = [
    {
        "noise_ratios": [0.8],
        # "alphas": [0.3, 0.2, 0.6, 0.8],
        "alphas": [0.3],
        "modes": ["vote_match_noisy", "vote_relabel", "ema_hard"],
    },
    {
        "noise_ratios": [0.6],
        # "alphas": [0.6, 0.3, 0.2,  0.8],
        "alphas": [0.6],
        "modes": ["vote_match_noisy", "vote_relabel", "ema_hard"],

    },
    {
        "noise_ratios": [0.4],
        # "alphas": [0.8 ,0.3, 0.2, 0.6],        
        "alphas": [0.8],        
        "modes": ["vote_match_noisy", "vote_relabel", "ema_hard"],

    },
    {
        "noise_ratios": [0.2],
        # "alphas": [0.2 , 0.3 , 0.6, 0.8],
        "alphas": [0.2],
        "modes": ["vote_match_noisy", "vote_relabel", "ema_hard"],

    },
]

# metric dùng để chọn "best iteration" trong report tổng hợp
# ✅ nên set theo config training của bạn để tránh lệch logic
# Ví dụ: nếu training early_stop_metric="val_acc_noisy" thì để "val_acc_noisy_posthoc"
BEST_PICK_KEY = "val_acc_noisy_posthoc"   # hoặc "val_acc_orig_posthoc", "test_acc_orig_posthoc", ...

# ---------------------------
# Utils
# ---------------------------
def _ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def _format_seconds(sec: float) -> str:
    sec = float(sec)
    if sec < 60:
        return f"{sec:.1f}s"
    if sec < 3600:
        return f"{sec/60:.1f}m"
    return f"{sec/3600:.2f}h"

def _iter_index_from_name(name: str) -> Optional[int]:
    m = re.search(r"iteration_(\d+)", name)
    return int(m.group(1)) if m else None

def list_iteration_dirs(exp_dir: Path) -> List[Tuple[int, Path]]:
    out = []
    for p in exp_dir.glob("iteration_*"):
        if p.is_dir():
            it = _iter_index_from_name(p.name)
            if it is not None:
                out.append((it, p))
    out.sort(key=lambda x: x[0])
    return out

def resolve_checkpoint_for_iter(exp_dir: Path, it: int) -> Optional[Path]:
    ck_best = exp_dir / f"iteration_{it}" / "checkpoints" / f"model_iter{it}_best.pth"
    ck_last = exp_dir / f"iteration_{it}" / "checkpoints" / f"model_iter{it}_last.pth"
    if ck_best.exists():
        return ck_best
    if ck_last.exists():
        return ck_last
    return None

def save_confusion_matrix_png(cm: np.ndarray, out_png: Path, title: str = "") -> None:
    _ensure_dir(out_png.parent)
    plt.figure()
    plt.imshow(cm)
    plt.title(title)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()

@torch.no_grad()
def infer_predictions(model: torch.nn.Module, loader, device: str, use_amp: bool) -> Dict[str, np.ndarray]:
    model.eval()
    y_pred, y_orig, y_noisy = [], [], []
    has_orig, has_noisy = False, False

    amp_ok = bool(use_amp and device.startswith("cuda"))

    t0 = time.time()
    for batch in loader:
        x = batch["image"].to(device, non_blocking=True)
        if amp_ok:
            with torch.autocast(device_type="cuda", enabled=True):
                logits = model(x)
        else:
            logits = model(x)

        preds = logits.argmax(dim=1).cpu().numpy().astype(np.int32)
        y_pred.append(preds)

        if "label_orig" in batch:
            y = batch["label_orig"].cpu().numpy().astype(np.int32)
            y_orig.append(y)
            has_orig = True
        if "label_noisy" in batch:
            y = batch["label_noisy"].cpu().numpy().astype(np.int32)
            y_noisy.append(y)
            has_noisy = True

    t1 = time.time()
    out = {"y_pred": np.concatenate(y_pred), "infer_seconds": float(t1 - t0)}
    if has_orig:
        out["y_true_orig"] = np.concatenate(y_orig)
    if has_noisy:
        out["y_true_noisy"] = np.concatenate(y_noisy)
    return out

def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int) -> Dict[str, Any]:
    acc = float(accuracy_score(y_true, y_pred))
    macro_f1 = float(f1_score(y_true, y_pred, average="macro", labels=list(range(num_classes))))
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    rep = classification_report(
        y_true, y_pred,
        labels=list(range(num_classes)),
        output_dict=True,
        zero_division=0
    )
    return {"acc": acc, "macro_f1": macro_f1, "cm": cm, "report_dict": rep}

def find_run_dirs(base_dir: Path, noise: float, alpha: float, mode: str) -> List[Path]:
    # base/noise_x/alpha_y/mode_z/(maybe suffix)
    root = base_dir / f"noise_{noise}" / f"alpha_{alpha}" / f"mode_{mode}"
    if not root.exists():
        return []
    cands = []
    if (root / "experiment_summary.csv").exists():
        cands.append(root)
    for p in root.iterdir():
        if p.is_dir() and (p / "experiment_summary.csv").exists():
            cands.append(p)
    cands.sort(key=lambda p: p.stat().st_mtime, reverse=True)
    return cands

def sum_train_time_seconds(exp_dir: Path) -> float:
    total = 0.0
    for it, it_dir in list_iteration_dirs(exp_dir):
        m = it_dir / "metrics" / "metrics_epoch.csv"
        if m.exists():
            df = pd.read_csv(m)
            if "epoch_time" in df.columns:
                total += float(df["epoch_time"].fillna(0).sum())
    return float(total)

# ---------------------------
# Evaluate ALL iterations of 1 exp_dir
# ---------------------------
def evaluate_all_iterations(exp_dir: Path, noise: float, alpha: float, mode: str, data_dir: Path) -> Dict[str, Any]:
    cfg = TrainConfig()
    cfg.data_dir = str(data_dir)
    cfg.noise_ratio = float(noise)
    cfg.alpha = float(alpha)
    cfg.filter_mode = str(mode)
    cfg.exp_dir = str(exp_dir)

    device = cfg.device
    use_amp = bool(cfg.use_amp)

    csv_dir = data_dir / "csvs" / f"noise_{noise}"
    train_csv = str(csv_dir / "train.csv")
    val_csv = str(csv_dir / "val.csv")
    test_csv = str(csv_dir / "test.csv")
    if not (Path(train_csv).exists() and Path(val_csv).exists() and Path(test_csv).exists()):
        raise FileNotFoundError(f"Missing CSVs at {csv_dir}")

    # loaders for inference (reuse make_data_loaders)
    dls = make_data_loaders(
        train_csv=train_csv,
        val_csv=val_csv,
        test_csv=test_csv,
        config=cfg,
        train_full_csv=train_csv,
        train_label_col="label_noisy",
    )
    val_loader = dls["val"]
    test_loader = dls["test"]

    it_dirs = list_iteration_dirs(exp_dir)
    if len(it_dirs) == 0:
        raise RuntimeError(f"No iteration_* folders in {exp_dir}")

    report_dir = exp_dir / "posthoc_report"
    _ensure_dir(report_dir)

    rows = []
    total_train_time = sum_train_time_seconds(exp_dir)

    for it, _ in it_dirs:
        ckpt = resolve_checkpoint_for_iter(exp_dir, it)
        if ckpt is None:
            rows.append({
                "iteration": it,
                "status": "missing_checkpoint",
            })
            continue

        # load model
        model = get_model(num_classes=cfg.num_classes, pretrained=True, device=device)
        load_checkpoint(str(ckpt), model, optimizer=None, scheduler=None, map_location=device)

        # infer
        val_out = infer_predictions(model, val_loader, device=device, use_amp=use_amp)
        test_out = infer_predictions(model, test_loader, device=device, use_amp=use_amp)

        iter_out_dir = report_dir / f"iter_{it}"
        _ensure_dir(iter_out_dir)

        row = {
            "noise_ratio": float(noise),
            "alpha": float(alpha),
            "mode": str(mode),
            "exp_dir": str(exp_dir),
            "iteration": int(it),
            "checkpoint": str(ckpt),
            "infer_val_seconds": float(val_out["infer_seconds"]),
            "infer_test_seconds": float(test_out["infer_seconds"]),
            "train_time_seconds_sum": float(total_train_time),
            "status": "ok",
        }

        # metrics (orig)
        if "y_true_orig" in val_out:
            m = compute_metrics(val_out["y_true_orig"], val_out["y_pred"], cfg.num_classes)
            row["val_acc_orig_posthoc"] = m["acc"]
            row["val_macro_f1_orig_posthoc"] = m["macro_f1"]
            pd.DataFrame(m["cm"]).to_csv(iter_out_dir / "val_confusion_matrix_orig.csv", index=False)
            save_confusion_matrix_png(m["cm"], iter_out_dir / "val_confusion_matrix_orig.png",
                                      title=f"VAL(orig) it={it} noise={noise} alpha={alpha} mode={mode}")
            pd.DataFrame(m["report_dict"]).to_csv(iter_out_dir / "val_classification_report_orig.csv")

        if "y_true_orig" in test_out:
            m = compute_metrics(test_out["y_true_orig"], test_out["y_pred"], cfg.num_classes)
            row["test_acc_orig_posthoc"] = m["acc"]
            row["test_macro_f1_orig_posthoc"] = m["macro_f1"]
            pd.DataFrame(m["cm"]).to_csv(iter_out_dir / "test_confusion_matrix_orig.csv", index=False)
            save_confusion_matrix_png(m["cm"], iter_out_dir / "test_confusion_matrix_orig.png",
                                      title=f"TEST(orig) it={it} noise={noise} alpha={alpha} mode={mode}")
            pd.DataFrame(m["report_dict"]).to_csv(iter_out_dir / "test_classification_report_orig.csv")

        # metrics (noisy) - hữu ích nếu bạn training/early-stop theo noisy metric
        if "y_true_noisy" in val_out:
            m = compute_metrics(val_out["y_true_noisy"], val_out["y_pred"], cfg.num_classes)
            row["val_acc_noisy_posthoc"] = m["acc"]
            row["val_macro_f1_noisy_posthoc"] = m["macro_f1"]

        if "y_true_noisy" in test_out:
            m = compute_metrics(test_out["y_true_noisy"], test_out["y_pred"], cfg.num_classes)
            row["test_acc_noisy_posthoc"] = m["acc"]
            row["test_macro_f1_noisy_posthoc"] = m["macro_f1"]

        rows.append(row)

    per_iter_df = pd.DataFrame(rows)
    per_iter_df.to_csv(report_dir / "per_iteration_metrics.csv", index=False)

    # pick best iteration for a compact summary (but you vẫn có full detail)
    best_row = None
    if BEST_PICK_KEY in per_iter_df.columns and per_iter_df[per_iter_df["status"]=="ok"][BEST_PICK_KEY].notna().any():
        best_row = per_iter_df[per_iter_df["status"]=="ok"].sort_values(BEST_PICK_KEY, ascending=False).iloc[0]
    else:
        # fallback: max iteration
        ok_df = per_iter_df[per_iter_df["status"]=="ok"]
        if len(ok_df) > 0:
            best_row = ok_df.sort_values("iteration", ascending=False).iloc[0]

    if best_row is not None:
        pd.DataFrame([best_row.to_dict()]).to_csv(report_dir / "best_iteration_summary.csv", index=False)

    return {
        "exp_dir": str(exp_dir),
        "noise_ratio": float(noise),
        "alpha": float(alpha),
        "mode": str(mode),
        "n_iterations_found": int(len(it_dirs)),
        "report_dir": str(report_dir),
        "status": "ok",
    }

# ---------------------------
# MAIN: sweep posthoc
# ---------------------------
all_runs = []
t0 = time.time()

for plan in EXPERIMENT_PLAN:
    for noise in plan["noise_ratios"]:
        for alpha in plan["alphas"]:
            for mode in plan["modes"]:
                cands = find_run_dirs(BASE_OUTPUT_DIR, noise=noise, alpha=alpha, mode=mode)
                if len(cands) == 0:
                    print(f"[MISS] noise={noise} alpha={alpha} mode={mode} -> no exp_dir")
                    all_runs.append({
                        "noise_ratio": noise, "alpha": alpha, "mode": mode,
                        "status": "missing_exp_dir", "exp_dir": None
                    })
                    continue

                exp_dir = cands[0]
                print(f"\n[POSTHOC] noise={noise} alpha={alpha} mode={mode}")
                print("         exp_dir =", exp_dir)

                try:
                    out = evaluate_all_iterations(exp_dir, noise=noise, alpha=alpha, mode=mode, data_dir=DATA_DIR)
                    all_runs.append(out)
                except Exception as e:
                    all_runs.append({
                        "noise_ratio": noise, "alpha": alpha, "mode": mode,
                        "status": f"error: {e}", "exp_dir": str(exp_dir)
                    })

agg_dir = BASE_OUTPUT_DIR / "_POSTHOC_AGGREGATE"
_ensure_dir(agg_dir)
df_runs = pd.DataFrame(all_runs)
df_runs.to_csv(agg_dir / "posthoc_run_status.csv", index=False)

t1 = time.time()
print("\nDONE posthoc.")
print("Saved:", agg_dir / "posthoc_run_status.csv")
print("Elapsed:", _format_seconds(t1 - t0))

display(df_runs)
