# CMB Lensing from Field-Level Inference

Ce notebook démontre le calcul de **convergence CMB κ(θ)** à partir de l'inférence field-level sur le clustering des galaxies.

**Workflow**:
1. Inférence sur galaxies (comme notebook 04)
2. Calcul de la convergence CMB depuis le champ de densité
3. Analyse de la corrélation galaxies-CMB

**Configuration**: Test CPU avec petit mesh (8³) pour exécution rapide.

## 1. Imports et Setup

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState

from desi_cmb_fli.bricks import get_cosmology
from desi_cmb_fli.cmb_lensing import density_field_to_convergence
from desi_cmb_fli.model import FieldLevelModel, default_config
from desi_cmb_fli.samplers import get_mclmc_run, get_mclmc_warmup

jax.config.update("jax_enable_x64", True)

print(f"JAX version: {jax.__version__}")
print(f"Backend: {jax.default_backend()}")
print(f"Devices: {jax.devices()}")

## 2. Configuration

In [None]:
# Model config (small mesh for CPU)
model_config = default_config.copy()
model_config["mesh_shape"] = (8, 8, 8)
model_config["box_shape"] = (200.0, 200.0, 200.0)  # Mpc/h
model_config["evolution"] = "lpt"
model_config["lpt_order"] = 2
model_config["a_obs"] = 0.5
model_config["gxy_density"] = 0.001

model = FieldLevelModel(**model_config)

# CMB lensing config
cmb_config = {
    "field_size_deg": 5.0,
    "field_npix": 32,  # Small for fast computation
    "z_source": 1.0,
}

print("Model configuration:")
print(f"  Mesh: {model_config['mesh_shape']}")
print(f"  Box: {model_config['box_shape']} Mpc/h")
print("\nCMB configuration:")
print(f"  Field size: {cmb_config['field_size_deg']}°")
print(f"  Resolution: {cmb_config['field_npix']} pixels")
print(f"  Source redshift: {cmb_config['z_source']}")

## 3. Génération de la Vérité

In [None]:
# Truth parameters
truth_params = {
    "Omega_m": 0.3,
    "sigma8": 0.8,
    "b1": 1.0,
    "b2": 0.0,
    "bs2": 0.0,
    "bn2": 0.0,
}

# Generate galaxy field
seed = 42
truth = model.predict(
    samples=truth_params,
    hide_base=False,
    hide_samp=False,
    frombase=True,
    rng=jr.key(seed),
)

print(f"Galaxy field shape: {truth['obs'].shape}")
print(f"Mean count: {jnp.mean(truth['obs']):.4f}")
print(f"Std: {jnp.std(truth['obs']):.4f}")

## 4. Calcul de la Convergence CMB (Vérité)

In [None]:
# Convert galaxy field to matter density
cosmo_truth = get_cosmology(Omega_m=truth_params["Omega_m"], sigma8=truth_params["sigma8"])
b1_truth = truth_params["b1"]

# delta_gxy = gxy_mesh - 1, delta_matter = delta_gxy / b1
delta_matter_truth = (truth["gxy_mesh"] - 1.0) / b1_truth
density_field_truth = 1.0 + delta_matter_truth

# Compute CMB convergence
kappa_truth = density_field_to_convergence(
    density_field_truth,
    model.box_shape,
    cosmo_truth,
    cmb_config["field_size_deg"],
    cmb_config["field_npix"],
    cmb_config["z_source"],
)

print(f"Convergence shape: {kappa_truth.shape}")
print(f"Mean κ: {jnp.mean(kappa_truth):.6e}")
print(f"Std κ: {jnp.std(kappa_truth):.6e}")
print(f"Range: [{jnp.min(kappa_truth):.6e}, {jnp.max(kappa_truth):.6e}]")

## 5. Visualisation Vérité

