In [1]:
%load_ext autoreload
%autoreload 2

from Shared.shared import *
from Shared.specific_CNB_sim import *

sim_name = f"no_gravity"
sim_folder = f"sim_output/{sim_name}"
fig_folder = f"figures_local/{sim_name}"
Cl_folder = f"Shared/Cls"
Delta_folder = f"Shared/Deltas"
nu_m_range = jnp.load(f"{sim_folder}/neutrino_massrange_eV.npy")
nu_m_picks = jnp.array([
    0.01, 0.05, 0.1, 
    # 0.2, 0.3
    ])*Params.eV
simdata = SimData(sim_folder)

# Using Fermi-Dirac (FD) PSD

## Overdensity band

In [None]:
tot_dens_halos = jnp.load(f"{sim_folder}/total_densities.npy")
ic(tot_dens_halos.shape)

eta = 1e-3
SimPlot.overdensity_band(
    dens_arr=tot_dens_halos,
    m_arr=nu_m_range,
    plot_ylims=((0-eta),(0+eta)),
    plot_log=False,
    fig_path=f"{fig_folder}/overdensity_band_{sim_name}.pdf",
    args=Params())

## Anisotropy Skymaps and Power Spectra

In [None]:
nu_vectors = jnp.load(f"{sim_folder}/vectors_halo1.npy")
ic(nu_vectors.shape)

nu_dens = Physics.number_densities_all_sky(
    v_arr=nu_vectors[..., 3:],
    m_arr=nu_m_range,
    pix_sr=simdata.pix_sr,
    args=Params())
ic(nu_dens.shape)

In [None]:
# Choose which halo to plot
halo_i = 1

# Load associated origID of chosen halo
halo_indices = jnp.load(
    f"{sim_folder}/halo_batch_0.6-2.0x1e+12.0_Msun_indices.npy")
haloID = halo_indices[halo_i-1]

# Choose which neutrino mass to plot
m_val = 0.3
m_idx = jnp.abs(nu_m_range - m_val).argmin()

# Load corresponding number densities
CNB_dens_1mass = nu_dens[m_idx]
ic(CNB_dens_1mass.shape)
ic(CNB_dens_1mass.min(), CNB_dens_1mass.max())

CNB_skymap, DM_halo_skymap = SimPlot.skymaps_CNBxDM(
    nu_mass=m_val, 
    CNB_dens=CNB_dens_1mass, 
    DM_halo_pos=jnp.load(f"Data/halo_grids/DM_pos_origID{haloID}_snap_0036.npy"), 
    init_xyz=jnp.load(f"{sim_folder}/init_xyz_halo{halo_i}.npy"), 
    fig_path=f"{fig_folder}/skymaps_CNBxDM_{sim_name}.pdf",
    args=Params())

## Phase-space distributions today

In [14]:
def most_clustered():
    ...
    # splits up momentum/velocity arrays into blocks?
def halo_parameters():
    ...
def escape_momentum():
    ...


def get_Deltas_halo_0th_axis(y_arr, m_num, Deltas_arr, args):

    ic(Deltas_arr.shape, m_num)

    # q-momenta at z=4
    q_z4 = y_arr[...,-1] / (1+4)  #? in terms of T_CNB(z=0) or T_CNB(z=4) ?
    ic(q_z4.shape)

    # Pixel indices for all neutrinos
    # (looks like [0,...,0,1,...,1,...,Npix-1,...,Npix-1])
    p_idx = jnp.repeat(jnp.arange(args.Npix), args.p_num)[None, :]
    ic(p_idx.shape)

    # Reshape, s.t. arrays have the right number of singleton dimensions
    Cl_qs = Primordial.Cl_qs
    new_shape = (1,) * jnp.ndim(q_z4) + Cl_qs.shape
    Cl_sync = jnp.reshape(Cl_qs, new_shape)
    ic(Cl_sync.shape)

    # Find indices to match neutrino momenta to Cl momenta
    q_idx = jnp.abs(Cl_sync - q_z4[..., None]).argmin(axis=-1)
    ic(q_idx.shape)
    q_idx = jnp.reshape(q_idx, (1, m_num, -1))
    ic(q_idx.shape)

    # Mass indices adjusted for broadcasting / fancy indexing of Delta matrix
    m_idx = jnp.arange(m_num)[:, None]
    ic(m_idx.shape)

    # Select corresponding pixels, i.e. temp. perturbations, for all neutrinos
    Deltas_fin = jnp.reshape(
        Deltas_arr[m_idx, q_idx, p_idx], (m_num, args.Npix, args.p_num))
    ic(Deltas_fin.shape)

    return Deltas_fin


