# JAX PR4 vs Cobaya Likelihood Validation

This notebook compares the outputs of the JAX PR4 likelihood implementations with the official Cobaya versions to validate consistency.

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import time
import os
os.environ["JAX_PLATFORMS"]="cpu"

# Import the JAX PR4 package
from jax_pr4 import CamSpecPR4, HillipopPR4, LollipopPR4
from jax_pr4.config import set_jax_enabled, get_jax_enabled

# Import CLASS for generating theory spectra
from classy import Class

# Import Cobaya for official likelihood comparison
from cobaya.theory import Theory
from cobaya.input import update_info
from cobaya.model import get_model


print("JAX PR4 vs Cobaya Validation Notebook")
print("=" * 50)
print(f"JAX version: {jax.__version__}")
print(f"NumPy version: {np.__version__}")

JAX PR4 vs Cobaya Validation Notebook
JAX version: 0.4.33
NumPy version: 1.26.0


## First lets get the ground truth Cobaya results and get the $C_\ell$ computed so we can use as input to the `JAX` versions 


In [2]:

def get_camspec_results(params, l_max=2500, debug=False):
    """
    Compute lensed Cℓ (μK²) with Cobaya's CAMB backend and evaluate the
    Planck NPIPE CamSpec PR4 TTTEEE likelihood.

    Parameters
    ----------
    params : dict
        Cosmological parameters in Cobaya style, e.g.
          {"H0": {"value": 67.37}, "ombh2": {"value": 0.02237}, ... }

    l_max : int
        Maximum multipole for scalar modes (default 2500).

    debug : bool
        If True, Cobaya prints detailed debug info.

    Returns
    -------
    ells : ndarray
    ClTT, ClTE, ClEE : ndarray
        Lensed raw Cℓ in μK².
    loglike : float
        CamSpec TTTEEE log-likelihood.
    """

    # 1. Cobaya input -------------------------------------------------------
    info = {
        "debug": debug,
        "theory": {
            "camb": {
                "extra_args": {
                    "lmax": l_max,
                    "lens_potential_accuracy": 1,
                    "lens_margin": 1250,
                }
            }
        },
        "likelihood": {
            "planck_NPIPE_highl_CamSpec.TTTEEE": None
        },
        "params": params,
    }

    # 2. Build model and run ------------------------------------------------
    t0 = time.time()
    model = get_model(info)
    p_vec = model.parameterization.sampled_params()
    loglikes = model.loglikes(p_vec)
    loglike = loglikes[0][0]  # single likelihood returns tuple
    dt_ms = (time.time() - t0) * 1e3

    # 3. Fetch spectra ------------------------------------------------------
    cls = model.provider.get_Cl(ell_factor=False, units="muK2")
    ells = cls["ell"]
    ClTT = cls["tt"]
    ClTE = cls["te"]
    ClEE = cls["ee"]

    print(f"CamSpec TTTEEE log-like = {loglike:.3f}   ({dt_ms:.1f} ms)")

    return ells, ClTT, ClTE, ClEE, loglike



def get_hillipop_results(params, l_max=3000, debug=False):
    """Evaluate Planck PR4 HiLLiPoP TTTEEE likelihood with explicit nuisance parameters."""

    nuisance_defaults = {
        "A_planck": {"value": 1.0},
        "cal100A": {"value": 1.0},
        "cal100B": {"value": 1.0},
        "cal143B": {"value": 1.0},
        "cal217A": {"value": 1.0},
        "cal217B": {"value": 1.0},
        "Aradio": {"value": 60.},
        "Adusty": {"value": 6.},
        "AdustT": {"value": 1.0},
        "beta_dustT": {"value": 1.51},
        "Acib": {"value": 4.0},
        "beta_cib": {"value": 1.75},
        "Atsz": {"value": 3.0},
        "Aksz": {"value": 1.0},
        "xi": {"value": 0.1},
        "AdustP": {"value": 1.0},
        "beta_dustP": {"value": 1.59},
    }

    full_params = {**params, **nuisance_defaults}

    info = {
        "debug": debug,
        "theory": {"camb": {"extra_args": {"lmax": l_max, "lens_potential_accuracy": 1}}},
        "likelihood": {"planck_2020_hillipop.TTTEEE": None},
        "params": full_params,
    }

    t0 = time.time()
    model = get_model(info)
    p_vec = model.parameterization.sampled_params()
    loglike = model.loglikes(p_vec)[0][0]
    dt_ms = (time.time() - t0) * 1e3

    cls = model.provider.get_Cl(ell_factor=False, units="muK2")
    ells, ClTT, ClTE, ClEE = cls["ell"], cls["tt"], cls["te"], cls["ee"]

    # print("HiLLiPoP nuisance parameters:")
    # for name in nuisance_defaults:
    #     print(f"  {name} = {p_vec[name]}")

    print(f"HiLLiPoP TTTEEE log-like = {loglike:.3f}   ({dt_ms:.1f} ms)")

    return ells, ClTT, ClTE, ClEE, loglike


