In [1]:
import jax.numpy as jnp

from furax._base.core import HomothetyOperator
from furax.landscapes import StokesPyTree
from functools import partial

In [2]:
from generate_maps import save_to_cache

nsides = [32, 64, 128, 256, 512]
for nside in nsides:
    save_to_cache(nside)

Loaded freq_maps for nside 32 from cache.
Loaded freq_maps for nside 64 from cache.
Loaded freq_maps for nside 128 from cache.
Loaded freq_maps for nside 256 from cache.
Loaded freq_maps for nside 512 from cache.


In [3]:
from generate_maps import load_from_cache

nside = 32

nu, freq_maps = load_from_cache(nside)
# Check the shape of freq_maps
print('freq_maps shape:', freq_maps.shape)

Loaded freq_maps for nside 32 from cache.
freq_maps shape: (15, 3, 12288)


In [4]:
d = StokesPyTree.from_stokes(Q=freq_maps[:, 1, :], U=freq_maps[:, 2, :])
N = HomothetyOperator(jnp.ones(1), _in_structure=d.structure)

In [6]:
from furax.comp_sep import spectral_cmb_variance

dust_nu0 = 150.0
synchrotron_nu0 = 20.0

spectral_cmb_variance = partial(
    spectral_cmb_variance, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0
)

In [7]:
from furax.comp_sep import optimize
import optax
import optax.tree_utils as otu

solver = optax.lbfgs()

params = {'beta_dust': 1.59, 'beta_pl': -3.1, 'temp_dust': 19.6}

final_params, final_state = optimize(
    params, spectral_cmb_variance, solver, max_iter=100, tol=1e-4, nu=nu, N=N, d=d
)

print(
    f"Final parameters: {final_params}, number of evaluations: {otu.tree_get(final_state, 'count')}"
)
print(f'Initial Value: {spectral_cmb_variance(final_params , nu=nu, N=N, d=d)}')

Final parameters: {'beta_dust': Array(1.54888217, dtype=float64), 'beta_pl': Array(-3.05748183, dtype=float64), 'temp_dust': Array(19.59906649, dtype=float64)}, number of evaluations: 12
Initial Value: 0.30217391179530706


In [8]:
from jax_grid_search import DistributedGridSearch


def objective_function(beta_dust, temp_dust, beta_pl):
    params = {
        'beta_dust': beta_dust,
        'temp_dust': temp_dust,
        'beta_pl': beta_pl,
    }

    return spectral_cmb_variance(params, nu=nu, N=N, d=d)


# Put the good values for the grid search
search_space = {
    'beta_dust': jnp.linspace(1.5, 3.5, 10).at[3].set(final_params['beta_dust']),
    'temp_dust': jnp.linspace(5.0, 50.0, 10).at[2].set(final_params['temp_dust']),
    'beta_pl': jnp.linspace(-4.5, -1.5, 10).at[4].set(final_params['beta_pl']),
}

grid_search = DistributedGridSearch(
    objective_function, search_space, batch_size=25, progress_bar=True, log_every=0.1
)

results = grid_search.run()

for keys, values in results.items():
    print(f'{keys} : {values[0]}')

Selecting batch size of 25
log_interval: 4


Processing batches: 100%|██████████| 40/40 [00:20<00:00,  1.93it/s]


Done .. Stacking the results
beta_dust : 1.548882170771958
temp_dust : 19.599066490764134
beta_pl : -3.057481830864701
value : 0.30217391179530706