def phase_space_2x2_FD_vs_PF(sim_dir, halo_num):

    # Load neutrino vectors for all halos
    halo_vectors = SimData.load_vectors(sim_dir=sim_dir, halo_num=halo_num)
    ic(halo_vectors.shape)

    # Convert velocities to momenta
    p_arr, y_arr = Physics.velocities_to_momenta_ND_halo_0th_axis(
        v_arr=halo_vectors[...,3:], 
        m_arr=nu_m_picks,
        args=Params())
    ic(y_arr.shape)
    del halo_vectors

    p_z0 = p_arr[...,0]
    p_z4 = p_arr[...,-1]
    y_z0 = y_arr[...,0]
    del p_arr

    # Sort in ascending order of momentum array today
    ind = p_z0.argsort(axis=-1)
    p_z4_sort = jnp.take_along_axis(p_z4, ind, axis=-1)
    y_z0_sort = jnp.take_along_axis(y_z0, ind, axis=-1)

    # Each velocity has a batch of neutrinos:
    # Take min. of each to represent most (likely) clustered ones.
    #? needs better explanation...
    p_z4_select = jnp.min(jnp.swapaxes(p_z4_sort, -1, -2), axis=-1)

    # y_arr has repeated elements, take first element to sync with p_arr
    y_z0_select = jnp.swapaxes(y_z0_sort, -1, -2)[..., 0]

    del p_z4_sort, y_z0_sort

    # Compute medians
    p_z4_median   = jnp.median(p_z4_select, axis=0)
    y_z0_median   = jnp.median(y_z0_select, axis=0)

    # Compute percentile ranges
    p_z4_perc2p5  = jnp.percentile(p_z4_select, q=2.5,  axis=0)
    p_z4_perc97p5 = jnp.percentile(p_z4_select, q=97.5, axis=0)
    p_z4_perc16   = jnp.percentile(p_z4_select, q=16,   axis=0)
    p_z4_perc84   = jnp.percentile(p_z4_select, q=84,   axis=0)

    # PSD using Fermi-Dirac assumption
    FD_median   = Physics.Fermi_Dirac(p_z4_median, Params())
    FD_perc2p5  = Physics.Fermi_Dirac(p_z4_perc2p5, Params())
    FD_perc97p5 = Physics.Fermi_Dirac(p_z4_perc97p5, Params())
    FD_perc16   = Physics.Fermi_Dirac(p_z4_perc16, Params())
    FD_perc84   = Physics.Fermi_Dirac(p_z4_perc84, Params())

    # PSD including primordial fluctuations
    Deltas = get_Deltas_halo_0th_axis(
        y_arr=y_arr[0, ...], 
        m_num=len(nu_m_picks),
        Deltas_arr=jnp.load(f"{Delta_folder}/Delta_matrix_z4.npy")[None, ...], 
        args=SimData(sim_folder))
    # PF_median   = Physics.Fermi_Dirac_Delta(p_z4_median, Deltas, Params())
    # PF_perc2p5  = Physics.Fermi_Dirac_Delta(p_z4_perc2p5, Deltas, Params())
    # PF_perc97p5 = Physics.Fermi_Dirac_Delta(p_z4_perc97p5, Deltas, Params())
    # PF_perc16   = Physics.Fermi_Dirac_Delta(p_z4_perc16, Deltas, Params())
    # PF_perc84   = Physics.Fermi_Dirac_Delta(p_z4_perc84, Deltas, Params())

    fig = plt.figure()
    for i in range(4):
        ax = plt.subplot(221 + i)

        FD_color = "mediumblue"
        PF_color = "orangered"

        # Median of all halos
        ax.plot(
            y_z0_median[i], FD_median[i], 
            color=FD_color, alpha=0.9, label="halos median")

        # Percentile ranges of all halos
        ax.fill_between(
            y_z0_median[i], FD_perc2p5[i], FD_perc97p5[i],
            color=FD_color, alpha=0.2, label=r"$2.5-97.5$ % C.L.")
        ax.fill_between(
            y_z0_median[i], FD_perc16[i], FD_perc84[i],
            color=FD_color, alpha=0.3, label=r"$16-84$ % C.L.")

        ax.set_title(fr"PSD for $m_\nu=${nu_m_picks[i]} eV")
        ax.set_ylim(1e-2, 1e0)
        ax.set_xlim(Params.p_start, 1e1)
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_xlabel(r"$p/T_{\nu,0}$")
        ax.set_ylabel(r"$f_{\oplus,\mathrm{today}}$")

        if i==3:
            ax.legend(
                loc='lower left', borderpad=0.5, labelspacing=0.5, 
                fontsize='medium', prop={"size":10})

        plt.tight_layout()

    plt.savefig(f"{fig_folder}/PSD_FD_vs_PF_no_gravity.png", bbox_inches="tight")
    plt.savefig(f"{fig_folder}/PSD_FD_vs_PF_no_gravity.pdf", bbox_inches="tight")
    plt.show(); plt.close()


