# FGBuster and Furax Imports for Component Separation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/your-repo/your-notebook.ipynb)


In [1]:
# FGBUSTER IMPORTS
import operator
from functools import partial

# FURAX IMPORTS
import jax
import jax.numpy as jnp
import jaxopt
from fgbuster import (
    CMB,
    Dust,
    Synchrotron,
    basic_comp_sep,
    get_instrument,
)
from furax import HomothetyOperator
from furax.obs.landscapes import Stokes
from furax.obs.operators import (
    CMBOperator,
    DustOperator,
    MixingMatrixOperator,
    SynchrotronOperator,
)

## Mixed Sky Maps Creation Using PySM

In this section, we create simulated sky maps using the `PySM` library with specified parameters for each astrophysical component. Key elements:

- **NSIDE**: Sets the HEALPix resolution, determining the number of pixels in the sky map.
- **Reference Frequencies**:
  - **Dust** at 150 GHz
  - **Synchrotron** at 20 GHz
- **Instrument Configuration**: Using the `LiteBIRD` instrument model to simulate observed frequency maps.

This setup provides the mixed sky maps required for component separation, with the shape of the output maps indicated for verification.

`generate_maps` is used on Jean-Zay to cache the frequency maps on the front node since there is no internet access in the slurm nodes

In [2]:
from generate_maps import save_to_cache

nsides = [32]
for nside in nsides:
    save_to_cache(nside)

Loaded freq_maps for nside 32 from cache.


In [3]:
from generate_maps import load_from_cache

nside = 32

nu, freq_maps = load_from_cache(nside)
# Check the shape of freq_maps
print("freq_maps shape:", freq_maps.shape)

Loaded freq_maps for nside 32 from cache.
freq_maps shape: (15, 3, 12288)


Furax expects a `Stokes` object as input, so we convert the frequency maps to the correct format, so we transform the `freq_maps` into a `Stokes` object to make it compatible with Furax.

__Note__: Although Furax includes its own functions to create sky maps from PySM, we use `fgbuster` here to ensure that both methods receive identical inputs for comparison.


In [4]:
d = Stokes.from_stokes(I=freq_maps[:, 0, :], Q=freq_maps[:, 1, :], U=freq_maps[:, 2, :])
d.structure

StokesIQU(i=ShapeDtypeStruct(shape=(15, 12288), dtype=float64), q=ShapeDtypeStruct(shape=(15, 12288), dtype=float64), u=ShapeDtypeStruct(shape=(15, 12288), dtype=float64))

In [5]:
dust_nu0 = 150.0
synchrotron_nu0 = 20.0
instrument = get_instrument("LiteBIRD")
components = [CMB(), Dust(dust_nu0), Synchrotron(synchrotron_nu0)]

## Defining the Likelihood Function for Component Separation

In this cell, we define the `negative_log_prob` function, which calculates the negative log-likelihood of observing the given data `d` based on the model parameters.

The likelihood function is based on a quadratic form that includes the mixing matrix `A`, inverse noise covariance `N^{-1}`, and observed data `d`. The key term in the likelihood is:

$$
\left(A^T N^{-1} d\right)^T \left(A^T N^{-1} A\right)^{-1} \left(A^T N^{-1} d\right)
$$

### Explanation of Each Term

1. **$A$**: The mixing matrix operator, which maps the component space to the observed frequency space.
2. **$N^{-1}$**: The inverse of the noise covariance matrix, represented by `invN` in the code.
3. **$d$**: The observed data, which is structured as a `Stokes` in Furax.

### Implementation Details

- **Transposing and Applying `A`**: `A.T(d)` applies the transpose of `A` to `d`, equivalent to the term $A^T d$.
- **Computing the Likelihood**: The quadratic form is computed by applying $A^T N^{-1} d$, inverting $A^T N^{-1} A$, and performing matrix multiplications to evaluate the likelihood.
- **Negative Log-Likelihood**: The final output of `negative_log_prob` is the negative of this log-likelihood value, allowing us to use it as a loss function for optimization.


