In [None]:
from pathlib import Path
import numpy as np, pandas as pd
import matplotlib.pyplot as plt

def infer_attribute_from_path(p: Path) -> str:
    """Infer attribute name from folders in the path."""
    parts = [s.lower() for s in p.parts]
    if "results" in parts:
        i = parts.index("results")
        candidate = parts[i + 1] if i + 1 < len(parts) else ""
    else:
        candidate = ""
    text = "/".join(parts + [candidate])
    if ("corr" in text) or ("correlation" in text): return "Correlation"
    if ("dim" in text) or ("dimension" in text):    return "Dimension"
    if ("tail" in text) or ("tails" in text) or ("nu" in text): return "Tail weight (ν)"
    if ("mm"   in text) or ("modedist" in text) or ("mode" in text): return "Mode distance"
    return "Unknown"

SAMPLER_NAME = {"HMC":"HMC", "Metro":"Metropolis", "DEMetro_Z":"DEMetropolisZ", "SMC":"SMC"}

def generate_heatmaps_for_experiment(
    experiment_path,
    metric="ws_glass_delta",                      # preferred metric
    fallback_metric="global_median_wasserstein_distance",
    which_results="chain",                        # "chain" | "pooled" | "global" | None (scan all)
    agg="median",                                 # "median" or "mean" across attribute values
    normalize="none",                             # "none" or "row" (row-normalized colors)
    out_subdir="heatmaps",
    save=False                                   # whether to save the heatmap as a PDF
):
    """
    Build a sampler-vs-attribute heatmap for a SINGLE experiment folder.

    - metric: main column to use (e.g., "ws_glass_delta")
    - fallback_metric: used if 'metric' not present in a CSV
    - which_results: choose exactly one results folder for clean inputs:
        "chain" -> **/chain_results/Global_results_*.csv
        "pooled" -> **/pooled_results/Global_results_*.csv
        "global" -> **/global_results/Global_results_*.csv
        None -> scan all three and de-duplicate by (Attribute, Sampler)
    - agg: aggregate across varying attribute values ("median" or "mean")
    - normalize: "none" = global color scale; "row" = per-row 0..1 scale
    """
    exp = Path(experiment_path).resolve()
    if not exp.exists():
        raise FileNotFoundError(exp)

    # 1) Collect CSVs
    leafs_all = ("pooled_results", "chain_results")
    if which_results in {"chain","pooled"}:
        dir_map = {"chain":"chain_results","pooled":"pooled_results"}
        leafs = (dir_map[which_results],)
    elif which_results is None:
        leafs = leafs_all
    else:
        raise ValueError("which_results must be 'chain', 'pooled', or None")

    csvs = []
    for leaf in leafs:
        csvs += list(exp.glob(f"**/{leaf}/Global_results_*.csv"))
    if not csvs:
        raise RuntimeError(f"No Global_results_*.csv found under **/{'|'.join(leafs)}/")

    # 2) Read & aggregate
    records = []
    for csv_path in csvs:
        sampler_key = csv_path.stem.replace("Global_results_","")
        sampler = SAMPLER_NAME.get(sampler_key)
        if not sampler:
            continue
        attr = infer_attribute_from_path(csv_path)
        df = pd.read_csv(csv_path)

        # pick metric column with fallback
        metric_col = metric if metric in df.columns else (fallback_metric if fallback_metric in df.columns else None)
        if metric_col is None or df.empty:
            continue

        val = getattr(df[metric_col], agg)()
        records.append({"Attribute": attr, "Sampler": sampler, "Value": float(val)})

    if not records:
        raise RuntimeError("No usable records found. Check metric column names or which_results selection.")

    df_rec = pd.DataFrame(records)

    # 3) Deduplicate if we scanned multiple result folders
    if which_results is None:
        df_rec = (df_rec
                  .groupby(["Attribute","Sampler"], as_index=False)
                  .agg(Value=("Value", agg)))

    # 4) Pivot into Attribute × Sampler
    desired_rows = ["Dimension","Correlation","Tail weight (ν)","Mode distance"]
    desired_cols = ["Metropolis","DEMetropolisZ","HMC","SMC"]
    pivot = (df_rec
             .pivot(index="Attribute", columns="Sampler", values="Value")
             .reindex(index=desired_rows, columns=desired_cols))

    # 5) Save summary tables
    out_dir = exp / out_subdir
    out_dir.mkdir(parents=True, exist_ok=True)
    pivot.to_csv(out_dir / f"pivot_{metric}.csv")

    winners = pivot.apply(lambda r: pd.Series({"Best Sampler": r.idxmin(), "Best Value": r.min()}), axis=1)
    winners.to_csv(out_dir / f"winners_{metric}.csv")

    # 6) Prepare data for heatmap
    arr = pivot.values.astype(float)
    if normalize == "row":
        rmin = np.nanmin(arr, axis=1)[:, None]
        rmax = np.nanmax(arr, axis=1)[:, None]
        rrange = np.maximum(rmax - rmin, 1e-12)
        arr_plot = (arr - rmin) / rrange
        vmin, vmax = 0.0, 1.0
        cbar_label = "Row-normalized (min→max)"
        title_norm = "row-normalized"
    else:
        arr_plot = arr.copy()
        finite = np.isfinite(arr_plot)
        if finite.any():
            vmin = np.nanpercentile(arr_plot[finite], 5)
            vmax = np.nanpercentile(arr_plot[finite], 95)
        else:
            vmin, vmax = 0.0, 1.0
        cbar_label = f"{metric} (absolute)"
        title_norm = "absolute"

    # 7) Plot heatmap
    fig, ax = plt.subplots(figsize=(8, 5.0))
    im = ax.imshow(arr_plot, aspect="auto", cmap="RdYlGn_r", vmin=vmin, vmax=vmax)

    ax.set_xticks(np.arange(len(pivot.columns))); ax.set_xticklabels(pivot.columns, rotation=30, ha="right")
    ax.set_yticks(np.arange(len(pivot.index)));   ax.set_yticklabels(pivot.index)
    ax.set_title(f"Heatmap ({title_norm}): {metric}  ↓ better", pad=10)

    # annotate with absolute numbers and star row-min
    for i in range(arr.shape[0]):
        if not np.isfinite(arr[i]).any():
            continue
        j_best = int(np.nanargmin(arr[i]))
        for j in range(arr.shape[1]):
            v = arr[i, j]
            txt = "–" if np.isnan(v) else (f"{v:.3f}" if v < 1 else f"{v:.2f}")
            if j == j_best and np.isfinite(v):
                txt = "★ " + txt
            ax.text(j, i, txt, ha="center", va="center", color="black", fontsize=10)

    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label(cbar_label)
    
    fig.tight_layout()
    pdf = out_dir / f"heatmap_{metric}_{title_norm}.pdf"

    if save:
        fig.savefig(pdf, bbox_inches="tight", format="pdf")
        print(f"Saved:\n- {pdf}\n- {out_dir / f'pivot_{metric}.csv'}\n- {out_dir / f'winners_{metric}.csv'}")

    return pivot, winners


