# Bayesian Retrieval Playground

This notebook is for exploring the nested sampling-based Bayesian retrieval pipeline for exoplanet atmospheric parameters.

## Overview
- **Forward Model**: `compute_binned_modulations` - converts gas concentrations to transit spectrum
- **Free Parameters**: 5 gas concentrations (CO2, CO, NH3, CH4, H2O) with log-uniform priors [1e-8, 1]
- **Fixed Parameters**: All planetary/stellar parameters (T, P, radii, gravity)
- **Method**: Dynesty nested sampling
- **Synthetic Data**: Known concentrations + Gaussian noise (10 ppm)

In [None]:
! pipi install

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Set up environment
plt.style.use('dark_background')
np.random.seed(42)

# Add parent directory to path
sys.path.append(os.path.dirname(os.getcwd()))

# Import PyREx modules  
from ariel_wavelengths import ariel_wavelengths
from nested_sampling.bayesian_retrieval import ExoplanetRetrieval, create_fixed_params
from nested_sampling.visualization import plot_corner, plot_spectrum_fit, plot_posterior_summary

print("Environment set up successfully!")
print(f"Wavelength array length: {len(ariel_wavelengths)}")

ModuleNotFoundError: No module named 'scipy'

In [None]:
# Create fixed parameters (planetary and stellar)
fixed_params = create_fixed_params()
print("Fixed parameters:")
for key, value in fixed_params.items():
    print(f"  {key}: {value:.3e}")

# Set up retrieval configuration
gas_names = ['CO2', 'CO', 'NH3', 'CH4', 'H2O']
prior_bounds = (1e-8, 1.0)  # Log-uniform prior bounds
sigma = 10e-6  # 10 ppm noise level

# True concentrations for synthetic data (your specified values)
true_concentrations = [0.1, 0.0, 0.2, 0.05, 0.1]

print(f"\nRetrieval setup:")
print(f"  Gas species: {gas_names}")
print(f"  Prior bounds: {prior_bounds}")
print(f"  Noise level: {sigma:.1e}")
print(f"  True concentrations: {true_concentrations}")

In [None]:
# Initialize the retrieval object
retrieval = ExoplanetRetrieval(
    fixed_params=fixed_params,
    gas_names=gas_names,
    prior_bounds=prior_bounds,
    sigma=sigma
)

# Generate synthetic observations
y_obs, true_spectrum = retrieval.generate_synthetic_data(
    true_concentrations=true_concentrations,
    noise_seed=42
)

print(f"Generated synthetic data:")
print(f"  Spectrum length: {len(y_obs)}")
print(f"  Mean signal: {np.mean(true_spectrum):.2e}")
print(f"  Noise RMS: {np.std(y_obs - true_spectrum):.2e}")
print(f"  SNR: {np.mean(true_spectrum) / sigma:.1f}")

In [None]:
# Plot the synthetic data
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Top panel: Full spectrum
ax1.errorbar(ariel_wavelengths, y_obs, yerr=sigma, fmt='o', color='white', 
            alpha=0.7, markersize=3, label='Observations')
ax1.plot(ariel_wavelengths, true_spectrum, 'r-', linewidth=2, label='True spectrum')
ax1.set_ylabel('Modulation')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title('Synthetic Transit Spectrum')

# Bottom panel: Residuals
residuals = y_obs - true_spectrum
ax2.errorbar(ariel_wavelengths, residuals, yerr=sigma, fmt='o', color='white',
            alpha=0.7, markersize=3)
ax2.axhline(0, color='red', linestyle='--', alpha=0.7)
ax2.fill_between(ariel_wavelengths, -sigma, sigma, alpha=0.3, color='gray',
                label=f'±{sigma:.0e} noise level')
ax2.set_xlabel('Wavelength (μm)')
ax2.set_ylabel('Residuals')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nData quality check:")
print(f"  Max absolute residual: {np.max(np.abs(residuals)):.2e}")
print(f"  Residual RMS: {np.sqrt(np.mean(residuals**2)):.2e}")

In [None]:
# Test the prior transform
print("Testing prior transform:")
test_u = np.array([0.0, 0.5, 1.0, 0.3, 0.7])  # Unit cube samples
test_concentrations = retrieval.prior_transform(test_u)
print(f"  Unit cube: {test_u}")
print(f"  Concentrations: {test_concentrations}")
print(f"  Log10 concentrations: {np.log10(test_concentrations)}")

# Test the likelihood function
print("\nTesting likelihood function:")
# True concentrations should give high likelihood
true_loglike = retrieval.log_likelihood(true_concentrations)
print(f"  True concentrations log-likelihood: {true_loglike:.2f}")

# Random concentrations should give lower likelihood
random_concentrations = retrieval.prior_transform(np.random.rand(5))
random_loglike = retrieval.log_likelihood(random_concentrations)
print(f"  Random concentrations: {random_concentrations}")
print(f"  Random concentrations log-likelihood: {random_loglike:.2f}")
print(f"  Likelihood ratio: {np.exp(true_loglike - random_loglike):.2e}")

In [None]:
# Run a quick test with fewer live points
print("Running quick nested sampling test...")
print("This may take a few minutes...")

# Use fewer live points for quick testing
test_results = retrieval.run_nested_sampling(
    nlive=200,  # Reduced for speed
    dlogz=0.5,  # Less strict convergence
    print_progress=True
)

