In [None]:
# Config for 64bit precision
from jax.config import config
config.update("jax_enable_x64", True)

# Check we're running on GPU
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

import numpy as np 
import jax.numpy as jnp
from jax import device_put, local_device_count
from s2wav.transforms import jax_scattering, jax_wavelets
from s2wav.filter_factory import filters as filter_generator
import s2fft
import pyssht as ssht

In [None]:
L = 256
N = 3
J_min = 0 
lam = 2.0
reality = True
sampling = "mw"
multiresolution = True

In [None]:
f = ssht.inverse(ssht.forward(np.random.randn(L, 2*L-1), L, Reality=reality), L, Reality=reality)
f = np.load('data/CosmoML_shell_40_L_256.npy')
flm = ssht.forward(f, 256)
flm = flm[:L**2]
f = ssht.inverse(flm, L)
f -= np.nanmean(f)
f /= np.nanstd(f)
flm = s2fft.forward_jax(f,L)

In [None]:
filters = filter_generator.filters_directional_vectorised(L, N, J_min, lam)
precomps = jax_wavelets.generate_wigner_precomputes(L, N, J_min, lam, sampling, None, False, reality, multiresolution)

In [None]:
coeffs = jax_scattering.scatter_new(flm, L, N, reality, filters, precomps)
ps_ref = jnp.sum(jnp.abs(flm)**2,axis=-1)
print(coeffs.shape)

In [None]:
from jax import grad, jit

@jit
def scattering_func(x):
    coeffs_iter = jax_scattering.scatter_new(x, L, N, reality, filters, precomps)
    loss = jnp.sum((jnp.abs(coeffs_iter-coeffs))**2)
    # ps_iter = jnp.sum(jnp.abs(x)**2,axis=-1)
    # loss += jnp.sum((jnp.abs(ps_iter-ps_ref))**2)
    return loss

grad_func = jit(grad(scattering_func))

In [None]:
f_temp = np.random.randn(L, 2*L-1)
flm_temp = s2fft.forward_jax(f_temp,L)
flm_start = jnp.copy(flm_temp)
E0 = scattering_func(flm_start)

In [None]:
import optax 
lr = 1e-2
optimizer = optax.adam(lr)
opt_state = optimizer.init(flm_temp)

for i in range(1000):
    grads = jnp.conj(grad_func(flm_temp))
    updates, opt_state = optimizer.update(grads, opt_state)
    flm_temp = optax.apply_updates(flm_temp, updates)
    if i % 10 == 0: 
        print(f"Iteration {i}: Loss/Loss-0 = {scattering_func(flm_temp)}/{E0}")

In [None]:
start_coeffs = jax_scattering.scatter_new(
        flm_start, L, N, reality, filters, precomps
    )
optimised_coeffs = jax_scattering.scatter_new(
        flm_temp, L, N, reality, filters, precomps
    )

In [None]:
c1 = coeffs
c2 = start_coeffs
c3 = optimised_coeffs

In [None]:
for i in range(len(c1)):
    print(c1[i], c2[i], c3[i])

In [None]:
from matplotlib import pyplot as plt 
f_temp = np.real(s2fft.inverse_jax(flm_temp, L, reality=reality))
f_start = np.real(s2fft.inverse_jax(flm_start, L, reality=reality))
f= np.real(f)
mx, mn = np.nanmax(f), np.nanmin(f)
fig, (ax1,ax2, ax3) = plt.subplots(1,3, figsize=(40,20), dpi=200)
ax1.imshow(f, vmax=mx, vmin=mn, cmap='magma')
ax2.imshow(f_start, vmax=mx, vmin=mn, cmap='magma')
ax3.imshow(f_temp, vmax=mx, vmin=mn, cmap='magma')
plt.show()

In [None]:
def power_spectrum(flm):
    return jnp.sum(jnp.abs(flm)**2,axis=-1)
ps = power_spectrum(flm)
ps_temp = power_spectrum(flm_temp)
ps_start = power_spectrum(flm_start)

plt.plot(ps, label="input")
plt.plot(ps_temp, label="converged")
plt.plot(ps_start, label="initial")
plt.legend()
plt.show()