In [None]:
# Compare galaxy field and CMB convergence
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Galaxy field (slice centrale)
idx = model_config["mesh_shape"][0] // 2
im0 = axes[0].imshow(truth["obs"][:, :, idx], origin="lower", cmap="viridis")
axes[0].set_title(f"Galaxy Field (z={idx})")
axes[0].set_xlabel("x")
axes[0].set_ylabel("y")
plt.colorbar(im0, ax=axes[0], label="1+δ_gxy")

# Matter density (slice centrale)
im1 = axes[1].imshow(density_field_truth[:, :, idx], origin="lower", cmap="RdBu_r")
axes[1].set_title(f"Matter Density (z={idx})")
axes[1].set_xlabel("x")
axes[1].set_ylabel("y")
plt.colorbar(im1, ax=axes[1], label="1+δ_m")

# CMB convergence (projection)
extent = [0, cmb_config["field_size_deg"], 0, cmb_config["field_size_deg"]]
im2 = axes[2].imshow(kappa_truth, origin="lower", cmap="RdBu_r", extent=extent)
axes[2].set_title(f"CMB Convergence (z_s={cmb_config['z_source']})")
axes[2].set_xlabel("θ_x [deg]")
axes[2].set_ylabel("θ_y [deg]")
plt.colorbar(im2, ax=axes[2], label="κ")

plt.tight_layout()
plt.show()

## 6. Inférence MCMC (Mini Version)

### Stage 1: Warmup Mesh Only

In [None]:
# Condition with obs + fiducial cosmo/bias
model.reset()
model.condition({"obs": truth["obs"]} | model.loc_fid, frombase=True)
model.block()

# Initialize from Kaiser
init_params_ = model.kaiser_post(jr.key(45), delta_obs=truth["obs"] - 1)
init_mesh_ = {k: init_params_[k] for k in ["init_mesh_"]}

# Warmup mesh
print("Warming mesh (2^10 steps)...")
warmup_mesh_fn = get_mclmc_warmup(
    model.logpdf,
    n_steps=2**10,
    config=None,
    desired_energy_var=1e-6,
    diagonal_preconditioning=False,
)

state_mesh, config_mesh = warmup_mesh_fn(jr.key(43), init_mesh_)

print("✓ Mesh warmup done")
print(f"  Logdens: {state_mesh.logdensity:.2f}")
print(f"  L: {config_mesh.L:.6f}")

init_params_ |= state_mesh.position

### Stage 2: Warmup All Parameters

In [None]:
# Reset and condition with obs only
model.reset()
model.condition({"obs": truth["obs"]})
model.block()

# Warmup all
num_warmup = 50  # Minimal for testing
print(f"Warming all params ({num_warmup} steps)...")

warmup_all_fn = get_mclmc_warmup(
    model.logpdf,
    n_steps=num_warmup,
    config=None,
    desired_energy_var=5e-4,
    diagonal_preconditioning=False,
)

state, config = warmup_all_fn(jr.key(43), init_params_)

print("✓ Full warmup done")
print(f"  Logdens: {state.logdensity:.2f}")

# Recalculate L
eval_per_ess = 1e3
recalc_L = 0.4 * eval_per_ess / 2 * config.step_size

config = MCLMCAdaptationState(
    L=recalc_L, step_size=config.step_size, inverse_mass_matrix=config.inverse_mass_matrix
)

print(f"  L: {recalc_L:.6f}")

### Stage 3: Sample

In [None]:
# Sample
num_samples = 50  # Minimal for testing
print(f"Sampling {num_samples} steps...")

run_fn = get_mclmc_run(
    model.logpdf,
    n_samples=num_samples,
    thinning=1,
    progress_bar=False,
)

state, samples_dict = run_fn(jr.key(42), state, config)

print("✓ Sampling done!")
print(f"  Mean MSE/dim: {jnp.mean(samples_dict['mse_per_dim']):.6e}")

