In [None]:
# Specify CUDA device
import os, pickle
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import jax, optax, jaxopt
jax.config.update("jax_enable_x64", True)

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

import jax.numpy as jnp
import numpy as np 
from matplotlib import pyplot as plt 

import s2scat 
import s2fft
import s2wav
from s2fft.sampling import s2_samples as samples

In [None]:
L = 512
N = 3
J_min = 0
recursive = False
delta_j = None
isotropic = False

cmap="magma"

map_type = 'weak_lensing'
f = np.load('texture_maps/{}/CosmoML_f_{}.npy'.format(map_type, L))
flm = np.load('texture_maps/{}/CosmoML_flm_{}.npy'.format(map_type, L))

file = 'texture_maps/{}/targets_{}.pickle'.format(map_type, L)
targets, norm = pickle.load(file)
file.close()

In [None]:
wavelets = s2wav.filters.filters_directional_vectorised(L, N)[0]
matrices = s2scat.operators.matrices.generate_recursive_matrices(L, N, J_min, reality)
quads = s2scat.operators.spherical.quadrature(L, J_min)
norm = s2scat.utility.normalisation.compute_norm(flm, L, N, J_min, reality, wavelets, matrices, recursive)
wavelets_linear = filters_directional_vectorised(L=L, N=N, J_min=J_min, lam=np.sqrt(2.5))[0][J_min:-2, :, L-1]

In [None]:
@jit
def get_P00prime(flm, filter_lin, normalisation):
    P00prime_ell = jnp.sum(jnp.abs(flm[None, :, :] * filter_lin[:, :, None])**2, axis=2)  # [Nfilters, L]
    P00prime = jnp.mean(P00prime_ell, axis=1)  # [Nfilters]
    if normalisation is not None:
        P00prime /= normalisation
    return P00prime_ell, P00prime

_, tP00prime_norm = get_P00prime(flm_target, wavelets_linear, None)
_, tP00prime = get_P00prime(flm_target, wavelets_linear, tP00prime_norm)

In [None]:
@jit 
def chi2_loss(predicts, targets):
    loss = 0
    for i in range(6):
        loss += jnp.mean(jnp.abs(predicts[i]-targets[i])**2)
    return loss

def loss_func(glm_float):
    # Make complex flm
    glm = glm_float[0, :, :] + 1j * glm_float[1, :, :]
    
    predicts = s2scat.core.scatter.directional(glm, L, N, J_min, reality, wavelets, norm, quads, matrices, recursive, isotropic, delta_j)
    
    # Match statistics
    loss = chi2_loss(predicts, targets)
    P00prime_new = get_P00prime(glm, wavelets_linear, tP00prime_norm)
    loss += jnp.mean(jnp.abs(P00prime_new-tP00prime)**2)

    return loss

In [None]:
def compute_ps(flm):
    """Compute the angular power spectrum Cls = 1/(2l+1) Sum_m[|f_lm|^2]."""
    L = flm.shape[0]
    ell = np.arange(L)
    Cls = jnp.nansum(jnp.abs(flm) ** 2, axis=-1) / (2 * ell + 1)
    Cls = 2. * Cls - Cls[0]
    return Cls

def generate_grf(ps, L):
    alm = np.zeros((L, L), dtype=np.complex128)
    for l in range(L):
        alm[l,0] = np.sqrt(ps[l]) * np.random.randn()
        for m in range(1,l):
            alm[l,m] = np.sqrt(ps[l]/2) * (np.random.randn() + 1j*np.random.randn())
    return alm

In [None]:
target_ps = compute_ps(flm)
glm = generate_grf(target_ps, L)
glm_start = jnp.copy(glm)
g_start = s2fft.inverse(s2scat.operators.spherical.make_flm_full(glm_start, L), L, reality=reality, method="jax")

In [None]:
fig, (ax1,ax2) = plt.subplots(1,2)
mx, mn = np.nanmax(f), np.nanmin(f)
ax1.imshow(f, cmap=cmap, vmax=mx, vmin=mn)
ax2.imshow(g_start, cmap=cmap, vmax=mx, vmin=mn)
plt.show()

In [None]:
def fit_jaxopt_Scipy(params, loss_func, method='L-BFGS-B', niter: int = 10, loss_history: list = None):
    if loss_history is None:
        loss_history = []
        loss_history.append(loss_func(params))

    optimizer = jaxopt.ScipyMinimize(fun=loss_func, method=method, jit=False, maxiter=1)

    for i in range(niter):
        start = time.time()
        params, opt_state = optimizer.run(params)
        end = time.time()
        if i % 10 == 0:
            loss_history.append(opt_state.fun_val)
            print(
                f'Iter {i}, Success: {opt_state.success}, Loss = {opt_state.fun_val}, Time = {end - start:.5f} s/iter')

    return params, loss_history

def fit_optax(params: optax.Params, optimizer: optax.GradientTransformation, loss_func,
              niter: int = 10, loss_history: list = None) -> optax.Params:
    ### Gradient of the loss function
    grad_func = jax.grad(loss_func)

    if loss_history is None:
        loss_history = []
    opt_state = optimizer.init(params)
    for i in range(niter):
        start = time.time()
        grads = jnp.conj(grad_func(params))  # Take the conjugate of the gradient
        #grads = grad_func(params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        end = time.time()
        if i % 10 == 0:
            loss_value = loss_func(params)
            loss_history.append(loss_value)
            print(f'Iter {i}, Loss: {loss_value:.10f}, Time = {end - start:.5f} s/iter')

    return params, loss_history

In [None]:
### OPTAX
niter = 400
optimizer = optax.adam(1e-4)
glm, loss_history = fit_optax(glm, optimizer, loss_func, niter=niter, loss_history=None)
glm_end = jnp.copy(glm)

# ### JAXOPT
# niter=400
# glm_float = jnp.array([jnp.real(glm), jnp.imag(glm)]) # [2, L, L]          
# glm, loss_history = fit_jaxopt_Scipy(glm_float, loss_func, method='L-BFGS-B', niter=niter, loss_history=None)
# glm_end = glm[0, :, :] + 1j * glm[1, :, :]