In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import math
import time
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from scipy.stats import trim_mean

# ==========================================
# 0. CONFIGURATION
# ==========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_dtype(torch.float32)
print(f"Running on device: {device}")

# Sweep Settings
K_TRIALS = 1     # Number of trials to average over
SIGMA_STEPS = 1   # Number of log-uniform noise levels
SIGMA_MIN = 0.01
SIGMA_MAX = 2.5
DIMS = [2, 4, 8, 16, 32]

# Flexible N_REFS list
N_REFS = [1000]

# Analysis Parameters
TRIM_AMOUNT = .0 # Outlier removal
SMOOTHING = 5  # Smoothing window

# ==========================================
# 1. MODELS: HELIX PRIOR & EXACT POSTERIOR
# ==========================================
class HelixPrior(torch.nn.Module):
    def __init__(self, dim=3, n_components=2000, sigma=0.1):
        super().__init__()
        self.dim = dim
        self.n_components = n_components
        t = torch.linspace(0, 4 * math.pi, n_components, device=device)
        r = 1.0 + 0.25 * torch.sin(3 * t)
        means_list = []
        n_osc = dim - 1
        for i in range(n_osc):
            phase = i * math.pi / 2.0
            means_list.append(r * torch.cos(t - phase))
        means_list.append(0.4 * t)
        self.means = torch.stack(means_list, dim=1)
        self.weights = torch.ones(n_components, device=device) / n_components
        self.cov_val = sigma**2
        self.covs = torch.eye(dim, device=device).unsqueeze(0).repeat(n_components, 1, 1) * self.cov_val

    def sample(self, n):
        comp_idxs = torch.multinomial(self.weights, n, replacement=True)
        m = self.means[comp_idxs]
        eps = torch.randn(n, self.dim, device=device)
        return m + eps * math.sqrt(self.cov_val)

    def log_prob(self, x):
        N, D = x.shape
        diff = x.unsqueeze(1) - self.means.unsqueeze(0)
        sq_dists = torch.sum(diff**2, dim=2)
        log_coeffs = torch.log(self.weights) - (D/2.0)*math.log(2*math.pi*self.cov_val)
        log_probs_comps = log_coeffs.unsqueeze(0) - 0.5 * sq_dists / self.cov_val
        return torch.logsumexp(log_probs_comps, dim=1)

    def score0(self, x):
        with torch.enable_grad():
            x = x.detach().requires_grad_(True)
            lp = self.log_prob(x)
            grad = torch.autograd.grad(lp.sum(), x)[0]
        return grad.detach()

class GaussianLikelihood:
    def __init__(self, y_obs, sigma_noise):
        self.y_obs = y_obs
        self.sigma_noise = sigma_noise

    def log_likelihood(self, x):
        diff = x - self.y_obs.unsqueeze(0)
        sq_dist = torch.sum(diff**2, dim=1)
        return -sq_dist / (2 * self.sigma_noise**2)

    def grad_log_likelihood(self, x):
        return -(x - self.y_obs.unsqueeze(0)) / (self.sigma_noise**2)

class ExactPosteriorGMM:
    def __init__(self, prior: HelixPrior, likelihood: GaussianLikelihood):
        sig_p = prior.cov_val
        sig_y = likelihood.sigma_noise**2
        inv_var_post = (1.0/sig_p) + (1.0/sig_y)
        self.var_post = 1.0 / inv_var_post
        self.std_post = math.sqrt(self.var_post)
        y_obs = likelihood.y_obs
        self.means = self.var_post * (prior.means / sig_p + y_obs / sig_y)
        var_conv = sig_p + sig_y
        diff = y_obs - prior.means
        sq_diff = torch.sum(diff**2, dim=1)
        log_lik_term = -0.5 * sq_diff / var_conv
        log_w_unnorm = torch.log(prior.weights) + log_lik_term
        self.weights = torch.softmax(log_w_unnorm, dim=0)
        self.dim = prior.dim

    def sample(self, n):
        comp_idxs = torch.multinomial(self.weights, n, replacement=True)
        m = self.means[comp_idxs]
        eps = torch.randn(n, self.dim, device=device)
        return m + eps * self.std_post

# ==========================================
# 2. METRICS
# ==========================================
def robust_clean_samples(samples):
    mask = torch.isfinite(samples).all(dim=1)
    return samples[mask]

def compute_mmd_gaussian(x, y, sigma=1.0):
    if x.shape[0] > 1000: x = x[:1000]
    if y.shape[0] > 1000: y = y[:1000]
    def kernel(a, b):
        dist_sq = torch.cdist(a, b, p=2)**2
        return torch.exp(-dist_sq / (2 * sigma**2)).mean()
    xx = kernel(x, x)
    yy = kernel(y, y)
    xy = kernel(x, y)
    return xx + yy - 2*xy

def compute_multiscale_ksd(samples, score_func, sigmas=(0.1, 0.2, 0.5, 1.0)):
    clean_x = robust_clean_samples(samples)
    N = clean_x.shape[0]
    if N > 500:
        idx = torch.randperm(N)[:500]
        X = clean_x[idx]
        N = 500
    else:
        X = clean_x
    if N < 10: return float('nan')
    D = X.shape[1]
    with torch.no_grad():
        s = score_func(X)
    diff = X.unsqueeze(1) - X.unsqueeze(0)
    r2 = torch.sum(diff**2, dim=-1)
    ksd_agg = 0.0
    for sigma in sigmas:
        K = torch.exp(-r2 / (2 * sigma**2))
        sdot = torch.mm(s, s.t())
        term1 = sdot * K
        r_dot_sx = torch.einsum('ijd,id->ij', diff, s)
        r_dot_sy = torch.einsum('ijd,jd->ij', diff, s)
        term2 = (r_dot_sx - r_dot_sy) / (sigma**2) * K
        term3 = (D / (sigma**2) - r2 / (sigma**4)) * K
        U_stat = term1 + term2 + term3
        ksd_agg += torch.sum(U_stat) / (N * N)
    return ksd_agg.item() / len(sigmas)

def compute_sliced_w2(samples_a, samples_b, num_projections=100):
    """
    Approximates the Sliced Wasserstein-2 distance between two sets of samples.
    """
    x = robust_clean_samples(samples_a)
    y = robust_clean_samples(samples_b)

    # Ensure same sample size for simple sorting
    n = min(x.shape[0], y.shape[0])
    if n < 10: return float('nan')

    x = x[:n]
    y = y[:n]

    dim = x.shape[1]
    # Random projection vectors
    projections = torch.randn(num_projections, dim, device=x.device)
    projections = projections / torch.sqrt(torch.sum(projections**2, dim=1, keepdim=True))

    # Project samples
    x_proj = torch.matmul(x, projections.t()) # [n, num_projections]
    y_proj = torch.matmul(y, projections.t())

    # Sort projections
    x_proj_sorted, _ = torch.sort(x_proj, dim=0)
    y_proj_sorted, _ = torch.sort(y_proj, dim=0)

    # Compute squared distance
    diff = x_proj_sorted - y_proj_sorted
    w2_sq = torch.mean(diff**2)

    return torch.sqrt(w2_sq).item()

# ==========================================
# 3. SAMPLER
# ==========================================
def get_posterior_snis_weights_log(y, t, X_ref, log_lik_ref):
    et = math.exp(-t)
    var_t = 1.0 - math.exp(-2*t)
    mus = et * X_ref
    diff = y.unsqueeze(1) - mus.unsqueeze(0)
    dists_sq = torch.sum(diff**2, dim=2)
    log_kernel = -dists_sq / (2 * var_t)
    log_unnorm = log_kernel + log_lik_ref.unsqueeze(0)
    log_norm = torch.logsumexp(log_unnorm, dim=1, keepdim=True)
    return log_unnorm - log_norm