def get_lollipop_results(params, l_max=200, debug=False, marginalised=False, Nsim=400, fsky=0.52, lmin=2, lmax=30):
    """Evaluate Planck PR4 LoLLiPoP low-ell EE likelihood."""
    nuisance_defaults = {"A_planck": {"value": 1.0}}
    full_params = {**params, **nuisance_defaults}

    info = {
        "debug": debug,
        "theory": {"camb": {"extra_args": {"lmax": l_max, "lens_potential_accuracy": 1}}},
        "likelihood": {
            "planck_2020_lollipop.lowlE": {
                "lmin": lmin,
                "lmax": lmax,
                "marginalised_over_covariance": marginalised,
                "Nsim": Nsim,
            }
        },
        "params": full_params,
    }

    t0 = time.time()
    model = get_model(info)
    p_vec = model.parameterization.sampled_params()
    loglike = model.loglikes(p_vec)[0][0]
    dt_ms = (time.time() - t0) * 1e3

    # Internals + cross-checks
    like = [v for k, v in model.likelihood.items() if "lollipop" in k and "lowlE" in k][0]
    ee_data, ee_fid, ee_off = like.cldata[0], like.clfid[0], like.cloff[0]
    inv_cov = like.invclcov
    cls = model.provider.get_Cl(ell_factor=False, units="muK2")
    ClEE = cls["ee"]

    ells = cls["ell"]
    return ells, cls["tt"], cls["te"], ClEE, loglike

In [3]:
params = {
    # Cosmological parameters
    "H0": {"value": 67.37},
    "ombh2": {"value": 0.02237},
    "omch2": {"value": 0.1200},
    "As": {"value": 2.1e-9},
    "ns": {"value": 0.9649},
    "tau": {"value": 0.0544},
    "mnu": {"value": 0.06},

    # CamSpec nuisance parameters (set explicitly!)
    "A_planck": {"value": 1.0},
    "calTE": {"value": 1.0},
    "calEE": {"value": 1.0},
    "amp_143": {"value": 1.0},
    "amp_217": {"value": 1.0},
    "amp_143x217": {"value": 1.0},
    "n_143": {"value": 1.0},
    "n_217": {"value": 1.0},
    "n_143x217": {"value": 1.0},
}

ells_camspec, tt_camspec, te_camspec, ee_camspec, loglike_camspec = get_camspec_results(params)

[camb] `camb` module loaded successfully from /cluster/project/refregier/areeves/mcmc_jax_new/lib/python3.11/site-packages/camb
[planck_npipe_highl_camspec.ttteee] L-range for 143x143: 30 2000
[planck_npipe_highl_camspec.ttteee] L-range for 217x217: 500 2500
[planck_npipe_highl_camspec.ttteee] L-range for 143x217: 500 2500
[planck_npipe_highl_camspec.ttteee] L-range for TE: 30 2000
[planck_npipe_highl_camspec.ttteee] L-range for EE: 30 2000
[planck_npipe_highl_camspec.ttteee] Number of data points: 9915
CamSpec TTTEEE log-like = -5908.789   (4993.1 ms)


In [4]:
params = {
    # Cosmological parameters
    "H0": {"value": 67.37},
    "ombh2": {"value": 0.02237},
    "omch2": {"value": 0.1200},
    "As": {"value": 2.1e-9},
    "ns": {"value": 0.9649},
    "tau": {"value": 0.0544},
    "mnu": {"value": 0.06},
}

ells_hillipop, tt_hillipop, te_hillipop, ee_hillipop, loglike_hillipop = get_hillipop_results(params)

[camb] `camb` module loaded successfully from /cluster/project/refregier/areeves/mcmc_jax_new/lib/python3.11/site-packages/camb
[planck_2020_hillipop.ttteee] Initialized!
HiLLiPoP TTTEEE log-like = -16173.720   (5058.8 ms)


In [5]:
params = {
    # Cosmological parameters
    "H0": {"value": 67.37},
    "ombh2": {"value": 0.02237},
    "omch2": {"value": 0.1200},
    "As": {"value": 2.1e-9},
    "ns": {"value": 0.9649},
    "tau": {"value": 0.0544},
    "mnu": {"value": 0.06},
}

ells_lollipop, tt_lollipop, te_lollipop, ee_lollipop, loglike_lollipop = get_lollipop_results(params)

[camb] `camb` module loaded successfully from /cluster/project/refregier/areeves/mcmc_jax_new/lib/python3.11/site-packages/camb
[planck_2020_lollipop.lowle] Initialized!


### Great now with these inputs lets go for the `JAX` versions

