# Script 1: prepare_data1_experiment_summary.py

In [1]:
#!/usr/bin/env python3
"""
Prepare Data1 for reporting: aggregate experiment_summary.csv across runs/noise/alpha.

Supported input folder layouts:

A) Non-namespaced:
  folder_store/
    noise_0.2/alpha_0.2/experiment_summary.csv
    noise_0.2/alpha_0.3/experiment_summary.csv
    ...

B) Namespaced by run_id:
  folder_store/
    run_2025_12_19/noise_0.2/alpha_0.2/experiment_summary.csv
    run_2025_12_19/noise_0.2/alpha_0.3/experiment_summary.csv
    ...

Baseline definition:
  Baseline for each (run_id, noise_ratio, alpha) is the row where iteration == 0.

Selection protocol (matching your pipeline):
  - Best iteration per (run_id, noise_ratio, alpha) selected by maximizing val_acc_noisy.
  - Best alpha per (run_id, noise_ratio) selected by maximizing the best-iteration val_acc_noisy.

Outputs written to:
  folder_store/data_to_report/data1/
    experiment_summary_all.csv
    baseline_rows.csv
    best_by_noisy_val_per_alpha.csv
    alpha_sweep_ready.csv
    best_alpha_per_noise.csv
    main_table_ready.csv
"""
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Set, Tuple

import pandas as pd


@dataclass(frozen=True)
class Settings:
    folder_store: Path
    out_dir: str = "data_to_report/data1"
    overwrite: bool = True
    selection_metric_col: str = "val_acc_noisy"
    numeric_ratio_threshold: float = 0.95
    enforce_path_meta: bool = False


def _to_float_safe(x: str) -> Optional[float]:
    try:
        return float(str(x).replace(",", "."))
    except ValueError:
        return None


def _read_csv_normalized(
    path: Path,
    force_str_cols: Optional[Set[str]] = None,
    numeric_ratio_threshold: float = 0.95,
) -> pd.DataFrame:
    """
    Read CSV and normalize:
      - replace comma decimals -> dot decimals for object cols
      - convert columns to numeric when they look numeric-like
      - keep non-numeric columns as strings

    numeric_ratio_threshold:
      if >= threshold fraction of non-empty values convert successfully -> treat as numeric.
    """
    if force_str_cols is None:
        force_str_cols = set()

    df = pd.read_csv(path)

    for col in df.columns:
        if col in force_str_cols:
            df[col] = df[col].astype(str)
            continue

        s = df[col]

        if pd.api.types.is_numeric_dtype(s):
            continue

        # Normalize decimal commas for object columns
        if pd.api.types.is_object_dtype(s):
            s_str = s.astype(str).str.replace(",", ".", regex=False)
        else:
            s_str = s.astype(str)

        # Convert to numeric safely (no deprecated errors='ignore')
        s_num = pd.to_numeric(s_str, errors="coerce")

        # Determine non-empty values to evaluate conversion ratio
        s_str_norm = s_str.str.strip()
        non_empty = s_str_norm.ne("") & s_str_norm.str.lower().ne("nan")
        denom = int(non_empty.sum())

        if denom == 0:
            df[col] = s_str
            continue

        ok = int(s_num[non_empty].notna().sum())
        if ok / denom >= numeric_ratio_threshold:
            df[col] = s_num
        else:
            df[col] = s_str

    return df


def _find_experiment_summaries(folder_store: Path) -> List[Tuple[Optional[str], Path]]:
    """
    Return list of (run_id, csv_path) for experiment_summary.csv.
    Supports both layouts:
      - folder_store/noise_*/alpha_*/experiment_summary.csv
      - folder_store/<run_id>/noise_*/alpha_*/experiment_summary.csv
    """
    folder_store = folder_store.resolve()

    direct = list(folder_store.glob("noise_*/alpha_*/experiment_summary.csv"))
    if direct:
        return [(None, p) for p in direct]

    results: List[Tuple[Optional[str], Path]] = []
    for run_dir in folder_store.iterdir():
        if not run_dir.is_dir():
            continue
        for p in run_dir.glob("noise_*/alpha_*/experiment_summary.csv"):
            results.append((run_dir.name, p))
    return results


def _extract_noise_alpha_from_path(csv_path: Path) -> Tuple[Optional[float], Optional[float]]:
    noise_val: Optional[float] = None
    alpha_val: Optional[float] = None

    for parent in csv_path.parents:
        name = parent.name
        if name.startswith("noise_"):
            noise_val = _to_float_safe(name.replace("noise_", ""))
        elif name.startswith("alpha_"):
            alpha_val = _to_float_safe(name.replace("alpha_", ""))

    return noise_val, alpha_val


def _ensure_meta_columns(df: pd.DataFrame, csv_path: Path, enforce_path_meta: bool) -> pd.DataFrame:
    df = df.copy()
    noise_from_path, alpha_from_path = _extract_noise_alpha_from_path(csv_path)

    if enforce_path_meta:
        if noise_from_path is not None:
            df["noise_ratio"] = noise_from_path
        if alpha_from_path is not None:
            df["alpha"] = alpha_from_path
        return df

    if "noise_ratio" not in df.columns and noise_from_path is not None:
        df["noise_ratio"] = noise_from_path
    if "alpha" not in df.columns and alpha_from_path is not None:
        df["alpha"] = alpha_from_path

    return df


def _pick_best_row_per_group(
    df: pd.DataFrame,
    group_cols: List[str],
    metric_col: str,
    tie_test_col: str = "test_acc_reported",
    tie_kept_col: str = "kept_ratio",
    iter_col: str = "iteration",
) -> pd.DataFrame:
    """
    Pick best row per group by:
      1) metric_col descending
      2) tie_test_col descending
      3) tie_kept_col descending
      4) iteration ascending
    """
    df2 = df.copy()

    for c in [metric_col, tie_test_col, tie_kept_col, iter_col]:
        if c not in df2.columns:
            df2[c] = pd.NA

    df2 = df2.sort_values(
        by=[metric_col, tie_test_col, tie_kept_col, iter_col],
        ascending=[False, False, False, True],
    )

    best = df2.groupby(group_cols, as_index=False).head(1).copy()
    return best


def build_data1(settings: Settings) -> Path:
    folder_store = settings.folder_store.resolve()
    out_dir = folder_store / settings.out_dir
    out_dir.mkdir(parents=True, exist_ok=True)

    entries = _find_experiment_summaries(folder_store)
    if not entries:
        raise FileNotFoundError(
            "No experiment_summary.csv found. Expected noise_*/alpha_*/experiment_summary.csv "
            "under folder_store (or under one extra run_id folder)."
        )

    all_frames: List[pd.DataFrame] = []
    for run_id, csv_path in entries:
        df = _read_csv_normalized(
            csv_path,
            force_str_cols={"timestamp"},
            numeric_ratio_threshold=settings.numeric_ratio_threshold,
        )
        df = _ensure_meta_columns(df, csv_path, settings.enforce_path_meta)

        if "iteration" not in df.columns:
            raise ValueError(f"Missing 'iteration' column in: {csv_path}")
        if "noise_ratio" not in df.columns or "alpha" not in df.columns:
            raise ValueError(
                f"Missing 'noise_ratio' or 'alpha' in: {csv_path}. "
                "Either include them in CSV or ensure folder name noise_x/alpha_y."
            )

        df["run_id"] = run_id if run_id is not None else "default"
        df["source_path"] = str(csv_path)
        all_frames.append(df)

    df_all = pd.concat(all_frames, ignore_index=True)

    metric_col = settings.selection_metric_col
    if metric_col not in df_all.columns:
        raise ValueError(
            f"Selection metric column '{metric_col}' not found. "
            f"Available columns: {list(df_all.columns)}"
        )

    # Baseline rows
    df_baseline = df_all[df_all["iteration"] == 0].copy()

    # Best per (run_id, noise, alpha)
    group_cols_alpha = ["run_id", "noise_ratio", "alpha"]
    df_best_per_alpha = _pick_best_row_per_group(df_all, group_cols_alpha, metric_col)
    df_best_per_alpha = df_best_per_alpha.rename(columns={"iteration": "best_iteration"})

    # Alpha sweep ready (thin)
    alpha_cols = [
        "run_id",
        "noise_ratio",
        "alpha",
        "best_iteration",
        metric_col,
        "val_acc_reported",
        "val_acc_orig",
        "test_acc_reported",
        "test_acc_orig",
        "kept_ratio",
        "samples_kept",
        "samples_removed",
        "samples_total",
        "training_samples_used",
        "timestamp",
        "source_path",
    ]
    alpha_cols = [c for c in alpha_cols if c in df_best_per_alpha.columns]
    df_alpha_sweep = df_best_per_alpha[alpha_cols].copy()

    # Best alpha per (run_id, noise)
    group_cols_noise = ["run_id", "noise_ratio"]
    # Use df_best_per_alpha rows (one per alpha), pick best alpha by metric_col
    df_best_alpha = _pick_best_row_per_group(df_best_per_alpha, group_cols_noise, metric_col)
    df_best_alpha = df_best_alpha.rename(columns={"alpha": "best_alpha"})

    # Main table ready: baseline vs ours per (run_id, noise)
    # Ours row:
    df_ours = df_best_alpha.copy()
    df_ours["Method"] = "Ours"
    df_ours["alpha"] = df_ours["best_alpha"]  # unify naming

    # Baseline row: try baseline for that best_alpha; otherwise fallback any baseline for noise.
    base_keys = ["run_id", "noise_ratio", "alpha"]
    df_base = df_best_alpha[["run_id", "noise_ratio", "best_alpha"]].copy()
    df_base = df_base.rename(columns={"best_alpha": "alpha"})

    df_base = df_base.merge(
        df_baseline,
        how="left",
        on=base_keys,
        suffixes=("", "_baseline"),
    )

    # Fill missing baseline rows with fallback baseline (any alpha) within same (run_id, noise_ratio)
    missing = df_base["iteration"].isna()
    if missing.any():
        fallback = (
            df_baseline.sort_values(["run_id", "noise_ratio", "alpha"])
            .groupby(["run_id", "noise_ratio"], as_index=False)
            .head(1)
        )
        df_base_missing = df_base.loc[missing, ["run_id", "noise_ratio"]].merge(
            fallback,
            on=["run_id", "noise_ratio"],
            how="left",
        )
        df_base.loc[missing, df_base_missing.columns] = df_base_missing.values

    df_base["Method"] = "Baseline"
    # For baseline: best_iteration concept is 0
    df_base["best_iteration"] = 0

    # Columns for main table (only keep what exists)
    main_cols = [
        "run_id",
        "noise_ratio",
        "Method",
        "alpha",
        "best_iteration",
        metric_col,
        "val_acc_reported",
        "val_acc_orig",
        "test_acc_reported",
        "test_acc_orig",
        "kept_ratio",
        "samples_kept",
        "samples_removed",
        "samples_total",
        "training_samples_used",
    ]
    main_cols = [c for c in main_cols if c in df_all.columns or c in df_ours.columns or c in df_base.columns]

    df_main = pd.concat(
        [
            df_base.reindex(columns=main_cols, fill_value=pd.NA),
            df_ours.reindex(columns=main_cols, fill_value=pd.NA),
        ],
        ignore_index=True,
    ).sort_values(["run_id", "noise_ratio", "Method"])

    # Write outputs
    outputs = {
        "experiment_summary_all.csv": df_all,
        "baseline_rows.csv": df_baseline,
        "best_by_noisy_val_per_alpha.csv": df_best_per_alpha,
        "alpha_sweep_ready.csv": df_alpha_sweep,
        "best_alpha_per_noise.csv": df_best_alpha,
        "main_table_ready.csv": df_main,
    }

    for fname, df in outputs.items():
        out_path = out_dir / fname
        if settings.overwrite or not out_path.exists():
            df.to_csv(out_path, index=False)

    return out_dir