def eval_blend_score_batch(y, t, X_ref, s0_post_ref, log_lik_ref):
    if t < 1e-4: t = 1e-4
    et = math.exp(-t)
    var_t = 1.0 - math.exp(-2*t)
    inv_v = 1.0 / var_t
    log_w = get_posterior_snis_weights_log(y, t, X_ref, log_lik_ref)
    w = torch.exp(log_w)
    mu_x = torch.mm(w, X_ref)
    s_twd = -inv_v * (y - et * mu_x)
    s_kss = et * torch.mm(w, s0_post_ref)
    diff_s = (et * s0_post_ref).unsqueeze(0) - s_kss.unsqueeze(1)
    var_kss = torch.sum(w.unsqueeze(2) * (diff_s**2), dim=(1, 2))
    diff_x = X_ref.unsqueeze(0) - mu_x.unsqueeze(1)
    diff_twd = inv_v * et * diff_x
    var_twd = torch.sum(w.unsqueeze(2) * (diff_twd**2), dim=(1, 2))
    cov_c = torch.sum(w.unsqueeze(2) * (diff_s * diff_twd), dim=(1, 2))
    numerator = var_kss - cov_c
    denominator = var_twd + var_kss - 2*cov_c + 1e-8
    lam = torch.clamp(numerator / denominator, 0.0, 1.0)
    score = lam.unsqueeze(1) * s_twd + (1.0 - lam.unsqueeze(1)) * s_kss
    return score

def run_heun_sweep(mode, dim, X_ref, s0_post_ref, log_lik_ref, steps=30):
    N_GEN = 500
    y = torch.randn(N_GEN, dim, device=device)
    ts = torch.logspace(0.4, -3.5, steps + 1, device=device)
    for i in range(steps):
        t_cur = ts[i]
        t_next = ts[i+1]
        dt = t_cur - t_next
        def get_s(yy, tt):
            if mode == 'tweedie':
                if tt < 1e-4: tt = 1e-4
                et = math.exp(-tt)
                inv_v = 1.0 / (1 - math.exp(-2*tt))
                log_w = get_posterior_snis_weights_log(yy, tt, X_ref, log_lik_ref)
                w = torch.exp(log_w)
                mu_x = torch.mm(w, X_ref)
                s = -inv_v * (yy - et * mu_x)
                return s
            elif mode == 'blend':
                return eval_blend_score_batch(yy, tt, X_ref, s0_post_ref, log_lik_ref)
        s_cur = get_s(y, t_cur.item())
        d_cur = y + 2 * s_cur
        z = torch.randn_like(y)
        y_hat = y + d_cur * dt + math.sqrt(2 * dt) * z
        s_next = get_s(y_hat, t_next.item())
        d_next = y_hat + 2 * s_next
        drift = 0.5 * (d_cur + d_next)
        y = y + drift * dt + math.sqrt(2 * dt) * z
    return y

# ==========================================
# 4. SWEEP
# ==========================================
def calculate_stat(data_list, trim_amount=0.0):
    if not data_list: return float('nan')
    if trim_amount <= 0.0: return np.mean(data_list)
    elif trim_amount >= 0.5: return np.median(data_list)
    else: return trim_mean(data_list, proportiontocut=trim_amount)

# ==========================================
# 4. SWEEP  (Tan-normalized by ||y_obs||)
# ==========================================
def calculate_stat(data_list, trim_amount=0.0):
    if not data_list: return float('nan')
    if trim_amount <= 0.0: return np.mean(data_list)
    elif trim_amount >= 0.5: return np.median(data_list)
    else: return trim_mean(data_list, proportiontocut=trim_amount)


def run_fine_grained_sweep(trim_amount=0.0):
    # Now these are RELATIVE noise levels (sigma_rel), not absolute sigmas
    sigma_rels = np.logspace(np.log10(SIGMA_MIN), np.log10(SIGMA_MAX), SIGMA_STEPS)

    results = []
    print(f"Starting sweep: {len(DIMS)} dims x {len(N_REFS)} refs x {len(sigma_rels)} sigma_rels x {K_TRIALS} trials")
    print(f"Aggregation: Trim={trim_amount}")

    pbar = tqdm(total=len(DIMS)*len(N_REFS)*len(sigma_rels)*K_TRIALS)

    for nr in N_REFS:
        for d in DIMS:
            prior_model = HelixPrior(dim=d, n_components=100)

            for sigma_rel in sigma_rels:
                mmd_t_list, ksd_t_list, sw_t_list = [], [], []
                mmd_b_list, ksd_b_list, sw_b_list = [], [], []

                sigma_abs_list = []  # just for logging/diagnostics (varies across trials)

                for _ in range(K_TRIALS):
                    obs_loc = torch.randn(d, device=device)

                    # Tan normalization: sigma_abs = sigma_rel * ||y_obs||
                    obs_norm = torch.norm(obs_loc).item()
                    sigma_abs = float(sigma_rel) * max(obs_norm, 1e-12)
                    sigma_abs_list.append(sigma_abs)

                    lik_model = GaussianLikelihood(obs_loc, sigma_abs)
                    true_post = ExactPosteriorGMM(prior_model, lik_model)
                    X_true = true_post.sample(1000)

                    def true_score_fn(x):
                        return prior_model.score0(x) + lik_model.grad_log_likelihood(x)

                    X_ref = prior_model.sample(nr)
                    s0_prior = prior_model.score0(X_ref)
                    grad_lik = lik_model.grad_log_likelihood(X_ref)
                    s0_post_ref = s0_prior + grad_lik
                    log_lik_ref = lik_model.log_likelihood(X_ref)

                    # Tweedie Run
                    samples_t = run_heun_sweep('tweedie', d, X_ref, s0_post_ref, log_lik_ref)
                    mmd_t = compute_mmd_gaussian(samples_t, X_true, sigma=0.5 * math.sqrt(d/2.0)).item()
                    ksd_t = compute_multiscale_ksd(samples_t, true_score_fn)
                    sw_t = compute_sliced_w2(samples_t, X_true)

                    # Blend Run
                    samples_b = run_heun_sweep('blend', d, X_ref, s0_post_ref, log_lik_ref)
                    mmd_b = compute_mmd_gaussian(samples_b, X_true, sigma=0.5 * math.sqrt(d/2.0)).item()
                    ksd_b = compute_multiscale_ksd(samples_b, true_score_fn)
                    sw_b = compute_sliced_w2(samples_b, X_true)

                    mmd_t_list.append(mmd_t); ksd_t_list.append(ksd_t); sw_t_list.append(sw_t)
                    mmd_b_list.append(mmd_b); ksd_b_list.append(ksd_b); sw_b_list.append(sw_b)

                    pbar.update(1)

                res = {
                    'dim': d,
                    'n_ref': nr,
                    # store the sweep coordinate as sigma_rel
                    'sigma_rel': float(sigma_rel),
                    # store a representative absolute sigma (median across trials)
                    'sigma_abs_median': float(np.median(sigma_abs_list)),
                    'sigma_abs_mean': float(np.mean(sigma_abs_list)),

                    'mmd_tweedie': calculate_stat(mmd_t_list, trim_amount),
                    'ksd_tweedie': calculate_stat(ksd_t_list, trim_amount),
                    'sw_tweedie': calculate_stat(sw_t_list, trim_amount),
                    'mmd_blend': calculate_stat(mmd_b_list, trim_amount),
                    'ksd_blend': calculate_stat(ksd_b_list, trim_amount),
                    'sw_blend': calculate_stat(sw_b_list, trim_amount)
                }
                results.append(res)

    pbar.close()
    return pd.DataFrame(results)


