# Using Neural Posterior Estimation with Embedding Networks for high-dimensional data

## Introduction

### The importance of Summary statistics
Neural Posterior Estimation relies upon training a neural network to understand the distribution of relationships between model parameters and the types of data they generate.
However, when the dataset is very high dimensional (or alternatively, has many features), training this neural network can be very expensive.
In such cases, it is common to turn to some *summary statistics* to provide a lower-dimensional representation of the dataset. 
The neural netowrk can then learn the distribution of data in this subspace, and the comparison between real and simulated data can then be performed there too.
This may speed up the inference dramatically, and potentially also make it more reliable (but then again, it also might not, unless the statistics are well crafted). 
More information on SBI and summary statistics can be found in [the SBI documentation](https://www.mackelab.org/sbi/tutorial/10_crafting_summary_statistics/)

### Automatic summary statistics
However, for astronomical data, it may be hard to define reasonable summary statistics *a priori*.
They are also often high dimensional - in terms of the kinds of data that Ampere typically handles, if you are working with photometry you probably don't need to worry, but most spectra will exhibit these issues.
Hence, if we had a way to handle this automatically, it would be great.

One way to do this is to introduce another neural network, called an *embedding network*; so-called because it seems to *embed* the high-dimensional data in a lower-dimensional subspace that can represent it with minimal loss.
This embedding is then used as a way to automatically determine summary statistics without knowning what to look for ahead of time. 
There are a range of architectures that can do this, but Fully Connected, Recurrent and Convolutional Neural Networks are common choices.

This tutorial will introduce one of the ways you can use this approach in Ampere, and point you to resources that can help you understand other approaches to defining an embedding network.

## Using embeddings in your code

Now we will discuss how to actually put this into practice. 

### Setup

First, we have to make a variety of imports.

In [None]:
import numpy as np
import os
import ampere
from ampere.data import Spectrum, Photometry
from ampere.infer.sbi import SBI_SNPE
from ampere.models import Model
from spectres import spectres
import pyphot

As usual, we have to define our model. Since embeddings are most useful for high-dimensional data, we will return to the simple straight-line model defined in the quickstart guide, and use it to create some spectra and photometry.

In [None]:
class ASimpleModel(Model):
    '''This is a very simple model in which the flux is linear in wavelength.

    This model shows you the basics of writing a model for ampere

    '''
    def __init__(self, wavelengths, flatprior=True,
                 lims=np.array([[-10, 10],
                                [-10, 10]])):
        '''The model constructor, which will set everything up

        This method does essential setup actions, primarily things that 
        may change from one fit to another, but will stay constant throughout 
        the fit. This may be things like the grid of wavelengths to calculate
        the model output on, or establishing the dust opacities if involved.
        There are also several important variables it *MUST* define here
        '''
        self.wavelength = wavelengths
        self.npars = 2 #Number of free parameters for the model (__call__()). For some 
        # models this can be determined through introspection, but it is still strongly 
        # recommended to define this explicitly here. Introspection will only be 
        # attempted if self.npars is not defined.
        self.npars_ptform = 2 #Sometimes the number of free parameters is different 
        # when using the prior transform instead of the prior. In that case, 
        # self.npars_ptform should also be defined.

        # You can do any other set up you need in this method.
        # For example, we could define some cases to set up different priors
        # But that's for a slightly more complex example.
        # Here we'll just use a simple flat prior
        self.lims = lims
        self.flatprior = flatprior
        self.parLabels = ["slope", "intercept"]

    def __call__(self, slope, intercept, **kwargs):
        '''The model itself, using the callable class functionality of python.

        This is an essential method. It should do any steps required to 
        calculate the output fluxes. Once done, it should stop the output fluxes
        in self.modelFlux.
        '''

        self.modelFlux = slope*self.wavelength + intercept
        return {"spectrum": {"wavelength": self.wavelength, "flux": self.modelFlux}}

    def lnprior(self, theta, **kwargs):
        """The model prior probability distribution
       
        The prior is essential to most types of inference with ampere. The
        prior describes the relative weights (or probabilities if normalised)
        of different parameter combinations. Using a normalised prior with
        SNPE is strongly recommended, otherwise ampere will attempt
        approximate normalisation using Monte Carlo integration.
        """
        if not self.flatprior:
            raise NotImplementedError()
        slope = theta[0]
        #print(slope)
        intercept = theta[1]
        return (
            0
            if self.lims[0, 0] < slope < self.lims[0, 1]
            and self.lims[1, 0] < intercept < self.lims[1, 1]
            else -np.inf
        )
        

    def prior_transform(self, u, **kwargs):
        '''The prior transform, which takes samples from the Uniform(0,1)
        distribution to the desired distribution.

        Prior transforms are essential for SNPE. SNPE needs to be able to
        generate samples from the prior, and this method is integral to doing 
        so. Therefore, unlike other inference methods, if you want to use 
        SNPE (or other SBI approaches) you need to define *both* lnprior and 
        prior_transform.
        '''
        if self.flatprior:
            return (self.lims[:,1] - self.lims[:,0]) * u + self.lims[:,0]
        else:
            raise NotImplementedError()

### Creating synthetic data

Now that we have our model, we can generate some ground-truth synthetic observations to test the approach with. Firstly, we compute the model for fixed values of the parameters, with the slope and intercept both set to 1.

In [None]:
wavelengths = 10**np.linspace(0., 1.9, 2000)

""" Choose some model parameters """
slope = 1.  # Keep it super simple for now
intercept = 1.

# Now init the model:
model = ASimpleModel(wavelengths)
# And call it to produce the fluxes for our chosen parameters
model_dict = model(slope, intercept)
model_flux = model_dict['spectrum']['flux']

Now that we have a noise-free spectrum, we can convolve it with some filters to extract synthetic photometry. In this case, we will use the WISE W1 and *Spitzer* MIPS 70 filters. Then we will add some Gaussian noise to it and create a `Photometry` object for it.

In [None]:
filterName = np.array(['WISE_RSR_W1', 'SPITZER_MIPS_70'])  

libDir = ampere.__file__.strip('__init__.py') # '/home/peter/pythonlibs/ampere/ampere/'
libname = f'{libDir}ampere_allfilters.hd5'
filterLibrary = pyphot.get_library(fname=libname)
filters = filterLibrary.load_filters(filterName, interp=True, 
                                        lamb=wavelengths*pyphot.unit['micron'])
# Now we need to extract the photometry with pyphot
# first we need to convert the flux from Fnu to Flambda
flam = model_flux / wavelengths**2
modSed = []
for i, f in enumerate(filters):
    lp = f.lpivot.to("micron").value
    fphot = f.get_flux(wavelengths*pyphot.unit['micron'], flam*pyphot.unit['flam'], axis=-1).value
    modSed.append(fphot*lp**2)

modSed = np.array(modSed)

input_noise_phot = 0.1  # Fractional uncertainty
photunc = input_noise_phot * modSed  # Absolute uncertainty
# Now perturb data by drawing from a Gaussian distribution
modSed = modSed + np.random.randn(len(filterName)) * photunc 

photometry = Photometry(filterName=filterName, value=modSed, 
                        uncertainty=photunc, photunits="Jy", 
                        libName=libname)
# print(photometry.filterMask)
photometry.reloadFilters(wavelengths)

Now we repeat the process for some spectra. We choose to create a synthetic *Spizter*/IRS spectrum, and set it to use fast resampling.

In [None]:
dataDir = f'{os.getcwd()}/test_data/'
specFileExample = 'cassis_yaaar_spcfw_14191360t.fits'
irsEx = Spectrum.fromFile(os.path.normpath(dataDir+specFileExample),
                            format='SPITZER-YAAAR')
spec0 = spectres(irsEx[0].wavelength,wavelengths,model_flux)
spec1 = spectres(irsEx[1].wavelength,wavelengths,model_flux)

# And again, add some noise to it
input_noise_spec = 0.1
unc0 = input_noise_spec*spec0
unc1 = input_noise_spec*spec1
spec0 = spec0 + np.random.randn(len(spec0))*unc0
spec1 = spec1 + np.random.randn(len(spec1))*unc1

spec0 = Spectrum(irsEx[0].wavelength, spec0, unc0, "um", "Jy", 
                    calUnc=0.0025, 
                    scaleLengthPrior=0.01)  # , resampleMethod=resmethod)
spec1 = Spectrum(irsEx[1].wavelength, spec1, unc1, "um", "Jy", 
                    calUnc=0.0025, 
                    scaleLengthPrior=0.01)  # , resampleMethod=resmethod)

# Now let's try changing the resampling method so it's faster
# This model is very simple so exact flux conservation is not important
resmethod = "fast"  # "exact"#"fast"#
spec0.setResampler(resampleMethod=resmethod)
spec1.setResampler(resampleMethod=resmethod)

Now we combine our synthetic data into a dataset.

In [None]:
dataset = [photometry,
           # spec0, #Fitting spectra is slow because it needs to do a lot of resampling
           spec1   # As a result, we're leaving some of them out
           ]

### Defining an embedding network and doing inference

Now that we have our model and (synthetic) data, the last thing we need before we can do inference is our embedding. 
SBI is all uses torch, so we need to define our embedding as a torch neural network - or at least one that torch understands. 
We *could* do that from scratch, with something like the following:

In [None]:
from torch import Tensor, nn


class FullyCOnnectedNetwork(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim = 20,
        num_layers = 3,
        num_hidden_units = 100,
    ):
        super().__init__()
        layers = [nn.Linear(input_dim, num_hiddens), nn.ReLU()]
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(num_hiddens, num_hidden_units))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(num_hiddens, output_dim))
        layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

