# Lab 14: Modern Variational Inference
#### [Penn State Astroinformatics Summer School 2022](https://sites.psu.edu/astrostatistics/astroinfo-su22-program/)
#### [Jeffrey Regier](https://regier.stat.lsa.umich.edu/)

In this tutorial, we'll analyze images of stars using modern variational inference and PyTorch. First let's import some packages that we'll use throughout.

In [None]:
!pip install torch

In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.distributions import Pareto, Poisson, Normal

_ = torch.manual_seed(0)

## Generating the data
For simplicity, and so that we can know the ground truth, we'll use synthetic data. Let's generate it. The next block of code defines a pixelated point spread function (PSF).

In [None]:
n = 16  # the number of images in our dataset
img_dim = 15  # the height and width of our images. must be odd

psf_marginal = 1 + torch.arange(img_dim, dtype=torch.float32)
half_dim = img_dim // 2
psf_marginal[half_dim:] -= 2 * torch.arange(half_dim + 1)
psf = torch.mm(psf_marginal.view(img_dim, 1), psf_marginal.view(1, img_dim))
psf /= psf.sum()

_ = plt.imshow(psf.data)

In our generative model, the flux for each star follows a Gaussian distribution:

In [None]:
min_flux = 100 / psf.max()
flux_prior = Normal(10 * min_flux, 2 * min_flux)

To generate our synthetic dataset, let's draw the "true" fluxes of `n` stars. These are the latent values we'll subsequently aim to infer.

In [None]:
true_fluxes = flux_prior.sample([n])
print(f"flux mean: {true_fluxes.mean().item()}")
print(f"flux sd: {true_fluxes.std().item()}")
print(f"flux min: {true_fluxes.min().item()}")
print(f"flux max: {true_fluxes.max().item()}")

In a realistic model of images of stars, a fixed background intensity is added to flux-scaled PSF to give the expected intensity of each pixel:

In [None]:
background_intensity = 3 * min_flux
star_intensity = true_fluxes.view(n, 1, 1) * psf.view(1, img_dim, img_dim)
true_intensity = background_intensity + star_intensity
_ = plt.imshow(true_intensity[0])

Now let's draw some images of stars with the fluxes we've previously sampled:

In [None]:
images = Poisson(true_intensity).sample()
_ = plt.imshow(images[0])

## Numerical integration

Numerical integration is a precursor to varational inference.
Numerical integration approximates integrals by partitioning the domain into a grid, evaluating the integrand at each point of the grid, and averaging these values.
In Bayesian inference, the integrand is the joint distribution: $$p(\mathrm{fluxes}, \mathrm{images}) = p(\mathrm{fluxes}) \, p(\mathrm{images} \mid \mathrm{fluxes}).$$
Integrating out the fluxes gives us $p(\mathrm{images}).$
Then, using Bayes rule, we can solve for the posterior, i.e.,
$$p(\mathrm{fluxes} \mid \mathrm{images}) = \frac{p(\mathrm{images} \mid \mathrm{fluxes}) \, p(\mathrm{fluxes})}{p(\mathrm{images})}.$$

In [None]:
bin_width = 100
grid_size = 5000
flux_grid = min_flux + torch.arange(grid_size) * bin_width

rate = psf.view(1, img_dim, img_dim, 1) * flux_grid.view(1, 1, 1, grid_size)
rate += background_intensity

# conditional log likelihood (for each observed image and each flux grid point)
images4d = images.view(n, img_dim, img_dim, 1)
log_p_images_given_fluxes = Poisson(rate).log_prob(images4d).sum([1, 2])
assert log_p_images_given_fluxes.shape == (n, grid_size)

# joint log likelihood
log_p_fluxes_and_images = log_p_images_given_fluxes + flux_prior.log_prob(flux_grid)
assert log_p_fluxes_and_images.shape == (n, grid_size)

# posterior log likelihood
log_p_fluxes_given_images = min_flux + log_p_fluxes_and_images * bin_width
assert log_p_fluxes_given_images.shape == (n, grid_size)

Comparing point estimates to the ground truth is one way to assess how well various inference methods are performing.

In [None]:
def flux_rmse(est_fluxes):
    return (true_fluxes - est_fluxes).pow(2).mean().sqrt()