# ==========================================
# 5. DYNAMIC PLOTTING FUNCTION  (Tan-normalized x-axis)
# ==========================================
def plot_results(df, smoothing_window=0):
    n_rows = len(N_REFS)
    fig, axs = plt.subplots(n_rows, 3, figsize=(21, 5 * n_rows))

    if n_rows == 1:
        axs = np.array([axs])

    dim_colors = {1: 'tab:purple', 2: 'tab:blue', 4: 'tab:orange', 8: 'tab:green',
                  16: 'tab:red', 128: 'tab:brown'}

    xcol = 'sigma_rel'
    xlabel = r'Normalized obs noise  $\sigma_{\mathrm{rel}} = \sigma_y / \|y_{\mathrm{obs}}\|$  (dimensionless; $\downarrow$ = higher SNR)'

    for i, nr in enumerate(N_REFS):
        subset = df[df['n_ref'] == nr]

        # ---- MMD Ratio ----
        ax_mmd = axs[i, 0]
        for d in DIMS:
            sub = subset[subset['dim'] == d].sort_values(xcol)
            mt = np.maximum(sub['mmd_tweedie'].to_numpy(), 1e-12)
            mb = np.maximum(sub['mmd_blend'].to_numpy(), 1e-12)
            ratio = np.log(mt / mb)
            if smoothing_window > 0:
                ratio = pd.Series(ratio).rolling(window=smoothing_window, min_periods=1, center=True).mean().values
            ax_mmd.plot(sub[xcol], ratio, label=f'd={d}', color=dim_colors.get(d, 'black'), linewidth=2)

        ax_mmd.axhline(0, color='black', linestyle='--', alpha=0.7)
        ax_mmd.set_xscale('log')
        ax_mmd.set_title(f'MMD Advantage (Log Ratio) vs normalized noise | N_ref={nr}')
        ax_mmd.set_ylabel('Log(Tweedie/Blend)')
        ax_mmd.set_xlabel(xlabel)
        ax_mmd.legend()
        ax_mmd.grid(True, which='both', linestyle='--', alpha=0.3)

        # ---- KSD Ratio ----
        ax_ksd = axs[i, 1]
        for d in DIMS:
            sub = subset[subset['dim'] == d].sort_values(xcol)
            kt = np.maximum(sub['ksd_tweedie'].to_numpy(), 1e-12)
            kb = np.maximum(sub['ksd_blend'].to_numpy(), 1e-12)
            ratio = np.log(kt / kb)
            if smoothing_window > 0:
                ratio = pd.Series(ratio).rolling(window=smoothing_window, min_periods=1, center=True).mean().values
            ax_ksd.plot(sub[xcol], ratio, label=f'd={d}', color=dim_colors.get(d, 'black'), linewidth=2)

        ax_ksd.axhline(0, color='black', linestyle='--', alpha=0.7)
        ax_ksd.set_xscale('log')
        ax_ksd.set_title(f'KSD Advantage (Log Ratio) vs normalized noise | N_ref={nr}')
        ax_ksd.set_ylim(bottom=-4)
        ax_ksd.set_xlabel(xlabel)
        ax_ksd.legend()
        ax_ksd.grid(True, which='both', linestyle='--', alpha=0.3)

        # ---- Sliced W2 Ratio ----
        ax_sw = axs[i, 2]
        for d in DIMS:
            sub = subset[subset['dim'] == d].sort_values(xcol)
            st = np.maximum(sub['sw_tweedie'].to_numpy(), 1e-12)
            sb = np.maximum(sub['sw_blend'].to_numpy(), 1e-12)
            ratio = np.log(st / sb)
            if smoothing_window > 0:
                ratio = pd.Series(ratio).rolling(window=smoothing_window, min_periods=1, center=True).mean().values
            ax_sw.plot(sub[xcol], ratio, label=f'd={d}', color=dim_colors.get(d, 'black'), linewidth=2)

        ax_sw.axhline(0, color='black', linestyle='--', alpha=0.7)
        ax_sw.set_xscale('log')
        ax_sw.set_title(f'Sliced W2 Advantage (Log Ratio) vs normalized noise | N_ref={nr}')
        ax_sw.set_xlabel(xlabel)
        ax_sw.set_ylim(bottom=-4)

        ax_sw.legend()
        ax_sw.grid(True, which='both', linestyle='--', alpha=0.3)

    # Standardize scales across rows (if multiple N_references)
    min_ref = min(N_REFS)
    min_ref_idx = N_REFS.index(min_ref)
    ref_ylim_mmd = axs[min_ref_idx, 0].get_ylim()
    ref_ylim_ksd = axs[min_ref_idx, 1].get_ylim()
    ref_ylim_sw  = axs[min_ref_idx, 2].get_ylim()

    print(f"Standardizing Y-axis to N_ref={min_ref} range.")
    for i in range(n_rows):
        axs[i, 0].set_ylim(ref_ylim_mmd)
        axs[i, 1].set_ylim(ref_ylim_ksd)
        axs[i, 2].set_ylim(ref_ylim_sw)

    plt.tight_layout()
    plt.savefig('sweep_results.png', dpi=300)
    print("Plots saved to sweep_results.png")

    df.to_csv('fine_sweep_results.csv', index=False)
    print("Aggregated data saved to fine_sweep_results.csv")


if __name__ == "__main__":
    df_results = run_fine_grained_sweep(trim_amount=TRIM_AMOUNT)
    plot_results(df_results, smoothing_window=SMOOTHING)

ValueError: module functions cannot set METH_CLASS or METH_STATIC

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import math
import time
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from scipy.stats import trim_mean

# ==========================================
# 0. CONFIGURATION
# ==========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_dtype(torch.float32)
print(f"Running on device: {device}")

# Sweep Settings
K_TRIALS = 1       # Number of trials to average over
SIGMA_STEPS = 30    # Number of log-uniform noise levels
SIGMA_MIN = 0.01
SIGMA_MAX = 2.5
DIMS = [2, 4, 8, 16]

# Flexible N_REFS list
N_REFS = [1000]

# Analysis Parameters
TRIM_AMOUNT = 0.0 # Outlier removal
SMOOTHING = 1   # Smoothing window

# ==========================================
# 1. MODELS: HELIX PRIOR & EXACT POSTERIOR
# ==========================================
class HelixPrior(torch.nn.Module):
    def __init__(self, dim=3, n_components=2000, sigma=0.1):
        super().__init__()
        self.dim = dim
        self.n_components = n_components
        t = torch.linspace(0, 4 * math.pi, n_components, device=device)
        r = 1.0 + 0.25 * torch.sin(3 * t)
        means_list = []
        n_osc = dim - 1
        for i in range(n_osc):
            phase = i * math.pi / 2.0
            means_list.append(r * torch.cos(t - phase))
        means_list.append(0.4 * t)
        self.means = torch.stack(means_list, dim=1)
        self.weights = torch.ones(n_components, device=device) / n_components
        self.cov_val = sigma**2
        self.covs = torch.eye(dim, device=device).unsqueeze(0).repeat(n_components, 1, 1) * self.cov_val

    def sample(self, n):
        comp_idxs = torch.multinomial(self.weights, n, replacement=True)
        m = self.means[comp_idxs]
        eps = torch.randn(n, self.dim, device=device)
        return m + eps * math.sqrt(self.cov_val)

    def log_prob(self, x):
        N, D = x.shape
        diff = x.unsqueeze(1) - self.means.unsqueeze(0)
        sq_dists = torch.sum(diff**2, dim=2)
        log_coeffs = torch.log(self.weights) - (D/2.0)*math.log(2*math.pi*self.cov_val)
        log_probs_comps = log_coeffs.unsqueeze(0) - 0.5 * sq_dists / self.cov_val
        return torch.logsumexp(log_probs_comps, dim=1)

    def score0(self, x):
        with torch.enable_grad():
            x = x.detach().requires_grad_(True)
            lp = self.log_prob(x)
            grad = torch.autograd.grad(lp.sum(), x)[0]
        return grad.detach()

class GaussianLikelihood:
    def __init__(self, y_obs, sigma_noise):
        self.y_obs = y_obs
        self.sigma_noise = sigma_noise

    def log_likelihood(self, x):
        diff = x - self.y_obs.unsqueeze(0)
        sq_dist = torch.sum(diff**2, dim=1)
        return -sq_dist / (2 * self.sigma_noise**2)

    def grad_log_likelihood(self, x):
        return -(x - self.y_obs.unsqueeze(0)) / (self.sigma_noise**2)

