# Bayesian IGNO: HMC Posterior Sampling Experiment

This notebook implements Hamiltonian Monte Carlo (via Pyro's NUTS) for uncertainty quantification in IGNO inverse problems.

## Key Idea

Instead of finding a point estimate $\beta^* = \arg\min F(\beta)$, we **sample** from the posterior:

$$p(\beta | u_{obs}) \propto \exp(-U(\beta))$$

where the potential energy is:

$$U(\beta) = \underbrace{\rho_{data} \cdot L_{data}(\beta)}_{\text{likelihood}} + \underbrace{\rho_{pde} \cdot L_{pde}(\beta)}_{\text{physics prior}} + \underbrace{(-\log p_{NF}(\beta))}_{\text{flow prior}}$$

The flow prior term $-\log p_{NF}(\beta) = \frac{1}{2}\|z\|^2 - \log|\det J|$ where $z = NF(\beta)$.

## 1. Setup & Imports

In [None]:
import sys
from pathlib import Path

# Add project root to path (adjust as needed)
PROJECT_ROOT = Path(".").resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

# Pyro for NUTS
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

# Our modules
from src.solver.config import TrainingConfig
from src.problems import create_problem

# Plotting setup
plt.rcParams['figure.figsize'] = (12, 4)
plt.rcParams['figure.dpi'] = 100

print(f"PyTorch: {torch.__version__}")
print(f"Pyro: {pyro.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Load Pretrained Model

In [None]:
# Configuration
CHECKPOINT_PATH = PROJECT_ROOT / "runs" / "18_dims" / "weights" / "best.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# You'll need a config file or create one programmatically
# For now, let's create a minimal config
config_dict = {
    'problem': {
        'type': 'darcy_continuous',
        'train_data': None,  # Not needed for evaluation
        'test_data': str(PROJECT_ROOT / 'data' / 'darcy' / 'test'),  # Adjust path
    },
    'device': DEVICE,
    'seed': SEED,
}

print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Device: {DEVICE}")

In [None]:
# Create problem and load checkpoint
# Note: You may need to adjust this based on your actual config setup

config = TrainingConfig.from_dict(config_dict)
problem = create_problem(config, load_train_data=False)

# Load pretrained weights
problem.load_checkpoint(CHECKPOINT_PATH)

# Freeze all models
problem.eval_mode()
for name, model in problem.model_dict.items():
    for p in model.parameters():
        p.requires_grad = False
    print(f"{name}: {sum(p.numel() for p in model.parameters()):,} params (frozen)")

print(f"\nLatent dimension: {problem.BETA_SIZE}")
print(f"Test samples available: {problem.get_n_test_samples()}")

## 3. Select Test Sample and Prepare Observations

In [None]:
# Pick a single test sample to work with
SAMPLE_IDX = 0  # Change this to experiment with different samples

# Observation setup (matching your evaluation config)
N_OBS = 100  # Number of observation points
SNR_DB = 30  # Signal-to-noise ratio (None for clean observations)

# Sample observation indices
n_points = problem.get_n_points()
obs_indices = problem.sample_observation_indices(
    n_total=n_points,
    n_obs=N_OBS,
    method='random'
)

# Prepare observations for this sample
obs_data = problem.prepare_observations(
    sample_indices=[SAMPLE_IDX],
    obs_indices=obs_indices,
    snr_db=SNR_DB
)

# Extract tensors (squeeze batch dimension since we're doing single sample)
x_obs = obs_data['x_obs']      # (1, n_obs, 2)
u_obs = obs_data['u_obs']      # (1, n_obs, 1)
x_full = obs_data['x_full']    # (1, n_points, 2)
u_true = obs_data['u_true']    # (1, n_points, 1)
a_true = obs_data['a_true']    # (1, n_points, 1)

print(f"Sample index: {SAMPLE_IDX}")
print(f"Observation points: {N_OBS} / {n_points}")
print(f"SNR: {SNR_DB} dB" if SNR_DB else "Clean observations")
print(f"\nShapes:")
print(f"  x_obs: {x_obs.shape}")
print(f"  u_obs: {u_obs.shape}")
print(f"  x_full: {x_full.shape}")

In [None]:
# Visualize the true coefficient field and observation locations
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Reshape for plotting (assuming 29x29 grid)
grid_size = int(np.sqrt(n_points))
a_true_2d = a_true[0].cpu().reshape(grid_size, grid_size)
u_true_2d = u_true[0].cpu().reshape(grid_size, grid_size)

# True coefficient
im0 = axes[0].imshow(a_true_2d, origin='lower', extent=[0, 1, 0, 1], cmap='viridis')
axes[0].set_title('True Coefficient $a(x)$')
plt.colorbar(im0, ax=axes[0])

# True solution
im1 = axes[1].imshow(u_true_2d, origin='lower', extent=[0, 1, 0, 1], cmap='coolwarm')
axes[1].set_title('True Solution $u(x)$')
plt.colorbar(im1, ax=axes[1])

# Observation locations
x_obs_np = x_obs[0].cpu().numpy()
axes[2].scatter(x_obs_np[:, 0], x_obs_np[:, 1], c='red', s=10, alpha=0.7)
axes[2].set_xlim(0, 1)
axes[2].set_ylim(0, 1)
axes[2].set_aspect('equal')
axes[2].set_title(f'Observation Locations (n={N_OBS})')

plt.tight_layout()
plt.show()

## 4. Define Potential Energy Function

The potential energy (negative log-posterior) is:

$$U(\beta) = w_{data} \cdot L_{data}(\beta) + w_{pde} \cdot L_{pde}(\beta) - \log p_{NF}(\beta)$$

Where:
- $L_{data}$ = data mismatch loss (relative L2)
- $L_{pde}$ = PDE residual loss  
- $\log p_{NF}(\beta)$ = log probability from normalizing flow

In [None]:
# Loss weights (same as your evaluation config)
W_DATA = 1.0
W_PDE = 1.0

def compute_potential_energy(beta: torch.Tensor) -> torch.Tensor:
    """
    Compute U(β) = -log p(β | u_obs) + const
    
    Args:
        beta: Latent representation (1, latent_dim) or (latent_dim,)
        
    Returns:
        Scalar potential energy
    """
    # Ensure batch dimension
    if beta.dim() == 1:
        beta = beta.unsqueeze(0)
    
    # Data loss: ||u_pred - u_obs||² / ||u_obs||² (relative L2)
    loss_data = problem.loss_data_from_beta(beta, x_obs, u_obs, target_type='u')
    
    # PDE loss: weak form residual
    loss_pde = problem.loss_pde_from_beta(beta)
    
    # Flow prior: -log p_NF(β)
    # log_prob_latent returns log p(β), so we negate it
    log_prior = problem.log_prob_latent(beta)  # (batch,)
    neg_log_prior = -log_prior.mean()  # Scalar
    
    # Total potential energy
    U = W_DATA * loss_data + W_PDE * loss_pde + neg_log_prior
    
    return U


def compute_log_prob(beta: torch.Tensor) -> torch.Tensor:
    """
    Compute log p(β | u_obs) = -U(β) for Pyro.
    
    Pyro's NUTS expects log probability, not potential energy.
    """
    return -compute_potential_energy(beta)

In [None]:
# Test the potential energy function
print("Testing potential energy computation...")

# Sample a random beta from the NF prior
beta_test = problem.sample_latent_from_nf(num_samples=1)
print(f"β shape: {beta_test.shape}")
print(f"β range: [{beta_test.min():.3f}, {beta_test.max():.3f}]")

# Compute potential energy
beta_test.requires_grad_(True)
U = compute_potential_energy(beta_test)
print(f"\nU(β) = {U.item():.4f}")

# Verify gradient exists
U.backward()
print(f"∇U(β) norm: {beta_test.grad.norm().item():.4f}")
print(f"∇U(β) range: [{beta_test.grad.min():.4f}, {beta_test.grad.max():.4f}]")

# Check individual loss components
beta_test2 = problem.sample_latent_from_nf(num_samples=1)
with torch.no_grad():
    l_data = problem.loss_data_from_beta(beta_test2, x_obs, u_obs, 'u')
    l_pde = problem.loss_pde_from_beta(beta_test2)
    l_prior = -problem.log_prob_latent(beta_test2).mean()
    
print(f"\nLoss components:")
print(f"  L_data: {l_data.item():.4f}")
print(f"  L_pde:  {l_pde.item():.4f}")
print(f"  -log p(β): {l_prior.item():.4f}")

## 5. Run Point Estimation (Baseline)

First, let's run the standard IGNO inversion to get a baseline point estimate.

In [None]:
from src.evaluation import IGNOInverter
from src.solver.config import InversionConfig, LossWeights, OptimizerConfig

# Create inversion config
inv_config = InversionConfig(
    epochs=500,
    loss_weights=LossWeights(pde=W_PDE, data=W_DATA),
    optimizer=OptimizerConfig(type='Adam', lr=0.01),
)

# Run IGNO inversion
nf = problem.model_dict['nf']
inverter = IGNOInverter(problem, nf)

print("Running IGNO point estimation...")
beta_point = inverter.invert(
    x_obs=x_obs,
    u_obs=u_obs,
    x_full=x_full,
    config=inv_config,
    verbose=True
)

print(f"\nPoint estimate β shape: {beta_point.shape}")

In [None]:
# Get predictions from point estimate
preds_point = problem.predict_from_beta(beta_point, x_full)
a_point = preds_point['a_pred']  # (1, n_points, 1)
u_point = preds_point['u_pred']  # (1, n_points, 1)

# Compute metrics
from src.evaluation import compute_all_metrics

metrics_a_point = compute_all_metrics(a_point[0], a_true[0])
metrics_u_point = compute_all_metrics(u_point[0], u_true[0])

print("Point Estimate Metrics:")
print(f"  Coefficient a: RMSE={metrics_a_point['rmse']:.6f}, RelL2={metrics_a_point['rel_l2']:.6f}")
print(f"  Solution u:    RMSE={metrics_u_point['rmse']:.6f}, RelL2={metrics_u_point['rel_l2']:.6f}")

## 6. Define Pyro Model for NUTS

In [None]:
LATENT_DIM = problem.BETA_SIZE

def pyro_potential_fn(params):
    """
    Potential function for Pyro's NUTS.
    
    Pyro NUTS with `potential_fn` expects a function that takes a dict
    of parameters and returns the potential energy (negative log prob).
    """
    beta = params['beta']
    return compute_potential_energy(beta)


def run_hmc_sampling(
    num_samples: int = 500,
    warmup_steps: int = 200,
    num_chains: int = 1,
    init_beta: torch.Tensor = None,
):
    """
    Run NUTS sampling using Pyro.
    
    Args:
        num_samples: Number of posterior samples to collect
        warmup_steps: Number of warmup/adaptation steps
        num_chains: Number of parallel chains
        init_beta: Initial beta value (if None, sample from NF prior)
        
    Returns:
        Dict with samples and diagnostics
    """
    # Initialize from NF prior if not provided
    if init_beta is None:
        init_beta = problem.sample_latent_from_nf(num_samples=1).squeeze(0)
    
    # Initial params dict
    init_params = {'beta': init_beta.clone()}
    
    # Create NUTS kernel with potential function
    nuts_kernel = NUTS(
        potential_fn=pyro_potential_fn,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        full_mass=False,  # Diagonal mass matrix (faster)
        max_tree_depth=10,
        target_accept_prob=0.8,
    )
    
    # Run MCMC
    mcmc = MCMC(
        nuts_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        num_chains=num_chains,
        initial_params=init_params,
        disable_progbar=False,
    )
    
    print(f"Running NUTS: {warmup_steps} warmup + {num_samples} samples")
    mcmc.run()
    
    # Get samples
    samples = mcmc.get_samples()
    beta_samples = samples['beta']  # (num_samples, latent_dim)
    
    # Get diagnostics
    diagnostics = {
        'step_size': nuts_kernel.step_size,
        'accept_prob': nuts_kernel.num_accepts / (num_samples + warmup_steps),
    }
    
    return {
        'beta_samples': beta_samples,
        'mcmc': mcmc,
        'diagnostics': diagnostics,
    }

## 7. Run NUTS Sampling

In [None]:
# Clear any previous Pyro state
pyro.clear_param_store()

# Run NUTS
# Start with fewer samples for testing, increase later
results = run_hmc_sampling(
    num_samples=500,    # Posterior samples to collect
    warmup_steps=200,   # Adaptation phase
    num_chains=1,
    init_beta=beta_point.squeeze(0),  # Initialize at point estimate
)

beta_samples = results['beta_samples']
print(f"\nCollected {beta_samples.shape[0]} samples")
print(f"Sample shape: {beta_samples.shape}")

In [None]:
# Print MCMC diagnostics
mcmc = results['mcmc']
mcmc.summary()

## 8. Analyze Posterior Samples

In [None]:
# Trace plots for first few dimensions
n_dims_to_plot = min(6, LATENT_DIM)
fig, axes = plt.subplots(2, n_dims_to_plot, figsize=(15, 6))

for i in range(n_dims_to_plot):
    # Trace plot
    axes[0, i].plot(beta_samples[:, i].cpu().numpy(), alpha=0.7)
    axes[0, i].axhline(beta_point[0, i].cpu().item(), color='r', linestyle='--', label='Point est.')
    axes[0, i].set_title(f'β[{i}] trace')
    axes[0, i].set_xlabel('Sample')
    
    # Histogram
    axes[1, i].hist(beta_samples[:, i].cpu().numpy(), bins=30, density=True, alpha=0.7)
    axes[1, i].axvline(beta_point[0, i].cpu().item(), color='r', linestyle='--', label='Point est.')
    axes[1, i].set_title(f'β[{i}] posterior')
    axes[1, i].set_xlabel('Value')

axes[0, 0].legend()
plt.tight_layout()
plt.show()

In [None]:
# Compute posterior statistics for β
beta_mean = beta_samples.mean(dim=0)
beta_std = beta_samples.std(dim=0)

print("Posterior statistics for β:")
print(f"  Mean: {beta_mean[:5].cpu().numpy()} ...")
print(f"  Std:  {beta_std[:5].cpu().numpy()} ...")
print(f"\n  Point estimate: {beta_point[0, :5].cpu().numpy()} ...")

## 9. Transform to Coefficient Field Samples

In [None]:
# Decode all β samples to coefficient fields
# This might take a moment for many samples

print(f"Decoding {len(beta_samples)} β samples to coefficient fields...")

a_samples = []
u_samples = []

batch_size = 50  # Process in batches for efficiency
with torch.no_grad():
    for i in trange(0, len(beta_samples), batch_size, desc="Decoding"):
        batch_beta = beta_samples[i:i+batch_size]
        # Expand x_full to match batch size
        batch_x = x_full.expand(len(batch_beta), -1, -1)
        
        preds = problem.predict_from_beta(batch_beta, batch_x)
        a_samples.append(preds['a_pred'])
        u_samples.append(preds['u_pred'])

a_samples = torch.cat(a_samples, dim=0)  # (n_samples, n_points, 1)
u_samples = torch.cat(u_samples, dim=0)

print(f"\na_samples shape: {a_samples.shape}")
print(f"u_samples shape: {u_samples.shape}")

In [None]:
# Compute posterior statistics for coefficient field
a_posterior_mean = a_samples.mean(dim=0)  # (n_points, 1)
a_posterior_std = a_samples.std(dim=0)

u_posterior_mean = u_samples.mean(dim=0)
u_posterior_std = u_samples.std(dim=0)

# Also get quantiles for credible intervals
a_lower = torch.quantile(a_samples, 0.025, dim=0)
a_upper = torch.quantile(a_samples, 0.975, dim=0)

print("Posterior statistics computed!")
print(f"  a mean range: [{a_posterior_mean.min():.4f}, {a_posterior_mean.max():.4f}]")
print(f"  a std range:  [{a_posterior_std.min():.4f}, {a_posterior_std.max():.4f}]")

## 10. Visualize Results

In [None]:
# Reshape for 2D plotting
def to_2d(tensor, size=grid_size):
    return tensor.cpu().squeeze().reshape(size, size)

# Create comparison plot
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Row 1: Coefficient field a
vmin_a, vmax_a = a_true[0].min().item(), a_true[0].max().item()

im00 = axes[0, 0].imshow(to_2d(a_true[0]), origin='lower', extent=[0,1,0,1], vmin=vmin_a, vmax=vmax_a, cmap='viridis')
axes[0, 0].set_title('True $a(x)$')
plt.colorbar(im00, ax=axes[0, 0])

im01 = axes[0, 1].imshow(to_2d(a_point[0]), origin='lower', extent=[0,1,0,1], vmin=vmin_a, vmax=vmax_a, cmap='viridis')
axes[0, 1].set_title('Point Estimate $a(x)$')
plt.colorbar(im01, ax=axes[0, 1])

im02 = axes[0, 2].imshow(to_2d(a_posterior_mean), origin='lower', extent=[0,1,0,1], vmin=vmin_a, vmax=vmax_a, cmap='viridis')
axes[0, 2].set_title('Posterior Mean $a(x)$')
plt.colorbar(im02, ax=axes[0, 2])

im03 = axes[0, 3].imshow(to_2d(a_posterior_std), origin='lower', extent=[0,1,0,1], cmap='hot')
axes[0, 3].set_title('Posterior Std $a(x)$')
plt.colorbar(im03, ax=axes[0, 3])

# Row 2: Solution field u
vmin_u, vmax_u = u_true[0].min().item(), u_true[0].max().item()

im10 = axes[1, 0].imshow(to_2d(u_true[0]), origin='lower', extent=[0,1,0,1], vmin=vmin_u, vmax=vmax_u, cmap='coolwarm')
axes[1, 0].set_title('True $u(x)$')
plt.colorbar(im10, ax=axes[1, 0])

im11 = axes[1, 1].imshow(to_2d(u_point[0]), origin='lower', extent=[0,1,0,1], vmin=vmin_u, vmax=vmax_u, cmap='coolwarm')
axes[1, 1].set_title('Point Estimate $u(x)$')
plt.colorbar(im11, ax=axes[1, 1])

im12 = axes[1, 2].imshow(to_2d(u_posterior_mean), origin='lower', extent=[0,1,0,1], vmin=vmin_u, vmax=vmax_u, cmap='coolwarm')
axes[1, 2].set_title('Posterior Mean $u(x)$')
plt.colorbar(im12, ax=axes[1, 2])

im13 = axes[1, 3].imshow(to_2d(u_posterior_std), origin='lower', extent=[0,1,0,1], cmap='hot')
axes[1, 3].set_title('Posterior Std $u(x)$')
plt.colorbar(im13, ax=axes[1, 3])

plt.tight_layout()
plt.savefig('hmc_results.png', dpi=150)
plt.show()

In [None]:
# Plot some posterior samples (gallery view)
n_gallery = 9
sample_indices = np.linspace(0, len(a_samples)-1, n_gallery, dtype=int)

fig, axes = plt.subplots(3, 3, figsize=(10, 10))
axes = axes.flatten()

for i, idx in enumerate(sample_indices):
    im = axes[i].imshow(to_2d(a_samples[idx]), origin='lower', extent=[0,1,0,1], 
                        vmin=vmin_a, vmax=vmax_a, cmap='viridis')
    axes[i].set_title(f'Sample {idx}')
    axes[i].axis('off')

plt.suptitle('Posterior Samples of $a(x)$', fontsize=14)
plt.tight_layout()
plt.show()

## 11. Compute Bayesian Metrics

In [None]:
# Compare point estimate vs posterior mean
metrics_a_posterior = compute_all_metrics(a_posterior_mean.unsqueeze(0), a_true[0])
metrics_u_posterior = compute_all_metrics(u_posterior_mean.unsqueeze(0), u_true[0])

print("=" * 60)
print("COMPARISON: Point Estimate vs Bayesian Posterior Mean")
print("=" * 60)
print(f"\nCoefficient a:")
print(f"  Point estimate: RMSE={metrics_a_point['rmse']:.6f}, RelL2={metrics_a_point['rel_l2']:.6f}")
print(f"  Posterior mean: RMSE={metrics_a_posterior['rmse']:.6f}, RelL2={metrics_a_posterior['rel_l2']:.6f}")
print(f"\nSolution u:")
print(f"  Point estimate: RMSE={metrics_u_point['rmse']:.6f}, RelL2={metrics_u_point['rel_l2']:.6f}")
print(f"  Posterior mean: RMSE={metrics_u_posterior['rmse']:.6f}, RelL2={metrics_u_posterior['rel_l2']:.6f}")

In [None]:
# Calibration check: Does the true value fall within credible intervals?
# For well-calibrated UQ, ~95% of true values should be within 95% CI

a_true_flat = a_true[0].squeeze().cpu()
a_lower_flat = a_lower.squeeze().cpu()
a_upper_flat = a_upper.squeeze().cpu()

coverage = ((a_true_flat >= a_lower_flat) & (a_true_flat <= a_upper_flat)).float().mean()
print(f"\n95% Credible Interval Coverage: {coverage.item()*100:.1f}%")
print(f"  (Should be ~95% for well-calibrated uncertainty)")

# Average width of credible intervals
ci_width = (a_upper_flat - a_lower_flat).mean()
print(f"\nAverage 95% CI width: {ci_width.item():.4f}")

## 12. Summary & Next Steps

### What We Did
1. Loaded a pretrained IGNO model (encoder, decoders, NF)
2. Defined the potential energy $U(\beta) = w_{data} L_{data} + w_{pde} L_{pde} - \log p_{NF}(\beta)$
3. Ran NUTS sampling to get posterior samples $\{\beta^{(i)}\}$
4. Decoded to coefficient field samples $\{a^{(i)}(x)\}$
5. Computed posterior statistics (mean, std, credible intervals)

### Key Results
- **Posterior mean** often similar to point estimate (sanity check ✓)
- **Posterior std** gives spatial uncertainty map
- **Coverage** tells us if uncertainty is calibrated

### Next Steps
1. Run on multiple test samples to get aggregate statistics
2. Compare performance across noise levels (SNR)
3. Tune MCMC parameters (more samples, longer warmup)
4. Implement proper Bayesian metrics (CRPS, calibration curves)
5. Integrate into main evaluation pipeline

In [None]:
# Save results for later analysis
results_to_save = {
    'sample_idx': SAMPLE_IDX,
    'n_obs': N_OBS,
    'snr_db': SNR_DB,
    'beta_samples': beta_samples.cpu(),
    'beta_point': beta_point.cpu(),
    'a_posterior_mean': a_posterior_mean.cpu(),
    'a_posterior_std': a_posterior_std.cpu(),
    'a_point': a_point.cpu(),
    'a_true': a_true.cpu(),
    'metrics_point': metrics_a_point,
    'metrics_posterior': metrics_a_posterior,
    'coverage_95': coverage.item(),
}

torch.save(results_to_save, 'hmc_results.pt')
print("Results saved to hmc_results.pt")