print(f"prior mean RMSE: {flux_rmse(flux_prior.mean)}")

ss_flux = (images - background_intensity).sum([1,2])
print(f"sky subtracted RMSE: {flux_rmse(ss_flux)}")

mle_flux = min_flux + 1 + log_p_images_given_fluxes.argmax(1) * bin_width
print(f"grid MLE RMSE: {flux_rmse(mle_flux)}")

map_flux = min_flux + 1 + log_p_fluxes_and_images.argmax(1) * bin_width
print(f"grid MAP RMSE: {flux_rmse(map_flux)}")

posterior_mean = (log_p_fluxes_given_images.softmax(1) * flux_grid.view(1, grid_size)).sum(1)
print(f"grid posterior mean RMSE: {flux_rmse(posterior_mean)}")

## Variational inference
In variational inference, we attempt to find a distribution $q(\mathrm{flux})$ that minimizes $$\mathrm{KL}(q(\mathrm{flux})\, \| \, p(\mathrm{flux} \mid \mathrm{images}).$$
Below, we restrict $q$ to the the class of $n$-dimensional multivariate normal distributions that have a diagonal covariance matrix.
The approximation $q$ is parameterized by a unique mean and standardization for each image.
We compute stochastic gradients of the objective function using the reparameterization trick, and use stochastic gradient descent for optimization.

In [None]:
q_mean = nn.Parameter((images - background_intensity).sum([1,2]) + 5000)
q_sd = nn.Parameter(torch.ones(n) * 100)

optimizer = torch.optim.SGD([q_mean, q_sd], lr=100)
num_samples = 64  # number of samples of q per image

for i in range(3000):
    q = Normal(q_mean, q_sd.clamp(1e-4))
    z = q.rsample((num_samples,))
    zt = z.permute(1, 0)

    rate = psf.view(1, img_dim, img_dim, 1) * zt.view(n, 1, 1, num_samples)
    rate += background_intensity
    cond_ll = Poisson(rate)
    
    neg_elbo = q.log_prob(z).sum()
    neg_elbo -= flux_prior.log_prob(z).sum()
    neg_elbo -= cond_ll.log_prob(images.view(n, img_dim, img_dim, 1)).sum()
    
    if i % 200 == 0:
        obj = neg_elbo.item() / num_samples
        rmse = flux_rmse(q_mean)
        print(f"[{i}] objective: {obj}   rmse: {rmse}")
    
    optimizer.zero_grad()
    neg_elbo.sum().backward()
    optimizer.step()

The approach above can be slow because it requires us to effectively solve a unique optimization problem for each image. Amortized inference is more efficient for large datasets. In amortized inference, the approximating distribution for each of the $n$ images is specified by shared neural network, called an encoder. 

In [None]:
class StarEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim * img_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 2),
        )
    
    def forward(self, x):
        x = x.view(-1, img_dim ** 2)
        out = self.net(x)
        q_mean = out[:, 0]
        q_sd = out[:, 1].clamp(-6, 6).exp()
        return Normal(q_mean, q_sd)
    

encoder = StarEncoder()
optimizer = torch.optim.SGD(encoder.parameters(), lr=1e-3)

mb = 8   # minibatch size
num_samples = 64  # number of samples of q per image in the minibatch

for i in range(30000):
    indices = torch.randint(n, (mb,))
    x = images[indices]

    q = encoder(x)

    z = q.rsample((num_samples,))
    zt = z.permute(1, 0)
    
    rate = psf.view(1, img_dim, img_dim, 1) * zt.view(mb, 1, 1, num_samples)
    rate = rate.clamp(0) + background_intensity
    cond_ll = Poisson(rate)

    neg_elbo = q.log_prob(z).sum()
    neg_elbo -= flux_prior.log_prob(z).sum()
    neg_elbo -= cond_ll.log_prob(x.view(mb, img_dim, img_dim, 1)).sum()

    if i % 500 == 0:
        obj = neg_elbo.item() * n / (mb * num_samples)
        rmse = flux_rmse(encoder(images).mean)
        print(f"[{i}] objective: {obj}   rmse: {rmse}")

    optimizer.zero_grad()
    neg_elbo.sum().backward()
    optimizer.step()