if __name__ == "__main__":
    # ====== INPUT SETTINGS (edit here) ======
    folder_store = "store_output_cifar10_iter_ema_noise_validation_v2"
    # =======================================

    out = build_data1(Settings(folder_store=Path(folder_store)))
    print(f"[OK] Data1 written to: {out}")


[OK] Data1 written to: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/store_output_cifar10_iter_ema_noise_validation_v2/data_to_report/data1


# Script 2: prepare_data2_filter_quality.py

In [2]:
#!/usr/bin/env python3
"""
Prepare Data2 for reporting: compute filter quality from train_kept_*.csv files.

Supported input folder layouts:

A) Non-namespaced:
  folder_store/noise_0.2/alpha_0.2/iteration_0/train_kept_0.csv

B) Namespaced:
  folder_store/<run_id>/noise_0.2/alpha_0.2/iteration_0/train_kept_0.csv

train_kept_*.csv contains ONLY kept samples.
We compute:
  - precision_kept (exact): clean_kept_count / kept_count
  - recall_clean_total using:
        clean_total = round((1 - noise_ratio) * samples_total)
    where samples_total is read from experiment_summary.csv if available,
    else defaults to 45000.

Outputs written to:
  folder_store/data_to_report/data2/
    filter_quality_all.csv
    class_balance_long.csv
    class_balance_clean_noisy_long.csv
"""
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple

import pandas as pd


@dataclass(frozen=True)
class Settings:
    folder_store: Path
    out_dir: str = "data_to_report/data2"
    default_samples_total: int = 45000
    overwrite: bool = True
    warn_iteration_mismatch: bool = True
    numeric_ratio_threshold: float = 0.95


def _to_float_safe(x: str) -> Optional[float]:
    try:
        return float(str(x).replace(",", "."))
    except ValueError:
        return None


def _read_csv_normalized(
    path: Path,
    force_str_cols: Optional[Set[str]] = None,
    numeric_ratio_threshold: float = 0.95,
) -> pd.DataFrame:
    if force_str_cols is None:
        force_str_cols = set()

    df = pd.read_csv(path)

    for col in df.columns:
        if col in force_str_cols:
            df[col] = df[col].astype(str)
            continue

        s = df[col]

        if pd.api.types.is_numeric_dtype(s):
            continue

        if pd.api.types.is_object_dtype(s):
            s_str = s.astype(str).str.replace(",", ".", regex=False)
        else:
            s_str = s.astype(str)

        s_num = pd.to_numeric(s_str, errors="coerce")

        s_str_norm = s_str.str.strip()
        non_empty = s_str_norm.ne("") & s_str_norm.str.lower().ne("nan")
        denom = int(non_empty.sum())

        if denom == 0:
            df[col] = s_str
            continue

        ok = int(s_num[non_empty].notna().sum())
        if ok / denom >= numeric_ratio_threshold:
            df[col] = s_num
        else:
            df[col] = s_str

    return df


def _find_kept_files(folder_store: Path) -> List[Tuple[Optional[str], Path]]:
    folder_store = folder_store.resolve()

    direct = list(folder_store.glob("noise_*/alpha_*/iteration_*/train_kept_*.csv"))
    if direct:
        return [(None, p) for p in direct]

    results: List[Tuple[Optional[str], Path]] = []
    for run_dir in folder_store.iterdir():
        if not run_dir.is_dir():
            continue
        for p in run_dir.glob("noise_*/alpha_*/iteration_*/train_kept_*.csv"):
            results.append((run_dir.name, p))
    return results


def _extract_meta_from_path(p: Path) -> Tuple[Optional[float], Optional[float], Optional[int], Optional[int]]:
    """
    Extract:
      noise_ratio, alpha, iteration_from_folder, iteration_from_filename
    From:
      .../noise_0.2/alpha_0.6/iteration_4/train_kept_4.csv
    """
    noise_val: Optional[float] = None
    alpha_val: Optional[float] = None
    iter_folder: Optional[int] = None

    for part in p.parts:
        if part.startswith("noise_"):
            noise_val = _to_float_safe(part.replace("noise_", ""))
        elif part.startswith("alpha_"):
            alpha_val = _to_float_safe(part.replace("alpha_", ""))
        elif part.startswith("iteration_"):
            try:
                iter_folder = int(part.replace("iteration_", ""))
            except ValueError:
                pass

    iter_file: Optional[int] = None
    name = p.name
    if name.startswith("train_kept_") and name.endswith(".csv"):
        raw = name.replace("train_kept_", "").replace(".csv", "")
        try:
            iter_file = int(raw)
        except ValueError:
            pass

    return noise_val, alpha_val, iter_folder, iter_file


def _find_experiment_summary(folder_store: Path, run_id: Optional[str], noise: float, alpha: float) -> Optional[Path]:
    """
    Locate experiment_summary.csv for the same (noise, alpha) under same run_id.
    Robust to float formatting in folder names.
    """
    root = folder_store / run_id if run_id else folder_store

    # Fast path: exact string
    p = root / f"noise_{noise}" / f"alpha_{alpha}" / "experiment_summary.csv"
    if p.exists():
        return p

    # Robust scan
    noise_dir = None
    for d in root.glob("noise_*"):
        val = _to_float_safe(d.name.replace("noise_", ""))
        if val is not None and abs(val - noise) < 1e-9:
            noise_dir = d
            break
    if noise_dir is None:
        return None

    alpha_dir = None
    for d in noise_dir.glob("alpha_*"):
        val = _to_float_safe(d.name.replace("alpha_", ""))
        if val is not None and abs(val - alpha) < 1e-9:
            alpha_dir = d
            break
    if alpha_dir is None:
        return None

    p2 = alpha_dir / "experiment_summary.csv"
    return p2 if p2.exists() else None