phase_space_2x2_FD_vs_PF(sim_dir=sim_folder, halo_num=1)

ic| halo_vectors.shape: (1, 768, 1000, 2, 6)
ic| y_arr.shape: (1, 3, 768, 1000, 2)
ic| Deltas_arr.shape: (1, 3, 50, 768), m_num: 3
ic| q_z4.shape: (3, 768, 1000)
ic| p_idx.shape: (1, 768000)
ic| Cl_sync.shape: (1, 1, 1, 50)
ic| q_idx.shape: (3, 768, 1000)
ic| q_idx.shape: (1, 3, 768000)
ic| m_idx.shape: (3, 1)


In [None]:
Deltas_z4 = jnp.load(f"{Delta_folder}/Delta_matrix_z4.npy")
Deltas_z0 = jnp.load(f"{Delta_folder}/Delta_matrix_z0.npy")
ic(Deltas_z0.shape)

# Using Primordial Fluctuations (PF) PSD

## Temperature fluctuation skymaps from Cl

In [None]:
SimPlot.temp_deltas_skymaps(
    m_Cl=0.01, 
    q_select=Primordial.q_select, 
    Cl_qi=Primordial.Cl_qi, 
    Cl_dir=Cl_folder, 
    fig_dir=fig_folder,
    seed=5, 
    Nside=simdata.Nside,
    args=Params)

SimPlot.temp_deltas_skymaps(
    m_Cl=0.1, 
    q_select=Primordial.q_select, 
    Cl_qi=Primordial.Cl_qi, 
    Cl_dir=Cl_folder, 
    fig_dir=fig_folder,
    seed=5, 
    Nside=simdata.Nside,
    args=Params)

## Inspect number density files

In [None]:
# Fermi-Dirac distribution assumption
# note: total densities are computed for length 50 mass range
pix_dens_FD = jnp.load(f"{sim_folder}/pixel_densities.npy")
tot_dens_FD = jnp.load(f"{sim_folder}/total_densities.npy")
ic(pix_dens_FD.shape)
ic(tot_dens_FD.shape)
print("\n")