In [6]:
# Test JAX CamSpec with the same Cl spectra from Cobaya
print("Testing JAX CamSpec...")
print("=" * 40)

# Prepare Cl spectra for JAX (add batch dimension)
# we also need to remove ell-0, ell=1 and divide by 1e12 to put back in units of Kelvin! 

TCMB0=2.7255
ClTT_batch = tt_camspec[np.newaxis, 2:]/1e12
ClTE_batch = te_camspec[np.newaxis, 2:]/1e12
ClEE_batch = ee_camspec[np.newaxis, 2:]/1e12

# CamSpec parameters (from your example)
camspec_params = {
    'A_planck': np.array([1.0]),
    'calTE': np.array([1.0]),
    'calEE': np.array([1.0]),
    'amp_143': np.array([1.0]),
    'amp_217': np.array([1.0]),
    'amp_143x217': np.array([1.0]),
    'n_143': np.array([1.0]),
    'n_217': np.array([1.0]),
    'n_143x217': np.array([1.0]),
}

# Initialize JAX CamSpec
set_jax_enabled(False)  # Use NumPy mode
jax_camspec = CamSpecPR4()

# Compute likelihood
start_time = time.time()
loglike_jax_camspec = jax_camspec.compute_like(ClTT_batch, ClTE_batch, ClEE_batch, camspec_params)
jax_camspec_time = time.time() - start_time

print(f"JAX CamSpec log-like = {float(loglike_jax_camspec[0]):.3f}   ({jax_camspec_time*1000:.1f} ms)")
print(f"Cobaya CamSpec log-like = {loglike_camspec:.3f}")
print(f"Difference = {float(loglike_jax_camspec[0]) - loglike_camspec:.6f}")
print(f"Relative difference = {(float(loglike_jax_camspec[0]) - loglike_camspec) / abs(loglike_camspec):.2e}")

Testing JAX CamSpec...
✓ CamSpec data loaded: 9915 data points
  Covariance matrix shape: (9915, 9915)
  Foreground array size: 2501
CamSpec PR4 likelihood initialized successfully!
IS JAX: False
JAX CamSpec log-like = -5908.788   (123.1 ms)
Cobaya CamSpec log-like = -5908.789
Difference = 0.001507
Relative difference = 2.55e-07


In [7]:
print("Testing JAX Hillipop...")
print("=" * 40)

# Prepare Cl spectra for JAX (add batch dimension)
ClTT_batch_hill = tt_hillipop[np.newaxis, 2:]/1e12
ClTE_batch_hill = te_hillipop[np.newaxis, 2:]/1e12
ClEE_batch_hill = ee_hillipop[np.newaxis, 2:]/1e12

hillipop_params = {
    'A_planck': np.array([1.0]),
    # Calibration parameters
    'cal100A': np.array([1.0]), 
    'cal100B': np.array([1.0]),
    'cal143A': np.array([1.0]), 
    'cal143B': np.array([1.0]),
    'cal217A': np.array([1.0]), 
    'cal217B': np.array([1.0]),
    'pe100A': np.array([1.0]), 
    'pe100B': np.array([1.0]),
    'pe143A': np.array([1.0]), 
    'pe143B': np.array([1.0]),
    'pe217A': np.array([0.975]), 
    'pe217B': np.array([0.975]),
    'Aradio': np.array([60.0]),
    'beta_radio': np.array([-0.8]),
    'Adusty': np.array([6.0]), 
    'AdustT': np.array([1.0]), 
    'beta_dustT': np.array([1.51]),
    'AsyncT': np.array([0.0]),
    'Acib': np.array([4.0]), 
    'beta_cib': np.array([1.75]), 
    'Atsz': np.array([3.0]), 
    'Aksz': np.array([1.0]),
    'xi': np.array([0.1]),
    'AdustP': np.array([1.0]),
    'beta_dustP': np.array([1.59]),
    'AsyncP': np.array([0.0]),
    

}

# Initialize JAX Hillipop (with low-ell TT)
hillipop_config = {'add_lowl_tt': False}
jax_hillipop = HillipopPR4(additional_args=hillipop_config)

# Compute likelihood
start_time = time.time()
loglike_jax_hillipop = jax_hillipop.compute_like(ClTT_batch_hill, ClTE_batch_hill, ClEE_batch_hill, hillipop_params)
jax_hillipop_time = time.time() - start_time

print(f"JAX Hillipop log-like = {float(loglike_jax_hillipop[0]):.3f}   ({jax_hillipop_time*1000:.1f} ms)")
print(f"Cobaya Hillipop log-like = {loglike_hillipop:.3f}")
print(f"Difference = {float(loglike_jax_hillipop[0]) - loglike_hillipop:.6f}")
print(f"Relative difference = {(float(loglike_jax_hillipop[0]) - loglike_hillipop) / abs(loglike_hillipop):.2e}")