def build_data2(settings: Settings) -> Path:
    folder_store = settings.folder_store.resolve()
    out_dir = folder_store / settings.out_dir
    out_dir.mkdir(parents=True, exist_ok=True)

    kept_files = _find_kept_files(folder_store)
    if not kept_files:
        raise FileNotFoundError(
            "No train_kept_*.csv found. Expected noise_*/alpha_*/iteration_*/train_kept_*.csv."
        )

    quality_rows: List[Dict] = []
    class_long_rows: List[Dict] = []
    class_cn_rows: List[Dict] = []

    for run_id, csv_path in kept_files:
        noise, alpha, iter_folder, iter_file = _extract_meta_from_path(csv_path)
        if noise is None or alpha is None or iter_folder is None:
            continue

        if settings.warn_iteration_mismatch and iter_file is not None and iter_file != iter_folder:
            print(
                f"[WARN] Iteration mismatch: folder iteration_{iter_folder} but filename train_kept_{iter_file}.csv "
                f"({csv_path})"
            )

        df = _read_csv_normalized(
            csv_path,
            force_str_cols={"image_path", "class_name", "split", "filter_flag"},
            numeric_ratio_threshold=settings.numeric_ratio_threshold,
        )

        # Clean mask
        if "noise_flag" in df.columns:
            clean_mask = df["noise_flag"] == 0
        elif "label_noisy" in df.columns and "label_orig" in df.columns:
            clean_mask = df["label_noisy"] == df["label_orig"]
        else:
            raise ValueError(
                f"Missing columns to compute clean mask in {csv_path}. "
                "Need noise_flag OR (label_noisy,label_orig)."
            )

        kept_count = int(len(df))
        clean_kept = int(clean_mask.sum())
        noisy_kept = int((~clean_mask).sum())

        precision = clean_kept / kept_count if kept_count > 0 else 0.0
        noisy_kept_rate = noisy_kept / kept_count if kept_count > 0 else 0.0

        # samples_total
        samples_total = settings.default_samples_total
        sum_path = _find_experiment_summary(folder_store, run_id, noise, alpha)
        if sum_path is not None:
            df_sum = _read_csv_normalized(
                sum_path,
                force_str_cols={"timestamp"},
                numeric_ratio_threshold=settings.numeric_ratio_threshold,
            )
            if "samples_total" in df_sum.columns and len(df_sum) > 0:
                try:
                    samples_total = int(df_sum["samples_total"].iloc[0])
                except Exception:
                    pass

        clean_total = max(int(round((1.0 - noise) * samples_total)), 1)
        recall = clean_kept / clean_total
        f1 = (2 * precision * recall) / (precision + recall + 1e-12)

        kept_ratio = kept_count / samples_total
        clean_kept_ratio_of_total = clean_kept / samples_total

        quality_rows.append(
            dict(
                run_id=run_id if run_id else "default",
                noise_ratio=noise,
                alpha=alpha,
                iteration=iter_folder,
                samples_total=samples_total,
                clean_total=clean_total,
                kept_count=kept_count,
                clean_kept_count=clean_kept,
                noisy_kept_count=noisy_kept,
                kept_ratio=kept_ratio,
                clean_kept_ratio_of_total=clean_kept_ratio_of_total,
                precision_kept=precision,
                recall_clean_total=recall,
                f1_clean_total=f1,
                noisy_kept_rate=noisy_kept_rate,
                source_path=str(csv_path),
            )
        )

        # Class balance (long format)
        class_col: Optional[str] = None
        if "class_name" in df.columns:
            class_col = "class_name"
        elif "label_orig" in df.columns:
            class_col = "label_orig"

        if class_col is not None:
            vc_total = df[class_col].value_counts()
            for cls, cnt in vc_total.items():
                class_long_rows.append(
                    dict(
                        run_id=run_id if run_id else "default",
                        noise_ratio=noise,
                        alpha=alpha,
                        iteration=iter_folder,
                        class_name=str(cls),
                        kept_count=int(cnt),
                    )
                )

            df_clean = df[clean_mask]
            df_noisy = df[~clean_mask]
            vc_clean = df_clean[class_col].value_counts()
            vc_noisy = df_noisy[class_col].value_counts()
            all_classes = set(vc_total.index.tolist()) | set(vc_clean.index.tolist()) | set(vc_noisy.index.tolist())

            for cls in sorted(all_classes, key=lambda x: str(x)):
                class_cn_rows.append(
                    dict(
                        run_id=run_id if run_id else "default",
                        noise_ratio=noise,
                        alpha=alpha,
                        iteration=iter_folder,
                        class_name=str(cls),
                        kept_count_total=int(vc_total.get(cls, 0)),
                        kept_count_clean=int(vc_clean.get(cls, 0)),
                        kept_count_noisy=int(vc_noisy.get(cls, 0)),
                    )
                )

    df_quality = pd.DataFrame(quality_rows).sort_values(["run_id", "noise_ratio", "alpha", "iteration"])
    df_class_long = pd.DataFrame(class_long_rows).sort_values(
        ["run_id", "noise_ratio", "alpha", "iteration", "class_name"]
    )
    df_class_cn = pd.DataFrame(class_cn_rows).sort_values(
        ["run_id", "noise_ratio", "alpha", "iteration", "class_name"]
    )

    p_quality = out_dir / "filter_quality_all.csv"
    p_class_long = out_dir / "class_balance_long.csv"
    p_class_cn = out_dir / "class_balance_clean_noisy_long.csv"

    if settings.overwrite or not p_quality.exists():
        df_quality.to_csv(p_quality, index=False)
    if settings.overwrite or not p_class_long.exists():
        df_class_long.to_csv(p_class_long, index=False)
    if settings.overwrite or not p_class_cn.exists():
        df_class_cn.to_csv(p_class_cn, index=False)

    return out_dir


if __name__ == "__main__":
    # ====== INPUT SETTINGS (edit here) ======
    folder_store = "store_output_cifar10_iter_ema_noise_validation_v2"
    # =======================================

    out = build_data2(Settings(folder_store=Path(folder_store)))
    print(f"[OK] Data2 written to: {out}")


[OK] Data2 written to: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/store_output_cifar10_iter_ema_noise_validation_v2/data_to_report/data2


# 3 generate_report_assets.py

In [3]:
#!/usr/bin/env python3
"""
generate_report_assets.py (UPGRADED)

- Generate tables/figures for report from prepared Data1/Data2
- Add Table 3: filter_summary_table.csv (Precision/Recall/F1 at selected best alpha+iteration)
- Create captions.md with Table 1–3 and figure captions
- Minimal annotation: highlight only:
    ★ Selected (by noisy-val rule)
    ▲ Oracle best (by test_acc) for analysis

Prerequisites:
folder_store/
  data_to_report/
    data1/
      experiment_summary_all.csv
      alpha_sweep_ready.csv
      best_alpha_per_noise.csv
      main_table_ready.csv
    data2/
      filter_quality_all.csv
      class_balance_long.csv (optional)
      class_balance_clean_noisy_long.csv (optional)

Outputs:
folder_store/report_assets/
  tables/
    main_table_ready.csv
    best_alpha_per_noise.csv
    filter_summary_table.csv   <-- NEW
  figures/
    ...
  captions.md
"""
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Set

import matplotlib.pyplot as plt
import pandas as pd


# -------------------------
# Robust CSV normalization
# -------------------------
def read_csv_normalized(
    path: Path,
    force_str_cols: Optional[Set[str]] = None,
    numeric_ratio_threshold: float = 0.95,
) -> pd.DataFrame:
    """
    Read CSV and normalize:
      - replace comma decimals -> dot decimals for object cols
      - convert numeric-like columns to numeric (>= threshold)
    """
    if force_str_cols is None:
        force_str_cols = set()

    df = pd.read_csv(path)

    for col in df.columns:
        if col in force_str_cols:
            df[col] = df[col].astype(str)
            continue

        s = df[col]
        if pd.api.types.is_numeric_dtype(s):
            continue

        if pd.api.types.is_object_dtype(s):
            s_str = s.astype(str).str.replace(",", ".", regex=False)
        else:
            s_str = s.astype(str)

        s_num = pd.to_numeric(s_str, errors="coerce")

        s_norm = s_str.str.strip()
        non_empty = s_norm.ne("") & s_norm.str.lower().ne("nan")
        denom = int(non_empty.sum())
        if denom == 0:
            df[col] = s_str
            continue

        ok = int(s_num[non_empty].notna().sum())
        if ok / denom >= numeric_ratio_threshold:
            df[col] = s_num
        else:
            df[col] = s_str

    return df


# -------------------------
# Settings
# -------------------------
@dataclass(frozen=True)
class Settings:
    folder_store: Path
    out_dir: str = "report_assets"
    dpi: int = 300
    numeric_ratio_threshold: float = 0.95

    # columns
    test_acc_col: str = "test_acc_reported"
    val_noisy_col: str = "val_acc_noisy"
    val_orig_col: str = "val_acc_orig"
    kept_ratio_col: str = "kept_ratio"

    # annotation
    ann_decimals: int = 4


# -------------------------
# Helpers
# -------------------------
def ensure_exists(path: Path, desc: str) -> None:
    if not path.exists():
        raise FileNotFoundError(f"Missing {desc}: {path}")


def savefig(path: Path, dpi: int) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()


def expected_clean_ratio(noise_ratio: float) -> float:
    return max(0.0, min(1.0, 1.0 - float(noise_ratio)))


def fmt(x: float, d: int) -> str:
    return f"{float(x):.{d}f}"


def make_table_preview_png(df: pd.DataFrame, out_path: Path, dpi: int, max_rows: int = 16) -> None:
    view = df.head(max_rows).copy()
    fig, ax = plt.subplots()
    ax.axis("off")
    tbl = ax.table(cellText=view.values, colLabels=view.columns, loc="center")
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(8)
    tbl.scale(1, 1.2)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


