# Likelihood-free Inference

_Alex Malz (LINCC@CMU)_
_LSSTC Data Science Fellowship Program_



In [None]:
# !pip install --quiet jax-cosmo numpyro dm-haiku optax sbi chainconsumer tensorflow-probability numpyro lenstools
# !pip install --quiet git+https://github.com/EiffL/powerbox-jax.git

In [None]:
import jax
import jax_cosmo as jc
import jax.numpy as jnp

import numpyro
from numpyro.handlers import seed, trace, condition
import numpyro.distributions as dist

import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions

import haiku as hk
seq = hk.PRNGSequence(42)

import torch
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer

import powerbox_jax as pbj

from chainconsumer import ChainConsumer

import lenstools as lt
import astropy.units as u

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pickle

## Theory

When explicit likelihoods are unavailable, we can't do MCMC sampling, so we need alternatives.
Those that use an _implicit_ likelihood, in the form of any process that takes in the parameters we want to constrain and outputs realizations of data, form a family of closely related methods called Likelihood-free Inference (LFI), which may also be referred to as Simulation-based Inference (SBI) and which includes Approximate Bayesian Computation (ABC).

![overview](01_algorithms_tikz.png "terminology")

This diagram from [Lueckmann+ 2021](https://arxiv.org/abs/2101.04653), via [Lanusse 2022](https://eiffl.github.io/talks/EAS2022/), illustrates the nuances between terminology for these methods, which seem to change every few years, but the principle behind them is the same.
Instead of evaluating the likelihood of proposed parameters, a black-box implicit likelihood, typically a simulator or emulator, is used to forward-model mock data that is then compared with the real data.
In both approaches, sampled parameters that are accepted are used to generate subsequent samples. 

## Context

Let's try to learn the cosmological parameters from summary statistics of weak lensing mass maps, with and without using the likelihood.
This problem is distilled from [a tutorial](https://colab.research.google.com/drive/1K8cB1h3ge3kTVut81Xnkw2kNiKFIn8HI?usp=sharing) by Francois Lanusse.

## LensingForwardModelLogNormal

In [None]:
def make_power_map(pk_fn, N, field_size, zero_freq_val=0.0):
    k = 2*jnp.pi*jnp.fft.fftfreq(N, d=field_size / N)
    kcoords = jnp.meshgrid(k,k)
    k = jnp.sqrt(kcoords[0]**2 + kcoords[1]**2)
    ps_map = pk_fn(k)
    ps_map = ps_map.at[0,0].set(zero_freq_val)
    return ps_map * (N / field_size)**2

def make_lognormal_power_map(power_map, shift, zero_freq_val=0.0):
    power_spectrum_for_lognorm = jnp.fft.ifft2(power_map).real
    power_spectrum_for_lognorm = jnp.log(1 + power_spectrum_for_lognorm/shift**2)
    power_spectrum_for_lognorm = jnp.abs(jnp.fft.fft2(power_spectrum_for_lognorm))
    power_spectrum_for_lognorm = power_spectrum_for_lognorm.at[0,0].set(0.)
    return power_spectrum_for_lognorm

def model(N=128,               # number of pixels on the map
          map_size=10,         # map size in deg.
          gal_per_arcmin2=10,   
          sigma_e=0.2, 
          shift=0.05,
          model_type='lognormal'): # either 'lognormal' or 'gaussian'      
    
    pix_area = (map_size * 60 / N)**2 # arcmin2 
    map_size = map_size / 180 * jnp.pi    # radians

    # Sampling cosmology
    omega_c = numpyro.sample('omega_c', dist.Normal(0., 1.0)) * 0.05 + 0.3
    sigma_8 = numpyro.sample('sigma_8', dist.Normal(0., 1.0)) * 0.05 + 0.8
    
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma_8)
    # Creating a given redshift distribution
    pz = jc.redshift.smail_nz(0.5, 2., 1.0)
    tracer = jc.probes.WeakLensing([pz])
    
    # Defining the function that will compute the power spectrum of the field
    # Create an interpolation array for the cls to avoid far too many computations
    ell_tab = jnp.logspace(0, 4.5, 128)
    cell_tab = jc.angular_cl.angular_cl(cosmo, ell_tab, [tracer])[0]
    P = lambda k: jc.scipy.interpolate.interp(k.flatten(), ell_tab, cell_tab).reshape(k.shape)
    
    # Sampling latent variables
    z = numpyro.sample('z', dist.MultivariateNormal(loc=jnp.zeros((N,N)), precision_matrix=jnp.eye(N)))

    # Convolving by the power spectrum
    power_map = make_power_map(P, N, map_size) 
    if model_type == 'lognormal':
        power_map =  make_lognormal_power_map(power_map, shift)

    field = jnp.fft.ifft2(jnp.fft.fft2(z) * jnp.sqrt(power_map)).real

    if model_type == 'lognormal':
        field = shift * (jnp.exp(field - jnp.var(field) / 2) - 1)

    # Adding "observational noise"
    x = numpyro.sample('x', dist.Independent(dist.Normal(field, sigma_e/jnp.sqrt(gal_per_arcmin2 * pix_area)), 2))
    
    return x

In [None]:
# Create our fiducial observations
fiducial_model = condition(model, {'omega_c': 0., 'sigma_8': 0.})
sample_map_fiducial = seed(fiducial_model, jax.random.PRNGKey(42))
m_data = sample_map_fiducial()

In [None]:
plt.imshow(m_data, extent=(0,10,0,10))
plt.colorbar()

In [None]:
other_m_data = sample_map_fiducial()
plt.imshow(other_m_data, extent=(0,10,0,10))
plt.colorbar()

In [None]:
# Checking that the power spectrum looks ok with Lenstools
cosmo = jc.Planck15(Omega_c=0.3, sigma8=0.8)
# Creating a given redshift distribution
pz = jc.redshift.smail_nz(0.5, 2., 1.0)
tracer = jc.probes.WeakLensing([pz])

kmap_lt = lt.ConvergenceMap(m_data, 10*u.deg)
l_edges = np.arange(100.0,3000.0,100.0)
l2,Pl2 = kmap_lt.powerSpectrum(l_edges)

cell = jc.angular_cl.angular_cl(cosmo, l2, [tracer])[0]
plt.plot(l2, cell, label='Theory')
plt.plot(l2, Pl2)
plt.loglog()
plt.legend()

# yikes, it doesn't look good at all

In [None]:
# Now we condition the model on obervations
observed_model = condition(model, {'x': m_data})

### Problem 0

MCMC sample this -- __NOT!__

In [None]:
# nuts_kernel = numpyro.infer.NUTS(observed_model,
#                                  init_strategy=numpyro.infer.init_to_median,
#                                  max_tree_depth=6,
#                                  step_size=0.02)

In [None]:
# mcmc = numpyro.infer.MCMC(nuts_kernel, 
#                           num_warmup=100, 
#                           num_samples=1000)

In [None]:
# # very slow on CPU (hours), might be faster on GPU (hour)
# mcmc.run(jax.random.PRNGKey(3))

In [None]:
# res = mcmc.get_samples()

# # # Saving the trace
# # with open('lensing_fwd_mdl_lognorm.pickle', 'wb') as handle:
# #     pickle.dump(res, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
# # with open('lensing_fwd_mdl_lognorm.pickle', 'rb') as handle:
# #     res = pickle.load(handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# plt.figure(figsize=[10,10])
# plt.scatter(res['omega_c']* 0.05 + 0.3, res['sigma_8']* 0.05 + 0.8, c=arange(len(res['sigma_8'])));
# plt.axvline(0.3)
# plt.axhline(0.8)
# plt.xlabel('Omega_c')
# plt.ylabel('sigma_8')

In [None]:
# plt.imshow(res['z'].mean(axis=0),vmin=-4,vmax=4); colorbar()

In [None]:
# conditional_model = condition(model, {'z': res['z'].mean(axis=0), 'omega_c': 0., 'sigma_8': 0.})

In [None]:
# sample_map_rec = seed(conditional_model, jax.random.PRNGKey(2))
# m_data_rec = sample_map_rec()

In [None]:
# plt.imshow(m_data_rec)

In [None]:
# model_trace = trace(sample_map_rec).get_trace()

In [None]:
# plt.imshow(model_trace['x']['fn'].mean)

## DemoSBI

In [None]:
cosmo = jc.Planck15()                  # Create a cosmology with default parameters
nz = jc.redshift.smail_nz(1., 2,  0.75, 
                  gals_per_arcmin2=6)  # Create a Smail redshift distribution
tracer = jc.probes.WeakLensing([nz])   # Define a lensing probe

# Let's build an array of parameters 
fid_params = np.array([cosmo.Omega_c, cosmo.sigma8]) 

# An array of ells
ell = jnp.logspace(2, np.log10(2_000), 20)

# Computing the mean and covariance matrix for this cosmology and this tracer
mu, cov = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo, ell, [tracer], f_sky=0.125);

