In [None]:
### Autres exp
import os, numpy as np, torch, matplotlib.pyplot as plt

def savefig(outdir, name, dpi=200):
    path = os.path.join(outdir, name)
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    print("saved:", path)

@torch.no_grad()
def collect_samples(
    model, x_true, x_init, mask_obs,
    *,
    method="mwg",                 # "pseudo" ou "mwg"
    n_iters=20000,
    burn_in=5000,
    thinning=50,
    warmup_pg=50,
    proposal_scale=1.0,
    max_saved=200,                # limite nb samples gardés
):
    """
    Retourne: samples_probs (S,B,1,28,28) sur CPU (float), et last_bin (B,1,28,28)
    """
    x = x_init.clone()
    z = None

    if method == "mwg":
        for _ in range(warmup_pg):
            x_out, _ = pseudo_gibbs_step(model, x, mask_obs)
            x = torch.bernoulli(x_out)
        mu, logvar = model.encode(x.view(-1, 784))
        z = model.reparameterize(mu, logvar)

    saved = []
    last_bin = None

    for t in range(n_iters):
        if method == "pseudo":
            x_out, _ = pseudo_gibbs_step(model, x, mask_obs)
        else:
            x_out, z, acc = metropolis_within_gibbs_step(
                model, x, z, mask_obs,
                return_accept=True,
                adaptive=False,
                proposal_scale=proposal_scale
            )

        x = torch.bernoulli(x_out)
        last_bin = x.detach().cpu()

        keep = (t >= burn_in) and ((t - burn_in) % thinning == 0)
        if keep:
            saved.append(x_out.detach().cpu())
            if len(saved) >= max_saved:
                break

    samples = torch.stack(saved, dim=0)  # (S,B,1,28,28)
    return samples, last_bin

def plot_uncertainty(samples_probs, mask_obs, x_true, x_init, idx=0, outdir=None, prefix="mwg"):
    """
    samples_probs: (S,B,1,28,28) CPU
    """
    S = samples_probs.size(0)
    mean = samples_probs.mean(dim=0)     # (B,1,28,28)
    var  = samples_probs.var(dim=0)      # (B,1,28,28)

    m = mask_obs.detach().cpu()
    miss = (1 - m)

    # uncertainty only on missing
    var_miss = var * miss
    ent = -(mean.clamp(1e-6, 1-1e-6)*torch.log(mean.clamp(1e-6,1-1e-6)) +
            (1-mean).clamp(1e-6,1-1e-6)*torch.log((1-mean).clamp(1e-6,1-1e-6)))
    ent_miss = ent * miss

    # figures
    fig, axes = plt.subplots(1, 5, figsize=(14,3))
    axes[0].imshow(x_true[idx].detach().cpu().squeeze(), cmap="gray", vmin=0, vmax=1); axes[0].set_title("True"); axes[0].axis("off")
    axes[1].imshow(x_init[idx].detach().cpu().squeeze(), cmap="gray", vmin=0, vmax=1); axes[1].set_title("Masked init"); axes[1].axis("off")
    axes[2].imshow(mean[idx].squeeze(), cmap="gray", vmin=0, vmax=1); axes[2].set_title("Posterior mean"); axes[2].axis("off")
    axes[3].imshow(var_miss[idx].squeeze(), cmap="viridis"); axes[3].set_title("Var (missing)"); axes[3].axis("off")
    axes[4].imshow(ent_miss[idx].squeeze(), cmap="viridis"); axes[4].set_title("Entropy (missing)"); axes[4].axis("off")
    plt.tight_layout()
    if outdir is not None:
        savefig(outdir, f"{prefix}_uncertainty_maps_idx{idx}.png")
    plt.show()

    # scalar metrics to save
    metrics = {
        "S": S,
        "var_missing_mean": float(var_miss.mean().item()),
        "entropy_missing_mean": float(ent_miss.mean().item()),
    }
    return mean, var, metrics