def load_inputs(settings: Settings):
    root = settings.folder_store.resolve()
    d1 = root / "data_to_report" / "data1"
    d2 = root / "data_to_report" / "data2"

    p_all = d1 / "experiment_summary_all.csv"
    p_alpha = d1 / "alpha_sweep_ready.csv"
    p_best_alpha = d1 / "best_alpha_per_noise.csv"
    p_main = d1 / "main_table_ready.csv"

    p_fq = d2 / "filter_quality_all.csv"
    p_cb_long = d2 / "class_balance_long.csv"
    p_cb_cn = d2 / "class_balance_clean_noisy_long.csv"

    ensure_exists(p_all, "Data1 experiment_summary_all.csv")
    ensure_exists(p_alpha, "Data1 alpha_sweep_ready.csv")
    ensure_exists(p_best_alpha, "Data1 best_alpha_per_noise.csv")
    ensure_exists(p_main, "Data1 main_table_ready.csv")
    ensure_exists(p_fq, "Data2 filter_quality_all.csv")

    df_all = read_csv_normalized(
        p_all,
        force_str_cols={"timestamp", "source_path", "run_id"},
        numeric_ratio_threshold=settings.numeric_ratio_threshold,
    )
    df_alpha = read_csv_normalized(
        p_alpha,
        force_str_cols={"source_path", "run_id"},
        numeric_ratio_threshold=settings.numeric_ratio_threshold,
    )
    df_best_alpha = read_csv_normalized(
        p_best_alpha,
        force_str_cols={"source_path", "run_id"},
        numeric_ratio_threshold=settings.numeric_ratio_threshold,
    )
    df_main = read_csv_normalized(
        p_main,
        force_str_cols={"run_id", "Method"},
        numeric_ratio_threshold=settings.numeric_ratio_threshold,
    )
    df_fq = read_csv_normalized(
        p_fq,
        force_str_cols={"source_path", "run_id"},
        numeric_ratio_threshold=settings.numeric_ratio_threshold,
    )

    df_cb_long = pd.DataFrame()
    if p_cb_long.exists():
        df_cb_long = read_csv_normalized(
            p_cb_long,
            force_str_cols={"run_id", "class_name"},
            numeric_ratio_threshold=settings.numeric_ratio_threshold,
        )

    df_cb_cn = pd.DataFrame()
    if p_cb_cn.exists():
        df_cb_cn = read_csv_normalized(
            p_cb_cn,
            force_str_cols={"run_id", "class_name"},
            numeric_ratio_threshold=settings.numeric_ratio_threshold,
        )

    return df_all, df_alpha, df_best_alpha, df_main, df_fq, df_cb_long, df_cb_cn


# -------------------------
# Table 3: Filter summary
# -------------------------
def build_filter_summary_table(
    settings: Settings,
    df_best_alpha: pd.DataFrame,
    df_fq: pd.DataFrame,
    df_all: pd.DataFrame,
) -> pd.DataFrame:
    """
    Create a compact filter summary at the selected best alpha + selected best iteration:
      columns:
        run_id, noise_ratio, best_alpha, best_iteration,
        kept_ratio, training_samples_used,
        precision_kept, recall_clean_total, f1_clean_total
    """
    required_best = {"run_id", "noise_ratio", "best_alpha"}
    if not required_best.issubset(df_best_alpha.columns):
        raise ValueError(f"best_alpha_per_noise missing {sorted(required_best - set(df_best_alpha.columns))}")

    required_fq = {"run_id", "noise_ratio", "alpha", "iteration", "precision_kept", "recall_clean_total", "f1_clean_total"}
    if not required_fq.issubset(df_fq.columns):
        raise ValueError(f"filter_quality_all missing {sorted(required_fq - set(df_fq.columns))}")

    required_all = {"run_id", "noise_ratio", "alpha", "iteration", settings.kept_ratio_col, "training_samples_used"}
    if not required_all.issubset(df_all.columns):
        raise ValueError(f"experiment_summary_all missing {sorted(required_all - set(df_all.columns))}")

    rows: List[dict] = []
    for _, r in df_best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"]).iterrows():
        run_id = str(r["run_id"])
        noise = float(r["noise_ratio"])
        best_alpha = float(r["best_alpha"])

        best_iter = None
        if "best_iteration" in df_best_alpha.columns and pd.notna(r.get("best_iteration")):
            best_iter = int(r["best_iteration"])

        fq_sub = df_fq[(df_fq["run_id"] == run_id) & (df_fq["noise_ratio"] == noise) & (df_fq["alpha"] == best_alpha)].copy()
        if fq_sub.empty:
            continue

        # Prefer selected best_iter; fallback to max F1 iteration
        if best_iter is not None:
            fq_row = fq_sub[fq_sub["iteration"] == best_iter].head(1)
            if fq_row.empty:
                fq_row = fq_sub.loc[[fq_sub["f1_clean_total"].idxmax()]]
                best_iter = int(fq_row["iteration"].iloc[0])
            else:
                fq_row = fq_row
        else:
            fq_row = fq_sub.loc[[fq_sub["f1_clean_total"].idxmax()]]
            best_iter = int(fq_row["iteration"].iloc[0])

        fq_row = fq_row.iloc[0]

        # Get kept_ratio + training_samples_used from experiment summary at same (alpha, iter)
        all_row = df_all[
            (df_all["run_id"] == run_id)
            & (df_all["noise_ratio"] == noise)
            & (df_all["alpha"] == best_alpha)
            & (df_all["iteration"] == best_iter)
        ].head(1)

        kept_ratio = float(all_row[settings.kept_ratio_col].iloc[0]) if not all_row.empty else float("nan")
        training_used = int(all_row["training_samples_used"].iloc[0]) if not all_row.empty else int(-1)

        rows.append(
            {
                "run_id": run_id,
                "noise_ratio": noise,
                "best_alpha": best_alpha,
                "best_iteration": best_iter,
                "kept_ratio": kept_ratio,
                "training_samples_used": training_used,
                "precision_kept": float(fq_row["precision_kept"]),
                "recall_clean_total": float(fq_row["recall_clean_total"]),
                "f1_clean_total": float(fq_row["f1_clean_total"]),
            }
        )

    return pd.DataFrame(rows).sort_values(["run_id", "noise_ratio"]).reset_index(drop=True)


# -------------------------
# Export tables + captions
# -------------------------
def export_tables_and_captions(
    settings: Settings,
    df_main: pd.DataFrame,
    df_best_alpha: pd.DataFrame,
    df_filter_summary: pd.DataFrame,
    out_root: Path,
) -> None:
    tables_dir = out_root / "tables"
    figs_dir = out_root / "figures"
    tables_dir.mkdir(parents=True, exist_ok=True)
    figs_dir.mkdir(parents=True, exist_ok=True)

    df_main.to_csv(tables_dir / "main_table_ready.csv", index=False)
    df_best_alpha.to_csv(tables_dir / "best_alpha_per_noise.csv", index=False)
    df_filter_summary.to_csv(tables_dir / "filter_summary_table.csv", index=False)

    make_table_preview_png(df_main, figs_dir / "main_table_preview.png", dpi=settings.dpi, max_rows=16)

    # captions.md
    lines: List[str] = []
    lines.append("# Captions for Report\n\n")
    lines.append("## Tables\n\n")
    lines.append("- **Table 1. Main results (Baseline vs Ours)** — `tables/main_table_ready.csv`\n")
    lines.append("- **Table 2. Selected best EMA momentum α per noise (by noisy-val)** — `tables/best_alpha_per_noise.csv`\n")
    lines.append("- **Table 3. Filter quality summary at selected best α & iteration (Precision/Recall/F1)** — `tables/filter_summary_table.csv`\n\n")

    lines.append("## Figures\n\n")
    lines.append("- Notation: ★ = selected by noisy-val (your selection rule), ▲ = oracle best by test_acc (analysis only)\n\n")

    pairs = df_best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"])[["run_id", "noise_ratio"]]
    for _, rr in pairs.sort_values(["run_id", "noise_ratio"]).iterrows():
        run_id = str(rr["run_id"])
        noise = float(rr["noise_ratio"])
        lines.append(f"### Run `{run_id}` — Noise `{noise}`\n\n")
        rels = [
            ("Figure A. Ablation α vs test accuracy", f"ablation_alpha_testacc/{run_id}/noise_{noise}.png"),
            ("Figure B. Kept ratio vs iteration (best α) + expected clean ratio (1-noise)", f"kept_ratio_vs_iteration/{run_id}/noise_{noise}.png"),
            ("Figure C. Test accuracy vs iteration (best α)", f"testacc_vs_iteration/{run_id}/noise_{noise}.png"),
            ("Figure D. Val noisy vs Val orig (scatter across α)", f"val_noisy_vs_val_orig_scatter/{run_id}/noise_{noise}.png"),
            ("Figure E. Val curves across iterations (best α): val_noisy vs val_orig", f"val_noisy_vs_val_orig_iteration/{run_id}/noise_{noise}.png"),
            ("Figure F. Filter quality across iterations (best α): Precision/Recall/F1", f"filter_quality_prf_vs_iteration/{run_id}/noise_{noise}.png"),
            ("Figure G. Class balance at best iteration (kept samples)", f"class_balance_bar_best/{run_id}/noise_{noise}.png"),
            ("Figure H. Class kept clean vs noisy at best iteration", f"class_balance_clean_noisy_best/{run_id}/noise_{noise}.png"),
        ]
        for title, rel in rels:
            if (figs_dir / rel).exists():
                lines.append(f"- **{title}** — `figures/{rel}`\n")
        lines.append("\n")

    (out_root / "captions.md").write_text("".join(lines), encoding="utf-8")


