In [None]:
# Config for 64bit precision
from jax.config import config
config.update("jax_enable_x64", True)

# Check we're running on GPU
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

import numpy as np 
import jax.numpy as jnp
from jax import device_put, local_device_count
from s2wav.transforms import jax_scattering, jax_wavelets
from s2wav.filter_factory import filters as filter_generator
import s2fft
import pyssht as ssht

In [None]:
L = 16
N = 1
nlayers = 2
J_min = 0 
lam = 2.0
reality = True
sampling = "mw"
multiresolution = True
spmd=False

In [None]:
f = ssht.inverse(ssht.forward(np.random.randn(L, 2*L-1), L, Reality=reality), L, Reality=reality)
f -= np.nanmean(f)
f /= np.nanmax(abs(f))

In [None]:
filters = filter_generator.filters_directional_vectorised(L, N, J_min, lam)
precomps = jax_wavelets.generate_wigner_precomputes(L, N, J_min, lam, sampling, None, False, reality, multiresolution)

In [None]:
coeffs = jax_scattering.scatter(
        f=jnp.array(f),
        L=L,
        N=N,
        J_min=J_min,
        lam=lam,
        nlayers=nlayers,
        reality=reality,
        multiresolution=multiresolution,
        filters=filters,
        spmd=spmd,
    )

print(coeffs[:,0,0])
print(coeffs.shape)

In [None]:
from jax import grad
def mse_loss(y):
    return (1./y.size)*jnp.sum((jnp.abs(y-coeffs))**2)

def power_spectrum(flm, L):
    ps = np.zeros(L, dtype=np.float64)
    flm = np.abs(flm)**2
    return np.sum(flm, axis=-1)

# ps_true = power_spectrum(s2fft.forward(np.array(f), L, 0, reality=reality), L)
# def ps_loss(x):
#     z = s2fft.forward_jax(x, L, 0,reality=reality)
#     ps = jnp.sum(jnp.abs(z)**2, axis=-1)
#     return (1./ps.size)*jnp.sum((ps-ps_true)**2)
    
def scattering_func(x):
    y = jax_scattering.scatter(
        jnp.array(x),
        L=L,
        N=N,
        J_min=J_min,
        lam=lam,
        nlayers=nlayers,
        reality=reality,
        multiresolution=multiresolution,
        filters=filters,
        spmd=spmd,
    )
    return mse_loss(y)
    # return mse_loss(y) + ps_loss(x)

grad_func = grad(scattering_func)
f_temp = np.random.randn(L, 2*L-1)
print(grad_func(f_temp))
f_start = np.copy(f_temp)

In [None]:
momentum = 100
E0 = scattering_func(f_start)
for i in range(1000000):
    f_temp -= momentum*grad_func(f_temp)
    if i % 10 == 0: 
        print(f"Iteration {i}: Energy/E0 = {scattering_func(f_temp)}/{E0}, Momentum = {momentum}")

In [None]:
start_coeffs = jax_scattering.scatter(
        f=f_start,
        L=L,
        N=N,
        J_min=J_min,
        lam=lam,
        nlayers=nlayers,
        reality=reality,
        multiresolution=multiresolution,
        filters=filters,
        spmd=spmd,
    )
optimised_coeffs = jax_scattering.scatter(
        f=f_temp,
        L=L,
        N=N,
        J_min=J_min,
        lam=lam,
        nlayers=nlayers,
        reality=reality,
        multiresolution=multiresolution,
        filters=filters,
        spmd=spmd,
    )

In [None]:
c1 = coeffs[:,0,0]
c2 = start_coeffs[:,0,0]
c3 = optimised_coeffs[:,0,0]


In [None]:
for i in range(len(c1)):
    print(c1[i], c2[i], c3[i])

In [None]:
from matplotlib import pyplot as plt 
f_temp = ssht.inverse(ssht.forward(np.array(f_temp), L, Reality=reality), L, Reality=reality)
f_start2 = ssht.inverse(ssht.forward(np.array(f_start), L, Reality=reality), L, Reality=reality)
mx, mn = np.nanmax(f), np.nanmin(f)
fig, (ax1,ax2, ax3) = plt.subplots(1,3, figsize=(20,10))
ax1.imshow(f, vmax=mx, vmin=mn, cmap='magma')
ax2.imshow(f_start2, vmax=mx, vmin=mn, cmap='magma')
ax3.imshow(f_temp, vmax=mx, vmin=mn, cmap='magma')
plt.show()

In [None]:
flm = s2fft.sampling.s2_samples.flm_1d_to_2d(ssht.forward(f, L, Reality=reality), L)
flm_temp = s2fft.sampling.s2_samples.flm_1d_to_2d(ssht.forward(np.array(f_temp), L, Reality=reality), L)
flm_start = s2fft.sampling.s2_samples.flm_1d_to_2d(ssht.forward(np.array(f_start2), L, Reality=reality), L)

flm=np.real(flm)
flm_temp=np.real(flm_temp)
flm_start=np.real(flm_start)

from matplotlib import pyplot as plt 
mx, mn = np.nanmax(flm), np.nanmin(flm)
fig, (ax1,ax2, ax3) = plt.subplots(1,3, figsize=(20,10))
ax1.imshow(flm, vmax=mx, vmin=mn, cmap='magma')
ax2.imshow(flm_start, vmax=mx, vmin=mn, cmap='magma')
ax3.imshow(flm_temp, vmax=mx, vmin=mn, cmap='magma')
plt.show()

In [None]:
ps = power_spectrum(flm, L)
ps_temp = power_spectrum(flm_temp, L)
ps_start = power_spectrum(flm_start, L)

plt.plot(ps, label="input")
plt.plot(ps_temp, label="converged")
plt.plot(ps_start, label="initial")
plt.legend()
plt.show()