class ExactPosteriorGMM:
    def __init__(self, prior: HelixPrior, likelihood: GaussianLikelihood):
        sig_p = prior.cov_val
        sig_y = likelihood.sigma_noise**2
        inv_var_post = (1.0/sig_p) + (1.0/sig_y)
        self.var_post = 1.0 / inv_var_post
        self.std_post = math.sqrt(self.var_post)
        y_obs = likelihood.y_obs
        self.means = self.var_post * (prior.means / sig_p + y_obs / sig_y)
        var_conv = sig_p + sig_y
        diff = y_obs - prior.means
        sq_diff = torch.sum(diff**2, dim=1)
        log_lik_term = -0.5 * sq_diff / var_conv
        log_w_unnorm = torch.log(prior.weights) + log_lik_term
        self.weights = torch.softmax(log_w_unnorm, dim=0)
        self.dim = prior.dim

    def sample(self, n):
        comp_idxs = torch.multinomial(self.weights, n, replacement=True)
        m = self.means[comp_idxs]
        eps = torch.randn(n, self.dim, device=device)
        return m + eps * self.std_post

# ==========================================
# 2. METRICS
# ==========================================
def robust_clean_samples(samples):
    mask = torch.isfinite(samples).all(dim=1)
    return samples[mask]

def compute_mmd_gaussian(x, y, sigma=1.0):
    if x.shape[0] > 1000: x = x[:1000]
    if y.shape[0] > 1000: y = y[:1000]
    def kernel(a, b):
        dist_sq = torch.cdist(a, b, p=2)**2
        return torch.exp(-dist_sq / (2 * sigma**2)).mean()
    xx = kernel(x, x)
    yy = kernel(y, y)
    xy = kernel(x, y)
    return xx + yy - 2*xy

def compute_multiscale_ksd(samples, score_func, sigmas=(0.1, 0.2, 0.5, 1.0)):
    clean_x = robust_clean_samples(samples)
    N = clean_x.shape[0]
    if N > 500:
        idx = torch.randperm(N)[:500]
        X = clean_x[idx]
        N = 500
    else:
        X = clean_x
    if N < 10: return float('nan')
    D = X.shape[1]
    with torch.no_grad():
        s = score_func(X)
    diff = X.unsqueeze(1) - X.unsqueeze(0)
    r2 = torch.sum(diff**2, dim=-1)
    ksd_agg = 0.0
    for sigma in sigmas:
        K = torch.exp(-r2 / (2 * sigma**2))
        sdot = torch.mm(s, s.t())
        term1 = sdot * K
        r_dot_sx = torch.einsum('ijd,id->ij', diff, s)
        r_dot_sy = torch.einsum('ijd,jd->ij', diff, s)
        term2 = (r_dot_sx - r_dot_sy) / (sigma**2) * K
        term3 = (D / (sigma**2) - r2 / (sigma**4)) * K
        U_stat = term1 + term2 + term3
        ksd_agg += torch.sum(U_stat) / (N * N)
    return ksd_agg.item() / len(sigmas)

def compute_sliced_w2(samples_a, samples_b, num_projections=100):
    """
    Approximates the Sliced Wasserstein-2 distance between two sets of samples.
    """
    x = robust_clean_samples(samples_a)
    y = robust_clean_samples(samples_b)

    # Ensure same sample size for simple sorting
    n = min(x.shape[0], y.shape[0])
    if n < 10: return float('nan')

    x = x[:n]
    y = y[:n]

    dim = x.shape[1]
    # Random projection vectors
    projections = torch.randn(num_projections, dim, device=x.device)
    projections = projections / torch.sqrt(torch.sum(projections**2, dim=1, keepdim=True))

    # Project samples
    x_proj = torch.matmul(x, projections.t()) # [n, num_projections]
    y_proj = torch.matmul(y, projections.t())

    # Sort projections
    x_proj_sorted, _ = torch.sort(x_proj, dim=0)
    y_proj_sorted, _ = torch.sort(y_proj, dim=0)

    # Compute squared distance
    diff = x_proj_sorted - y_proj_sorted
    w2_sq = torch.mean(diff**2)

    return torch.sqrt(w2_sq).item()

# ==========================================
# 3. SAMPLER
# ==========================================
def get_posterior_snis_weights_log(y, t, X_ref, log_lik_ref):
    et = math.exp(-t)
    var_t = 1.0 - math.exp(-2*t)
    mus = et * X_ref
    diff = y.unsqueeze(1) - mus.unsqueeze(0)
    dists_sq = torch.sum(diff**2, dim=2)
    log_kernel = -dists_sq / (2 * var_t)
    log_unnorm = log_kernel + log_lik_ref.unsqueeze(0)
    log_norm = torch.logsumexp(log_unnorm, dim=1, keepdim=True)
    return log_unnorm - log_norm

def eval_blend_score_batch(y, t, X_ref, s0_post_ref, log_lik_ref):
    if t < 1e-4: t = 1e-4
    et = math.exp(-t)
    var_t = 1.0 - math.exp(-2*t)
    inv_v = 1.0 / var_t
    log_w = get_posterior_snis_weights_log(y, t, X_ref, log_lik_ref)
    w = torch.exp(log_w)
    mu_x = torch.mm(w, X_ref)
    s_twd = -inv_v * (y - et * mu_x)
    s_kss = et * torch.mm(w, s0_post_ref)
    diff_s = (et * s0_post_ref).unsqueeze(0) - s_kss.unsqueeze(1)
    var_kss = torch.sum(w.unsqueeze(2) * (diff_s**2), dim=(1, 2))
    diff_x = X_ref.unsqueeze(0) - mu_x.unsqueeze(1)
    diff_twd = inv_v * et * diff_x
    var_twd = torch.sum(w.unsqueeze(2) * (diff_twd**2), dim=(1, 2))
    cov_c = torch.sum(w.unsqueeze(2) * (diff_s * diff_twd), dim=(1, 2))
    numerator = var_kss - cov_c
    denominator = var_twd + var_kss - 2*cov_c + 1e-8
    lam = torch.clamp(numerator / denominator, 0.0, 1.0)
    score = lam.unsqueeze(1) * s_twd + (1.0 - lam.unsqueeze(1)) * s_kss
    return score

def run_heun_sweep(mode, dim, X_ref, s0_post_ref, log_lik_ref, steps=30):
    N_GEN = 500
    y = torch.randn(N_GEN, dim, device=device)
    ts = torch.logspace(0.4, -3.5, steps + 1, device=device)
    for i in range(steps):
        t_cur = ts[i]
        t_next = ts[i+1]
        dt = t_cur - t_next
        def get_s(yy, tt):
            if mode == 'tweedie':
                if tt < 1e-4: tt = 1e-4
                et = math.exp(-tt)
                inv_v = 1.0 / (1 - math.exp(-2*tt))
                log_w = get_posterior_snis_weights_log(yy, tt, X_ref, log_lik_ref)
                w = torch.exp(log_w)
                mu_x = torch.mm(w, X_ref)
                s = -inv_v * (yy - et * mu_x)
                return s
            elif mode == 'blend':
                return eval_blend_score_batch(yy, tt, X_ref, s0_post_ref, log_lik_ref)
        s_cur = get_s(y, t_cur.item())
        d_cur = y + 2 * s_cur
        z = torch.randn_like(y)
        y_hat = y + d_cur * dt + math.sqrt(2 * dt) * z
        s_next = get_s(y_hat, t_next.item())
        d_next = y_hat + 2 * s_next
        drift = 0.5 * (d_cur + d_next)
        y = y + drift * dt + math.sqrt(2 * dt) * z
    return y

# ==========================================
# 4. SWEEP
# ==========================================
def calculate_stat(data_list, trim_amount=0.0):
    if not data_list: return float('nan')
    if trim_amount <= 0.0: return np.mean(data_list)
    elif trim_amount >= 0.5: return np.median(data_list)
    else: return trim_mean(data_list, proportiontocut=trim_amount)