embedding = FullyCOnnectedNetwork(np.sum([data.wavelength.shape[0] for data in dataset]))

You can find some further information on defining embedding networks in the SBI documentation at https://www.mackelab.org/sbi/tutorial/05_embedding_net/.

However, SBI also packages a small selection of pre-defined networks that you can use, and we can specify the same network by just importing one of those and passing it the same setup, for example:

In [None]:
from sbi.neural_nets.embedding_nets import FCEmbedding

embedding = FCEmbedding(np.sum([data.wavelength.shape[0] for data in dataset]), output_dim=20, num_layers=3, num_hidden_units=100)

This is already a lot easier, at least as long as one of the architectures provided by SBI (Fully Connected and Convolutional Neural Networks) is suitable for your problem.

But to try and make life even easier, Ampere can construct an embedding for you, using the architectures packaged by SBI!
You can simply accept the default hyperparameters for these networks by passing either `embedding_net='CNN'` or `embedding_net='FC'` when instantiating the inference object, or you can pass Ampere a dictionaary defining the values of the hyperparameters that you want to change along with an indication of the type of network you want to use:

In [None]:
embedding = {'type': 'FC', 'num_hiddens': 100, 'n_layers': 3, "output_dim": 20}

And now things proceed in a very similar fashion to a normal NPE run! 
The only difference is that our dictionary telling Ampere how to define the embedding is passed along with the model and dataset when instantiating our inference object. 
The dict could also just be replaced with a user-defined network or the strings indicated above for default hyperparameters.

In [None]:
optimizer = SBI_SNPE(model=model, data=dataset, embedding_net=embedding)

# Then we tell it to explore the parameter space
optimizer.optimise(nsamples=10000, nsamples_post=10000, n_rounds=1
                    )

# now we call the postprocessing to produce some figures
optimizer.postProcess()