In [1]:
%load_ext autoreload
%autoreload 2
import os

os.environ["EQX_ON_ERROR"] = "nan"
from functools import partial

import jax
import jax.numpy as jnp
import lineax as lx
import numpy as np
import optax
from furax import Config, HomothetyOperator
from furax._instruments.sky import (
    FGBusterInstrument,
    get_noise_from_instrument,
    get_observation,
)
from furax.comp_sep import (
    negative_log_likelihood,
    spectral_cmb_variance,
)
from furax.obs.landscapes import HealpixLandscape
from generate_maps import get_mixin_matrix_operator, simulate_D_from_params
from jax_grid_search import optimize
from jax_healpy import get_clusters, get_cutout_from_mask

In [2]:
GAL020 = np.load("../data/masks/GAL_PlanckMasks_64.npz")["GAL020"]
GAL040 = np.load("../data/masks/GAL_PlanckMasks_64.npz")["GAL040"]
GAL060 = np.load("../data/masks/GAL_PlanckMasks_64.npz")["GAL060"]

In [3]:
nside = 64
npixel = 12 * nside**2
patch_counts = {
    "temp_dust_patches": 1,
    "beta_dust_patches": 100,
    "beta_pl_patches": 1,
}

max_centroids = 300
mask = GAL020
(indices,) = jnp.where(mask == 1)

In [4]:
patch_indices = jax.tree.map(
    lambda c: get_clusters(
        mask, indices, c, jax.random.PRNGKey(0), max_centroids=max_centroids
    ),
    patch_counts,
)
masked_clusters = jax.tree.map(
    lambda full_map: get_cutout_from_mask(full_map, indices).astype(jnp.int32),
    patch_indices,
)

nu = FGBusterInstrument.default_instrument().frequency
land_scape = HealpixLandscape(nside=nside, stokes="QU")

sky = {
    "cmb": land_scape.normal(jax.random.key(0)),
    "dust": land_scape.normal(jax.random.key(1)),
    "synchrotron": land_scape.normal(jax.random.key(2)),
}
masked_sky = get_cutout_from_mask(sky, indices)

In [5]:
best_params = {
    "temp_dust": jnp.full((patch_counts["temp_dust_patches"],), 20.0),
    "beta_dust": jnp.full((patch_counts["beta_dust_patches"],), 1.54),
    "beta_pl": jnp.full((patch_counts["beta_pl_patches"],), -3.0),
}

best_params_flat, tree_struct = jax.tree.flatten(best_params)
best_params = jax.tree.map_with_path(
    lambda path, x: x + jax.random.normal(jax.random.key(path[0].idx), x.shape) * 0.2,
    best_params_flat,
)
best_params = jax.tree.unflatten(tree_struct, best_params)

In [6]:
best_params

