In [1]:
from furax import AbstractLinearOperator
import furax.tree
from furax.obs.stokes import StokesIQU
import jax
from jax import Array
from jaxtyping import Inexact, PyTree
import equinox
import jax.numpy as jnp
from furax.obs.landscapes import Stokes
import numpy as np
from furax.obs.operators import (
    CMBOperator, DustOperator, MixingMatrixOperator, SynchrotronOperator,
)
import healpy as hp
from furax.core._blocks import BlockDiagonalOperator, BlockRowOperator, BlockColumnOperator
from furax import HomothetyOperator
import operator
from hwp_VES import ListToStokesOperator, StokesToListOperator, MixedStokesOperator
import jaxopt
from functools import partial
from furax.obs import negative_log_likelihood
from fgbuster.observation_helpers import _jysr2rj, _rj2cmb
import pysm3
from pysm3 import units as u
from hwp_VES import ListToStokesOperator, StokesToListOperator, MixedStokesOperator

In [2]:
nside   = 64
nfreq   = 50

freq_dict = {
    'LF1':  {'frequency': 27.0,   'amplitude': 1., 'phaseshift': 0., 'offset': 0.},
    'LF2':  {'frequency': 39.0,   'amplitude': 1., 'phaseshift': 0., 'offset': 0.},
    'MF1': {'frequency': 92.0, 'amplitude': np.float64(-0.9304176070254246), 'phaseshift': np.float64(0.785444797556243), 'offset': np.float64(0.004623759680303902)},
    'MF2': {'frequency': 148.0, 'amplitude': np.float64(-0.90126293657285), 'phaseshift': np.float64(-2.357455125723842), 'offset': np.float64(0.0030693534485759654)},
    'UHF1': {'frequency': 225.5, 'amplitude': np.float64(-0.9496306860180984), 'phaseshift': np.float64(0.784989765678136), 'offset': np.float64(0.0004266008626783833)}, 
    'UHF2': {'frequency': 286.5, 'amplitude': np.float64(-0.9152820744914298), 'phaseshift': np.float64(0.7860182289948432), 'offset': np.float64(4.6078803292807735e-05)}}

frequency_channels = np.array([freq_dict[key]["frequency"] for key in freq_dict.keys()])

def get_bp(freq_c, NFREQ):
    nu_r = np.linspace(0.8, 1.2, NFREQ) * freq_c
    # nu_r = np.linspace(1., 1.2, NFREQ) * freq_c
    bp = np.zeros_like(nu_r)
    bp[nu_r > 0.85 * freq_c] = 1.
    bp[nu_r > 1.15 * freq_c] = 0.
    return nu_r,np.nan_to_num(bp, nan=0)

def get_bp_uKCMB(freq_c, NFREQ):
    nu_r = np.linspace(0.8, 1.2, NFREQ) * freq_c
    bp = np.zeros_like(nu_r)
    bp[nu_r > 0.85 * freq_c] = 1.
    bp[nu_r > 1.15 * freq_c] = 0.
    weights = bp / _jysr2rj(nu_r)
    weights /= _rj2cmb(nu_r) 
    weights /= np.trapezoid(np.nan_to_num(weights, nan=0), nu_r*1E9)
    # print(np.sum(weights))
    bp_norm = [nu_r, np.nan_to_num(weights, nan=0)]
    return bp_norm

bp_arrays = np.array([get_bp(freq, nfreq) for freq in frequency_channels ])

In [4]:
sky = pysm3.Sky(nside=nside, preset_strings=["d0"], output_unit=u.uK_CMB)

freq_maps = np.zeros((len(frequency_channels), 3,hp.nside2npix(nside)))

for i, f in enumerate(frequency_channels):
    freq_maps[i] = sky.get_emission(bp_arrays[i][0]* u.GHz, weights=bp_arrays[i][1])

print("freq_maps shape:", freq_maps.shape)

mask_ = hp.read_map('binary_mask.fits')
mask_i = np.where((mask_ > 0) & (mask_ < 1), 0, mask_)
mask = hp.ud_grade(mask_i, nside_out=nside)
print(mask.shape)