In [None]:
invN = HomothetyOperator(jnp.ones(1), _in_structure=d.structure)
DND = invN(d) @ d

in_structure = d.structure_for((d.shape[1],))
best_params = {"temp_dust": 20.0, "beta_dust": 1.54, "beta_pl": -3.0}

dust_nu0 = 150.0
synchrotron_nu0 = 20.0


@jax.jit
def negative_log_prob(params, d):
    cmb = CMBOperator(nu, in_structure=in_structure)
    dust = DustOperator(
        nu,
        frequency0=dust_nu0,
        temperature=params["temp_dust"],
        beta=params["beta_dust"],
        in_structure=in_structure,
    )
    synchrotron = SynchrotronOperator(
        nu,
        frequency0=synchrotron_nu0,
        beta_pl=params["beta_pl"],
        in_structure=in_structure,
    )

    A = MixingMatrixOperator(cmb=cmb, dust=dust, synchrotron=synchrotron)

    x = (A.T @ invN)(d)
    likelihood = jax.tree.map(lambda a, b: a @ b, x, (A.T @ invN @ A).I(x))
    summed_log_prob = jax.tree.reduce(operator.add, likelihood)

    return -summed_log_prob

Evaluate the performance of the likelihood

In [7]:
print("Performance of the nll evaluation")
negative_log_prob(best_params, d).block_until_ready()
%timeit negative_log_prob(best_params, d).block_until_ready()
print("Performance of the nll grad evaluation")
jax.grad(negative_log_prob)(best_params, d)["beta_pl"].block_until_ready()
%timeit jax.grad(negative_log_prob)(best_params, d)['beta_pl'].block_until_ready()

Performance of the nll evaluation
65.2 ms ± 7.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Performance of the nll grad evaluation
186 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Check for Correctness

In this cell, we perform a basic correctness check by comparing the gradients of the negative log-likelihood at two sets of parameters:

1. **Wrong Parameters**: A set of parameters obtained by adding random noise to `best_params`.
2. **Correct Parameters**: The original `best_params`.

By calculating and comparing the gradient magnitudes (using the `max` reduction), we can verify that the gradient at the correct parameters is smaller, indicating proximity to an optimal or near-optimal point.


In [8]:
wrong_params = jax.tree.map(
    lambda x: x + jax.random.normal(jax.random.PRNGKey(0)), best_params
)
print(
    f"Wrong parameters grad {jax.tree.reduce(max, jax.grad(negative_log_prob)(wrong_params, d))}"
)
print(
    f"Correct parameters grad {jax.tree.reduce(max, jax.grad(negative_log_prob)(best_params, d))}"
)

Wrong parameters grad -113186446.19540569
Correct parameters grad 154.78784358372695


## Off the shelf likelihoods



In [9]:
from furax.comp_sep import negative_log_likelihood

negative_log_likelihood = partial(
    negative_log_likelihood, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0
)

L = negative_log_likelihood(best_params, nu=nu, N=invN, d=d)

assert jax.tree.all(
    jax.tree.map(
        lambda x, y: jnp.isclose(x, y, rtol=1e-5), L, negative_log_prob(best_params, d)
    )
)

# Validating Against FGBuster: The `c1d0s0` Model

In this section, we validate our custom implementation of the likelihood model by comparing it to the `c1d0s0` model from `fgbuster`. By aligning our results with FGBuster’s well-established component separation model, we ensure that our setup and computations are consistent and reliable.


### Case 1 : Initial Validation: Using `best_params` as the Starting Point

We begin the validation process by setting `best_params` as the initial point for both our custom implementation and FGBuster’s `c1d0s0` model. This allows us to directly compare the outputs and confirm that the models produce similar results when initialized wit


In [10]:
components[1]._set_default_of_free_symbols(beta_d=1.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-3.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)

