# FGBuster vs FURAX: Framework Comparison for CMB Component Separation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CMBSciPol/furax-compsep-paper/blob/main/notebooks/01_FGBuster_vs_FURAX_Comparison.ipynb)

## 🎯 Learning Objectives

By the end of this notebook, you will:
- ✅ Understand the differences between traditional (FGBuster) and modern (FURAX) component separation frameworks
- ✅ See the performance advantages of JAX over NumPy for CMB analysis
- ✅ Learn how to implement and benchmark likelihood functions
- ✅ Understand automatic differentiation benefits for parameter optimization

## 📚 Background

### The Component Separation Problem

CMB observations contain multiple astrophysical components:
- **CMB signal**: What we want to measure
- **Galactic dust**: Modified blackbody emission
- **Synchrotron**: Power-law emission from cosmic rays
- **Instrumental noise**: Detector and systematic effects

The challenge is to separate these components accurately to recover the CMB signal.


In [None]:
import os
os.environ["EQX_ON_ERRORS"] = "nan"
# Core scientific computing
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial
import operator

# JAX for high-performance computing
import jax
import jax.numpy as jnp
import jaxopt

# FGBuster - Traditional component separation framework
from fgbuster import (
    CMB, Dust, Synchrotron,
    basic_comp_sep, get_instrument,
)

# FURAX - Modern JAX-based framework
from furax import HomothetyOperator
from furax.obs.landscapes import Stokes
from furax.obs.operators import (
    CMBOperator, DustOperator, MixingMatrixOperator, SynchrotronOperator,
)

# Set JAX to use 64-bit precision for scientific accuracy
jax.config.update("jax_enable_x64", True)


## 🌌 Step 1: Generate Simulated Sky Maps

We start by creating realistic CMB and foreground simulations using PySM (Python Sky Model). These simulated observations will serve as our test data for comparing the two frameworks.

### Key Parameters:
- **NSIDE = 32**: HEALPix resolution (creates 12,288 pixels)
- **Instrument**: LiteBIRD frequency configuration (15 bands: 40-402 GHz)
- **Components**: CMB + dust + synchrotron emission
- **Stokes**: I, Q, U polarization parameters

### Why Use Simulations?
1. **Ground truth**: We know the input parameters
2. **Controlled testing**: Compare framework accuracy
3. **Reproducibility**: Same data for fair comparison

> **💡 Pro Tip**: On HPC clusters without internet access, these maps are pre-cached using `generate_maps.py`"

In [4]:
import sys

sys.path.append("../data")
from generate_maps import save_to_cache

nsides = [32]
for nside in nsides:
    save_to_cache(nside)

Loaded freq_maps for nside 32 from cache and noise False.


In [5]:
from generate_maps import load_from_cache

nside = 32

nu, freq_maps = load_from_cache(nside)
# Check the shape of freq_maps
print("freq_maps shape:", freq_maps.shape)

Loaded freq_maps for nside 32 from cache.
freq_maps shape: (15, 3, 12288)


Furax expects a `Stokes` object as input, so we convert the frequency maps to the correct format, so we transform the `freq_maps` into a `Stokes` object to make it compatible with Furax.

__Note__: Although Furax includes its own functions to create sky maps from PySM, we use `fgbuster` here to ensure that both methods receive identical inputs for comparison.


In [6]:
d = Stokes.from_stokes(I=freq_maps[:, 0, :], Q=freq_maps[:, 1, :], U=freq_maps[:, 2, :])
d.structure

StokesIQU(i=ShapeDtypeStruct(shape=(15, 12288), dtype=float64), q=ShapeDtypeStruct(shape=(15, 12288), dtype=float64), u=ShapeDtypeStruct(shape=(15, 12288), dtype=float64))

In [7]:
dust_nu0 = 150.0
synchrotron_nu0 = 20.0
instrument = get_instrument("LiteBIRD")
components = [CMB(), Dust(dust_nu0), Synchrotron(synchrotron_nu0)]

## Defining the Likelihood Function for Component Separation

In this cell, we define the `negative_log_prob` function, which calculates the negative log-likelihood of observing the given data `d` based on the model parameters.