def run_fine_grained_sweep(trim_amount=0.0):
    sigma_rels = np.logspace(np.log10(SIGMA_MIN), np.log10(SIGMA_MAX), SIGMA_STEPS)
    results = []
    print(f"Starting sweep: {len(DIMS)} dims x {len(sigma_rels)} sigma_rels (dim-coherent inv-SNR)")

    pbar = tqdm(total=len(DIMS) * len(N_REFS) * len(sigma_rels) * K_TRIALS)

    for nr in N_REFS:
        for d in DIMS:
            prior_model = HelixPrior(dim=d, n_components=100)

            # MMD kernel bandwidth (kept as you had it)
            mmd_sigma = 0.5 * math.sqrt(d / 2.0)

            sqrt_d = math.sqrt(d)

            for sigma_rel in sigma_rels:
                mmd_t_list, ksd_t_list, sw_t_list = [], [], []
                mmd_b_list, ksd_b_list, sw_b_list = [], [], []
                mmd_f_list, ksd_f_list, sw_f_list = [], [], []

                sigma_abs_list = []
                y_norm_list = []
                xstar_norm_list = []

                for _ in range(K_TRIALS):
                    # -------------------------------
                    # Dimension-coherent generative obs:
                    #   x* ~ prior
                    #   y = x* + sigma_abs eps
                    #   sigma_abs = sigma_rel * sqrt(d)
                    # Here sigma_rel is "inverse SNR" normalized by typical signal scale ~ sqrt(d).
                    # -------------------------------
                    x_star = prior_model.sample(1).squeeze(0)          # [d]
                    eps = torch.randn(d, device=device)
                    sigma_abs = float(sigma_rel) * sqrt_d
                    sigma_abs_list.append(sigma_abs)

                    obs_loc = x_star + sigma_abs * eps                 # y_obs
                    y_norm_list.append(torch.norm(obs_loc).item())
                    xstar_norm_list.append(torch.norm(x_star).item())

                    lik_model = GaussianLikelihood(obs_loc, sigma_abs)
                    true_post = ExactPosteriorGMM(prior_model, lik_model)

                    # 1. True Samples
                    X_true_1 = true_post.sample(1000)
                    X_true_2 = true_post.sample(1000)

                    def true_score_fn(x):
                        return prior_model.score0(x) + lik_model.grad_log_likelihood(x)

                    # 2. Floors
                    mmd_f_list.append(compute_mmd_gaussian(X_true_2, X_true_1, sigma=mmd_sigma).item())
                    sw_f_list.append(compute_sliced_w2(X_true_2, X_true_1))
                    ksd_f_list.append(compute_multiscale_ksd(X_true_1, true_score_fn))

                    # 3. References
                    X_ref = prior_model.sample(nr)
                    s0_prior = prior_model.score0(X_ref)
                    grad_lik = lik_model.grad_log_likelihood(X_ref)
                    s0_post_ref = s0_prior + grad_lik
                    log_lik_ref = lik_model.log_likelihood(X_ref)

                    # 4. Tweedie
                    samples_t = run_heun_sweep('tweedie', d, X_ref, s0_post_ref, log_lik_ref)
                    mmd_t_list.append(compute_mmd_gaussian(samples_t, X_true_1, sigma=mmd_sigma).item())
                    ksd_t_list.append(compute_multiscale_ksd(samples_t, true_score_fn))
                    sw_t_list.append(compute_sliced_w2(samples_t, X_true_1))

                    # 5. Blend
                    samples_b = run_heun_sweep('blend', d, X_ref, s0_post_ref, log_lik_ref)
                    mmd_b_list.append(compute_mmd_gaussian(samples_b, X_true_1, sigma=mmd_sigma).item())
                    ksd_b_list.append(compute_multiscale_ksd(samples_b, true_score_fn))
                    sw_b_list.append(compute_sliced_w2(samples_b, X_true_1))

                    pbar.update(1)

                res = {
                    'dim': d,
                    'n_ref': nr,
                    'sigma_rel': float(sigma_rel),                     # now = inv-SNR (dimension coherent)
                    'sigma_abs_median': float(np.median(sigma_abs_list)),
                    'y_norm_median': float(np.median(y_norm_list)),
                    'x_star_norm_median': float(np.median(xstar_norm_list)),

                    'mmd_floor': calculate_stat(mmd_f_list, trim_amount),
                    'ksd_floor': calculate_stat(ksd_f_list, trim_amount),
                    'sw_floor': calculate_stat(sw_f_list, trim_amount),

                    'mmd_tweedie': calculate_stat(mmd_t_list, trim_amount),
                    'ksd_tweedie': calculate_stat(ksd_t_list, trim_amount),
                    'sw_tweedie': calculate_stat(sw_t_list, trim_amount),
                    'mmd_blend': calculate_stat(mmd_b_list, trim_amount),
                    'ksd_blend': calculate_stat(ksd_b_list, trim_amount),
                    'sw_blend': calculate_stat(sw_b_list, trim_amount),
                }
                results.append(res)

    pbar.close()
    return pd.DataFrame(results)



# ==========================================
# 5. DYNAMIC PLOTTING FUNCTION (Absolute vs Floor)
# ==========================================
def plot_results(df, smoothing_window=0):
    n_rows = len(N_REFS)
    fig, axs = plt.subplots(n_rows, 3, figsize=(21, 5 * n_rows))

    if n_rows == 1:
        axs = np.array([axs])

    dim_colors = {2: 'tab:blue', 4: 'tab:orange', 8: 'tab:green',
                  16: 'tab:red', 32: 'tab:purple'}

    xcol = 'sigma_rel'
    xlabel = r'Normalized obs noise  $\sigma_{\mathrm{rel}} = \sigma_y / \|y_{\mathrm{obs}}\|$'

    def get_smooth(series):
        if smoothing_window > 0:
            return pd.Series(series).rolling(window=smoothing_window, min_periods=1, center=True).mean().values
        return series

    for i, nr in enumerate(N_REFS):
        subset = df[df['n_ref'] == nr]

        # ---------------- MMD ----------------
        ax_mmd = axs[i, 0]
        for d in DIMS:
            sub = subset[subset['dim'] == d].sort_values(xcol)

            # Floor
            floor = np.maximum(sub['mmd_floor'].to_numpy(), 1e-12)

            # Tweedie
            mt = np.maximum(sub['mmd_tweedie'].to_numpy(), 1e-12)
            y_t = np.log(mt / floor)

            # Blend
            mb = np.maximum(sub['mmd_blend'].to_numpy(), 1e-12)
            y_b = np.log(mb / floor)

            c = dim_colors.get(d, 'black')
            ax_mmd.plot(sub[xcol], get_smooth(y_b), color=c, linestyle='-', label=f'd={d} Blend')
            ax_mmd.plot(sub[xcol], get_smooth(y_t), color=c, linestyle='--', label=f'd={d} Tweedie')

        ax_mmd.set_xscale('log')
        ax_mmd.set_title(f'MMD Perf (Log Ratio to Floor) | N_ref={nr}')
        ax_mmd.set_ylabel('Log(Metric / Floor)')
        ax_mmd.set_xlabel(xlabel)
        ax_mmd.grid(True, which='both', linestyle='--', alpha=0.3)
        # Simplify legend?
        # ax_mmd.legend()

        # ---------------- KSD ----------------
        ax_ksd = axs[i, 1]
        for d in DIMS:
            sub = subset[subset['dim'] == d].sort_values(xcol)

            floor = np.maximum(sub['ksd_floor'].to_numpy(), 1e-12)

            kt = np.maximum(sub['ksd_tweedie'].to_numpy(), 1e-12)
            y_t = np.log(kt / floor)

            kb = np.maximum(sub['ksd_blend'].to_numpy(), 1e-12)
            y_b = np.log(kb / floor)

            c = dim_colors.get(d, 'black')
            ax_ksd.plot(sub[xcol], get_smooth(y_b), color=c, linestyle='-', label=f'd={d}')
            ax_ksd.plot(sub[xcol], get_smooth(y_t), color=c, linestyle='--')

        ax_ksd.set_xscale('log')
        ax_ksd.set_title(f'KSD Perf (Log Ratio to Floor) | N_ref={nr}')
        ax_ksd.set_xlabel(xlabel)
        ax_ksd.grid(True, which='both', linestyle='--', alpha=0.3)

        # ---------------- SW2 ----------------
        ax_sw = axs[i, 2]
        for d in DIMS:
            sub = subset[subset['dim'] == d].sort_values(xcol)

            floor = np.maximum(sub['sw_floor'].to_numpy(), 1e-12)

            st = np.maximum(sub['sw_tweedie'].to_numpy(), 1e-12)
            y_t = np.log(st / floor)

            sb = np.maximum(sub['sw_blend'].to_numpy(), 1e-12)
            y_b = np.log(sb / floor)

            c = dim_colors.get(d, 'black')
            ax_sw.plot(sub[xcol], get_smooth(y_b), color=c, linestyle='-', label=f'd={d}')
            ax_sw.plot(sub[xcol], get_smooth(y_t), color=c, linestyle='--')

        ax_sw.set_xscale('log')
        ax_sw.set_title(f'Sliced W2 Perf (Log Ratio to Floor) | N_ref={nr}')
        ax_sw.set_xlabel(xlabel)
        ax_sw.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax_sw.grid(True, which='both', linestyle='--', alpha=0.3)

    plt.tight_layout()
    plt.savefig('sweep_results_absolute.png', dpi=300)
    print("Plots saved to sweep_results_absolute.png")

    df.to_csv('fine_sweep_results_absolute.csv', index=False)
    print("Aggregated data saved to fine_sweep_results_absolute.csv")