# Extract samples
samples = {}
param_names = [k for k in samples_dict.keys() if k not in ["logdensity", "mse_per_dim"]]
for p in param_names:
    samples[p] = samples_dict[p]

print(f"\nSampled parameters: {param_names}")

## 7. CMB Lensing depuis la Postérieure

In [None]:
# Compute kappa for posterior samples
print("Computing κ from posterior...")

n_kappa = min(20, num_samples)  # Small subset for speed
kappa_samples = []

for i in range(n_kappa):
    if i % 5 == 0:
        print(f"  Sample {i}/{n_kappa}")
    
    # Get sample
    sample_i = {k: samples[k][i] for k in param_names}
    
    # Compute forward model
    pred_i = model.predict(
        samples=sample_i,
        hide_base=False,
        hide_det=False,
        frombase=False,
        rng=jr.key(0),
    )
    
    # Convert to matter density
    cosmo_i = get_cosmology(Omega_m=sample_i["Omega_m_"], sigma8=sample_i["sigma8_"])
    b1_i = sample_i["b1_"]
    delta_matter_i = (pred_i["gxy_mesh"] - 1.0) / b1_i
    density_field_i = 1.0 + delta_matter_i
    
    # Compute kappa
    kappa_i = density_field_to_convergence(
        density_field_i,
        model.box_shape,
        cosmo_i,
        cmb_config["field_size_deg"],
        cmb_config["field_npix"],
        cmb_config["z_source"],
    )
    
    kappa_samples.append(kappa_i)

kappa_samples = jnp.array(kappa_samples)

print(f"\n✓ Computed {len(kappa_samples)} κ maps")
print(f"  Mean κ: {jnp.mean(kappa_samples):.6e}")
print(f"  Std κ: {jnp.std(kappa_samples):.6e}")

## 8. Analyse des Résultats

In [None]:
# Parameter recovery
comp_params = ["Omega_m", "sigma8", "b1", "b2", "bs2", "bn2"]

print("Parameter recovery:")
print(f"{'Param':<10} {'Truth':<10} {'Mean':<10} {'Std':<10} {'Bias(σ)':<10}")
print("-" * 60)

for p in comp_params:
    pk = p + "_"
    if pk in samples:
        vals = samples[pk]
        mean, std = np.mean(vals), np.std(vals)
        truth_val = truth_params[p]
        bias = (mean - truth_val) / std if std > 0 else 0
        print(f"{p:<10} {truth_val:<10.4f} {mean:<10.4f} {std:<10.4f} {bias:<10.2f}")

In [None]:
# CMB statistics
kappa_mean = jnp.mean(kappa_samples, axis=0)
kappa_std = jnp.std(kappa_samples, axis=0)

print("\nCMB Convergence:")
print(f"  Truth mean: {jnp.mean(kappa_truth):.6e}")
print(f"  Posterior mean: {jnp.mean(kappa_mean):.6e}")
print(f"  Posterior std: {jnp.mean(kappa_std):.6e}")

## 9. Visualisations Finales

In [None]:
# Parameter traces
fig, axes = plt.subplots(len(comp_params), 1, figsize=(12, 2 * len(comp_params)))

for i, p in enumerate(comp_params):
    pk = p + "_"
    if pk in samples:
        ax = axes[i]
        ax.plot(samples[pk], alpha=0.7, lw=0.5, color="blue")
        ax.axhline(truth_params[p], color="red", ls="--", lw=2, label="Truth")
        ax.set_ylabel(p)
        ax.set_xlabel("Iteration")
        if i == 0:
            ax.legend()
        ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# CMB convergence comparison
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Truth
im0 = axes[0].imshow(kappa_truth, origin="lower", cmap="RdBu_r", extent=extent)
axes[0].set_title("Truth κ")
axes[0].set_xlabel("θ_x [deg]")
axes[0].set_ylabel("θ_y [deg]")
plt.colorbar(im0, ax=axes[0], label="κ")

