# Wavelet transform (JAX-SSHT)
[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2wav/blob/main/notebooks/jax_ssht_transform.ipynb)

In [1]:
# Install s2wav
# !pip install s2wav &> /dev/null

In [22]:
import s2wav
import numpy as np
import s2wav.filters as filters
L = 10
N = 1
f = np.ones((L, 2*L-1))
# Compute wavelet coefficients
filter = filters.filters_directional_vectorised(L, N)
f_wav, f_scal = s2wav.analysis(f, L, N, filters=filter, reality=True)

# Map back to signal on the sphere 
f_build = s2wav.synthesis(f_wav, f_scal, L, N, filters=filter, reality=True)
print(np.abs(f_build -f).mean())

1.1558006014255577e-15


In [28]:
import healpy as hp
import numpy as np

# Path to your Healpix map FITS file
healpix_map_path = 'data/planck_simulation/ffp10_newdust_total_030_full_map.fits'

# Step 1: Read the Healpix map
healpix_map = hp.read_map(healpix_map_path)

# Step 2: Convert Healpix map to spherical harmonic coefficients
L_max = hp.npix2nside(len(healpix_map)) * 2  # Define maximum multipole moment
alm = hp.map2alm(healpix_map, lmax=L_max-1)

# Step 3: Reconstruct the map from spherical harmonic coefficients
reconstructed_map = hp.alm2map(alm, nside=hp.npix2nside(len(healpix_map)))

# Step 4: Compare the original and reconstructed maps
difference = np.abs(reconstructed_map - healpix_map).mean()
print(f"Mean absolute difference: {difference}")

Mean absolute difference: 3.6772237618124e-05


In [27]:
import s2wav
import numpy as np
import s2wav.filters as filters
import healpy as hp
L = 10
N = 1
f = np.ones((L, 2*L-1))
f.shape
# print(hp.get_nside(f))
# sampling = "healpix"
# nside = hp.get_nside(f)
# # Compute wavelet coefficients


# filter = filters.filters_directional_vectorised(L, N)

# f_wav, f_scal = s2wav.analysis(f, L, N, filters=filter, reality=True, sampling = "healpix", nside = nside)

# # # Map back to signal on the sphere 
# # f_build = s2wav.synthesis(f_wav, f_scal, L, N, filters=filter, reality=True)
# print(np.abs(f_build -f).mean())

(10, 19)

In [19]:
# import s2wav
# import numpy as np
# L = 128
# N = 1
# f = np.ones((L, 2*L-1))
# f_wav, f_scal = s2wav.analysis(f, L, N)
# f = s2wav.synthesis(f_wav, f_scal, L, N)

TypeError: 'NoneType' object is not subscriptable

Lets start by importing some packages which we'll be using in this notebook

In [14]:
# Make sure we configure 64 bit precision. 
# 32 bit can be faster but you will be (potentially much) less precise.
import jax
jax.config.update("jax_enable_x64", True)

import s2wav       # Wavelet transforms on the sphere and rotation group
import s2fft       # Spherical harmonic and Wigner transforms
import numpy as np

Now we'll define the constraints of the problem and generated some random data just for this example

In [7]:
import jax
jax.config.update("jax_enable_x64", True)

import s2wav       # Wavelet transforms on the sphere and rotation group
import s2fft       # Spherical harmonic and Wigner transforms
import numpy as np
L = 8               # Spherical harmonic bandlimit
N = 3                # Azimuthal (directional) bandlimit

sampling = "healpix"
nside = 4
# Generate a random bandlimited signal to work with
rng = np.random.default_rng(12346161)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, nside = nside, sampling = sampling)

filter_bank = s2wav.filters.filters_directional_vectorised(L, N)
# wavelet_coeffs, scaling_coeffs = s2wav.analysis(f, L, N, filters=filter_bank, use_c_backend=use_c_backend)
wavelet_coeffs, scaling_coeffs = s2wav.analysis(f = f, L = L, N = N, filters=filter_bank, sampling = sampling, nside = nside)

AssertionError: 

In [15]:
L = 16               # Spherical harmonic bandlimit
N = 3                # Azimuthal (directional) bandlimit
# sampling = "mw"      # Sampling scheme
# use_c_backend = True # Switches backend JAX harmonic and Wigner transforms to call underlying SSHT C libraries.

sampling = "healpix"
nside = 4
# Generate a random bandlimited signal to work with
rng = np.random.default_rng(12346161)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, nside = nside, sampling = sampling)

In [7]:
print(f.shape)

(192,)


We can calculate the wavelet and scaling coefficients by first building a bank of wavelet filters and the running the analysis transform

In [17]:
sampling = "healpix"
nside = 4
filter_bank = s2wav.filters.filters_directional_vectorised(L, N)
# wavelet_coeffs, scaling_coeffs = s2wav.analysis(f, L, N, filters=filter_bank, use_c_backend=use_c_backend)
wavelet_coeffs, scaling_coeffs = s2wav.analysis(f, L, N, filters=filter_bank, sampling = sampling, nside = nside)

AssertionError: 

In [21]:
# filter_bank

You'll notice that this first pass is very slow. That's because it is JIT compiling the function, so future calls to `s2wav.analysis` will be much fater! When an exact sampling theorem is chosen we can recover the original signal to machine precision by running

In [22]:
f_check = s2wav.synthesis(wavelet_coeffs, scaling_coeffs, L, N, filters=filter_bank)
# f_check

Again this first call is quite slow, but subsequent calls should be much faster. Lets double check that we actually got machine precision!

In [43]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_check - f))}")

Mean absolute error = 2.04106194888797e-14