The likelihood function is based on a quadratic form that includes the mixing matrix `A`, inverse noise covariance `N^{-1}`, and observed data `d`. The key term in the likelihood is:

$$
\left(A^T N^{-1} d\right)^T \left(A^T N^{-1} A\right)^{-1} \left(A^T N^{-1} d\right)
$$

### Explanation of Each Term

1. **$A$**: The mixing matrix operator, which maps the component space to the observed frequency space.
2. **$N^{-1}$**: The inverse of the noise covariance matrix, represented by `invN` in the code.
3. **$d$**: The observed data, which is structured as a `Stokes` in Furax.

### Implementation Details

- **Transposing and Applying `A`**: `A.T(d)` applies the transpose of `A` to `d`, equivalent to the term $A^T d$.
- **Computing the Likelihood**: The quadratic form is computed by applying $A^T N^{-1} d$, inverting $A^T N^{-1} A$, and performing matrix multiplications to evaluate the likelihood.
- **Negative Log-Likelihood**: The final output of `negative_log_prob` is the negative of this log-likelihood value, allowing us to use it as a loss function for optimization.


In [8]:
invN = HomothetyOperator(jnp.ones(1), _in_structure=d.structure)
DND = invN(d) @ d

in_structure = d.structure_for((d.shape[1],))
best_params = {"temp_dust": 20.0, "beta_dust": 1.54, "beta_pl": -3.0}

dust_nu0 = 150.0
synchrotron_nu0 = 20.0


@jax.jit
def negative_log_prob(params, d):
    cmb = CMBOperator(nu, in_structure=in_structure)
    dust = DustOperator(
        nu,
        frequency0=dust_nu0,
        temperature=params["temp_dust"],
        beta=params["beta_dust"],
        in_structure=in_structure,
    )
    synchrotron = SynchrotronOperator(
        nu,
        frequency0=synchrotron_nu0,
        beta_pl=params["beta_pl"],
        in_structure=in_structure,
    )

    A = MixingMatrixOperator(cmb=cmb, dust=dust, synchrotron=synchrotron)

    x = (A.T @ invN)(d)
    likelihood = jax.tree.map(lambda a, b: a @ b, x, (A.T @ invN @ A).I(x))
    summed_log_prob = jax.tree.reduce(operator.add, likelihood)

    return -summed_log_prob

Evaluate the performance of the likelihood

In [9]:
print("Performance of the nll evaluation")
negative_log_prob(best_params, d).block_until_ready()
%timeit negative_log_prob(best_params, d).block_until_ready()
print("Performance of the nll grad evaluation")
jax.grad(negative_log_prob)(best_params, d)["beta_pl"].block_until_ready()
%timeit jax.grad(negative_log_prob)(best_params, d)['beta_pl'].block_until_ready()

Performance of the nll evaluation
23.1 ms ± 4.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Performance of the nll grad evaluation
63.8 ms ± 599 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Check for Correctness

In this cell, we perform a basic correctness check by comparing the gradients of the negative log-likelihood at two sets of parameters:

1. **Wrong Parameters**: A set of parameters obtained by adding random noise to `best_params`.
2. **Correct Parameters**: The original `best_params`.

By calculating and comparing the gradient magnitudes (using the `max` reduction), we can verify that the gradient at the correct parameters is smaller, indicating proximity to an optimal or near-optimal point.


In [10]:
wrong_params = jax.tree.map(lambda x: x + jax.random.normal(jax.random.PRNGKey(0)), best_params)
print(f"Wrong parameters grad {jax.tree.reduce(max, jax.grad(negative_log_prob)(wrong_params, d))}")
print(
    f"Correct parameters grad {jax.tree.reduce(max, jax.grad(negative_log_prob)(best_params, d))}"
)

Wrong parameters grad -113186444.45561369
Correct parameters grad 0.0027135219067152814


## Off the shelf likelihoods



In [11]:
from furax.obs import negative_log_likelihood

negative_log_likelihood = partial(
    negative_log_likelihood, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0
)

L = negative_log_likelihood(best_params, nu=nu, N=invN, d=d)