# Posterior mean
im1 = axes[1].imshow(kappa_mean, origin="lower", cmap="RdBu_r", extent=extent)
axes[1].set_title("Posterior Mean κ")
axes[1].set_xlabel("θ_x [deg]")
axes[1].set_ylabel("θ_y [deg]")
plt.colorbar(im1, ax=axes[1], label="κ")

# Posterior std
im2 = axes[2].imshow(kappa_std, origin="lower", cmap="viridis", extent=extent)
axes[2].set_title("Posterior Std κ")
axes[2].set_xlabel("θ_x [deg]")
axes[2].set_ylabel("θ_y [deg]")
plt.colorbar(im2, ax=axes[2], label="σ_κ")

plt.tight_layout()
plt.show()

In [None]:
# Normalized residuals
residuals = (kappa_truth - kappa_mean) / kappa_std

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
im = ax.imshow(residuals, origin="lower", cmap="seismic", extent=extent, vmin=-3, vmax=3)
ax.set_title("Normalized Residuals: (Truth - Mean) / Std")
ax.set_xlabel("θ_x [deg]")
ax.set_ylabel("θ_y [deg]")
plt.colorbar(im, ax=ax, label="σ")
plt.show()

# Statistics
print("Residual statistics:")
print(f"  Mean: {jnp.mean(residuals):.3f} σ")
print(f"  Std: {jnp.std(residuals):.3f} σ")
print(f"  Max |residual|: {jnp.max(jnp.abs(residuals)):.3f} σ")

In [None]:
# 1D profiles (central row/column)
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

center = cmb_config["field_npix"] // 2
x_coords = np.linspace(0, cmb_config["field_size_deg"], cmb_config["field_npix"])

# Row profile
axes[0].plot(x_coords, kappa_truth[center, :], 'r-', lw=2, label="Truth")
axes[0].plot(x_coords, kappa_mean[center, :], 'b-', lw=2, label="Posterior Mean")
axes[0].fill_between(
    x_coords,
    kappa_mean[center, :] - kappa_std[center, :],
    kappa_mean[center, :] + kappa_std[center, :],
    alpha=0.3,
    label="1σ"
)
axes[0].set_xlabel("θ_x [deg]")
axes[0].set_ylabel("κ")
axes[0].set_title(f"Row Profile (y={center})")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Column profile
axes[1].plot(x_coords, kappa_truth[:, center], 'r-', lw=2, label="Truth")
axes[1].plot(x_coords, kappa_mean[:, center], 'b-', lw=2, label="Posterior Mean")
axes[1].fill_between(
    x_coords,
    kappa_mean[:, center] - kappa_std[:, center],
    kappa_mean[:, center] + kappa_std[:, center],
    alpha=0.3,
    label="1σ"
)
axes[1].set_xlabel("θ_y [deg]")
axes[1].set_ylabel("κ")
axes[1].set_title(f"Column Profile (x={center})")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Résumé

Ce notebook a démontré :

1. ✅ **Inférence field-level** sur le clustering des galaxies (mini version)
2. ✅ **Calcul de convergence CMB** κ(θ) depuis le champ de densité 3D
3. ✅ **Propagation d'incertitude** : κ depuis la postérieure des paramètres
4. ✅ **Validation** : comparaison vérité vs postérieure

**Physique implémentée** :
- Approximation de Born : $\kappa(\theta) = \frac{3H_0^2\Omega_m}{2c^2} \int d\chi \frac{\chi(\chi_s-\chi)}{\chi_s a(\chi)} \delta(\chi\theta,\chi)$
- Conversion galaxies → matière : $\delta_m = \delta_{gxy} / b_1$
- Intégration sur la ligne de visée

**Pour production** :
- Utiliser `scripts/05_cmb_lensing.py` avec GPU
- Mesh plus large (48³ ou 64³)
- Plus d'échantillons κ (100-1000)
- Comparaison avec données réelles (Planck)