{'beta_dust': Array([1.49883157, 1.38304684, 1.90321733, 1.5775688 , 1.55617358,
        1.46557784, 1.77803274, 1.60772846, 1.55696517, 1.36563643,
        1.75090322, 1.22810042, 1.61350792, 2.04327043, 1.59171303,
        1.48325791, 1.43322177, 1.61358954, 1.50080716, 1.84081536,
        1.55180981, 1.56831626, 1.57770313, 1.54934587, 1.25730186,
        1.50100732, 1.76260491, 1.09459231, 1.42748075, 1.60027231,
        1.30942339, 1.54260899, 1.79598285, 1.0333889 , 1.34438666,
        1.25838338, 1.30573365, 1.7396384 , 1.45527308, 1.61711098,
        1.43685442, 1.53523056, 1.6199428 , 1.21355378, 1.51907166,
        1.41277575, 1.60968798, 1.53036035, 1.60655999, 1.31296924,
        1.14378612, 1.68883529, 1.34205039, 1.53183941, 2.02366229,
        1.37742847, 1.33009791, 1.28797115, 1.39153294, 1.43431806,
        1.60543286, 1.55500526, 1.91423056, 1.68908721, 1.27919635,
        1.57905523, 1.85323274, 1.98509326, 1.32746007, 1.75164565,
        1.67378458, 1.18646281, 1.5

In [7]:
dust_nu0 = 150.0
synchrotron_nu0 = 20.0
masked_d = simulate_D_from_params(
    best_params,
    masked_clusters,
    nu,
    masked_sky,
    dust_nu0=dust_nu0,
    synchrotron_nu0=synchrotron_nu0,
)

instrument = FGBusterInstrument.default_instrument()
pysm3_d = get_observation(instrument, 64, stokes_type="QU", tag="c1d0s0")

masked_pysm3_d = get_cutout_from_mask(pysm3_d, indices, axis=1)

spectral_cmb_variance_fn = partial(
    spectral_cmb_variance, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0
)
negative_log_likelihood_fn = partial(
    negative_log_likelihood, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0
)

N = HomothetyOperator(jnp.ones(1), _in_structure=masked_d.structure)
solver = optax.lbfgs()

inverser_options = {
    "solver": lx.CG(rtol=1e-6, atol=1e-6, max_steps=1000),
    "solver_throw": False,
}

In [8]:
instrument = FGBusterInstrument.default_instrument()

instrument.depth_i.shape

(10,)

In [9]:
noise = get_noise_from_instrument(instrument, 64, stokes_type="QU")
masked_noise = get_cutout_from_mask(noise, indices, axis=1)
noised_d = masked_d + masked_noise * 0.1

In [10]:
A = get_mixin_matrix_operator(
    best_params,
    masked_clusters,
    nu,
    masked_sky,
    dust_nu0=dust_nu0,
    synchrotron_nu0=synchrotron_nu0,
)


def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool:
    return jax.tree.map(lambda x, y: ((x - y) ** 2).mean(), x, y)
    return jax.tree.all(
        jax.tree.map(
            lambda x, y: jnp.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan),
            x,
            y,
        ),
        True,
    )


s = (A.T @ N.I @ A).I((A.T @ N.I)(noised_d))
jax.tree.reduce(max, allclose(s, masked_sky), 0.0)

Array(0.00120374, dtype=float64)

In [35]:
masked_d = masked_pysm3_d


@partial(jax.jit, static_argnums=(5))
def compute_minimum_variance(
    T_d_patches, B_d_patches, B_s_patches, planck_mask, indices, max_patches=25
):
    temp_dust_patch_indices = None
    beta_dust_patch_indices = get_clusters(
        planck_mask,
        indices,
        B_d_patches,
        jax.random.PRNGKey(0),
        max_centroids=max_patches,
    )
    beta_pl_patch_indices = None

    params = {
        "beta_dust": jnp.full((max_patches,), 1.54),
        "temp_dust": jnp.full((1,), 20.0),
        "beta_pl": jnp.full((1,), (-3.0)),
    }

    patch_indices = {
        "temp_dust_patches": temp_dust_patch_indices,
        "beta_dust_patches": beta_dust_patch_indices,
        "beta_pl_patches": beta_pl_patch_indices,
    }

    masked_clusters = jax.tree.map(
        lambda full_map: get_cutout_from_mask(full_map, indices).astype(jnp.int32),
        patch_indices,
    )

    with Config(**inverser_options):
        final_params, final_state = optimize(
            params,
            negative_log_likelihood_fn,
            solver,
            max_iter=1000,
            tol=1e-15,
            verbose=True,
            log_interval=0.01,
            nu=nu,
            N=N,
            d=noised_d,
            patch_indices=masked_clusters,
        )

    cmb_var = spectral_cmb_variance_fn(
        final_params, nu=nu, d=noised_d, N=N, patch_indices=masked_clusters
    )
    nll = negative_log_likelihood_fn(
        final_params, nu=nu, d=noised_d, N=N, patch_indices=masked_clusters
    )

    return {
        "value": cmb_var,
        "NLL": nll,
        "beta_dust": final_params["beta_dust"],
        "temp_dust": final_params["temp_dust"],
        "beta_pl": final_params["beta_pl"],
    }

In [36]:
bad_res = compute_minimum_variance(1, 200, 1, GAL020, indices, max_patches=200)
good_res = compute_minimum_variance(1, 100, 1, GAL020, indices, max_patches=200)

update norm 29.720699857170256 at iter 0 value inf
update norm 0.00588818531296091 at iter 10 value -3763630.7008521864
update norm 0.28775612331648676 at iter 20 value -3763630.7378949975
update norm 4.1992327069108966e-05 at iter 30 value -3763630.763445585
update norm 1.393534376870711e-08 at iter 40 value -3763630.763445588
update norm 2.0185527364572618e-10 at iter 50 value -3763630.7634455883
update norm 5.007651660843308e-12 at iter 60 value -3763630.7634455883
update norm 4.359935808996625e-13 at iter 70 value -3763630.7634455883
update norm 2.166503540299223e-13 at iter 80 value -3763630.763445588
update norm 1.2626620896361715e-13 at iter 90 value -3763630.7634455883
update norm 3.704591405038563e-13 at iter 100 value -3763630.763445588
update norm 1.0949000037006079e-12 at iter 110 value -3763630.7634455883
update norm 2.698437471468854e-13 at iter 120 value -3763630.7634455883
update norm 4.661736363540289e-12 at iter 130 value -3763630.7634455883
update norm 3.672794811517

In [37]:
bell = negative_log_likelihood_fn(
    best_params, nu=nu, N=N, d=noised_d, patch_indices=masked_clusters
)
gll = negative_log_likelihood_fn(
    good_res, nu=nu, N=N, d=noised_d, patch_indices=masked_clusters
)
bll = negative_log_likelihood_fn(
    bad_res, nu=nu, N=N, d=noised_d, patch_indices=masked_clusters
)


var_bell = spectral_cmb_variance_fn(
    best_params, nu=nu, d=noised_d, N=N, patch_indices=masked_clusters
)
var_gll = spectral_cmb_variance_fn(
    good_res, nu=nu, d=noised_d, N=N, patch_indices=masked_clusters
)
var_bll = spectral_cmb_variance_fn(
    bad_res, nu=nu, d=noised_d, N=N, patch_indices=masked_clusters
)

In [38]:
var_bell, var_gll, var_bll

(Array(2.00928344, dtype=float64),
 Array(2.0160553, dtype=float64),
 Array(2.01685719, dtype=float64))

In [None]:
gll > bell, gll < bll

(Array(True, dtype=bool), Array(True, dtype=bool))

: 

In [34]:
var_gll > var_bell, var_gll < var_bll

(Array(True, dtype=bool), Array(True, dtype=bool))

In [23]:
bll < bell

Array(True, dtype=bool)

In [None]:
good_res["value"], bad_res["value"]

Array(0.9741552, dtype=float64)

In [None]:
bad_res["value"] > good_res["value"]

Array(0.9741552, dtype=float64)