time.sleep(0.5)
# Incl. primordial fluctuations
# note: all sky densities are computed for 5 specific masses
pix_dens_PF = jnp.load(f"{sim_folder}/pixel_densities_incl_PFs.npy")
tot_dens_PF = jnp.load(f"{sim_folder}/total_densities_incl_PFs.npy")
ic(pix_dens_PF.shape)
ic(tot_dens_PF.shape)

# Deltas = jnp.load(f"{sim_folder}/Deltas_halos.npy")
# ic(Deltas.shape)
print("END")

## Comparison of FD & PF skymaps

In [None]:
# Choose a halo and mass to display
halo_i = 1  #note: only halo_i=1 for no_gravity (just a label, sim doens't contain any halos)
nu_m_picks = jnp.array([0.01, 0.05, 0.1, 0.2, 0.3])*Params.eV
m_range = jnp.load(f"{sim_folder}/neutrino_massrange_eV.npy")

m_pick = 0.1
mPF_i = jnp.abs(nu_m_picks - m_pick).argmin()
mFD_i = jnp.abs(m_range - m_pick).argmin()

N0_pix = Params.N0 / simdata.Npix
skymap_norm = N0_pix / (Params.cm**-3)

In [None]:
### ------------------------------------------- ###
### Plot parameters for Fermi-Dirac (FD) skymap ###
### ------------------------------------------- ###

FD_skymap = pix_dens_FD[halo_i, mFD_i, :] / skymap_norm
ic(FD_skymap.sum()*N0_pix/(Params.cm**-3))

# Make center value to be 1 (no overdensity).
mid_FD = 1.
divnorm_FD = mcolors.TwoSlopeNorm(vcenter=mid_FD)

# Plot style healpy dictionary for left skymap
dict_FD = dict(
    title=fr"Overdensities FD of $m_\nu={m_pick}$ eV at z=0", 
    unit=r"$n_{\nu, pix} / n_{\nu, pix, 0}$",
    cmap="coolwarm",
    override_plot_properties={"cbar_pad": 0.1},
    cbar_ticks=[np.min(FD_skymap), mid_FD, np.max(FD_skymap)],
    norm=divnorm_FD,
    sub=121)


### ------------------------------------------------------- ###
### Plot parameters for Primordial Fluctuations (FD) skymap ###
### ------------------------------------------------------- ###

PF_skymap = pix_dens_PF[halo_i, mPF_i, :] / skymap_norm
ic(PF_skymap.sum()*N0_pix/(Params.cm**-3))

# Make center value to be 1 (no overdensity).
mid_PF = 1.
divnorm_PF = mcolors.TwoSlopeNorm(vcenter=mid_PF)

# Plot style healpy dictionary for right skymap
dict_PF = dict(
    title=fr"Overdensities PF of $m_\nu={m_pick}$ eV at z=0", 
    unit=r"$n_{\nu, pix} / n_{\nu, pix, 0}$",
    cmap="coolwarm",
    override_plot_properties={"cbar_pad": 0.1},
    cbar_ticks=[np.min(PF_skymap), mid_PF, np.max(PF_skymap)],
    norm=divnorm_PF,
    sub=122)


### ------------ ###
### Plot skymaps ###
### ------------ ###

SimPlot.healpix_side_by_side(
    map_left=FD_skymap,
    map_right=PF_skymap,
    dict_left=dict_FD,
    dict_right=dict_PF,
    fig_path=fig_folder,
    args=Params)  

## Comparison of FD & PF power spectra

In [None]:
skymap_norm = N0_pix / (Params.cm**-3)
for m_pick in (0.01, 0.05, 0.1):
    SimPlot.power_spectra_FDxPF(
        m_pick=m_pick, 
        pix_dens_FD=pix_dens_FD / skymap_norm, 
        pix_dens_PF=pix_dens_PF / skymap_norm, 
        sim_dir=sim_folder, 
        fig_dir=fig_folder,
        args=Params)