## Let's try numba data structures for bispec calculator

In [1]:
import numpy as np
import healpy as hp
import scipy
from numba import jit, njit, prange, set_num_threads
import numba.typed
import numba.types
from tqdm.notebook import tqdm
from helper_funcs import *

In [2]:
def sort_alms(alms, lmax):
    '''
    Sorts healpix alm's by \ell instead of m given a fortran90 array output from hp.map2alm.

    Parameters
    ----------
    alms : fortran90 healpix alm array from hp.map2alm
    num_ls : number of l's in alms array
    
    Returns
    ----------
    sorted_alms : alm dictionary keyed by ell values with numpy arrays 
                    consisting of the corresponding m values.
    '''
    start = 0
    sorted_alms = {}

    for l in range(lmax + 1):
        sorted_alms[l] = np.zeros(2*l+1, dtype=np.cdouble)

    for m in range(lmax + 1):
        num_ms = lmax + 1 - m
        ms = alms[start:num_ms + start]
        start = num_ms + start
        m_sign = (-1)**m
        for l in range(num_ms):
            idx = m + l
            sorted_alms[idx][m] = ms[l]
            if m != 0:
                sorted_alms[idx][-m] = m_sign * np.conj(ms[l])
    
    return sorted_alms

In [10]:
key_type = numba.types.int64
value_type = numba.types.complex128[:]
sorted_alms = numba.typed.Dict.empty(key_type, value_type)

In [11]:
sorted_alms

DictType[int64,array(complex128, 1d, A)]<iv=None>({})

In [28]:
@njit
def sort_alms_typed(alms, lmax):
    '''
    Sorts healpix alm's by \ell instead of m given a fortran90 array output from hp.map2alm.

    Parameters
    ----------
    alms : fortran90 healpix alm array from hp.map2alm
    num_ls : number of l's in alms array
    
    Returns
    ----------
    sorted_alms : alm dictionary keyed by ell values with numpy arrays 
                    consisting of the corresponding m values.
    '''
    start = 0
    key_type = numba.types.int64
    value_type = numba.types.complex128[:]
    sorted_alms = numba.typed.Dict.empty(key_type, value_type)

    for l in range(lmax + 1):
        sorted_alms[l] = np.zeros(2*l+1, dtype=np.complex128)

    for m in range(lmax + 1):
        num_ms = lmax + 1 - m
        ms = alms[start:num_ms + start]
        start = num_ms + start
        m_sign = (-1)**m
        for l in range(num_ms):
            idx = m + l
            sorted_alms[idx][m] = ms[l]
            if m != 0:
                sorted_alms[idx][-m] = m_sign * np.conj(ms[l])
    
    return sorted_alms

In [4]:
lmax = 1024
ells = np.arange(1, lmax+1)

cls = (ells + 0.0)**(-3.)

theory_map, alms = hp.sphtfunc.synfast(cls=cls, nside=1024, lmax=lmax, alm=True)

In [29]:
sorted_alms = sort_alms_typed(alms, lmax)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of <class 'numba.core.types.npytypes.Array'> with parameters (class(complex128), Literal[int](1), Literal[str](A))
No type info available for <class 'numba.core.types.npytypes.Array'> as a callable.
During: resolving callee type: typeref[<class 'numba.core.types.npytypes.Array'>]
During: typing of call at /tmp/ipykernel_11125/3732006500.py (18)


File "../../../../../../tmp/ipykernel_11125/3732006500.py", line 18:
<source missing, REPL/exec in use?>
