__Generating realisations (Basic)__ 
---

[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2scat/blob/main/notebooks/synthesis.ipynb)

This tutorial is a basic overview of how one may use the scattering covariances as a statistical generative model. 

Generative AI models typically require the abundance of realistic training data. In many (often high dimensional) application domains, such as the sciences, such training data does not exist, limiting generative AI approaches.

One may instead construct an expressive statistical representation from which, provided at least a single fiducial realisation, many realisations may be drawn. This concept is actually very familiar, particularly in cosmology where it is typical to draw Gaussian realisations from a known power spectrum.  However, this generative model does not capture complex non-linear structural information.

Here we will instead use the scattering covariances $\Phi(x)$ as our statistical representation. Given $\Phi$ is a non-linear function of the data $x$, generating new realisations isn't quite so straightforward.  In fact, to do so we'll need to minimise the loss function:

$$ \mathcal{L}(x) = ||\Phi(x) - \Phi(x_t)||^2_2$$

where $\Phi(x_t)$ are the target covariances computed from the signal we are aiming to emulate $x_t$. To solve this optimisation with gradient based methods we clearly need to be able to differentiate through $\Phi$ which is a complex function involving wavelet transforms, non-linearities, spherical harmonic and Wigner transforms. 

As ``S2SCAT`` is a ``JAX`` package, we can readily access these gradients, so lets see exactly how this works! 

## Import the package

Lets first import ``S2SCAT`` and some basic plotting functions. We'll also pick up ``pickle`` to load the targets which have been stored just to save you some time.

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install a spherical plotting package.
!pip install cartopy &> /dev/null

# Install s2fft and data if running on google colab.
if IN_COLAB:
    !pip install s2scat &> /dev/null
    !mkdir data/
    # Update below to include data files that we need on colab
    # !wget https://github.com/astro-informatics/s2fft/raw/main/notebooks/data/Gaia_EDR3_flux.npy -P data/ &> /dev/null

In [None]:
import jax, pickle
jax.config.update("jax_enable_x64", True)

from matplotlib import pyplot as plt 
import numpy as np 
import cartopy.crs as ccrs 
import s2scat, s2fft

## Configure the problem

Lets set up the target field we are aiming to emulate, and the hyperparameters of the scattering covariance representation we will work with.

In [None]:
L = 128                # Spherical harmonic bandlimit.
N = 3                  # Azimuthal bandlimit (directionality).
J_min = 0              # Minimum wavelet scale.
reality = True         # Input signal is real.
recursive = False      # Use the fully precompute transform.

# Lets load in spherical harmonic coefficients we want to model
xlm = np.load('data/WL_example_flm_{}.npy'.format(L))

## Build the generative model

We have included an easy to use function which will return the generative model as a callable function. This is done by calling the following function.

In [None]:
model = s2scat.generation.build_model(xlm, L, N, J_min, reality, recursive)

## Decode some textures from the model

The model function takes a JAX random key which is used to reliably generate random arrays from which to initialise the optimisation. It is also vectorised to allow multiple generations simultaneously.

In [None]:
key = jax.random.PRNGKey(0)
xlm_new = model(key, 1, 5, 1e-3)

We've generated some new harmonic coefficients $x_{\ell m}$ so we now need to map these back to spherical images which we'll do in HEALPix sampling.

## View these textures

Finally, lets check how our generated textures shape up against the target field!

In [None]:
fields = [x_t, x_start, x_end]
titles = ["Target", "Start", "Emulation"]
fig, axs = plt.subplots(1, 3, subplot_kw={'projection': ccrs.Mollweide()}, figsize=(30,10))
mx, mn = np.nanmax(x_t), np.nanmin(x_t)
for i in range(3):   
    axs[i].imshow(fields[i], transform=ccrs.PlateCarree(), cmap='magma', vmax=mx, vmin=mn)
    axs[i].set_title(titles[i])
    axs[i].axis('off')
plt.show()