def plot_samples_grid(samples_probs, idx=0, n_show=16, outdir=None, prefix="mwg"):
    S = samples_probs.size(0)
    take = min(n_show, S)
    # pick evenly spaced samples
    ids = np.linspace(0, S-1, take).astype(int)

    cols = 8
    rows = int(np.ceil(take / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
    axes = np.array(axes).reshape(rows, cols)
    for k in range(rows*cols):
        ax = axes[k//cols, k%cols]
        if k < take:
            ax.imshow(samples_probs[ids[k], idx].squeeze(), cmap="gray", vmin=0, vmax=1)
            ax.set_title(f"s{ids[k]}", fontsize=8)
        ax.axis("off")
    plt.suptitle(f"Posterior samples (idx={idx})")
    plt.tight_layout()
    if outdir is not None:
        savefig(outdir, f"{prefix}_samples_grid_idx{idx}.png")
    plt.show()

def exp4_uncertainty_and_samples(
    model, x_true, device,
    *,
    mask_kind="top",         # "top" or "random50"
    missing_rate=0.5,        # used if random50
    method="mwg",
    n_iters=20000, burn_in=5000, thinning=50,
    warmup_pg=50,
    proposal_scale=1.0,
    idx=0,
    outdir_root="results"
):
    run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    outdir = os.path.join(outdir_root, f"exp4_uncertainty_{run_id}")
    os.makedirs(outdir, exist_ok=True)
    print("OUT:", outdir)

    if mask_kind == "random50":
        mask = make_random_mask(x_true, missing_rate=missing_rate)
    else:
        mask = make_structured_mask(x_true, kind="top")
    x_init = init_with_noise(x_true, mask)

    samples_probs, last_bin = collect_samples(
        model, x_true, x_init, mask,
        method=method,
        n_iters=n_iters, burn_in=burn_in, thinning=thinning,
        warmup_pg=warmup_pg,
        proposal_scale=proposal_scale,
        max_saved=300
    )

    mean, var, metrics = plot_uncertainty(samples_probs, mask, x_true, x_init, idx=idx, outdir=outdir, prefix=method)
    plot_samples_grid(samples_probs, idx=idx, n_show=16, outdir=outdir, prefix=method)

    # Save metrics + config
    import pandas as pd, json
    pd.DataFrame([{
        "mask_kind": mask_kind,
        "missing_rate": missing_rate,
        "method": method,
        **metrics
    }]).to_csv(os.path.join(outdir, "exp4_metrics.csv"), index=False)

    with open(os.path.join(outdir, "exp4_config.json"), "w") as f:
        json.dump({
            "mask_kind": mask_kind,
            "missing_rate": missing_rate,
            "method": method,
            "n_iters": n_iters, "burn_in": burn_in, "thinning": thinning,
            "warmup_pg": warmup_pg,
            "proposal_scale": proposal_scale,
            "idx": idx
        }, f, indent=2)

    print("DONE exp4:", outdir)
    return outdir

# ---------- RUN exp4 (example) ----------
_ = exp4_uncertainty_and_samples(
    model=model,
    x_true=x_true,
    device=device,
    mask_kind="top",       # or "random50"
    missing_rate=0.5,
    method="mwg",
    n_iters=20000,
    burn_in=5000,
    thinning=50,
    warmup_pg=50,
    proposal_scale=1.0,
    idx=0
)
import os, datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def exp5_outdir():
    run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    out = os.path.join("results", f"exp5_mcmc_{run_id}")
    os.makedirs(out, exist_ok=True)
    print("OUT:", out)
    return out

def savefig(outdir, name, dpi=200):
    path = os.path.join(outdir, name)
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    print("saved:", path)

def acf_1d(x, max_lag=200):
    """
    Autocorrelation function (unbiased-ish) for 1D numpy array.
    """
    x = np.asarray(x, dtype=np.float64)
    x = x[np.isfinite(x)]
    n = len(x)
    if n < 5:
        return np.arange(max_lag+1), np.full(max_lag+1, np.nan)
    x = x - x.mean()
    var = np.dot(x, x) / n
    if var <= 1e-12:
        return np.arange(max_lag+1), np.zeros(max_lag+1)
    ac = np.zeros(max_lag+1, dtype=np.float64)
    ac[0] = 1.0
    for k in range(1, max_lag+1):
        ac[k] = np.dot(x[:-k], x[k:]) / (n * var)
    return np.arange(max_lag+1), ac

def ess_from_acf(x, max_lag=200):
    """
    ESS approx via integrated autocorrelation time (Geyer truncation simple).
    """
    x = np.asarray(x, dtype=np.float64)
    x = x[np.isfinite(x)]
    n = len(x)
    if n < 10:
        return np.nan
    lags, ac = acf_1d(x, max_lag=max_lag)
    # truncate when ac becomes negative (simple)
    pos = ac[1:]
    m = 0
    for v in pos:
        if v <= 0:
            break
        m += 1
    tau = 1 + 2*np.sum(pos[:m])  # integrated autocorr time
    return n / tau if tau > 0 else np.nan

def rhat(chains):
    """
    Gelman-Rubin R-hat for list/array of shape (M, N).
    chains: numpy array (M chains, N samples)
    """
    chains = np.asarray(chains, dtype=np.float64)
    # drop non-finite columns chainwise
    # simplest: keep only finite samples per chain, truncate to min length
    cleaned = []
    for c in chains:
        c = c[np.isfinite(c)]
        cleaned.append(c)
    m = len(cleaned)
    n = min(len(c) for c in cleaned) if m > 0 else 0
    if m < 2 or n < 10:
        return np.nan
    X = np.stack([c[:n] for c in cleaned], axis=0)  # (m,n)

    chain_means = X.mean(axis=1)
    chain_vars = X.var(axis=1, ddof=1)
    W = chain_vars.mean()
    B = n * chain_means.var(ddof=1)
    var_hat = (n-1)/n * W + (1/n) * B
    return np.sqrt(var_hat / W) if W > 0 else np.nan
import torch

@torch.no_grad()
def collect_trace(
    model, x_true, x_init, mask_obs,
    *,
    method="mwg",                 # "pseudo" | "mwg"
    n_iters=12000,
    burn_in=2000,
    thinning=20,
    warmup_pg=50,
    proposal_scale=1.0,
    idx_stat=0,                   # quel élément du batch pour stats K
):
    """
    Retourne dict avec:
      steps_kept: list
      logp_list : list (scalar logp MC est. mean over batch, mais à chaque sample gardé)
      k_list    : list (#pixels_on sur missing, pour idx_stat)
      acc_list  : list (MwG accept par itération gardée, sinon NaN)
    """
    x = x_init.clone()
    z = None
    accs = []

    if method == "mwg":
        for _ in range(warmup_pg):
            x_out, _ = pseudo_gibbs_step(model, x, mask_obs)
            x = torch.bernoulli(x_out)
        mu, logvar = model.encode(x.view(-1, 784))
        z = model.reparameterize(mu, logvar)

    steps_kept, logp_list, k_list, acc_list = [], [], [], []
    mask_cpu = mask_obs.detach().cpu()
    miss_idx = (1 - mask_cpu[idx_stat]).bool()  # (1,28,28) bool

    for t in range(n_iters):
        if method == "pseudo":
            x_out, _ = pseudo_gibbs_step(model, x, mask_obs)
            acc = np.nan
        else:
            x_out, z, acc = metropolis_within_gibbs_step(
                model, x, z, mask_obs,
                return_accept=True,
                adaptive=False,
                proposal_scale=proposal_scale
            )

        x = torch.bernoulli(x_out)

        keep = (t >= burn_in) and ((t - burn_in) % thinning == 0)
        if keep:
            # logp at this kept sample (mean over batch)
            ll = bernoulli_ll_missing(x_true, x_out, mask_obs)   # (B,) CPU
            logp = float(ll.mean().item())

            # K = number of ones in missing region for idx_stat (use binary x)
            x_cpu = x.detach().cpu()
            k = float(x_cpu[idx_stat][miss_idx].sum().item())

            steps_kept.append(t)
            logp_list.append(logp)
            k_list.append(k)
            acc_list.append(float(acc) if method == "mwg" else np.nan)

    return {
        "steps_kept": np.array(steps_kept),
        "logp": np.array(logp_list),
        "k": np.array(k_list),
        "acc": np.array(acc_list),
    }
def plot_trace_and_acf(trace, outdir, tag, max_lag=200):
    steps = trace["steps_kept"]
    logp = trace["logp"]
    k = trace["k"]
    acc = trace["acc"]

    # Trace logp
    plt.figure(figsize=(9,3))
    plt.plot(steps, logp, "-o", markersize=3)
    plt.xlabel("iteration")
    plt.ylabel("mean log p(x_miss_true | sample)")
    plt.title(f"{tag} — trace logp")
    plt.grid(True, alpha=0.3)
    savefig(outdir, f"{tag}_trace_logp.png")
    plt.show()

    # Trace k
    plt.figure(figsize=(9,3))
    plt.plot(steps, k, "-o", markersize=3)
    plt.xlabel("iteration")
    plt.ylabel("#ones in missing region (idx)")
    plt.title(f"{tag} — trace K (#ones missing)")
    plt.grid(True, alpha=0.3)
    savefig(outdir, f"{tag}_trace_k.png")
    plt.show()

    # ACF logp
    lags, ac = acf_1d(logp, max_lag=max_lag)
    plt.figure(figsize=(8,3))
    plt.plot(lags, ac, "-o", markersize=3)
    plt.xlabel("lag")
    plt.ylabel("ACF")
    plt.title(f"{tag} — ACF(logp)")
    plt.grid(True, alpha=0.3)
    savefig(outdir, f"{tag}_acf_logp.png")
    plt.show()

    # ACF k
    lags, ac = acf_1d(k, max_lag=max_lag)
    plt.figure(figsize=(8,3))
    plt.plot(lags, ac, "-o", markersize=3)
    plt.xlabel("lag")
    plt.ylabel("ACF")
    plt.title(f"{tag} — ACF(K)")
    plt.grid(True, alpha=0.3)
    savefig(outdir, f"{tag}_acf_k.png")
    plt.show()

    # acceptance trace (MwG)
    if np.isfinite(acc).any():
        plt.figure(figsize=(9,3))
        plt.plot(steps, acc, "-o", markersize=3)
        plt.xlabel("iteration")
        plt.ylabel("accept")
        plt.title(f"{tag} — MwG accept (at kept steps)")
        plt.grid(True, alpha=0.3)
        savefig(outdir, f"{tag}_accept.png")
        plt.show()

    # ESS
    ess_logp = ess_from_acf(logp, max_lag=max_lag)
    ess_k = ess_from_acf(k, max_lag=max_lag)

    return {"ess_logp": ess_logp, "ess_k": ess_k, "n_kept": len(logp)}
outdir = exp5_outdir()

# ---- Choix du scénario ----
mask_kind = "top"        # "top" ou "random50"
missing_rate = 0.5       # utilisé si random50

n_iters = 12000
burn_in = 2000
thinning = 20
warmup_pg = 50
proposal_scale = 1.0
idx_stat = 0             # quel exemple du batch pour K

# mask + init
if mask_kind == "random50":
    mask = make_random_mask(x_true, missing_rate=missing_rate)
else:
    mask = make_structured_mask(x_true, kind="top")
x_init = init_with_noise(x_true, mask)

# ---- 1) Pseudo trace ----
trace_p = collect_trace(
    model, x_true, x_init, mask,
    method="pseudo",
    n_iters=n_iters, burn_in=burn_in, thinning=thinning,
    idx_stat=idx_stat
)
ess_p = plot_trace_and_acf(trace_p, outdir, tag=f"pseudo_{mask_kind}", max_lag=200)

# ---- 2) MwG trace ----
trace_m = collect_trace(
    model, x_true, x_init, mask,
    method="mwg",
    n_iters=n_iters, burn_in=burn_in, thinning=thinning,
    warmup_pg=warmup_pg,
    proposal_scale=proposal_scale,
    idx_stat=idx_stat
)
ess_m = plot_trace_and_acf(trace_m, outdir, tag=f"mwg_{mask_kind}", max_lag=200)

# Save raw traces
pd.DataFrame({
    "step": trace_p["steps_kept"],
    "pseudo_logp": trace_p["logp"],
    "pseudo_k": trace_p["k"],
}).to_csv(os.path.join(outdir, f"trace_pseudo_{mask_kind}.csv"), index=False)

pd.DataFrame({
    "step": trace_m["steps_kept"],
    "mwg_logp": trace_m["logp"],
    "mwg_k": trace_m["k"],
    "mwg_acc": trace_m["acc"],
}).to_csv(os.path.join(outdir, f"trace_mwg_{mask_kind}.csv"), index=False)

# Summary table ESS
df_ess = pd.DataFrame([
    {"method": "pseudo", "mask": mask_kind, **ess_p},
    {"method": "mwg", "mask": mask_kind, **ess_m},
])
df_ess.to_csv(os.path.join(outdir, f"ess_summary_{mask_kind}.csv"), index=False)
print(df_ess)

# ---- 3) Multi-chain R-hat (MwG) ----
n_chains = 5
chains_logp = []
chains_k = []

for c in range(n_chains):
    tr = collect_trace(
        model, x_true, x_init, mask,
        method="mwg",
        n_iters=n_iters, burn_in=burn_in, thinning=thinning,
        warmup_pg=warmup_pg,
        proposal_scale=proposal_scale,
        idx_stat=idx_stat
    )
    chains_logp.append(tr["logp"])
    chains_k.append(tr["k"])

# Rhat on logp and k
Rhat_logp = rhat(chains_logp)
Rhat_k = rhat(chains_k)

df_rhat = pd.DataFrame([{
    "mask": mask_kind,
    "n_chains": n_chains,
    "Rhat_logp": Rhat_logp,
    "Rhat_k": Rhat_k
}])
df_rhat.to_csv(os.path.join(outdir, f"rhat_{mask_kind}.csv"), index=False)
print(df_rhat)

# Plot overlay of chains (logp)
plt.figure(figsize=(9,3))
for i, c in enumerate(chains_logp):
    plt.plot(c, label=f"chain{i}", alpha=0.7)
plt.title(f"MwG chains logp overlay (mask={mask_kind}) | Rhat={Rhat_logp:.3f}")
plt.xlabel("kept sample index")
plt.ylabel("logp")
plt.grid(True, alpha=0.3)
plt.legend(ncol=3, fontsize=8)
savefig(outdir, f"mwg_multichain_logp_overlay_{mask_kind}.png")
plt.show()
# ============================================
# EXP5-BIS : ESS / R-hat vs difficulté (p)
# ============================================
# Objectif: pour p in {0.1,0.3,0.5,0.7,0.9}, on calcule
# - ESS(logp) et ESS(K) pour Pseudo et MwG
# - Rhat(logp) et Rhat(K) sur MwG multi-chaînes
# + sauvegarde CSV + figures dans results/exp5_sweep_p_<timestamp>/

import os, datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --------- Output dir ----------
run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
outdir = os.path.join("results", f"exp5_sweep_p_{run_id}")
os.makedirs(outdir, exist_ok=True)
print("OUT:", outdir)

def savefig(name, dpi=200):
    path = os.path.join(outdir, name)
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    print("saved:", path)

# --------- Settings ----------
p_list = [0.1, 0.3, 0.5, 0.7, 0.9]

# Chain settings (diagnostics: mieux d'avoir assez de points gardés)
n_iters   = 20000
burn_in   = 5000
thinning  = 50
warmup_pg = 50
proposal_scale = 1.0
idx_stat = 0

# Multi-chain for Rhat
n_chains = 5

# ACF / ESS settings
max_lag = 200

# --------- Run sweep ----------
rows = []
overlay_dir = os.path.join(outdir, "overlays")
os.makedirs(overlay_dir, exist_ok=True)

for p in p_list:
    print("\n" + "="*60)
    print(f"p_missing = {p}")
    print("="*60)

    # mask + init
    mask = make_random_mask(x_true, missing_rate=p)
    x_init = init_with_noise(x_true, mask)

    # ---- Pseudo trace ----
    tr_p = collect_trace(
        model, x_true, x_init, mask,
        method="pseudo",
        n_iters=n_iters, burn_in=burn_in, thinning=thinning,
        idx_stat=idx_stat
    )

    ess_p_logp = ess_from_acf(tr_p["logp"], max_lag=max_lag)
    ess_p_k    = ess_from_acf(tr_p["k"],    max_lag=max_lag)

    # ---- MwG trace (single chain for ESS) ----
    tr_m = collect_trace(
        model, x_true, x_init, mask,
        method="mwg",
        n_iters=n_iters, burn_in=burn_in, thinning=thinning,
        warmup_pg=warmup_pg,
        proposal_scale=proposal_scale,
        idx_stat=idx_stat
    )

    ess_m_logp = ess_from_acf(tr_m["logp"], max_lag=max_lag)
    ess_m_k    = ess_from_acf(tr_m["k"],    max_lag=max_lag)

    acc_mean = float(np.nanmean(tr_m["acc"])) if np.isfinite(tr_m["acc"]).any() else np.nan

    # ---- MwG multi-chain for R-hat ----
    chains_logp, chains_k = [], []
    for c in range(n_chains):
        tr_c = collect_trace(
            model, x_true, x_init, mask,
            method="mwg",
            n_iters=n_iters, burn_in=burn_in, thinning=thinning,
            warmup_pg=warmup_pg,
            proposal_scale=proposal_scale,
            idx_stat=idx_stat
        )
        chains_logp.append(tr_c["logp"])
        chains_k.append(tr_c["k"])

    Rhat_logp = rhat(chains_logp)
    Rhat_k    = rhat(chains_k)

    # ---- Save overlay figure (logp) ----
    plt.figure(figsize=(9,3))
    for i, c in enumerate(chains_logp):
        plt.plot(c, alpha=0.7, label=f"c{i}")
    plt.title(f"MwG chains logp overlay | p={p} | Rhat={Rhat_logp:.3f}")
    plt.xlabel("kept sample index"); plt.ylabel("logp")
    plt.grid(True, alpha=0.3)
    plt.legend(ncol=5, fontsize=8)
    overlay_path = os.path.join("overlays", f"mwg_overlay_logp_p{int(p*100)}.png")
    plt.savefig(os.path.join(outdir, overlay_path), dpi=200, bbox_inches="tight")
    plt.close()
    print("saved:", os.path.join(outdir, overlay_path))

    # ---- Save traces (optional but useful) ----
    pd.DataFrame({
        "step": tr_p["steps_kept"],
        "pseudo_logp": tr_p["logp"],
        "pseudo_k": tr_p["k"],
    }).to_csv(os.path.join(outdir, f"trace_pseudo_p{int(p*100)}.csv"), index=False)

    pd.DataFrame({
        "step": tr_m["steps_kept"],
        "mwg_logp": tr_m["logp"],
        "mwg_k": tr_m["k"],
        "mwg_acc": tr_m["acc"],
    }).to_csv(os.path.join(outdir, f"trace_mwg_p{int(p*100)}.csv"), index=False)

    # ---- Summary row ----
    rows.append({
        "p_missing": p,
        "n_kept": len(tr_m["logp"]),
        "acc_mwg_mean": acc_mean,
        "ess_pseudo_logp": ess_p_logp,
        "ess_pseudo_k": ess_p_k,
        "ess_mwg_logp": ess_m_logp,
        "ess_mwg_k": ess_m_k,
        "rhat_mwg_logp": Rhat_logp,
        "rhat_mwg_k": Rhat_k
    })

df = pd.DataFrame(rows)
df.to_csv(os.path.join(outdir, "exp5_sweep_p_summary.csv"), index=False)
print("\nSaved summary:", os.path.join(outdir, "exp5_sweep_p_summary.csv"))
print(df)

# --------- Plot ESS vs p ----------
plt.figure(figsize=(7,4))
plt.plot(df["p_missing"], df["ess_pseudo_logp"], "--o", label="Pseudo ESS(logp)")
plt.plot(df["p_missing"], df["ess_mwg_logp"], "-o", label="MwG ESS(logp)")
plt.xlabel("p_missing")
plt.ylabel("ESS (logp)")
plt.title("ESS(logp) vs missing rate")
plt.grid(True, alpha=0.3)
plt.legend()
savefig("ess_logp_vs_p.png")
plt.show()

plt.figure(figsize=(7,4))
plt.plot(df["p_missing"], df["ess_pseudo_k"], "--o", label="Pseudo ESS(K)")
plt.plot(df["p_missing"], df["ess_mwg_k"], "-o", label="MwG ESS(K)")
plt.xlabel("p_missing")
plt.ylabel("ESS (K)")
plt.title("ESS(K) vs missing rate")
plt.grid(True, alpha=0.3)
plt.legend()
savefig("ess_k_vs_p.png")
plt.show()

# --------- Plot Rhat vs p ----------
plt.figure(figsize=(7,4))
plt.plot(df["p_missing"], df["rhat_mwg_logp"], "-o")
plt.axhline(1.0, linestyle="--")
plt.xlabel("p_missing")
plt.ylabel("R-hat (logp)")
plt.title("MwG R-hat(logp) vs missing rate (closer to 1 is better)")
plt.grid(True, alpha=0.3)
savefig("rhat_logp_vs_p.png")
plt.show()

plt.figure(figsize=(7,4))
plt.plot(df["p_missing"], df["rhat_mwg_k"], "-o")
plt.axhline(1.0, linestyle="--")
plt.xlabel("p_missing")
plt.ylabel("R-hat (K)")
plt.title("MwG R-hat(K) vs missing rate (closer to 1 is better)")
plt.grid(True, alpha=0.3)
savefig("rhat_k_vs_p.png")
plt.show()

# --------- Plot acceptance vs p ----------
plt.figure(figsize=(7,4))
plt.plot(df["p_missing"], df["acc_mwg_mean"], "-o")
plt.xlabel("p_missing")
plt.ylabel("MwG acceptance (mean over kept steps)")
plt.title("MwG acceptance vs missing rate")
plt.grid(True, alpha=0.3)
savefig("accept_vs_p.png")
plt.show()

print("\nDONE. Everything saved to:", outdir)
import os, datetime, math, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

from sampler import pseudo_gibbs_step, metropolis_within_gibbs_step
from sampler import metropolis_within_gibbs_step_mixture  # nouveau !

def make_random_mask(x, missing_rate: float):
    keep_prob = 1.0 - missing_rate
    return torch.bernoulli(torch.full_like(x, keep_prob))

def make_structured_mask(x, kind="top"):
    m = torch.ones_like(x)
    if kind == "top":
        m[:, :, :14, :] = 0
    elif kind == "bottom":
        m[:, :, 14:, :] = 0
    elif kind == "center":
        m[:, :, 8:20, 8:20] = 0
    else:
        raise ValueError
    return m

def init_with_noise(x_true, mask_obs):
    noise = torch.bernoulli(torch.full_like(x_true, 0.5))
    return x_true * mask_obs + noise * (1 - mask_obs)

def savefig(outdir, name, dpi=200):
    path = os.path.join(outdir, name)
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    print("saved:", path)

@torch.no_grad()
def run_mwg_variant_mean_last(
    model, x_true, x_init, mask_obs,
    *,
    variant="base",        # "base" or "mixture"
    n_iters=12000,
    burn_in=2000,
    thinning=20,
    warmup_pg=50,
    proposal_scale=1.0,    # base mwg
    alpha=0.5,             # mixture
    rw_sigma=0.5,
):
    t0 = time.perf_counter()
    x = x_init.clone()

    # init z
    for _ in range(warmup_pg):
        x_out, _ = pseudo_gibbs_step(model, x, mask_obs)
        x = torch.bernoulli(x_out)
    mu, logvar = model.encode(x.view(-1, 784))
    z = model.reparameterize(mu, logvar)

    sum_mean = None
    n_kept = 0
    logS = None
    last_bin = None
    acc_hist = []

    for t in range(n_iters):
        if variant == "base":
            x_out, z, acc = metropolis_within_gibbs_step(
                model, x, z, mask_obs,
                return_accept=True,
                adaptive=False,
                proposal_scale=proposal_scale
            )
        elif variant == "mixture":
            x_out, z, acc = metropolis_within_gibbs_step_mixture(
                model, x, z, mask_obs,
                alpha=alpha,
                rw_sigma=rw_sigma,
                return_accept=True
            )
        else:
            raise ValueError("variant must be base or mixture")

        x = torch.bernoulli(x_out)
        last_bin = x.detach().cpu()
        acc_hist.append(acc)

        keep = (t >= burn_in) and ((t - burn_in) % thinning == 0)
        if keep:
            if sum_mean is None:
                sum_mean = x_out.detach().clone()
            else:
                sum_mean += x_out.detach()
            n_kept += 1

            ll = bernoulli_ll_missing(x_true, x_out, mask_obs)  # CPU robust
            if logS is None:
                logS = ll.clone()
            else:
                logS = torch.logaddexp(logS, ll)

    mean_probs = (sum_mean / max(n_kept,1)).detach().cpu()
    f1_mean = f1_missing(x_true, mean_probs.to(device), mask_obs)
    f1_last = f1_missing(x_true, last_bin.to(device), mask_obs)
    logp_mc = float((logS - math.log(max(n_kept, 1))).mean().item()) if logS is not None else float("nan")
    acc_mean = float(np.mean(acc_hist[-max(1, len(acc_hist)//10):]))  # dernier 10%
    dt = time.perf_counter() - t0

    return {
        "f1_mean": f1_mean,
        "f1_last": f1_last,
        "logp_mc": logp_mc,
        "acc_mean": acc_mean,
        "time": dt,
        "kept": n_kept,
        "mean_probs": mean_probs,
        "last_bin": last_bin
    }
run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
outdir = os.path.join("results", f"exp6_mixture_alpha_{run_id}")
os.makedirs(outdir, exist_ok=True)
print("OUT:", outdir)

# --- Choix du masque ---
mask_kind = "top"     # "top" ou "random50"
missing_rate = 0.5    # utilisé si random50

if mask_kind == "random50":
    mask = make_random_mask(x_true, missing_rate=missing_rate)
else:
    mask = make_structured_mask(x_true, kind="top")

x_init = init_with_noise(x_true, mask)

# --- Params ---
n_iters=12000
burn_in=2000
thinning=20
warmup_pg=50

# baseline MwG (base)
base = run_mwg_variant_mean_last(
    model, x_true, x_init, mask,
    variant="base",
    n_iters=n_iters, burn_in=burn_in, thinning=thinning,
    warmup_pg=warmup_pg,
    proposal_scale=1.0
)

alpha_list = [0.0, 0.25, 0.5, 0.75, 1.0]  # 0=pur RW, 1=pur independence
rw_sigma = 0.5

rows = [{
    "variant": "base",
    "alpha": np.nan,
    "rw_sigma": np.nan,
    "f1_mean": base["f1_mean"],
    "f1_last": base["f1_last"],
    "logp_mc": base["logp_mc"],
    "acc": base["acc_mean"],
    "time_s": base["time"],
    "kept": base["kept"],
}]

for a in alpha_list:
    res = run_mwg_variant_mean_last(
        model, x_true, x_init, mask,
        variant="mixture",
        n_iters=n_iters, burn_in=burn_in, thinning=thinning,
        warmup_pg=warmup_pg,
        alpha=a,
        rw_sigma=rw_sigma
    )
    rows.append({
        "variant": "mixture",
        "alpha": a,
        "rw_sigma": rw_sigma,
        "f1_mean": res["f1_mean"],
        "f1_last": res["f1_last"],
        "logp_mc": res["logp_mc"],
        "acc": res["acc_mean"],
        "time_s": res["time"],
        "kept": res["kept"],
    })
    print(f"alpha={a:.2f} | F1={res['f1_mean']:.3f} logp={res['logp_mc']:.1f} acc={res['acc_mean']:.3f}")

df = pd.DataFrame(rows)
df.to_csv(os.path.join(outdir, "exp6_alpha_sweep.csv"), index=False)
print(df)

# Plots
df_mix = df[df["variant"]=="mixture"].copy()

plt.figure(figsize=(7,4))
plt.plot(df_mix["alpha"], df_mix["f1_mean"], "-o")
plt.axhline(base["f1_mean"], linestyle="--", label="base MwG")
plt.xlabel("alpha (independence weight)")
plt.ylabel("F1 (mean)")
plt.title(f"Exp6: F1 vs alpha (mask={mask_kind}, rw_sigma={rw_sigma})")
plt.grid(True, alpha=0.3)
plt.legend()
savefig(outdir, "f1_vs_alpha.png")
plt.show()

plt.figure(figsize=(7,4))
plt.plot(df_mix["alpha"], df_mix["logp_mc"], "-o")
plt.axhline(base["logp_mc"], linestyle="--", label="base MwG")
plt.xlabel("alpha (independence weight)")
plt.ylabel("MC log-likelihood")
plt.title(f"Exp6: logp vs alpha (mask={mask_kind}, rw_sigma={rw_sigma})")
plt.grid(True, alpha=0.3)
plt.legend()
savefig(outdir, "logp_vs_alpha.png")
plt.show()

plt.figure(figsize=(7,4))
plt.plot(df_mix["alpha"], df_mix["acc"], "-o")
plt.axhline(base["acc_mean"], linestyle="--", label="base MwG")
plt.xlabel("alpha")
plt.ylabel("acceptance (last-window)")
plt.title(f"Exp6: acceptance vs alpha (mask={mask_kind}, rw_sigma={rw_sigma})")
plt.grid(True, alpha=0.3)
plt.legend()
savefig(outdir, "acc_vs_alpha.png")
plt.show()
run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
outdir2 = os.path.join("results", f"exp6_mixture_sigma_{run_id}")
os.makedirs(outdir2, exist_ok=True)
print("OUT:", outdir2)

alpha = 0.5
sigma_list = [0.1, 0.25, 0.5, 1.0]

rows = []
for s in sigma_list:
    res = run_mwg_variant_mean_last(
        model, x_true, x_init, mask,
        variant="mixture",
        n_iters=n_iters, burn_in=burn_in, thinning=thinning,
        warmup_pg=warmup_pg,
        alpha=alpha,
        rw_sigma=s
    )
    rows.append({
        "alpha": alpha,
        "rw_sigma": s,
        "f1_mean": res["f1_mean"],
        "logp_mc": res["logp_mc"],
        "acc": res["acc_mean"],
        "time_s": res["time"],
        "kept": res["kept"]
    })
    print(f"sigma={s:.2f} | F1={res['f1_mean']:.3f} logp={res['logp_mc']:.1f} acc={res['acc_mean']:.3f}")

df2 = pd.DataFrame(rows)
df2.to_csv(os.path.join(outdir2, "exp6_sigma_sweep.csv"), index=False)

plt.figure(figsize=(7,4))
plt.plot(df2["rw_sigma"], df2["f1_mean"], "-o")
plt.xlabel("rw_sigma"); plt.ylabel("F1 (mean)")
plt.title(f"Exp6: F1 vs rw_sigma (alpha={alpha}, mask={mask_kind})")
plt.grid(True, alpha=0.3)
savefig(outdir2, "f1_vs_sigma.png")
plt.show()

plt.figure(figsize=(7,4))
plt.plot(df2["rw_sigma"], df2["logp_mc"], "-o")
plt.xlabel("rw_sigma"); plt.ylabel("MC log-likelihood")
plt.title(f"Exp6: logp vs rw_sigma (alpha={alpha}, mask={mask_kind})")
plt.grid(True, alpha=0.3)
savefig(outdir2, "logp_vs_sigma.png")
plt.show()

plt.figure(figsize=(7,4))
plt.plot(df2["rw_sigma"], df2["acc"], "-o")
plt.xlabel("rw_sigma"); plt.ylabel("acceptance")
plt.title(f"Exp6: acceptance vs rw_sigma (alpha={alpha}, mask={mask_kind})")
plt.grid(True, alpha=0.3)
savefig(outdir2, "acc_vs_sigma.png")
plt.show()