In [1]:
import matplotlib.pyplot as plt

%matplotlib inline

import numpy as np
import healpy as hp
import pickle
import numba as nb
import spherical
import pywigxjpf
import pywigxjpf_ffi
import numba_progress
import math
import scipy
import time

from tqdm.notebook import tqdm
from spherical import Wigner3j
from scipy import stats
from numba import jit, njit, prange, set_num_threads
from numba_progress import ProgressBar
from numba.core.typing import cffi_utils as cffi_support

try:
    from pywigxjpf_ffi import ffi, lib
except ImportError:
    from pywigxjpf.pywigxjpf_ffi import ffi, lib

cffi_support.register_module(pywigxjpf_ffi)

nb_wig3jj = pywigxjpf_ffi.lib.wig3jj

from helper_funcs import *

In [2]:
lib.wig_table_init(100,9)
lib.wig_temp_init(100)

print(lib.wig3jj(5,6,7,3,4,-7))
print(lib.wig6jj(6,6,6,6,6,6))
print(lib.wig9jj(6,6,6,7,7,6,7,7,8))
print(lib.wig9jj(6,6,6,7,7,6,7,7,8))
print(lib.wig6jj(6,6,6,6,6,int(4+0.5*4)))

0.2357022603955158
-0.07142857142857142
0.004661585414003767
0.004661585414003767
-0.07142857142857142


In [3]:
# Benchmark 1, make calls for a lot of trivially-0 symbols.
@njit
def benchmark(jjmax):
    sum = 0.0
    calls = 0
    for jj1 in range(0, jjmax + 1):
        for jj2 in range(0, jjmax + 1):
            for jj3 in range(0, jjmax + 1):
                for mm1 in range(-jjmax, jjmax + 1):
                    for mm2 in range(-jjmax, jjmax + 1):
                        for mm3 in range(-jjmax, jjmax + 1):
                            w = nb_wig3jj(jj1, jj2, jj3, \
                                           mm1, mm2, mm3)
                            # print((w)
                            sum = sum + w
                            calls = calls+1
    return (sum,calls)

# Benchmark 2, avoiding trivial 0 by triangle rules.
@njit
def benchmark_opt(jjmax):
    sum = 0.0
    calls = 0
    for jj1 in range(0, jjmax + 1, 1):
        for jj2 in range(0, jjmax + 1, 1):
            jj3_min = abs(jj1-jj2)
            jj3_max = jj1+jj2
            if (jj3_max > jjmax):
                jj3_max = jjmax
            for jj3 in range(jj3_min, jj3_max + 1, 2):
                for mm1 in range(-jj1, jj1 + 1, 2):
                    for mm2 in range(-jj2, jj2 + 1, 2):
                        #for m3 in range(-j3, j3 + 1):
                        mm3 = -mm1-mm2
                        if (abs(mm3) <= jjmax):
                            w = nb_wig3jj(jj1, jj2, jj3, \
                                           mm1, mm2, mm3)
                            sum = sum + w
                            calls = calls+1
    return (sum,calls)

In [4]:
jjmax=10

for i in range(0,5):
    start_time = time.time()
    wigsum, total_calls = benchmark(jjmax)
    total_time = time.time()-start_time
    print("Benchmark 1 for jjmax=%d, sum=%.10f, time=%.5fs, "
          "time/call=%4.0fns [%d calls]" %
          (jjmax, wigsum, total_time, total_time/total_calls*1e9, total_calls))

for i in range(0,5):
    start_time = time.time()
    wigsum, total_calls = benchmark_opt(jjmax)
    total_time = time.time()-start_time
    print("Benchmark 2 for jjmax=%d, sum=%.10f, time=%.5fs, "
          "time/call=%4.0fns [%d calls]" %
          (jjmax, wigsum, total_time, total_time/total_calls*1e9, total_calls))

lib.wig_temp_free()
lib.wig_table_free()

print('Done')

Benchmark 1 for jjmax=10, sum=3.5305263227, time=0.51836s, time/call=  42ns [12326391 calls]
Benchmark 1 for jjmax=10, sum=3.5305263227, time=0.32428s, time/call=  26ns [12326391 calls]
Benchmark 1 for jjmax=10, sum=3.5305263227, time=0.32342s, time/call=  26ns [12326391 calls]
Benchmark 1 for jjmax=10, sum=3.5305263227, time=0.32205s, time/call=  26ns [12326391 calls]
Benchmark 1 for jjmax=10, sum=3.5305263227, time=0.32293s, time/call=  26ns [12326391 calls]
Benchmark 2 for jjmax=10, sum=3.5305263227, time=0.27278s, time/call=17627ns [15475 calls]
Benchmark 2 for jjmax=10, sum=3.5305263227, time=0.00278s, time/call= 179ns [15475 calls]
Benchmark 2 for jjmax=10, sum=3.5305263227, time=0.00279s, time/call= 180ns [15475 calls]
Benchmark 2 for jjmax=10, sum=3.5305263227, time=0.00271s, time/call= 175ns [15475 calls]
Benchmark 2 for jjmax=10, sum=3.5305263227, time=0.00267s, time/call= 172ns [15475 calls]
Done


In [5]:
lmax = 1000

ells = np.arange(lmax+1)

cls = np.zeros_like(ells, dtype='float')

for l in ells[1:]:
    cls[l] = (l+0.0)**(-3.)

theory_map, alms = hp.sphtfunc.synfast(cls=cls, nside=1024, lmax=np.max(ells), alm=True)

sorted_alms = sort_alms(alms, len(cls))

In [25]:
@njit(parallel=True)
def compute_bispec_wig(l1, l2, l3, alms_l1, alms_l2, alms_l3, num_threads=16):

    assert (l1 + l2 + l3) % 2 == 0, "even parity not satisfied" # even parity
    assert np.abs(l1-l2) <= l3 <= l1+l2, "triangle inequality not satisfied" # triangle inequality

    bispec_sum = 0
    val_init = (max(l1, l2, l3) + 1) * 2

    set_num_threads(num_threads) # set for Roomba to be 16 threads max

    lib.wig_table_init(val_init, 3)

    lib.wig_temp_init(val_init)
    norm_factor = ((l1*2+1) * (l2*2+1) * (l3*2+1))/(4*np.pi) \
                        * (nb_wig3jj(2*l1, 2*l2, 2*l3, 0, 0, 0))**2
    lib.wig_temp_free()

    if not norm_factor:
        return 0

    for m1 in prange(-l1, l1+1):
        lib.wig_temp_init(val_init)
        for m2 in range(-l2, l2+1):
            m3 = -(m1 + m2) # noting that m1 + m2 + m3 == 0
            w3j = nb_wig3jj(2*l1, 2*l2, 2*l3, 2*m1, 2*m2, 2*m3)
            if w3j:
                exp_alms = alms_l1[m1] * alms_l2[m2] * alms_l3[m3]
                bispec_sum += w3j * np.abs(exp_alms)
        lib.wig_temp_free()
    
    lib.wig_table_free()

    return np.sqrt(norm_factor) * bispec_sum

In [27]:
l1, l2, l3 = 10, 20, 30

b_wig = compute_bispec_wig(l1, l2, l3, sorted_alms[l1], sorted_alms[l2], sorted_alms[l3])

print(b_wig)

9.850683559600588e-06