def mask_map(input_map, mask):
    masked_map = np.where(mask == 0, 0, input_map)
    return masked_map

freq_maps   = mask_map(freq_maps, mask)


freq_maps shape: (6, 3, 49152)
(49152,)


In [5]:
d = Stokes.from_stokes(I=freq_maps[:, 0, :], Q=freq_maps[:, 1, :], U=freq_maps[:, 2, :])
print(d.structure)

in_structure = d.structure_for(d[0].shape,)
print(in_structure)

StokesIQU(i=ShapeDtypeStruct(shape=(6, 49152), dtype=float64), q=ShapeDtypeStruct(shape=(6, 49152), dtype=float64), u=ShapeDtypeStruct(shape=(6, 49152), dtype=float64))
StokesIQU(i=ShapeDtypeStruct(shape=(49152,), dtype=float64), q=ShapeDtypeStruct(shape=(49152,), dtype=float64), u=ShapeDtypeStruct(shape=(49152,), dtype=float64))


In [6]:
keys = list(freq_dict.keys())
inputs = []
VESs = []
effs = []
TMs = []
ROOT = f'NFREQ100_NSIDE64_RCC'
mask = hp.read_map('binary_mask.fits')
mask = hp.ud_grade(mask, nside_out=64)
for key in keys:
    fits_directory = f'{ROOT}/{key}'
    inputs.append(hp.read_map(f'{fits_directory}/map_input.fits', field=[0,1,2], dtype=np.float64))
    VESs.append(hp.read_map(f'{fits_directory}/map_VES.fits', field=[0,1,2], dtype=np.float64))
    effs.append(hp.read_map(f'{fits_directory}/map_eff.fits', field=[0,1,2], dtype=np.float64))
    TMs.append(hp.read_map(f'{fits_directory}/map_id.fits', field=[0,1,2], dtype=np.float64))

inputs  = np.array(inputs)
effs    = np.array(effs)

map_input   = mask_map(inputs, mask)
map_eff     = mask_map(effs, mask)
map_TMs     = mask_map(TMs, mask)

map_eff[0] = map_input[0]
map_eff[1] = map_input[1]

d = Stokes.from_stokes(I=map_eff[:, 0, :], Q=map_eff[:, 1, :], U=map_eff[:, 2, :])
d.structure

in_structure = d.structure_for(d[0].shape,)
print(in_structure)

StokesIQU(i=ShapeDtypeStruct(shape=(49152,), dtype=float64), q=ShapeDtypeStruct(shape=(49152,), dtype=float64), u=ShapeDtypeStruct(shape=(49152,), dtype=float64))


In [18]:
dust_nu0 = 148.0
synchrotron_nu0 = None
best_params = {"beta_dust": 2.0}

In [19]:
import transfer_matrixJAX as tm
sapphire = tm.material( 3.05, 3.38, 2.3e-4, 1.25e-4, 'Sapphire', materialType='uniaxial')
duroid   = tm.material( 1.41, 1.41, 1.2e-3, 1.2e-3, 'RT Duroid', materialType='isotropic')
mullite = tm.material( 2.52, 2.52, 0.0121, 0.0121, 'Mullite', materialType='isotropic')
epoteck = tm.material( 1.7, 1.7, 0., 0., 'Epoteck', materialType='isotropic')

thicknesses   = [0.394*tm.mm, 0.04*tm.mm, 0.212*tm.mm, 3.75*tm.mm,3.75*tm.mm,3.75*tm.mm, 0.212*tm.mm, 0.04*tm.mm, 0.394*tm.mm]
thicknesses_HF = [0.183*tm.mm, 0.04*tm.mm, 0.097*tm.mm, 1.60*tm.mm,1.60*tm.mm,1.60*tm.mm, 0.097*tm.mm, 0.04*tm.mm, 0.183*tm.mm]
materials   = [duroid, epoteck, mullite, sapphire, sapphire, sapphire, mullite, epoteck, duroid]
angles      = [0.0, 0.0, 0.0, 0.0, 54.0*tm.deg, 0.0, 0.0, 0.0, 0.0]
angles_HF    = [0.0, 0.0, 0.0, 0.0, 57.0*tm.deg, 0.0, 0.0, 0.0, 0.0]