['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.54 20.   -3.  ]


In [11]:
best_params = {"temp_dust": 20.0, "beta_dust": 1.54, "beta_pl": -3.0}

scipy_solver = jaxopt.ScipyMinimize(
    fun=negative_log_likelihood, method="TNC", jit=True, tol=1e-10
)
result = scipy_solver.run(best_params, nu=nu, N=invN, d=d)
result.params

  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),


{'beta_dust': Array(1.53999987, dtype=float64),
 'beta_pl': Array(-3., dtype=float64),
 'temp_dust': Array(19.9999997, dtype=float64)}

## Case 2 : Validation with Incorrect Parameter: Setting `beta_dust` to a Wrong Value

In [12]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-3.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)

  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)*log(0.05*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)


SVD of A failed -> logL = -inf
SVD of A failed -> logL_dB not updated
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf


  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)


SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.53194524 19.97377854 -2.9430962 ]


In [13]:
params = {"temp_dust": 20.0, "beta_dust": 2.54, "beta_pl": -3.0}

scipy_solver = jaxopt.ScipyMinimize(
    fun=negative_log_likelihood, method="TNC", jit=True, tol=1e-10
)
result = scipy_solver.run(params, nu=nu, N=invN, d=d)
result.params

  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),


{'beta_dust': Array(1.53999179, dtype=float64),
 'beta_pl': Array(-2.9999352, dtype=float64),
 'temp_dust': Array(20.00027089, dtype=float64)}

## Case 3 : Setting `beta_dust` and `beta_pl` to Incorrect Values


In [14]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-6.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)

  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)*log(0.05*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)


SVD of A failed -> logL = -inf
SVD of A failed -> logL_dB not updated
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf


  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)


['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.53034346 19.97418092 -5.99479857]


In [15]:
params = {"temp_dust": 20.0, "beta_dust": 2.54, "beta_pl": -6.0}

scipy_solver = jaxopt.ScipyMinimize(
    fun=negative_log_likelihood, method="TNC", jit=True, tol=1e-10
)
result = scipy_solver.run(params, nu=nu, N=invN, d=d)
result.params

  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),


{'beta_dust': Array(1.53999777, dtype=float64),
 'beta_pl': Array(-3.00011661, dtype=float64),
 'temp_dust': Array(20.0001393, dtype=float64)}

## Case 4 : Setting All Parameters to Incorrect Values


In [16]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=25.0)
components[2]._set_default_of_free_symbols(beta_pl=-6.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)

['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.54000008 19.99999732 -3.00000111]


In [17]:
params = {"temp_dust": 25.0, "beta_dust": 2.54, "beta_pl": -6.0}

scipy_solver = jaxopt.ScipyMinimize(
    fun=negative_log_likelihood, method="TNC", jit=True, tol=1e-10
)
result = scipy_solver.run(params, nu=nu, N=invN, d=d)
result.params

  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),


{'beta_dust': Array(1.53999768, dtype=float64),
 'beta_pl': Array(-2.99996799, dtype=float64),
 'temp_dust': Array(20.00007388, dtype=float64)}

# Using Optax optimizers

In [19]:
import optax
import optax.tree_utils as otu
from jax_grid_search import optimize

In [20]:
solver = optax.lbfgs()

final_params, final_state = optimize(
    params, negative_log_likelihood, solver, max_iter=100, tol=1e-4, nu=nu, N=invN, d=d
)

print(
    f"Final parameters: {final_params}, number of evaluations: {otu.tree_get(final_state, 'count')}"
)
print(f"Initial Value: {negative_log_prob(final_params, d=d)}")

Final parameters: {'beta_dust': Array(1.54000002, dtype=float64), 'beta_pl': Array(-3.0000005, dtype=float64), 'temp_dust': Array(19.99999868, dtype=float64)}, number of evaluations: 39
Initial Value: -6114624736980.274
