__Generating realisations__ 
---

[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2scat/blob/main/notebooks/synthesis.ipynb)

This tutorial is a basic overview of how one may use the scattering covariances as a statistical generative model. There are a number of different applications which are presented in other notebooks, take a look!

In the machine learning (ML) literature, a generative model is typically associated with some model $M_{\lbrace w,b \rbrace}(\theta)$, with weights and biases $\lbrace w, b \rbrace$, which takes parameters $\theta$ for which a new realisation (very often an image) may be generated 

$$f = M_{w,b}(\theta)$$

The difficulty with this approach is that the quality of $f$ is highly dependent on the optimality of $\lbrace w, b \rbrace$, which in turn is entirely dependent on the abundance of realistic training data. In many (often high dimensional) application domains, such training data does not exist, axiomatically limiting naive ML approaches.

One may instead construct an expressive statistical representation from which, provided at least a single fiducial realisation, many realisations may be drawn. This concept is actually very familiar, particularly in Cosmology where it is typical to e.g. draw Gaussian realisations from a known power spectrum $f \sim P(L)$. In this simplistic case the process of drawing a new realisation is self evident, however this generative model does not capture complex non-linear structural information.

Here we will instead use the scattering covariances $\mathcal{S}$ as our statistical representation. Given $\mathcal{S}$ is non-linear, generating new realisations isn't quite so straightforward, in fact to do so we'll need to solve an optimisation problem

$$ f = \min_{\hat{f}}\Big [ \mathcal{S}(\hat{f}) - \mathcal{T} \Big]$$

where $\mathcal{T} = \mathcal{S}(f^{\text{in}})$ are the target covariances computed from the signal we are aiming to emulate $f^{\text{in}}$. To solve this optimisation with gradient based methods we clearly need to be able to differentiate through $\mathcal{S}$ which is a complex function involving Wigner transforms and non-linearities. 

As ``S2SCAT`` is a ``JAX`` package, we can readily access these gradients, so lets see exactly how this works! 

## Import the package

Lets first import ``S2SCAT`` and some basic plotting functions. We'll also pick up ``pickle`` to load the targets which have been stored just to save you some time.

In [None]:
import jax, pickle
jax.config.update("jax_enable_x64", True)

from matplotlib import pyplot as plt 
import numpy as np 

import s2scat, s2fft
from s2scat.core import synthesis

## Configure the problem

Lets set up the target field we are aiming to emulate, and the hyperparameters of the scattering covariance representation we will work with.

In [None]:
L = 128                # Spherical harmonic bandlimit.
N = 3                  # Azimuthal bandlimit (directionality).
J_min = 0              # Minimum wavelet scale.
reality = True         # Input signal is real.
recursive = False      # Use the fully precompute transform.

# Lets load in the spherical field we wish to emulate and its harmonic coefficients.
f = np.load('data/WL_example_f_{}.npy'.format(L))
flm = np.load('data/WL_example_flm_{}.npy'.format(L))

# Also lets load in the target scattering covariances, generated from f and flm.
file = 'data/targets_{}.pickle'.format(L)
with open(file, 'rb') as handle:
    targets, norm = pickle.load(handle)

Before calling the scattering transform you need to run configuration, which will generate any precomputed arrays and cache them. When running the recurisve transform this shouldn't take much memory at all. However, the fully precompute transform can be extremely memory hungry at L ~ 512 and above!

In [None]:
config = s2scat.configure(L, N, J_min, reality, recursive)

## Define a loss function

Lets define a simple $\ell_2$-loss function which just computes the mean squared distance between the scattering covariances computed at our current iterant and those of the target. In practice, any loss could be considered here, however we'll use the most straightforward scenario for this demonstration.

In [None]:
def loss_func(glm):
    predicts = s2scat.scatter(glm, L, N, J_min, reality, config, norm, recursive)
    return synthesis.l2_covariance_loss(predicts, targets)

## Generate an initial estimate

We need to choose a set of harmonic coefficients $g_{\ell m}$ from which to start our optimisation. Strictly speaking, we should start from Gaussianly distributed random signal to ensure we form a macro-canonical model of our target field, and we will do precisely this. However, in practice it may be better to start from e.g. a Gaussian random field, generated from a fiducial power spectrum, as this may reduce the total number of iterations required for convergence. 

In any case, lets generate a starting signal.

In [None]:
# Compute the standard deviation of the target field.
sigma_bar = np.std(np.abs(flm)[flm!=0])

# Generate Gaussian random harmonic coefficients with the correct variance.
glm = np.random.randn(L, L) * sigma_bar + 1j*np.random.randn(L, L) * sigma_bar 

# Save the starting noise signal for posterity and plotting!
glm_start = s2scat.operators.spherical.make_flm_full(glm, L)
g_start = s2fft.inverse(glm_start, L, reality=reality, method="jax")

## Minimise the objective

Now we can pass all these components to ``optax``, which we have internally configured to use the adam optimizer to minimise the loss and return us a synthetic realisation which should approximate the target field statistics.

In [None]:
# Run the optimisation to generate a new realisation glm.
glm, _ = synthesis.fit_optax(glm, loss_func, niters=400, learning_rate=1e-3)

# Convert the synthetic harmonic coefficients into a pixel-space image.
glm_end = s2scat.operators.spherical.make_flm_full(glm, L)
g_end = s2fft.inverse(glm_end, L, reality=reality, method="jax")

## Check the synthesis

Finally, lets check how our starting and ending realisations shape up against the target field!

In [None]:
fields = [f, g_start, g_end]
titles = ["Target", "Initial", "Emulation"]
fig, axs = plt.subplots(1,3, figsize=(30,10))
mx, mn = np.nanmax(f), np.nanmin(f)
for i in range(3):
    axs[i].imshow(fields[i], cmap="magma", vmax=mx, vmin=mn)
    axs[i].set_title(titles[i])
    axs[i].axis('off')
plt.show()