In [None]:
import healpy as hp
import numpy as np
import s2wav
import s2wav.filters as filters
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

# Step 1: Read the Healpix map
healpix_map_path = 'data/planck_simulation/ffp10_newdust_total_030_full_map.fits'
healpix_map = hp.read_map(healpix_map_path, dtype=jnp.float64)

# Step 2: Convert Healpix map to spherical harmonic coefficients
L_max = hp.npix2nside(len(healpix_map))*2  # Define maximum multipole moment (choose appropriately based on your data)
alm = hp.map2alm(healpix_map, lmax=L_max-1)

# Step 3: Rearrange coefficients for s2wav
f = np.zeros((L_max, 2 * L_max - 1), dtype=np.complex128)

for l in range(L_max):
    for m in range(-l, l + 1):
        index = hp.Alm.getidx(L_max - 1, l, abs(m))
        if m < 0:
            f[l, L_max + m - 1] = (-1)**m * np.conj(alm[index])
        else:
            f[l, L_max + m - 1] = alm[index]

# Step 4: Perform wavelet analysis
N = 1  # Number of directional wavelets
filter = filters.filters_directional_vectorised(L_max, N)

# wavelet_coeffs, scaling_coeffs
f_wav, f_scal = s2wav.analysis(f, L_max, N, filters=filter)

# Step 5 Store the wavelet_coeffs, scaling_coeffs for future use
# f_wav is a list of jaxlib.xla_extension.ArrayImpl objects

# Convert each element in f_wav to a numpy array and save it
for i, wav in enumerate(f_wav):
    np_wav = np.array(wav)  # Convert JAX array to numpy array
    np.save(f"convolution/30/f_wav_{i}", np_wav)  

# Convert  f_scal to a numpy array and save it
np_scal = np.array(f_scal)  # Convert JAX array to numpy array
np.save(f"convolution/30/f_scal", np_scal)  

Stored_f_wav = [np.load(f"/Users/maxwang/Documents/projects/CMB_plot/convolution/30/f_wav_{i}.npy", allow_pickle=True) for i in range(12)]
Sotred_f_sacl = np.load("/Users/maxwang/Documents/projects/CMB_plot/convolution/30/f_scal.npy")
# Step 6: reconstruct the signal if needed


In [77]:
# f_wav_1 = f_wav.copy()

# for i, wav in enumerate(f_wav_1):
#     wav = np.array(wav)  # Convert JAX array to numpy array
#     # np.save(f"/convolution/30/f_wav_{i}.npy", np_wav)  # Save numpy array to a file


# for i in f_wav_1:
#     print(i.shape)


In [30]:
f_check = s2wav.synthesis(f_wav, f_scal, L_max, N, filters=filter)


In [88]:
# Step 7 store the reconstructed map
reconstructed_map = np.array(f_check)
np.save(f"/Users/maxwang/Documents/projects/CMB_plot/convolution/30/reconstructed_map", reconstructed_map)  # Save numpy array to a file
load_map = np.load("/Users/maxwang/Documents/projects/CMB_plot/convolution/30/reconstructed_map.npy")
# print(np.nanmean(np.abs(load_map-f)))

In [39]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_check - f))}")

Mean absolute error = 2.0528992438406644e-07


In [89]:
# is_real = np.isreal(f)
# print(is_real.all()) 
#  setting reality=True in the context of spherical harmonic analysis indicates that the input data is real-valued, and the function leverages the conjugate symmetry property of the harmonic coefficients to optimize storage and computation.

False