# -------------------------
# Plot helpers (minimal annotation)
# -------------------------
def oracle_best_alpha_by_testacc(df_alpha: pd.DataFrame, run_id: str, noise: float, test_acc_col: str) -> Optional[pd.Series]:
    sub = df_alpha[(df_alpha["run_id"] == run_id) & (df_alpha["noise_ratio"] == noise)].copy()
    if sub.empty:
        return None
    sub = sub.dropna(subset=[test_acc_col])
    if sub.empty:
        return None
    return sub.loc[sub[test_acc_col].idxmax()]


def oracle_best_iter_by_testacc(df_all: pd.DataFrame, run_id: str, noise: float, alpha: float, test_acc_col: str) -> Optional[pd.Series]:
    sub = df_all[
        (df_all["run_id"] == run_id)
        & (df_all["noise_ratio"] == noise)
        & (df_all["alpha"] == alpha)
    ].copy()
    if sub.empty:
        return None
    sub = sub.dropna(subset=[test_acc_col])
    if sub.empty:
        return None
    return sub.loc[sub[test_acc_col].idxmax()]


# -------------------------
# Plots (minimal)
# -------------------------
def plot_ablation_alpha_testacc(settings: Settings, df_alpha: pd.DataFrame, df_best_alpha: pd.DataFrame, figs_dir: Path) -> None:
    req = {"run_id", "noise_ratio", "alpha", settings.test_acc_col}
    missing = req - set(df_alpha.columns)
    if missing:
        raise ValueError(f"alpha_sweep_ready.csv missing columns: {sorted(missing)}")

    for (run_id_obj, noise_obj), sub in df_alpha.groupby(["run_id", "noise_ratio"]):
        run_id = str(run_id_obj)
        noise = float(noise_obj)
        sub = sub.sort_values("alpha").copy()

        # Selected alpha (★) from best_alpha_per_noise
        sel_row = df_best_alpha[(df_best_alpha["run_id"] == run_id) & (df_best_alpha["noise_ratio"] == noise)].head(1)
        best_alpha = float(sel_row["best_alpha"].iloc[0]) if not sel_row.empty else None
        best_iter = int(sel_row["best_iteration"].iloc[0]) if (not sel_row.empty and "best_iteration" in sel_row.columns and pd.notna(sel_row["best_iteration"].iloc[0])) else None

        # Oracle alpha by test (▲)
        oracle = oracle_best_alpha_by_testacc(df_alpha, run_id, noise, settings.test_acc_col)

        plt.figure()
        plt.plot(sub["alpha"], sub[settings.test_acc_col], marker="o")
        plt.xlabel("EMA momentum α")
        # plt.ylabel(settings.test_acc_col)
        plt.ylabel("test acc")
        plt.title(f"Ablation: α vs test acc | run={run_id} | noise={noise}")
        plt.grid(True, alpha=0.3)

        # Mark selected
        if best_alpha is not None:
            sel_point = sub[sub["alpha"] == best_alpha].head(1)
            if not sel_point.empty:
                y_sel = float(sel_point[settings.test_acc_col].iloc[0])
                plt.scatter([best_alpha], [y_sel], marker="*", s=180, label="Selected (★) by noisy-val")
                plt.axvline(best_alpha, linestyle="--")
                it_txt = f", iter={best_iter}" if best_iter is not None else ""
                plt.annotate(
                    f"★ α={fmt(best_alpha, 2)}\nacc={fmt(y_sel, settings.ann_decimals)}{it_txt}",
                    (best_alpha, y_sel),
                    fontsize=9,
                    xytext=(10, -25),
                    textcoords="offset points",
                )

        # Mark oracle
        if oracle is not None:
            a_or = float(oracle["alpha"])
            y_or = float(oracle[settings.test_acc_col])
            plt.scatter([a_or], [y_or], marker="^", s=90, label="Oracle (▲) by test_acc")
            plt.annotate(
                f"▲ α={fmt(a_or, 2)}\nacc={fmt(y_or, settings.ann_decimals)}",
                (a_or, y_or),
                fontsize=9,
                xytext=(10, 10),
                textcoords="offset points",
            )

        plt.legend(loc="best", fontsize=8)
        savefig(figs_dir / "ablation_alpha_testacc" / run_id / f"noise_{noise}.png", settings.dpi)


def plot_kept_ratio_vs_iteration(settings: Settings, df_all: pd.DataFrame, df_best_alpha: pd.DataFrame, figs_dir: Path) -> None:
    req = {"run_id", "noise_ratio", "alpha", "iteration", settings.kept_ratio_col}
    missing = req - set(df_all.columns)
    if missing:
        raise ValueError(f"experiment_summary_all.csv missing columns: {sorted(missing)}")

    for _, r in df_best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"]).iterrows():
        run_id = str(r["run_id"])
        noise = float(r["noise_ratio"])
        best_alpha = float(r["best_alpha"])
        best_iter = int(r["best_iteration"]) if ("best_iteration" in df_best_alpha.columns and pd.notna(r.get("best_iteration"))) else None

        sub = df_all[
            (df_all["run_id"] == run_id)
            & (df_all["noise_ratio"] == noise)
            & (df_all["alpha"] == best_alpha)
        ].copy()
        if sub.empty:
            continue
        sub = sub.sort_values("iteration")

        plt.figure()
        plt.plot(sub["iteration"], sub[settings.kept_ratio_col], marker="o", label="kept_ratio")
        plt.axhline(expected_clean_ratio(noise), linestyle="--", label="expected clean ratio (1-noise)")
        plt.xlabel("Iteration")
        plt.ylabel(settings.kept_ratio_col)
        plt.title(f"Kept ratio vs iteration | run={run_id} | noise={noise} | selected α={fmt(best_alpha, 2)}")
        plt.grid(True, alpha=0.3)

        if best_iter is not None:
            row_it = sub[sub["iteration"] == best_iter].head(1)
            if not row_it.empty:
                y = float(row_it[settings.kept_ratio_col].iloc[0])
                plt.scatter([best_iter], [y], marker="*", s=180, label="Selected iter (★)")
                plt.annotate(
                    f"★ iter={best_iter}\nkept={fmt(y, settings.ann_decimals)}",
                    (best_iter, y),
                    fontsize=9,
                    xytext=(10, 10),
                    textcoords="offset points",
                )

        plt.legend(loc="best", fontsize=8)
        savefig(figs_dir / "kept_ratio_vs_iteration" / run_id / f"noise_{noise}.png", settings.dpi)


def plot_testacc_vs_iteration(settings: Settings, df_all: pd.DataFrame, df_best_alpha: pd.DataFrame, figs_dir: Path) -> None:
    req = {"run_id", "noise_ratio", "alpha", "iteration", settings.test_acc_col}
    missing = req - set(df_all.columns)
    if missing:
        raise ValueError(f"experiment_summary_all.csv missing columns: {sorted(missing)}")

    for _, r in df_best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"]).iterrows():
        run_id = str(r["run_id"])
        noise = float(r["noise_ratio"])
        best_alpha = float(r["best_alpha"])
        best_iter = int(r["best_iteration"]) if ("best_iteration" in df_best_alpha.columns and pd.notna(r.get("best_iteration"))) else None

        sub = df_all[
            (df_all["run_id"] == run_id)
            & (df_all["noise_ratio"] == noise)
            & (df_all["alpha"] == best_alpha)
        ].copy()
        if sub.empty:
            continue
        sub = sub.sort_values("iteration")

        plt.figure()
        plt.plot(sub["iteration"], sub[settings.test_acc_col], marker="o", label="test_acc")
        plt.xlabel("Iteration")
        # plt.ylabel(settings.test_acc_col)
        plt.ylabel("test acc")
        
        plt.title(f"Test acc vs iteration | run={run_id} | noise={noise} | selected α={fmt(best_alpha, 2)}")
        plt.grid(True, alpha=0.3)

        if best_iter is not None:
            row_it = sub[sub["iteration"] == best_iter].head(1)
            if not row_it.empty:
                y_sel = float(row_it[settings.test_acc_col].iloc[0])
                plt.scatter([best_iter], [y_sel], marker="*", s=180, label="Selected iter (★)")
                plt.annotate(
                    f"★ iter={best_iter}\nacc={fmt(y_sel, settings.ann_decimals)}",
                    (best_iter, y_sel),
                    fontsize=9,
                    xytext=(10, -25),
                    textcoords="offset points",
                )

        oracle_it = oracle_best_iter_by_testacc(df_all, run_id, noise, best_alpha, settings.test_acc_col)
        if oracle_it is not None:
            it_or = int(oracle_it["iteration"])
            y_or = float(oracle_it[settings.test_acc_col])
            plt.scatter([it_or], [y_or], marker="^", s=90, label="Oracle iter (▲) by test_acc")
            plt.annotate(
                f"▲ iter={it_or}\nacc={fmt(y_or, settings.ann_decimals)}",
                (it_or, y_or),
                fontsize=9,
                xytext=(10, 10),
                textcoords="offset points",
            )

        plt.legend(loc="best", fontsize=8)
        savefig(figs_dir / "testacc_vs_iteration" / run_id / f"noise_{noise}.png", settings.dpi)