if __name__ == "__main__":
    df_results = run_fine_grained_sweep(trim_amount=TRIM_AMOUNT)
    plot_results(df_results, smoothing_window=SMOOTHING)

Running on device: cuda
Starting sweep: 4 dims x 30 sigma_rels (dim-coherent inv-SNR)


  0%|          | 0/120 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import math
import time
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from scipy.stats import trim_mean

# ==========================================
# 1'. MODELS: SPECTRAL GMM PRIOR + DIAG INVERSE PROBLEM LIKELIHOOD
# ==========================================

# Sweep Settings
K_TRIALS = 20       # Number of trials to average over
SIGMA_STEPS = 24   # Number of log-uniform noise levels
SIGMA_MIN = 0.025
SIGMA_MAX = 1.0
DIMS = [3, 6, 12, 24]

# Flexible N_REFS list
N_REFS = [2000]

# Analysis Parameters
TRIM_AMOUNT = 0.15 # Outlier removal
SMOOTHING = 5   # Smoothing window



class SpectralGMMPrior(torch.nn.Module):
    """
    GMM prior in R^d with shared diagonal covariance diag(lambdas)
    and component means drawn with consistent scale across d.

    lambdas_i ~ i^{-2*alpha} (or exp decay if you want)
    """
    def __init__(self, dim, n_components=64, alpha=1.0, mean_radius=2.0, seed=None):
        super().__init__()
        self.dim = dim
        self.n_components = n_components
        if seed is not None:
            g = torch.Generator(device=device); g.manual_seed(seed)
        else:
            g = None

        i = torch.arange(1, dim + 1, device=device, dtype=torch.float32)
        lambdas = i.pow(-2.0 * alpha)  # KL-like spectrum
        self.lambdas = lambdas                     # [d]
        self.inv_lambdas = 1.0 / lambdas           # [d]
        self.logdet = torch.sum(torch.log(lambdas)).item()

        # Uniform mixture weights
        self.weights = torch.ones(n_components, device=device) / n_components
        self.log_weights = torch.log(self.weights)

        # Means in whitened coords on a sphere of radius mean_radius, then map by sqrt(lambdas)
        z = torch.randn(n_components, dim, device=device, generator=g)  # whitened
        z = z / (z.norm(dim=1, keepdim=True) + 1e-12)
        z = mean_radius * z
        self.means = z * torch.sqrt(lambdas).unsqueeze(0)               # [K,d]

    def sample(self, n):
        comp_idxs = torch.multinomial(self.weights, n, replacement=True)
        m = self.means[comp_idxs]  # [n,d]
        eps = torch.randn(n, self.dim, device=device)
        return m + eps * torch.sqrt(self.lambdas).unsqueeze(0)

    def log_prob(self, x):
        """
        log p(x) for GMM with diag cov
        """
        # x: [N,d], means: [K,d]
        N, D = x.shape
        diff = x.unsqueeze(1) - self.means.unsqueeze(0)  # [N,K,d]
        quad = torch.sum(diff * diff * self.inv_lambdas.view(1, 1, D), dim=2)  # [N,K]
        log_norm = -0.5 * (D * math.log(2 * math.pi) + self.logdet)
        logp_k = self.log_weights.view(1, -1) + log_norm - 0.5 * quad
        return torch.logsumexp(logp_k, dim=1)  # [N]

    def score0(self, x):
        """
        grad_x log p(x) computed analytically.
        score = sum_k r_k(x) * (-(x - mu_k) / lambdas)
        """
        N, D = x.shape
        diff = x.unsqueeze(1) - self.means.unsqueeze(0)  # [N,K,d]
        quad = torch.sum(diff * diff * self.inv_lambdas.view(1, 1, D), dim=2)  # [N,K]
        log_norm = -0.5 * (D * math.log(2 * math.pi) + self.logdet)
        logp_k = self.log_weights.view(1, -1) + log_norm - 0.5 * quad          # [N,K]
        r = torch.softmax(logp_k, dim=1)                                       # [N,K]
        # component scores: -(x - mu_k)/lambda
        comp_scores = -diff * self.inv_lambdas.view(1, 1, D)                   # [N,K,d]
        score = torch.sum(r.unsqueeze(2) * comp_scores, dim=1)                 # [N,d]
        return score


class LinearDiagGaussianLikelihood:
    """
    y = A x + sigma eps, with A = diag(a_diag), eps ~ N(0, I)
    loglik(x) = -||A x - y||^2 /(2 sigma^2)
    grad loglik(x) = -A^T(Ax - y)/sigma^2 = -(a^2 * x - a*y)/sigma^2
    """
    def __init__(self, y_obs, sigma_noise, a_diag):
        self.y_obs = y_obs                  # [d]
        self.sigma_noise = float(sigma_noise)
        self.a = a_diag                     # [d]
        self.a2 = a_diag * a_diag           # [d]

    def log_likelihood(self, x):
        # x: [N,d]
        diff = self.a.view(1, -1) * x - self.y_obs.view(1, -1)
        sq = torch.sum(diff * diff, dim=1)
        return -sq / (2.0 * (self.sigma_noise ** 2))

    def grad_log_likelihood(self, x):
        # -(a^2*x - a*y)/sigma^2
        return -(self.a2.view(1, -1) * x - self.a.view(1, -1) * self.y_obs.view(1, -1)) / (self.sigma_noise ** 2)


class ExactPosteriorDiagGMM:
    """
    Exact posterior for:
      prior: sum_k w_k N(mu_k, diag(lambdas))
      likelihood: y | x ~ N(Ax, sigma^2 I) with A diag.

    Posterior is also a GMM with same K, component posterior cov diag(var_post) shared across k.
    """
    def __init__(self, prior: SpectralGMMPrior, lik: LinearDiagGaussianLikelihood):
        self.dim = prior.dim
        self.K = prior.n_components

        lambdas = prior.lambdas            # [d]
        inv_l = prior.inv_lambdas          # [d]
        a = lik.a                          # [d]
        a2 = lik.a2                        # [d]
        sig2 = lik.sigma_noise ** 2
        y = lik.y_obs                      # [d]

        # shared posterior diag covariance:
        inv_var_post = inv_l + a2 / sig2
        self.var_post = 1.0 / inv_var_post                      # [d]
        self.std_post = torch.sqrt(self.var_post)               # [d]

        # component means:
        # m_k = var_post * (mu_k/lambda + a*y/sigma^2)
        term_data = (a * y) / sig2                              # [d]
        self.means = self.var_post.view(1, -1) * (prior.means * inv_l.view(1, -1) + term_data.view(1, -1))  # [K,d]

        # component weights:
        # w_k ∝ w_k_prior * N(y | A mu_k, diag(a^2 lambda + sigma^2))
        var_y = a2 * lambdas + sig2                             # [d]
        inv_var_y = 1.0 / var_y
        logdet_y = torch.sum(torch.log(var_y)).item()
        Ay_mu = a.view(1, -1) * prior.means                     # [K,d]
        diff = y.view(1, -1) - Ay_mu                            # [K,d]
        quad = torch.sum(diff * diff * inv_var_y.view(1, -1), dim=1)  # [K]
        log_norm = -0.5 * (self.dim * math.log(2 * math.pi) + logdet_y)
        logw_unnorm = prior.log_weights + log_norm - 0.5 * quad
        self.weights = torch.softmax(logw_unnorm, dim=0)         # [K]

    def sample(self, n):
        comp = torch.multinomial(self.weights, n, replacement=True)   # [n]
        m = self.means[comp]                                         # [n,d]
        eps = torch.randn(n, self.dim, device=device)
        return m + eps * self.std_post.view(1, -1)