In [None]:
# Let's define a function that will sample a Cl for a given cosmology
@jax.jit
def sample_likelihood(params, key):
    cosmo = jc.Planck15(Omega_c=params[0], sigma8=params[1])
    mu = jc.angular_cl.angular_cl(cosmo, ell, [tracer]).flatten() 
    dist = tfd.MultivariateNormalDiag(loc=mu, 
                                    scale_diag=jnp.sqrt(jnp.diag(cov)))
    return dist.sample(seed=key)

In [None]:
# Draw measured cls for different seeds at the fiducial cosmology
plt.plot(ell, sample_likelihood(fid_params, jax.random.PRNGKey(0)))
plt.plot(ell, sample_likelihood(fid_params, jax.random.PRNGKey(1)))
plt.plot(ell, sample_likelihood(fid_params, jax.random.PRNGKey(2)))
plt.loglog()

plt.plot(ell, mu, '--')

In [None]:
num_dim = 2
prior = utils.BoxUniform(low=0.1 * torch.ones(num_dim), high=1 * torch.ones(num_dim))

def simulator(parameter_set):
    return sample_likelihood(parameter_set.cpu().detach().numpy(), next(seq)).to_py()

We can use this canned SBI package to  perform the inference, but 

In [None]:
#takes minutes on laptop
posterior = infer(simulator, prior, method="SNPE", num_simulations=1000)

In [None]:
observation = sample_likelihood(fid_params, jax.random.PRNGKey(0)).to_py()

In [None]:
samples = posterior.sample((10000,), x=observation)
log_probability = posterior.log_prob(samples, x=observation)

In [None]:
c = ChainConsumer()
c.add_chain(samples.cpu().detach().numpy(), parameters=["$\Omega_c$", "$\sigma_8$"], name='SBI')

fig = c.plotter.plot(figsize="column", truth=fid_params)

# backup idea

## Simulation

BPZ --> photometry and redshifts

distance function of any of the three point estimate statistics

## Emulation

pzflow --> photometry and redshifts

distance function could be any of the 

## 

## LensingSimulator (currently broken)