In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import healpy as hp
import scipy
from numba import jit, njit, prange, set_num_threads
from tqdm.notebook import tqdm
from helper_funcs import *

In [3]:
def sort_alms(alms, lmax):
    '''
    Sorts healpix alm's by \ell instead of m given a fortran90 array output from hp.map2alm.

    Parameters
    ----------
    alms : fortran90 healpix alm array from hp.map2alm
    num_ls : number of l's in alms array
    
    Returns
    ----------
    sorted_alms : alm dictionary keyed by ell values with numpy arrays consisting of the corresponding m values.
    '''
    start = 0
    sorted_alms = {}

    for l in range(lmax):
        sorted_alms[l] = np.zeros(2*l+1, dtype=np.cdouble)

    for m in range(lmax):
        num_ms = lmax - m
        ms = alms[start:num_ms + start]
        start = num_ms + start
        m_sign = (-1)**m
        for l in range(num_ms):
            idx = m + l
            sorted_alms[idx][m] = ms[l]
            if m != 0:
                sorted_alms[idx][-m] = m_sign * np.conj(ms[l])
    
    return sorted_alms

In [4]:
lmax = 1000
ells = np.arange(lmax+1)
cls = np.zeros_like(ells, dtype='float')

for l in ells[1:]:
    cls[l] = (l+0.0)**(-3.)

theory_map, alms = hp.sphtfunc.synfast(cls=cls, nside=1024, lmax=lmax, alm=True)

In [5]:
sorted_alms = sort_alms(alms, len(cls))

In [10]:
def unsort_alms(sorted_alms):
    """
    Unsorts a sorted_alm dictionary from sorted_alms into a Fortran90
    array that can be processed by HEALPix.

    Inputs:
        sorted_alms (dict[ndarray]) : sorted alms from sort_alms

    Returns:
        (ndarray) : unsorted alms for HEALPix
    """
    ells = np.array(list(sorted_alms.keys()))
    lmax = ells[-1]
    unsorted_alms = np.zeros(np.sum(ells + 1), dtype=np.complex128)

    start = 0
    for m in range(lmax + 1):
        for ell in range(m, lmax + 1):
            unsorted_alms[start + ell] = sorted_alms[ell][m]
        start += lmax - m
    
    return unsorted_alms


In [16]:
unsorted_alms = unsort_alms(sorted_alms)

We divide the $\ell$-range $\left[\ell_{\min }, \ell_{\max }\right]$ into subintervals denoted by $\Delta_i=\left[\ell_i, \ell_{i+1}-1\right]$ where $i=0, \ldots,\left(N_{\text {bins }}-1\right)$ and $\ell_{N_{\text {bins }}}=\ell_{\max }+1$, so that the filtered maps are:

\begin{equation*}
M_i^p(\Omega)=\sum_{\ell \in \Delta_i} \sum_{m=-\ell}^{+\ell} a_{\ell m}^p Y_{\ell m}(\hat{\Omega}) \tag{2.5}
\end{equation*}

and we use these instead of $M_{\ell}^p$ in the expression for the bispectrum (2.5). The binned bispectrum is:

\begin{equation*}
B_{i_1 i_2 i_3}^{p_1 p_2 p_3, \mathrm{obs}}=\frac{1}{\Xi_{i_1 i_2 i_3}} \int d \hat{\Omega} M_{i_1}^{p_1, \mathrm{obs}}(\hat{\Omega}) M_{i_2}^{p_2, \mathrm{obs}}(\hat{\Omega}) M_{i_3}^{p_3, \mathrm{obs}}(\hat{\Omega}) \tag{2.6}
\end{equation*}

where $\Xi_{i_1 i_2 i_3}$ is the number of $\ell$ triplets within the $\left(i_1, i_2, i_3\right)$ bin triplet satisfying the triangle inequality and parity condition selection rule. Because of this normalization factor, $B_{i_1 i_2 i_3}^{p_1 p_2 p_3}$ may be considered an average over all valid $B_{\ell_1 \ell_2 \ell_3}^{p_1 p_2 p_3}$ inside the bin triplet.

In [27]:
def filter_map_binned(sorted_alms, lmin, lmax, nbins):
    """
    Map alms that have been binned and filtered.
    """
    filtered_alms = {}
    for ells in range(lmax):
        filtered_alms[ells] = np.zeros(2*l+1, dtype=np.cdouble)

    i_bins = np.arange(lmin, lmax + 1, nbins)

    for ell in i_bins:
        filtered_alms[ell] = sorted_alms[ell]
        # for m in range(-ell, ell + 1):
        #     filtered_alms[ell][m] = sorted_alms[ell][m]
    
    return filtered_alms

In [28]:
filtered_alms = filter_map_binned(sorted_alms, 2, 10, 1)

In [58]:
uf_alms = unsort_alms(filtered_alms)

In [62]:
uf_map = hp.alm2map(uf_alms, nside=1024)

In [69]:
pixarea = get_pix_area(uf_map)

In [71]:
pixarea * uf_map

array([-8.44863413e-08, -8.53155894e-08, -8.54289511e-08, ...,
       -3.25285993e-07, -3.25766502e-07, -3.26367384e-07])