def plot_val_noisy_vs_val_orig_scatter(settings: Settings, df_alpha: pd.DataFrame, df_best_alpha: pd.DataFrame, figs_dir: Path) -> None:
    req = {"run_id", "noise_ratio", "alpha", settings.val_noisy_col, settings.val_orig_col, settings.test_acc_col}
    missing = req - set(df_alpha.columns)
    if missing:
        print(f"[WARN] Skip val scatter: missing {sorted(missing)}")
        return

    for (run_id_obj, noise_obj), sub in df_alpha.groupby(["run_id", "noise_ratio"]):
        run_id = str(run_id_obj)
        noise = float(noise_obj)
        sub = sub.copy()

        sel_row = df_best_alpha[(df_best_alpha["run_id"] == run_id) & (df_best_alpha["noise_ratio"] == noise)].head(1)
        best_alpha = float(sel_row["best_alpha"].iloc[0]) if not sel_row.empty else None

        oracle = oracle_best_alpha_by_testacc(df_alpha, run_id, noise, settings.test_acc_col)

        plt.figure()
        plt.scatter(sub[settings.val_noisy_col], sub[settings.val_orig_col])
        plt.xlabel(f"{settings.val_noisy_col} (best-by-noisy-val)")
        plt.ylabel(settings.val_orig_col)
        plt.title(f"Val relation across α | run={run_id} | noise={noise}")
        plt.grid(True, alpha=0.3)

        if best_alpha is not None:
            p_sel = sub[sub["alpha"] == best_alpha].head(1)
            if not p_sel.empty:
                x_sel = float(p_sel[settings.val_noisy_col].iloc[0])
                y_sel = float(p_sel[settings.val_orig_col].iloc[0])
                plt.scatter([x_sel], [y_sel], marker="*", s=180, label="Selected α (★)")
                plt.annotate(
                    f"★ α={fmt(best_alpha, 2)}",
                    (x_sel, y_sel),
                    fontsize=9,
                    xytext=(10, -20),
                    textcoords="offset points",
                )

        if oracle is not None:
            a_or = float(oracle["alpha"])
            x_or = float(oracle[settings.val_noisy_col])
            y_or = float(oracle[settings.val_orig_col])
            plt.scatter([x_or], [y_or], marker="^", s=90, label="Oracle α (▲)")
            plt.annotate(
                f"▲ α={fmt(a_or, 2)}",
                (x_or, y_or),
                fontsize=9,
                xytext=(10, 10),
                textcoords="offset points",
            )

        plt.legend(loc="best", fontsize=8)
        savefig(figs_dir / "val_noisy_vs_val_orig_scatter" / run_id / f"noise_{noise}.png", settings.dpi)


def plot_val_noisy_vs_val_orig_iteration(settings: Settings, df_all: pd.DataFrame, df_best_alpha: pd.DataFrame, figs_dir: Path) -> None:
    req = {"run_id", "noise_ratio", "alpha", "iteration", settings.val_noisy_col, settings.val_orig_col}
    missing = req - set(df_all.columns)
    if missing:
        print(f"[WARN] Skip val curves: missing {sorted(missing)}")
        return

    for _, r in df_best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"]).iterrows():
        run_id = str(r["run_id"])
        noise = float(r["noise_ratio"])
        best_alpha = float(r["best_alpha"])
        best_iter = int(r["best_iteration"]) if ("best_iteration" in df_best_alpha.columns and pd.notna(r.get("best_iteration"))) else None

        sub = df_all[
            (df_all["run_id"] == run_id)
            & (df_all["noise_ratio"] == noise)
            & (df_all["alpha"] == best_alpha)
        ].copy()
        if sub.empty:
            continue
        sub = sub.sort_values("iteration")

        plt.figure()
        plt.plot(sub["iteration"], sub[settings.val_noisy_col], marker="o", label=settings.val_noisy_col)
        plt.plot(sub["iteration"], sub[settings.val_orig_col], marker="o", label=settings.val_orig_col)
        plt.xlabel("Iteration")
        plt.ylabel("Validation accuracy")
        plt.title(f"Val curves vs iteration | run={run_id} | noise={noise} | selected α={fmt(best_alpha, 2)}")
        plt.legend()
        plt.grid(True, alpha=0.3)

        if best_iter is not None:
            plt.axvline(best_iter, linestyle="--")
            plt.annotate(f"★ iter={best_iter}", (best_iter, float(sub[[settings.val_noisy_col, settings.val_orig_col]].max().max())),
                         fontsize=9, xytext=(8, 0), textcoords="offset points")

        savefig(figs_dir / "val_noisy_vs_val_orig_iteration" / run_id / f"noise_{noise}.png", settings.dpi)


def plot_filter_quality_prf(settings: Settings, df_fq: pd.DataFrame, df_best_alpha: pd.DataFrame, figs_dir: Path) -> None:
    req = {"run_id", "noise_ratio", "alpha", "iteration", "precision_kept", "recall_clean_total", "f1_clean_total"}
    missing = req - set(df_fq.columns)
    if missing:
        raise ValueError(f"filter_quality_all.csv missing columns: {sorted(missing)}")

    for _, r in df_best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"]).iterrows():
        run_id = str(r["run_id"])
        noise = float(r["noise_ratio"])
        best_alpha = float(r["best_alpha"])
        best_iter = int(r["best_iteration"]) if ("best_iteration" in df_best_alpha.columns and pd.notna(r.get("best_iteration"))) else None

        sub = df_fq[
            (df_fq["run_id"] == run_id)
            & (df_fq["noise_ratio"] == noise)
            & (df_fq["alpha"] == best_alpha)
        ].copy()
        if sub.empty:
            continue
        sub = sub.sort_values("iteration")

        plt.figure()
        plt.plot(sub["iteration"], sub["precision_kept"], marker="o", label="precision_kept")
        plt.plot(sub["iteration"], sub["recall_clean_total"], marker="o", label="recall_clean_total")
        plt.plot(sub["iteration"], sub["f1_clean_total"], marker="o", label="f1_clean_total")
        plt.xlabel("Iteration")
        plt.ylabel("Score")
        plt.title(f"Filter quality P/R/F1 | run={run_id} | noise={noise} | selected α={fmt(best_alpha, 2)}")
        plt.legend()
        plt.grid(True, alpha=0.3)

        if best_iter is not None:
            plt.axvline(best_iter, linestyle="--")
            row_it = sub[sub["iteration"] == best_iter].head(1)
            if not row_it.empty:
                y = float(row_it["f1_clean_total"].iloc[0])
                plt.scatter([best_iter], [y], marker="*", s=180, label="Selected iter (★) on F1")
                plt.annotate(f"★ F1={fmt(y, settings.ann_decimals)}", (best_iter, y),
                             fontsize=9, xytext=(10, 10), textcoords="offset points")

        savefig(figs_dir / "filter_quality_prf_vs_iteration" / run_id / f"noise_{noise}.png", settings.dpi)


def plot_class_balance_best(settings: Settings, df_cb_long: pd.DataFrame, df_best_alpha: pd.DataFrame, figs_dir: Path) -> None:
    if df_cb_long.empty:
        return
    req = {"run_id", "noise_ratio", "alpha", "iteration", "class_name", "kept_count"}
    if not req.issubset(df_cb_long.columns):
        return
    if "best_iteration" not in df_best_alpha.columns:
        return

    for _, r in df_best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"]).iterrows():
        run_id = str(r["run_id"])
        noise = float(r["noise_ratio"])
        best_alpha = float(r["best_alpha"])
        best_iter = int(r["best_iteration"]) if pd.notna(r.get("best_iteration")) else None
        if best_iter is None:
            continue

        sub = df_cb_long[
            (df_cb_long["run_id"] == run_id)
            & (df_cb_long["noise_ratio"] == noise)
            & (df_cb_long["alpha"] == best_alpha)
            & (df_cb_long["iteration"] == best_iter)
        ].copy()
        if sub.empty:
            continue

        sub = sub.sort_values("class_name")
        plt.figure()
        plt.bar(sub["class_name"].astype(str), sub["kept_count"])
        plt.xlabel("Class")
        plt.ylabel("Kept count")
        plt.title(f"Class balance (kept) | run={run_id} | noise={noise}\nselected α={fmt(best_alpha, 2)}, selected iter={best_iter}")
        plt.xticks(rotation=45, ha="right")
        plt.grid(True, axis="y", alpha=0.3)
        savefig(figs_dir / "class_balance_bar_best" / run_id / f"noise_{noise}.png", settings.dpi)


