In [None]:
import jax, pickle
jax.config.update("jax_enable_x64", True)

from matplotlib import pyplot as plt 
import numpy as np 

import s2scat, s2fft
from s2scat.core import synthesis

Lets set up the target field we are aiming to emulate, and the hyperparameters of the scattering covariance representation we will work with.

In [None]:
L = 128                # Spherical harmonic bandlimit.
N = 3                  # Azimuthal bandlimit (directionality).
J_min = 0              # Minimum wavelet scale.
reality = True         # Input signal is real.
recursive = False      # Use the fully precompute transform.

# Lets load in the spherical field we wish to emulate and its harmonic coefficients.
f = np.load('data/WL_example_f_{}.npy'.format(L))
flm = np.load('data/WL_example_flm_{}.npy'.format(L))

# Also lets load in the target scattering covariances, generated from f and flm.
file = 'data/targets_{}.pickle'.format(L)
with open(file, 'rb') as handle:
    targets, norm = pickle.load(handle)

Before calling the scattering transform you need to run configuration, which will generate any precomputed arrays and cache them. When running the recurisve transform this shouldn't take much memory at all. However, the fully precompute transform can be extremely memory hungry at L ~ 512 and above!

In [None]:
config = s2scat.configure(L, N, J_min, reality, recursive)

Lets define a simple $\ell_2$-loss function which just computes the mean squared distance between the scattering covariances computed at our current iterant and those of the target. In practice, any loss could be considered here, however we'll use the most straightforward scenario for this demonstration.

In [None]:
def loss_func(glm):
    predicts = s2scat.scatter(glm, L, N, J_min, reality, config, norm, recursive)
    return synthesis.l2_covariance_loss(predicts, targets)

We need to choose a set of harmonic coefficients $g_{\ell m}$ from which to start our optimisation. Strictly speaking, we should start from Gaussianly distributed random signal to ensure we form a macro-canonical model of our target field, and we will do precisely this. However, in practice it may be better to start from e.g. a Gaussian random field, generated from a fiducial power spectrum, as this may reduce the total number of iterations required for convergence. 

In any case, lets generate a starting signal.

In [None]:
# Compute the standard deviation of the target field.
sigma_bar = np.std(np.abs(flm)[flm!=0])

# Generate Gaussian random harmonic coefficients with the correct variance.
glm = np.random.randn(L, L) * sigma_bar + 1j*np.random.randn(L, L) * sigma_bar 

# Save the starting noise signal for posterity and plotting!
glm_start = s2scat.operators.spherical.make_flm_full(glm, L)
g_start = s2fft.inverse(glm_start, L, reality=reality, method="jax")

Now we can pass all these components to ``optax``, which we have internally configured to use the adam optimizer to minimise the loss and return us a synthetic realisation which should approximate the target field statistics.

In [None]:
# Run the optimisation to generate a new realisation glm.
glm, _ = synthesis.fit_optax(glm, loss_func, niters=400, learning_rate=1e-3)

# Convert the synthetic harmonic coefficients into a pixel-space image.
glm_end = s2scat.operators.spherical.make_flm_full(glm, L)
g_end = s2fft.inverse(glm_end, L, reality=reality, method="jax")

Finally, lets check how our starting and ending realisations shape up against the target field!

In [None]:
fields = [f, g_start, g_end]
titles = ["Target", "Initial", "Emulation"]
fig, axs = plt.subplots(1,3, figsize=(30,10))
mx, mn = np.nanmax(f), np.nanmin(f)
for i in range(3):
    axs[i].imshow(fields[i], cmap="magma", vmax=mx, vmin=mn)
    axs[i].set_title(titles[i])
    axs[i].axis('off')
plt.show()