In [11]:
# Kurze Aliase für die Metrik-Spalten in den CSVs
METRIC_MAP = {
    "ws":         "global_median_wasserstein_distance",  # kleiner = besser
    "ws_delta":   "ws_glass_delta",                      # Glass' Δ zu IID, kleiner = besser

    "mmd":        "global_median_mmd_rff",               # kleiner = besser
    "mmd_delta":  "mmd_rff_glass_delta",                 # Glass' Δ zu IID, kleiner = besser

    "rt":         "global_median_runtime",               # Laufzeit
    "ess":        "global_median_ess",                   # größer = besser
    "mt":         "global_median_mode_transitions",      # Übergänge zwischen Modi

    "rmse_mean":  "global_median_mean_rmse",             # kleiner = besser
    "rmse_var":   "global_median_var_rmse",              # kleiner = besser
}

def resolve_metric(key_or_name: str) -> str:
    """Erlaubt kurze Keys oder bereits volle Spaltennamen."""
    return METRIC_MAP.get(key_or_name, key_or_name)

In [20]:
EXPERIMENTS_ROOT = Path("/home/fabian/python_files/MA/experiments")  
print("Available experiments:", sorted(p.name for p in EXPERIMENTS_ROOT.iterdir() if p.is_dir()))

experiment_name = "first_half" 


exp_path = EXPERIMENTS_ROOT / f"exp_{experiment_name}"  
pivot, winners = generate_heatmaps_for_experiment(
    experiment_path=exp_path,
    metric=resolve_metric("ws"),       # or "global_median_wasserstein_distance"
    which_results="chain",          # "chain" (per-chain), "pooled", or "global"
    agg="median",
    normalize="row",               # or "row"
    out_subdir="heatmaps",
    save=False  
)

Available experiments: ['exp_30_runs_two_attr', 'exp_Base_and_single_attr', 'exp_SMC_dim_tails_alt', 'exp_SMC_dim_tails_neu', 'exp_SMC_heavy_tails', 'exp_SMC_new', 'exp_first_half', 'exp_label_check', 'exp_log_only_dim_5', 'exp_plot_ratio', 'exp_second_half_short', 'exp_thesis_run', 'exp_two_attr_log_scale_15_runs']


ValueError: Index contains duplicate entries, cannot reshape