def plot_class_balance_clean_noisy_best(settings: Settings, df_cb_cn: pd.DataFrame, df_best_alpha: pd.DataFrame, figs_dir: Path) -> None:
    if df_cb_cn.empty:
        return
    req = {"run_id", "noise_ratio", "alpha", "iteration", "class_name", "kept_count_clean", "kept_count_noisy"}
    if not req.issubset(df_cb_cn.columns):
        return
    if "best_iteration" not in df_best_alpha.columns:
        return

    for _, r in df_best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"]).iterrows():
        run_id = str(r["run_id"])
        noise = float(r["noise_ratio"])
        best_alpha = float(r["best_alpha"])
        best_iter = int(r["best_iteration"]) if pd.notna(r.get("best_iteration")) else None
        if best_iter is None:
            continue

        sub = df_cb_cn[
            (df_cb_cn["run_id"] == run_id)
            & (df_cb_cn["noise_ratio"] == noise)
            & (df_cb_cn["alpha"] == best_alpha)
            & (df_cb_cn["iteration"] == best_iter)
        ].copy()
        if sub.empty:
            continue

        sub = sub.sort_values("class_name")
        plt.figure()
        plt.bar(sub["class_name"].astype(str), sub["kept_count_clean"], label="clean_kept")
        plt.bar(sub["class_name"].astype(str), sub["kept_count_noisy"], bottom=sub["kept_count_clean"], label="noisy_kept")
        plt.xlabel("Class")
        plt.ylabel("Kept count")
        plt.title(f"Class kept clean vs noisy | run={run_id} | noise={noise}\nselected α={fmt(best_alpha, 2)}, selected iter={best_iter}")
        plt.xticks(rotation=45, ha="right")
        plt.legend()
        plt.grid(True, axis="y", alpha=0.3)
        savefig(figs_dir / "class_balance_clean_noisy_best" / run_id / f"noise_{noise}.png", settings.dpi)


# -------------------------
# Main pipeline
# -------------------------
def generate_report_assets(settings: Settings) -> Path:
    root = settings.folder_store.resolve()
    out_root = root / settings.out_dir
    figs_dir = out_root / "figures"
    figs_dir.mkdir(parents=True, exist_ok=True)

    df_all, df_alpha, df_best_alpha, df_main, df_fq, df_cb_long, df_cb_cn = load_inputs(settings)

    # Table 3
    df_filter_summary = build_filter_summary_table(settings, df_best_alpha, df_fq, df_all)

    # Export tables + captions
    export_tables_and_captions(settings, df_main, df_best_alpha, df_filter_summary, out_root)

    # Figures
    plot_ablation_alpha_testacc(settings, df_alpha, df_best_alpha, figs_dir)
    plot_kept_ratio_vs_iteration(settings, df_all, df_best_alpha, figs_dir)
    plot_testacc_vs_iteration(settings, df_all, df_best_alpha, figs_dir)
    plot_val_noisy_vs_val_orig_scatter(settings, df_alpha, df_best_alpha, figs_dir)
    plot_val_noisy_vs_val_orig_iteration(settings, df_all, df_best_alpha, figs_dir)
    plot_filter_quality_prf(settings, df_fq, df_best_alpha, figs_dir)
    plot_class_balance_best(settings, df_cb_long, df_best_alpha, figs_dir)
    plot_class_balance_clean_noisy_best(settings, df_cb_cn, df_best_alpha, figs_dir)

    (out_root / "MANIFEST.txt").write_text(
        "Generated tables/figures for report.\n"
        "- tables/main_table_ready.csv\n"
        "- tables/best_alpha_per_noise.csv\n"
        "- tables/filter_summary_table.csv\n"
        "- figures/*\n"
        "- captions.md\n",
        encoding="utf-8",
    )
    return out_root




if __name__ == "__main__":
    # ====== INPUT SETTINGS (edit here) ======
    folder_store = "store_output_cifar10_iter_ema_noise_validation_v2"
    # =======================================

    out = generate_report_assets(Settings(folder_store=Path(folder_store)))
    print(f"[OK] Report assets written to: {out}")


[OK] Report assets written to: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/store_output_cifar10_iter_ema_noise_validation_v2/report_assets


# 4 finalize_report_package.py

In [4]:
#!/usr/bin/env python3
"""
Finalize report package (tables + results template) from report_assets.

Prerequisite:
  - Run generate_report_assets.py first to create:
      folder_store/report_assets/tables/*.csv
      folder_store/report_assets/figures/**.png

This script generates:
  folder_store/report_assets/tables/
    main_table_ready.xlsx
    best_alpha_per_noise.xlsx
    main_table_ready.tex
    best_alpha_per_noise.tex

  folder_store/report_assets/
    results.md   (template to paste into Word/LaTeX, with links to figures/tables)

Notes:
- No seaborn used. No custom colors. PEP8-friendly.
- Excel formatting uses openpyxl.
"""
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple

import pandas as pd
from openpyxl import Workbook
from openpyxl.styles import Alignment, Font
from openpyxl.utils import get_column_letter


# -------------------------
# Robust CSV normalization
# -------------------------
def read_csv_normalized(
    path: Path,
    force_str_cols: Optional[Set[str]] = None,
    numeric_ratio_threshold: float = 0.95,
) -> pd.DataFrame:
    """
    Read CSV and normalize:
      - replace comma decimals -> dot decimals for object cols
      - convert columns to numeric when they look numeric-like
      - keep non-numeric columns as strings
    """
    if force_str_cols is None:
        force_str_cols = set()

    df = pd.read_csv(path)

    for col in df.columns:
        if col in force_str_cols:
            df[col] = df[col].astype(str)
            continue

        s = df[col]
        if pd.api.types.is_numeric_dtype(s):
            continue

        if pd.api.types.is_object_dtype(s):
            s_str = s.astype(str).str.replace(",", ".", regex=False)
        else:
            s_str = s.astype(str)

        s_num = pd.to_numeric(s_str, errors="coerce")

        s_norm = s_str.str.strip()
        non_empty = s_norm.ne("") & s_norm.str.lower().ne("nan")
        denom = int(non_empty.sum())
        if denom == 0:
            df[col] = s_str
            continue

        ok = int(s_num[non_empty].notna().sum())
        if ok / denom >= numeric_ratio_threshold:
            df[col] = s_num
        else:
            df[col] = s_str

    return df


# -------------------------
# Settings
# -------------------------
@dataclass(frozen=True)
class Settings:
    folder_store: Path
    out_dir: str = "report_assets"
    numeric_ratio_threshold: float = 0.95
    float_decimals: int = 4

    # Table files expected from previous step
    main_table_csv: str = "main_table_ready.csv"
    best_alpha_csv: str = "best_alpha_per_noise.csv"


# -------------------------
# Excel helpers (openpyxl)
# -------------------------
def _is_float_like(col: pd.Series) -> bool:
    return pd.api.types.is_float_dtype(col) or pd.api.types.is_integer_dtype(col)


def _best_width_for_values(values: List[str], min_w: int = 10, max_w: int = 40) -> int:
    if not values:
        return min_w
    w = max(len(v) for v in values)
    return max(min_w, min(max_w, w + 2))


def write_excel_pretty(df: pd.DataFrame, out_path: Path, sheet_name: str, float_decimals: int) -> None:
    """
    Write DataFrame to Excel with clean report-friendly formatting:
      - header bold + centered
      - freeze top row
      - autofilter
      - reasonable column widths
      - numeric format for float columns
    """
    out_path.parent.mkdir(parents=True, exist_ok=True)

    wb = Workbook()
    ws = wb.active
    ws.title = sheet_name[:31]  # Excel sheet name limit

    # Write header
    header_font = Font(bold=True)
    center = Alignment(horizontal="center", vertical="center", wrap_text=True)

    for j, col_name in enumerate(df.columns, start=1):
        cell = ws.cell(row=1, column=j, value=str(col_name))
        cell.font = header_font
        cell.alignment = center

    # Write rows
    for i, row in enumerate(df.itertuples(index=False), start=2):
        for j, value in enumerate(row, start=1):
            ws.cell(row=i, column=j, value=value)

    # Freeze header
    ws.freeze_panes = "A2"

    # Auto filter
    ws.auto_filter.ref = ws.dimensions

    # Column widths + formats
    for j, col_name in enumerate(df.columns, start=1):
        col_letter = get_column_letter(j)
        col_series = df[col_name]

        # width
        sample_vals = [str(col_name)]
        sample_vals += [str(v) for v in col_series.head(200).tolist()]
        ws.column_dimensions[col_letter].width = _best_width_for_values(sample_vals)

        # number format
        if _is_float_like(col_series):
            fmt = "0" if pd.api.types.is_integer_dtype(col_series) else ("0." + "0" * float_decimals)
            for i in range(2, 2 + len(df)):
                ws.cell(row=i, column=j).number_format = fmt

    wb.save(out_path)


# -------------------------
# LaTeX helpers
# -------------------------
def write_latex_table(
    df: pd.DataFrame,
    out_path: Path,
    caption: str,
    label: str,
    float_decimals: int,
) -> None:
    """
    Export DataFrame to LaTeX (booktabs). You can paste into LaTeX directly.
    """
    out_path.parent.mkdir(parents=True, exist_ok=True)

    float_fmt = f"%.{float_decimals}f"

    latex_body = df.to_latex(
        index=False,
        escape=False,
        float_format=lambda x: float_fmt % x if pd.notna(x) else "",
        longtable=False,
        caption=caption,
        label=label,
        na_rep="",
        bold_rows=False,
    )

    # Ensure booktabs is used (pandas already uses \toprule etc when available)
    out_path.write_text(latex_body, encoding="utf-8")