# ==========================================
# 4'. SWEEP (Spectral-GMM inverse problem family)
# ==========================================
# family hyperparams (tune once, then keep fixed)
K_COMPONENTS = 64
ALPHA_LAMBDA = 1.0    # prior eigen decay
BETA_A = 1.0          # forward singular decay
MEAN_RADIUS = 2.0     # mixture separation scale
SIGNAL_MC = 4096      # MC samples to estimate E||Ax||^2 per d

def build_a_diag(d, beta=BETA_A, a0=1.0):
    i = torch.arange(1, d + 1, device=device, dtype=torch.float32)
    return a0 * i.pow(-beta)   # [d]

def estimate_signal_scale(prior_model: SpectralGMMPrior, a_diag, n_mc=SIGNAL_MC):
    with torch.no_grad():
        x = prior_model.sample(n_mc)                        # [n,d]
        Ax = x * a_diag.view(1, -1)
        s2 = torch.mean(torch.sum(Ax * Ax, dim=1)).item()
    return math.sqrt(max(s2, 1e-12))

def run_fine_grained_sweep_spectral_gmm(trim_amount=0.0):
    sigma_rels = np.logspace(np.log10(SIGMA_MIN), np.log10(SIGMA_MAX), SIGMA_STEPS)
    results = []
    print(f"Starting spectral-GMM sweep: |DIMS|={len(DIMS)} x |sigma|={len(sigma_rels)} x trials={K_TRIALS}")

    pbar = tqdm(total=len(DIMS) * len(N_REFS) * len(sigma_rels) * K_TRIALS)

    for nr in N_REFS:
        for d in DIMS:
            prior_model = SpectralGMMPrior(
                dim=d,
                n_components=K_COMPONENTS,
                alpha=ALPHA_LAMBDA,
                mean_radius=MEAN_RADIUS,
            )

            a_diag = build_a_diag(d, beta=BETA_A, a0=1.0)     # [d]
            signal_scale = estimate_signal_scale(prior_model, a_diag)

            # MMD bandwidth (keep your previous scaling)
            mmd_sigma = 0.5 * math.sqrt(d / 2.0)

            for sigma_rel in sigma_rels:
                mmd_t_list, ksd_t_list, sw_t_list = [], [], []
                mmd_b_list, ksd_b_list, sw_b_list = [], [], []
                mmd_f_list, ksd_f_list, sw_f_list = [], [], []

                sigma_abs_list = []

                for _ in range(K_TRIALS):
                    # dimensionless noise -> absolute sigma via signal scale
                    sigma_abs = float(sigma_rel) * signal_scale
                    sigma_abs_list.append(sigma_abs)

                    # generative observation: x* ~ prior, y = A x* + sigma eps
                    x_star = prior_model.sample(1).squeeze(0)         # [d]
                    y_obs = a_diag * x_star + sigma_abs * torch.randn(d, device=device)

                    lik_model = LinearDiagGaussianLikelihood(y_obs, sigma_abs, a_diag)
                    true_post = ExactPosteriorDiagGMM(prior_model, lik_model)

                    # True samples (two sets for floors)
                    X_true_1 = true_post.sample(1000)
                    X_true_2 = true_post.sample(1000)

                    def true_score_fn(x):
                        return prior_model.score0(x) + lik_model.grad_log_likelihood(x)

                    # Floors (true vs true)
                    mmd_f_list.append(compute_mmd_gaussian(X_true_2, X_true_1, sigma=mmd_sigma).item())
                    sw_f_list.append(compute_sliced_w2(X_true_2, X_true_1))
                    ksd_f_list.append(compute_multiscale_ksd(X_true_1, true_score_fn))

                    # References
                    X_ref = prior_model.sample(nr)
                    s0_prior = prior_model.score0(X_ref)
                    grad_lik = lik_model.grad_log_likelihood(X_ref)
                    s0_post_ref = s0_prior + grad_lik
                    log_lik_ref = lik_model.log_likelihood(X_ref)

                    # Tweedie
                    samples_t = run_heun_sweep('tweedie', d, X_ref, s0_post_ref, log_lik_ref)
                    mmd_t_list.append(compute_mmd_gaussian(samples_t, X_true_1, sigma=mmd_sigma).item())
                    ksd_t_list.append(compute_multiscale_ksd(samples_t, true_score_fn))
                    sw_t_list.append(compute_sliced_w2(samples_t, X_true_1))

                    # Blend
                    samples_b = run_heun_sweep('blend', d, X_ref, s0_post_ref, log_lik_ref)
                    mmd_b_list.append(compute_mmd_gaussian(samples_b, X_true_1, sigma=mmd_sigma).item())
                    ksd_b_list.append(compute_multiscale_ksd(samples_b, true_score_fn))
                    sw_b_list.append(compute_sliced_w2(samples_b, X_true_1))

                    pbar.update(1)

                res = {
                    'dim': d,
                    'n_ref': nr,
                    'sigma_rel': float(sigma_rel),                # dimensionless
                    'sigma_abs_median': float(np.median(sigma_abs_list)),
                    'signal_scale': float(signal_scale),

                    'mmd_floor': calculate_stat(mmd_f_list, trim_amount),
                    'ksd_floor': calculate_stat(ksd_f_list, trim_amount),
                    'sw_floor': calculate_stat(sw_f_list, trim_amount),

                    'mmd_tweedie': calculate_stat(mmd_t_list, trim_amount),
                    'ksd_tweedie': calculate_stat(ksd_t_list, trim_amount),
                    'sw_tweedie': calculate_stat(sw_t_list, trim_amount),
                    'mmd_blend': calculate_stat(mmd_b_list, trim_amount),
                    'ksd_blend': calculate_stat(ksd_b_list, trim_amount),
                    'sw_blend': calculate_stat(sw_b_list, trim_amount),
                }
                results.append(res)

    pbar.close()
    return pd.DataFrame(results)



import matplotlib.gridspec as gridspec