In [7]:
# @njit(parallel=True)
# to speed up, will jit local function to compute bispec within.
def compute_binned_bispec(i1, i2, i3, map_i1, map_i2, map_i3, num_threads=16):
    """
    Compute binned bispec

    Parameters:
    """
    xi = count_valid_configs(i1, i2, i3, num_threads=num_threads)
    assert xi != 0

    pixarea_i1, pixarea_i2, pixarea_i3 = get_pix_areas(map_i1, map_i2, map_i3)
    msize_i1, msize_i2, msize_i3 = get_map_sizes(map_i1, map_i2, map_i3)

    B_i = 0

    return 1 / xi * B_i


In [66]:
print(get_pix_area(theory_map))
print(hp.get_map_size(theory_map))

9.986854087797142e-07
12582912


In [68]:
def get_pix_area(map):
    return hp.nside2pixarea(hp.get_nside(map))

def get_pix_areas(map1, map2, map3):
    return get_pix_area(map1), get_pix_area(map2), get_pix_area(map3)

def get_map_sizes(map1, map2, map3):
    return hp.get_map_size(map1). hp.get_map_size(map2), hp.get_map_size(map3)

In [50]:
@njit(parallel=True)
def count_valid_configs(i1, i2, i3, num_threads=16):
    """
    Counts the number of valid ell-triplet configurations in a single bin.
    Validity is determined as satisfying the parity condition selection
    rule and the triangle inequality.

    Inputs:
        i1 (int) : bin1
        i2 (int) : bin2
        i3 (int) : bin3
        num_threads (int) : number of threads to parallelize on

    Returns:
        (int) : number of valid configurations of ell-triplets.
    """

    set_num_threads(num_threads)

    configs = 0

    for l1 in prange(i1 + 1):
        for l2 in range(i2 + 1):
            for l3 in range(i3 + 1):
                if check_valid_triangle(l1, l2, l3):
                    configs += 1
    
    return configs

In [55]:
@njit(parallel=True)
def find_valid_configs(i1, i2, i3, num_threads=16):
    """
    Counts the number of valid ell-triplet configurations in a single bin.
    Validity is determined as satisfying the parity condition selection
    rule and the triangle inequality.

    Inputs:
        i1 (int) : bin1
        i2 (int) : bin2
        i3 (int) : bin3
        num_threads (int) : number of threads to parallelize on

    Returns:
        (int) : number of valid configurations of ell-triplets.
    """

    set_num_threads(num_threads)

    num_configs = count_valid_configs(i1, i2, i3)
    valid_configs = np.zeros(num_configs)

    i = 0
    for l1 in prange(i1 + 1):
        for l2 in range(i2 + 1):
            for l3 in range(i3 + 1):
                if check_valid_triangle(l1, l2, l3):
                    valid_configs[i] = (l1, l2, l3)
                    i += 1

    return valid_configs

In [56]:
configs = find_valid_configs(100, 100, 100)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(array(float64, 1d, C), int64, UniTuple(int64 x 3))
 
There are 16 candidate implementations:
      - Of which 16 did not match due to:
      Overload of function 'setitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 1d, C), int64, UniTuple(int64 x 3))':
       No match.

During: typing of setitem at /tmp/ipykernel_53247/1994168875.py (28)

File "../../../../../../tmp/ipykernel_53247/1994168875.py", line 28:
<source missing, REPL/exec in use?>


In [31]:
@njit
def check_valid_triangle(l1, l2, l3):
    """
    Checks if a set of l1, l2, l3's form a valid triangle, which can be
    confirmed if the l's satisfy the following three metrics:
    1. bispectrum is symmetric under any l1, l2, l3 permutations
    2. even parity selection rule
    3. triangle inequality

    Inputs:
        l1 (int) : l1
        l2 (int) : l2
        l3 (int) : l3
    
    Returns:
        (bool) : returns True if triangle is valid, False otherwise
    """
    permutations = l1 <= l2 <= l3
    even_parity = (l1 + l2 + l3) % 2 == 0
    tri_inequality = np.abs(l1 - l2) <= l3 <= l1 + l2

    return permutations and even_parity and tri_inequality


In [39]:
# def P(l, m, x):
# 	pmm = 1.0
# 	if m > 0:
# 		somx2 = np.sqrt((1.0 - x) * (1.0 + x))
# 		fact = 1.0
# 		for i in range(1, m + 1):
# 			pmm *= (-fact) * somx2
# 			fact += 2.0
	
# 	if l == m:
# 		return pmm * np.ones(x.shape)
	
# 	pmmp1 = x * (2.0*m+1.0) * pmm
	
# 	if l == m + 1:
# 		return pmmp1
	
# 	pll = np.zeros(x.shape)
# 	for ll in range(m + 2, l + 1):
# 		pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m)
# 		pmm = pmmp1
# 		pmmp1 = pll
	
# 	return pll

In [None]:
# def ylm_norm_factor(l, m):
#     """
#     Finds the normalization factor for the spherical harmonic function
#     a given ell and m.

#     Inputs:
#     l : (int) ell-value
#     m : (int) m-value

#     Returns:
#     (int) : normalization factor for spherical harmonic function
#     """
#     return np.sqrt((2 * l + 1)/(4 * np.pi) \
#         * ((math.factorial(l - m))/(math.factorial(l + m))))

# def ylm(omega, l, m):
#     """
#     The spherical harmonic functions.

#     Inputs:
#     omega : (var) solid-angle omega
#     l : (int) ell-value
#     m : (int) m-value

#     Returns:
#     (func) : spherical harmonic function for a given ell, m input.
#     """