Testing JAX Hillipop...
Define multipole ranges
Reading cross-spectra
Reading cross-spectra
Covariance matrix file: /cluster/work/refregier/alexree/local_packages/jax_pr4/data/planck_pr4_hillipop/invfll_PR4_v4.2_TTTEEE.fits
Adding 'dust_model' foreground for TT
Adding 'tsz' foreground for TT
Adding 'ksz' foreground for TT
Adding 'cib' foreground for TT
Adding 'szxcib' foreground for TT
Adding 'ps_radio' foreground for TT
Adding 'ps_dusty' foreground for TT
Adding 'dust_model' foreground for EE
Adding 'dust_model' foreground for TE
Initialized!
JAX Hillipop log-like = -16173.734   (884.4 ms)
Cobaya Hillipop log-like = -16173.720
Difference = -0.014648
Relative difference = -9.06e-07


In [8]:
# Test JAX Lollipop with the same Cl spectra from Cobaya
print("Testing JAX Lollipop...")
print("=" * 40)

# Prepare EE spectrum for JAX (Lollipop only uses EE)
ClEE_batch_lolli = ee_lollipop[np.newaxis, 2:]/1e12

# Lollipop parameters (simple - only A_planck)
lollipop_params = {
    'A_planck': np.array([1.0])
}

# Initialize JAX Lollipop
jax_lollipop = LollipopPR4()

# Compute likelihood
start_time = time.time()
loglike_jax_lollipop = jax_lollipop.compute_like(ClEE_batch_lolli, params=lollipop_params)
jax_lollipop_time = time.time() - start_time

print(f"JAX Lollipop log-like = {float(loglike_jax_lollipop[0]):.3f}   ({jax_lollipop_time*1000:.1f} ms)")
print(f"Cobaya Lollipop log-like = {loglike_lollipop:.3f}")
print(f"Difference = {float(loglike_jax_lollipop[0]) - loglike_lollipop:.6f}")
print(f"Relative difference = {(float(loglike_jax_lollipop[0]) - loglike_lollipop) / abs(loglike_lollipop):.2e}")

Testing JAX Lollipop...
NBINS 29
LMAX 30
Reading model
Reading covariance
Compute offsets
Initialized!
XSHAPE 29
JAX Lollipop log-like = -16.892   (0.8 ms)
Cobaya Lollipop log-like = -16.892
Difference = 0.000000
Relative difference = 2.10e-16


In [9]:
# Summary table
print("\\nJAX PR4 vs Cobaya Validation Results")
print("=" * 60)
print(f"{'Likelihood':<15} {'Cobaya':<12} {'JAX PR4':<12} {'Rel. Difference':<12} {'Status':<10}")
print("-" * 60)

# CamSpec
camspec_diff = (float(loglike_jax_camspec[0]) - loglike_camspec)/loglike_camspec
camspec_status = "PASS" if abs(camspec_diff) < 1e-3 else "FAIL"
print(f"{'CamSpec':<15} {loglike_camspec:<12.3f} {float(loglike_jax_camspec[0]):<12.3f} {camspec_diff:<12.6f} {camspec_status:<10}")

# Hillipop
hillipop_diff = (float(loglike_jax_hillipop[0]) - loglike_hillipop)/loglike_hillipop
hillipop_status = "PASS" if abs(hillipop_diff) < 1e-3 else "FAIL"
print(f"{'Hillipop':<15} {loglike_hillipop:<12.3f} {float(loglike_jax_hillipop[0]):<12.3f} {hillipop_diff:<12.6f} {hillipop_status:<10}")

# Lollipop
lollipop_diff = (float(loglike_jax_lollipop[0]) - loglike_lollipop)/loglike_lollipop
lollipop_status = "PASS" if abs(lollipop_diff) < 1e-3 else "FAIL"
print(f"{'Lollipop':<15} {loglike_lollipop:<12.3f} {float(loglike_jax_lollipop[0]):<12.3f} {lollipop_diff:<12.6f} {lollipop_status:<10}")

print("\\n" + "=" * 60)
print("Validation Summary:")
print(f"  - All differences should be < 1e-3 for numerical consistency")
print(f"  - JAX implementations are validated against official Cobaya versions")
print(f"  - Results can be used with confidence for cosmological analyses")

\nJAX PR4 vs Cobaya Validation Results
Likelihood      Cobaya       JAX PR4      Rel. Difference Status    
------------------------------------------------------------
CamSpec         -5908.789    -5908.788    -0.000000    PASS      
Hillipop        -16173.720   -16173.734   0.000001     PASS      
Lollipop        -16.892      -16.892      -0.000000    PASS      
Validation Summary:
  - All differences should be < 1e-3 for numerical consistency
  - JAX implementations are validated against official Cobaya versions
  - Results can be used with confidence for cosmological analyses