GHz    = 1e9 
deg = np.pi/180.    
angleIncidence = 5.

nE,nO =3.38,3.05
c = 2.998e8 
angle_incidence = 5*deg

shape = (hp.nside2npix(nside),)

In [20]:
def eval_A(params, f_c,nu_r, in_structure):
    
    cmb_template = CMBOperator(nu_r, in_structure=in_structure, units="K_CMB")
    dust_template = DustOperator(
        nu_r,
        frequency0=f_c,
        temperature= 20.0,
        beta=params["beta_dust"],
        in_structure=in_structure,units="K_CMB"
    )
    # synchrotron_template = SynchrotronOperator(
    #     nu_r,
    #     frequency0=f_c,
    #     beta_pl=params["beta_pl"],
    #     in_structure=in_structure,units="K_CMB"
    # )

    A = MixingMatrixOperator(cmb=cmb_template, dust=dust_template) #, synchrotron=synchrotron_template)
    return A 


def Operator(params):
    Op_list = []

    for nu_x, (key, value) in enumerate(freq_dict.items()):
        freq    = freq_dict[key]["frequency"]
        epsilon = freq_dict[key]["amplitude"]
        phi     = freq_dict[key]["phaseshift"]
        bp_arrays_uKCMB= get_bp_uKCMB(freq, nfreq)

        AOp     = eval_A(params, dust_nu0,bp_arrays_uKCMB[0], in_structure)
        ListOp = StokesToListOperator(axis=0, in_structure=AOp.out_structure())
        
        # hwp_list = []
        # if 'UHF' in key:
        #     thickness = thicknesses_HF[4]
        #     alpha_2 = angles_HF[4]
        
        # else: 
        #     thickness = thicknesses[4]
        #     alpha_2 = angles[4]

        # for freq in bp_arrays_uKCMB[0]:
        #     # Create HWP operator for this frequency
        #     hwp_list.append(MixedStokesOperator.create(
        #         shape=shape,
        #         stokes='IQU',
        #         frequency=freq,
        #         angleIncidence=angle_incidence,
        #         epsilon=epsilon,
        #         phi=phi,
        #         thickness=thickness,
        #         alpha_2=alpha_2
        #     ))

        # HDiag = BlockDiagonalOperator([hwp_list[i] for i in range(nfreq)])

        BP = bp_arrays_uKCMB[1]
        STRUCT = StokesIQU.structure_for(shape=(hp.nside2npix(nside),), dtype=np.float64)
        Bandpass_RowOp = BlockRowOperator([HomothetyOperator(BP[i]/np.sum(BP),STRUCT) for i in range (BP.size)])

        Op = Bandpass_RowOp  @ ListOp @ AOp
        Op_list.append(Op)

    OpBlock     = BlockColumnOperator([Op_list[i] for i in range (len(Op_list))])
    StokesOp    = ListToStokesOperator(axis=0, in_structure=OpBlock.out_structure())
    BMA = StokesOp @ OpBlock

    return BMA

In [21]:
BMA = Operator(best_params)

In [22]:
invN = HomothetyOperator(jnp.ones(1), _in_structure=BMA.T.in_structure())
DND = invN(d) @ d
from likelihoods_kwitchi import negative_log_likelihood_bma

In [23]:
def get_nll(params, nu, N, d, dust_nu0, synchrotron_nu0):
    BMA = Operator(params)
    return negative_log_likelihood_bma(params, nu, N, d, dust_nu0, synchrotron_nu0, BMA
    )

In [24]:
import optax
import optax.tree_utils as otu
from jax_grid_search import optimize

solver = optax.lbfgs()

final_params, final_state = optimize(
    best_params, get_nll, solver, max_iter=2, tol=1e-4, nu=frequency_channels, N=invN, d=d, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0)


print(
    f"Final parameters: {final_params}, number of evaluations: {otu.tree_get(final_state, 'count')}"
)
print(f"Initial Value: {get_nll(final_params, nu=frequency_channels, N=invN, d=d, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0)}")

In [1]:
final_params

NameError: name 'final_params' is not defined