def plot_advantage_and_perf(df, smoothing_window=0, out_prefix="spectral_gmm"):
    """
    Clean regime figure (paper-friendly):

      - Two stacked 2x2 blocks (overall 4 rows x 2 cols)
        Block 1 (top):    cols d=[1,4],  rows [MMD, KSD]
        Block 2 (bottom): cols d=[16,64], rows [MMD, KSD]
      - Lines: Blend = blue solid, Tweedie = red solid
      - Only plots log(metric / floor)
      - Shared y-scale across ALL panels using global min/max over all plotted data (+ padding)
      - Shorter, less-busy labels
      - Fixes spacing so labels are not clipped

    Saves: {out_prefix}_perf_grid_Nref{nr}.png
    Also writes: {out_prefix}_results.csv
    """
    xcol = 'sigma_rel'
    # Short bottom label (per your request)
    xlabel = r'Noise level $\sigma/\sqrt{\mathbb{E}\|Ax\|^2}$'

    dims_show = DIMS # Changed this line
    top_dims = [DIMS[0], DIMS[1]] # Changed this line
    bot_dims = [DIMS[2], DIMS[3]] # Changed this line

    def get_smooth(arr):
        if smoothing_window > 0:
            return pd.Series(arr).rolling(window=smoothing_window, min_periods=1, center=True).mean().values
        return arr

    def safe_log_ratio(num, den):
        num = np.maximum(np.asarray(num), 1e-12)
        den = np.maximum(np.asarray(den), 1e-12)
        out = np.log(num / den)
        out[~np.isfinite(out)] = np.nan
        return out

    # Paper fonts, slightly smaller y-label text to reduce clutter
    rc = {
        "font.size": 18,
        "axes.titlesize": 18,
        "axes.labelsize": 18,
        "xtick.labelsize": 12,
        "ytick.labelsize": 11,   # << slightly smaller
        "legend.fontsize": 13,
    }

    for nr in N_REFS:
        sub_nr = df[df['n_ref'] == nr]
        sub_nr = sub_nr[sub_nr['dim'].isin(dims_show)]

        # ---- compute global y-lims across all plotted data ----
        y_all = []
        x_all = []

        for d in dims_show:
            sub = sub_nr[sub_nr['dim'] == d].sort_values(xcol)
            if len(sub) == 0:
                continue

            x_all.append(sub[xcol].to_numpy())

            m_floor = np.maximum(sub['mmd_floor'].to_numpy(), 1e-12)
            k_floor = np.maximum(sub['ksd_floor'].to_numpy(), 1e-12)

            y_all.append(safe_log_ratio(sub['mmd_blend'].to_numpy(),   m_floor))
            y_all.append(safe_log_ratio(sub['mmd_tweedie'].to_numpy(), m_floor))
            y_all.append(safe_log_ratio(sub['ksd_blend'].to_numpy(),   k_floor))
            y_all.append(safe_log_ratio(sub['ksd_tweedie'].to_numpy(), k_floor))

        if not y_all or not x_all:
            raise ValueError("No data found for the requested dims to set limits.")

        y_all = np.concatenate(y_all)
        y_all = y_all[np.isfinite(y_all)]
        if y_all.size == 0:
            raise ValueError("All values are non-finite; cannot set y-limits.")

        y_min, y_max = float(np.min(y_all)), float(np.max(y_all))
        pad = 0.05 * (y_max - y_min + 1e-12)
        ylims = (y_min - pad, y_max + pad)

        x_all = np.concatenate(x_all)
        x_all = x_all[np.isfinite(x_all)]
        xlims = (float(np.min(x_all)), float(np.max(x_all)))

        with plt.rc_context(rc):
            # ---- layout: two 2x2 blocks stacked with a bigger gap ----
            fig = plt.figure(figsize=(14, 14))

            outer = gridspec.GridSpec(
                2, 1, height_ratios=[1, 1],
                hspace=0.42  # << more space between row 2 and row 3 (block gap)
            )

            inner_top = gridspec.GridSpecFromSubplotSpec(
                2, 2, subplot_spec=outer[0],
                wspace=0.12, hspace=0.14  # << a bit more within-block spacing
            )
            inner_bot = gridspec.GridSpecFromSubplotSpec(
                2, 2, subplot_spec=outer[1],
                wspace=0.12, hspace=0.14
            )

            axs_top = np.array([
                [fig.add_subplot(inner_top[0, 0]), fig.add_subplot(inner_top[0, 1])],
                [fig.add_subplot(inner_top[1, 0]), fig.add_subplot(inner_top[1, 1])]
            ])
            axs_bot = np.array([
                [fig.add_subplot(inner_bot[0, 0]), fig.add_subplot(inner_bot[0, 1])],
                [fig.add_subplot(inner_bot[1, 0]), fig.add_subplot(inner_bot[1, 1])]
            ])

            def style_ax(ax):
                ax.set_xscale('log')
                ax.set_xlim(*xlims)
                ax.set_ylim(*ylims)
                ax.grid(True, which='both', linestyle='--', alpha=0.25)

            def plot_dim_block(axs_block, dims_block, put_legend=False):
                for j, d in enumerate(dims_block):
                    sub = sub_nr[sub_nr['dim'] == d].sort_values(xcol)

                    ax_m = axs_block[0, j]
                    ax_k = axs_block[1, j]

                    # Slightly tighter title padding so it sits nicely with the block gap
                    ax_m.set_title(f"d={d}", pad=6)

                    if len(sub) == 0:
                        style_ax(ax_m); style_ax(ax_k)
                        continue

                    x = sub[xcol].to_numpy()

                    m_floor = np.maximum(sub['mmd_floor'].to_numpy(), 1e-12)
                    k_floor = np.maximum(sub['ksd_floor'].to_numpy(), 1e-12)

                    m_b = safe_log_ratio(sub['mmd_blend'].to_numpy(),   m_floor)
                    m_t = safe_log_ratio(sub['mmd_tweedie'].to_numpy(), m_floor)
                    k_b = safe_log_ratio(sub['ksd_blend'].to_numpy(),   k_floor)
                    k_t = safe_log_ratio(sub['ksd_tweedie'].to_numpy(), k_floor)

                    ax_m.plot(x, get_smooth(m_b), color='tab:blue', linewidth=2.6, label='Blend')
                    ax_m.plot(x, get_smooth(m_t), color='tab:red',  linewidth=2.6, label='Tweedie')

                    ax_k.plot(x, get_smooth(k_b), color='tab:blue', linewidth=2.6)
                    ax_k.plot(x, get_smooth(k_t), color='tab:red',  linewidth=2.6)

                    style_ax(ax_m); style_ax(ax_k)

                if put_legend:
                    axs_block[0, 0].legend(loc='upper left', frameon=True)

            plot_dim_block(axs_top, top_dims, put_legend=True)
            plot_dim_block(axs_bot, bot_dims, put_legend=False)

            # ---- short, less-busy y labels ----
            axs_top[0, 0].set_ylabel("MMD: log(/floor)")
            axs_top[1, 0].set_ylabel("KSD: log(/floor)")
            axs_bot[0, 0].set_ylabel("MMD: log(/floor)")
            axs_bot[1, 0].set_ylabel("KSD: log(/floor)")

            # xlabels only on the bottom row of each block
            for ax in [axs_top[1, 0], axs_top[1, 1], axs_bot[1, 0], axs_bot[1, 1]]:
                ax.set_xlabel(xlabel)

            # Extra bottom margin so xlabel never clips; avoid bbox_inches tight squeezing labels
            fig.subplots_adjust(left=0.10, right=0.98, top=0.95, bottom=0.08)

            plt.savefig(f"{out_prefix}_perf_grid_Nref{nr}.png", dpi=300)
            plt.close(fig)

    df.to_csv(f"{out_prefix}_results.csv", index=False)
    print(f"Saved: {out_prefix}_perf_grid_*.png and {out_prefix}_results.csv")




# ==========================================
# MAIN
# ==========================================
if __name__ == "__main__":
    df_results = run_fine_grained_sweep_spectral_gmm(trim_amount=TRIM_AMOUNT)
    plot_advantage_and_perf(df_results, smoothing_window=SMOOTHING, out_prefix="spectral_gmm_sweep")

Starting spectral-GMM sweep: |DIMS|=4 x |sigma|=24 x trials=20



  0%|          | 0/1920 [00:00<?, ?it/s][A
  0%|          | 1/1920 [00:00<08:23,  3.81it/s][A
  0%|          | 2/1920 [00:00<06:45,  4.73it/s][A
  0%|          | 3/1920 [00:00<06:02,  5.28it/s][A
  0%|          | 4/1920 [00:00<05:42,  5.59it/s][A
  0%|          | 5/1920 [00:00<05:29,  5.81it/s][A
  0%|          | 6/1920 [00:01<05:21,  5.95it/s][A
  0%|          | 7/1920 [00:01<05:16,  6.05it/s][A
  0%|          | 8/1920 [00:01<05:12,  6.11it/s][A
  0%|          | 9/1920 [00:01<05:10,  6.16it/s][A
  1%|          | 10/1920 [00:01<05:08,  6.18it/s][A
  1%|          | 11/1920 [00:01<05:07,  6.21it/s][A
  1%|          | 12/1920 [00:02<05:06,  6.22it/s][A
  1%|          | 13/1920 [00:02<05:05,  6.23it/s][A
  1%|          | 14/1920 [00:02<05:05,  6.24it/s][A
  1%|          | 15/1920 [00:02<05:04,  6.25it/s][A
  1%|          | 16/1920 [00:02<05:04,  6.25it/s][A
  1%|          | 17/1920 [00:02<05:05,  6.24it/s][A
  1%|          | 18/1920 [00:03<05:05,  6.23it/s][A
  1%|     

Saved: spectral_gmm_sweep_perf_grid_*.png and spectral_gmm_sweep_results.csv