# -------------------------
# Results.md generator
# -------------------------
def _figure_path(figs_dir: Path, rel: str) -> Optional[Path]:
    p = figs_dir / rel
    return p if p.exists() else None


def _collect_run_noise_pairs(best_alpha: pd.DataFrame) -> List[Tuple[str, float]]:
    pairs = []
    for _, r in best_alpha.drop_duplicates(subset=["run_id", "noise_ratio"]).iterrows():
        run_id = str(r["run_id"])
        noise = float(r["noise_ratio"])
        pairs.append((run_id, noise))
    pairs.sort(key=lambda x: (x[0], x[1]))
    return pairs


def _safe_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None


def generate_results_md(
    out_root: Path,
    main_table: pd.DataFrame,
    best_alpha: pd.DataFrame,
) -> None:
    """
    Create a report-ready Results template with:
      - main table link
      - per-noise summaries (best alpha, best iteration)
      - links to figures generated earlier
    """
    tables_dir = out_root / "tables"
    figs_dir = out_root / "figures"

    lines: List[str] = []
    lines.append("# Results\n")
    lines.append("\n")
    lines.append("## Summary of main results\n")
    lines.append("\n")
    lines.append(
        f"- Main table (CSV): `{(tables_dir / 'main_table_ready.csv').as_posix()}`\n"
    )
    if (tables_dir / "main_table_ready.xlsx").exists():
        lines.append(
            f"- Main table (Excel): `{(tables_dir / 'main_table_ready.xlsx').as_posix()}`\n"
        )
    if (tables_dir / "main_table_ready.tex").exists():
        lines.append(
            f"- Main table (LaTeX): `{(tables_dir / 'main_table_ready.tex').as_posix()}`\n"
        )
    lines.append("\n")

    # Optional: show key gains from main table in text (compact)
    # We keep it simple & robust: compute delta test_acc_reported if both methods exist.
    if {"noise_ratio", "Method", "test_acc_reported"}.issubset(set(main_table.columns)):
        lines.append("**Key observations (from the main table):**\n\n")
        try:
            for noise, sub in main_table.groupby("noise_ratio"):
                base = sub[sub["Method"].astype(str).str.lower() == "baseline"]
                ours = sub[sub["Method"].astype(str).str.lower() == "ours"]
                if base.empty or ours.empty:
                    continue
                base_acc = _safe_float(base["test_acc_reported"].iloc[0])
                ours_acc = _safe_float(ours["test_acc_reported"].iloc[0])
                if base_acc is None or ours_acc is None:
                    continue
                delta = ours_acc - base_acc
                lines.append(f"- Noise={noise}: Δ test_acc_reported = {delta:.4f}\n")
        except Exception:
            lines.append("- (Could not auto-compute deltas reliably; check main table.)\n")
        lines.append("\n")

    lines.append("## Detailed analysis by noise ratio\n\n")

    # Ensure required cols exist
    req_cols = {"run_id", "noise_ratio", "best_alpha"}
    if not req_cols.issubset(set(best_alpha.columns)):
        lines.append(
            "⚠️ Missing required columns in best_alpha_per_noise.csv. "
            "Expected: run_id, noise_ratio, best_alpha.\n"
        )
        out_root.joinpath("results.md").write_text("".join(lines), encoding="utf-8")
        return

    pairs = _collect_run_noise_pairs(best_alpha)

    for run_id, noise in pairs:
        row = best_alpha[(best_alpha["run_id"] == run_id) & (best_alpha["noise_ratio"] == noise)].head(1)
        if row.empty:
            continue

        best_a = float(row["best_alpha"].iloc[0])
        best_it = int(row["best_iteration"].iloc[0]) if "best_iteration" in row.columns and pd.notna(row["best_iteration"].iloc[0]) else None

        lines.append(f"### Run: `{run_id}` — Noise ratio: `{noise}`\n\n")
        lines.append(f"- Best EMA momentum α: **{best_a}**\n")
        if best_it is not None:
            lines.append(f"- Best iteration (selected by noisy-val): **{best_it}**\n")
        lines.append("\n")

        # Figures to embed (if exist)
        fig_candidates = [
            ("Ablation α vs test acc",
             f"ablation_alpha_testacc/{run_id}/noise_{noise}.png"),
            ("Kept ratio vs iteration (best α) + expected clean ratio line",
             f"kept_ratio_vs_iteration/{run_id}/noise_{noise}.png"),
            ("Test acc vs iteration (best α)",
             f"testacc_vs_iteration/{run_id}/noise_{noise}.png"),
            ("Noisy-val vs Orig-val (scatter across α)",
             f"val_noisy_vs_val_orig_scatter/{run_id}/noise_{noise}.png"),
            ("Val curves across iterations (best α): val_noisy vs val_orig",
             f"val_noisy_vs_val_orig_iteration/{run_id}/noise_{noise}.png"),
            ("Filter quality across iterations (best α): Precision/Recall/F1",
             f"filter_quality_prf_vs_iteration/{run_id}/noise_{noise}.png"),
            ("Class balance at best iteration (kept samples)",
             f"class_balance_bar_best/{run_id}/noise_{noise}.png"),
            ("Class kept: clean vs noisy at best iteration",
             f"class_balance_clean_noisy_best/{run_id}/noise_{noise}.png"),
        ]

        for title, rel in fig_candidates:
            p = _figure_path(figs_dir, rel)
            if p is None:
                continue
            lines.append(f"**{title}**\n\n")
            # Markdown image embed (relative to report_assets)
            rel_path = Path("figures") / Path(rel)
            lines.append(f"![{title}]({rel_path.as_posix()})\n\n")

        # Placeholders for interpretation paragraphs (you will fill)
        lines.append("**Interpretation notes (fill in):**\n\n")
        lines.append("- Expected behavior: …\n")
        lines.append("- Observed behavior: …\n")
        lines.append("- Why EMA-Predict helps at this noise level: …\n")
        lines.append("- Failure modes / limitations: …\n\n")

    out_root.joinpath("results.md").write_text("".join(lines), encoding="utf-8")


# -------------------------
# Main
# -------------------------
def finalize_report(settings: Settings) -> Path:
    root = settings.folder_store.resolve()
    out_root = root / settings.out_dir

    tables_dir = out_root / "tables"
    figs_dir = out_root / "figures"

    # Ensure report_assets exists
    if not out_root.exists():
        raise FileNotFoundError(
            f"Missing {out_root}. Run generate_report_assets.py first."
        )
    if not tables_dir.exists():
        raise FileNotFoundError(
            f"Missing {tables_dir}. Run generate_report_assets.py first."
        )

    main_csv = tables_dir / settings.main_table_csv
    best_alpha_csv = tables_dir / settings.best_alpha_csv
    if not main_csv.exists():
        raise FileNotFoundError(f"Missing main table CSV: {main_csv}")
    if not best_alpha_csv.exists():
        raise FileNotFoundError(f"Missing best alpha CSV: {best_alpha_csv}")

    df_main = read_csv_normalized(
        main_csv,
        force_str_cols={"run_id", "Method"},
        numeric_ratio_threshold=settings.numeric_ratio_threshold,
    )
    df_best_alpha = read_csv_normalized(
        best_alpha_csv,
        force_str_cols={"run_id"},
        numeric_ratio_threshold=settings.numeric_ratio_threshold,
    )

    # Export Excel
    write_excel_pretty(
        df_main,
        out_path=tables_dir / "main_table_ready.xlsx",
        sheet_name="MainTable",
        float_decimals=settings.float_decimals,
    )
    write_excel_pretty(
        df_best_alpha,
        out_path=tables_dir / "best_alpha_per_noise.xlsx",
        sheet_name="BestAlpha",
        float_decimals=settings.float_decimals,
    )

    # Export LaTeX
    write_latex_table(
        df_main,
        out_path=tables_dir / "main_table_ready.tex",
        caption="Main results: Baseline vs EMA-Predict (Ours).",
        label="tab:main_results",
        float_decimals=settings.float_decimals,
    )
    write_latex_table(
        df_best_alpha,
        out_path=tables_dir / "best_alpha_per_noise.tex",
        caption="Selected best EMA momentum $\\alpha$ for each noise ratio.",
        label="tab:best_alpha",
        float_decimals=settings.float_decimals,
    )

    # Create results.md
    generate_results_md(out_root=out_root, main_table=df_main, best_alpha=df_best_alpha)

    # Manifest update
    manifest = out_root / "MANIFEST_finalize.txt"
    manifest.write_text(
        "Finalize report outputs generated:\n"
        "- tables/main_table_ready.xlsx\n"
        "- tables/best_alpha_per_noise.xlsx\n"
        "- tables/main_table_ready.tex\n"
        "- tables/best_alpha_per_noise.tex\n"
        "- results.md\n",
        encoding="utf-8",
    )
    return out_root


if __name__ == "__main__":
    # ====== INPUT SETTINGS (edit here) ======
    folder_store = "store_output_cifar10_iter_ema_noise_validation_v2"
    # =======================================

    out = finalize_report(Settings(folder_store=Path(folder_store)))
    print(f"[OK] Final report package written to: {out}")


[OK] Final report package written to: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/store_output_cifar10_iter_ema_noise_validation_v2/report_assets