print(f"\nSampling completed!")
print(f"  Number of iterations: {test_results.niter}")
print(f"  Log evidence: {test_results.logz[-1]:.2f} ± {test_results.logzerr[-1]:.2f}")
print(f"  Effective sample size: {test_results.eff}")

In [None]:
# Process the results
posterior_samples, evidence, evidence_err = retrieval.process_results(test_results)

print(f"Posterior analysis:")
print(f"  Number of posterior samples: {len(posterior_samples)}")
print(f"  Log evidence: {evidence:.2f} ± {evidence_err:.2f}")

# Compute posterior statistics
for i, gas in enumerate(gas_names):
    samples_i = posterior_samples[:, i]
    median = np.median(samples_i)
    q16, q84 = np.percentile(samples_i, [16, 84])
    true_val = true_concentrations[i]
    
    print(f"  {gas}: {median:.3e} +{q84-median:.3e} -{median-q16:.3e} (true: {true_val:.3e})")

In [None]:
# Create corner plot
print("Creating corner plot...")

fig_corner = plot_corner(
    posterior_samples,
    gas_names=gas_names,
    true_values=true_concentrations,
    range=[(1e-8, 1) for _ in range(len(gas_names))]  # Set ranges to prior bounds
)

fig_corner.suptitle('Posterior Distributions (Quick Test)', fontsize=16, y=0.98)
plt.show()

print("Corner plot created!")

In [None]:
# Create spectrum fit plot
print("Creating spectrum fit plot...")

fig_fit, ax_fit = plot_spectrum_fit(
    wavelengths=ariel_wavelengths,
    y_obs=y_obs,
    posterior_samples=posterior_samples,
    fixed_params=fixed_params,
    gas_names=gas_names,
    sigma=sigma,
    true_spectrum=true_spectrum,
    n_posterior_draws=50
)

ax_fit.set_title('Spectrum Fit (Quick Test)')
plt.show()

print("Spectrum fit plot created!")

In [None]:
# Create posterior summary
print("Creating posterior summary...")

fig_summary, axes_summary = plot_posterior_summary(
    posterior_samples,
    gas_names=gas_names,
    true_values=true_concentrations
)

fig_summary.suptitle('Marginal Posterior Distributions (Quick Test)', fontsize=14, y=0.98)
plt.show()

print("Posterior summary created!")

In [None]:
# Setup for full production run
print("\n" + "="*60)
print("PRODUCTION RUN SETUP")
print("="*60)

print("\nFor a production run with better statistics, use:")
print("")
production_code = """
# Full production run (will take longer)
production_results = retrieval.run_nested_sampling(
    nlive=1000,     # More live points for better sampling
    dlogz=0.1,      # Stricter convergence criterion
    print_progress=True
)

# Process results
posterior_samples_prod, evidence_prod, evidence_err_prod = retrieval.process_results(production_results)

# Create high-quality plots
fig_corner_prod = plot_corner(
    posterior_samples_prod,
    gas_names=gas_names,
    true_values=true_concentrations
)
fig_corner_prod.suptitle('Production Run: Posterior Distributions', fontsize=16)
"""

print(production_code)
print("\nUncomment and run the cell below for production-quality results!")

In [None]:
# Uncomment for production run (takes ~10-30 minutes depending on your machine)

# print("Starting production run...")
# print("This will take significantly longer but provide better statistics.")

# production_results = retrieval.run_nested_sampling(
#     nlive=1000,
#     dlogz=0.1,
#     print_progress=True
# )

# # Process production results
# posterior_samples_prod, evidence_prod, evidence_err_prod = retrieval.process_results(production_results)

# print(f"\nProduction run completed!")
# print(f"  Log evidence: {evidence_prod:.3f} ± {evidence_err_prod:.3f}")
# print(f"  Posterior samples: {len(posterior_samples_prod)}")

# # Create production plots
# fig_corner_prod = plot_corner(
#     posterior_samples_prod,
#     gas_names=gas_names,
#     true_values=true_concentrations
# )
# fig_corner_prod.suptitle('Production Run: Posterior Distributions', fontsize=16)
# plt.show()

# fig_fit_prod, _ = plot_spectrum_fit(
#     wavelengths=ariel_wavelengths,
#     y_obs=y_obs,
#     posterior_samples=posterior_samples_prod,
#     fixed_params=fixed_params,
#     gas_names=gas_names,
#     sigma=sigma,
#     true_spectrum=true_spectrum,
#     n_posterior_draws=100
# )
# plt.show()

In [None]:
# Exploration cell - experiment with different scenarios

# Example: Try different noise levels
# retrieval_low_noise = ExoplanetRetrieval(
#     fixed_params=fixed_params,
#     gas_names=gas_names,
#     prior_bounds=prior_bounds,
#     sigma=1e-6  # Lower noise
# )

# Example: Try different true concentrations
# different_true_concs = [0.01, 0.001, 0.05, 0.001, 0.02]

# Example: Test parameter degeneracies by fixing some gases to zero
# subset_gas_names = ['CO2', 'H2O']  # Only fit these two

print("Use this cell to explore different scenarios:")
print("- Different noise levels")
print("- Different true concentrations")
print("- Subset of gases")
print("- Different prior bounds")
print("- Parameter correlations and degeneracies")