assert jax.tree.all(
    jax.tree.map(lambda x, y: jnp.isclose(x, y, rtol=1e-5), L, negative_log_prob(best_params, d))
)

# Validating Against FGBuster: The `c1d0s0` Model

In this section, we validate our custom implementation of the likelihood model by comparing it to the `c1d0s0` model from `fgbuster`. By aligning our results with FGBuster’s well-established component separation model, we ensure that our setup and computations are consistent and reliable.


### Case 1 : Initial Validation: Using `best_params` as the Starting Point

We begin the validation process by setting `best_params` as the initial point for both our custom implementation and FGBuster’s `c1d0s0` model. This allows us to directly compare the outputs and confirm that the models produce similar results when initialized wit


In [12]:
components[1]._set_default_of_free_symbols(beta_d=1.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-3.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)

['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.54 20.   -3.  ]


In [13]:
best_params = {"temp_dust": 20.0, "beta_dust": 1.54, "beta_pl": -3.0}

scipy_solver = jaxopt.ScipyMinimize(fun=negative_log_likelihood, method="TNC", jit=True, tol=1e-10)
result = scipy_solver.run(best_params, nu=nu, N=invN, d=d)
result.params

  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),


{'beta_dust': Array(1.53999953, dtype=float64),
 'beta_pl': Array(-2.99999996, dtype=float64),
 'temp_dust': Array(19.99999945, dtype=float64)}

## Case 2 : Validation with Incorrect Parameter: Setting `beta_dust` to a Wrong Value

In [14]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-3.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)

  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)*log(0.05*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)


SVD of A failed -> logL = -inf
SVD of A failed -> logL_dB not updated
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.53194524 19.97377854 -2.9430962 ]


In [15]:
params = {"temp_dust": 20.0, "beta_dust": 2.54, "beta_pl": -3.0}

scipy_solver = jaxopt.ScipyMinimize(fun=negative_log_likelihood, method="TNC", jit=True, tol=1e-10)
result = scipy_solver.run(params, nu=nu, N=invN, d=d)
result.params

  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),


{'beta_dust': Array(1.53991301, dtype=float64),
 'beta_pl': Array(-2.99768003, dtype=float64),
 'temp_dust': Array(20.00280906, dtype=float64)}

## Case 3 : Setting `beta_dust` and `beta_pl` to Incorrect Values


In [16]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-6.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)

  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)*log(0.05*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)


SVD of A failed -> logL = -inf
SVD of A failed -> logL_dB not updated
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.53034346 19.97418092 -5.99479857]


In [17]:
params = {"temp_dust": 20.0, "beta_dust": 2.54, "beta_pl": -6.0}

scipy_solver = jaxopt.ScipyMinimize(fun=negative_log_likelihood, method="TNC", jit=True, tol=1e-10)
result = scipy_solver.run(params, nu=nu, N=invN, d=d)
result.params

  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),


{'beta_dust': Array(1.53998646, dtype=float64),
 'beta_pl': Array(-2.99989118, dtype=float64),
 'temp_dust': Array(20.00048047, dtype=float64)}

## Case 4 : Setting All Parameters to Incorrect Values


In [18]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=25.0)
components[2]._set_default_of_free_symbols(beta_pl=-6.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)

['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.54 20.   -3.  ]


# Using Optax optimizers

In [20]:
import optax
import optax.tree_utils as otu
from jax_grid_search import optimize

In [23]:
solver = optax.lbfgs()

params = {"temp_dust": 25.0, "beta_dust": 2.54, "beta_pl": -6.0}


final_params, final_state = optimize(
    params, negative_log_likelihood, solver, max_iter=100, tol=1e-4, nu=nu, N=invN, d=d
)

print(
    f"Final parameters: {final_params}, number of evaluations: {otu.tree_get(final_state, 'count')}"
)
print(f"Initial Value: {negative_log_prob(final_params, d=d)}")

Final parameters: {'beta_dust': Array(1.53999998, dtype=float64), 'beta_pl': Array(-2.99999997, dtype=float64), 'temp_dust': Array(20.00000052, dtype=float64)}, number of evaluations: 39
Initial Value: -6114624700347.492
