This is section 5 in my series on using Variational Inference to speed up relatively complex Bayesian models like Multilevel Regression and Poststratification without the approximation being of disastrously poor quality. 

The general structure for this post and the around it will be to describe a problem with VI, and then describe how that problem can be fixed to some degree. Collectively, all the small improvements in these four posts will go a long way towards more robust variational inference. I'll also have a grab bag at the end of other interesting ideas from the literature I think are cool, but maybe not as important or interesting to me as the 3 below.

In the [last post](https://andytimm.github.io/posts/Variational%20MRP%20Pt4/variational_mrp_4.html) we saw a variety of different ways importance sampling can be used to improve VI and make it more robust, from defining a tighter bound to optimize in the importance weighted ELBO, to weighting $q(x)$ samples together efficiently to look more like $p(x)$, to combining entirely different variational approximations together to cover different parts of the posterior with multiple importance sampling.

In this post, we'll tackle the problem of how to define a deeply flexible variational
family $\mathscr{Q}$ that can adapt to each problem while still being easy to sample from.
To do this, we'll draw on normalizing flows, a technique for defining a composition
of invertible transformations on top of a simple base distribution like a normal
distribution. We'll build our way up to using increasingly complex neural networks
to define those transformations, allowing for for truly complex variational
families that are problem adaptive, training as we train our variational model.

The rough plan for the series is as follows:

1.  Introducing the Problem- Why is VI useful, why VI can produce spherical cows
2.  How far does iteration on classic VI algorithms like mean-field and full-rank get us?
3.  Problem 1: KL-D prefers exclusive solutions; are there alternatives?
4.  Problem 2: Not all VI samples are of equal utility; can we weight them cleverly?
5.  **(This post)** Problem 3: How can we get deeply flexible variational approximations; are Normalizing Flows the answer?
6. Problem 4: How can we know when VI is wrong? Are there useful error bounds?
7. Better grounded diagnostics and workflow

# A problem adaptive variational family with less tinkering?

![](images/flows_stairs_meme.png){fig-alt="Something about NNs makes me meme more"}

Jumping from mean-field or full-rank Gaussians and similar distributions
to neural networks feels a little... dramatic[^1],  so I want to spend
some time justifying why this is a good idea.

For VI to work well, we need something that's still simple to sample from, but capable
of, in aggregate, representing a posterior that is probably pretty complex. Certainly,
some problems are amenable to the simple variational families $\mathscr{Q}$ we've tried so far,
but it's worth re-emphasizing that we're probably trying to represent something complex,
and even moderate success at that using a composition of normals should be
a little surprising, not the expected outcome.

If we need $\mathscr{Q}$ to be more complex, aren't there choices between what
we've seen and a neural network? There's a whole literature of them- from using
mixture distributions as variational distributions to inducing some additional
structure into a mean-field type solution if you have some specific knowledge
about your target posterior you can use. By and large though, this type of
class of solutions has been surpassed by normalizing flows in much of modern
use for more complex posteriors.

Why? A first reason is described in the paper that started the normalizing flows
for VI literature, Rezende and Mohamed's [**Variational Inference with Normalizing Flows
**](https://arxiv.org/pdf/1505.05770.pdf): making our base variational distribution
more complex adds a variety of different computational costs, which add up quickly.
This isn't the most face-valid argument when I'm claiming a neural network
is a good alternative, but it gets more plausible when you think through
how poorly it'd scale to keep making your mixture distribution more and more
complex as your posteriors get harder to handle. So this is a *scalability*
argument- it might sound extreme to bring in a neural net, but as problems
get bigger, scaling matters.

The other point I'd raise is that all these other tools aren't very black box at
all- if we can make things work with a problem-adapted version of mean-field with
some structure based on the knowledge of a specific problem we have, that sounds
like it gets time consuming fast. If I'm going to have
to find a particular, problem-specific solution each time I want to use variational
inference, that feels fragile and fiddly as well- that's a poor user experience.

The novel idea with normalizing flows is that we'll start with a simple base
density like a normal distribution that is easy to sample from, but instead of only optimizing the parameters
of that normal distribution, we'll also use the training on our ELBO or
other objective to learn a transformation that reshapes that normal distribution to
look like our posterior. By having that transforming component be partially
composed of a neural network,
we give ourselves access to an incredibly expressive, automatically problem adaptive,
and heavily scalable variant of variational inference that is quite
widely used.

And if the approximation isn't expressive enough? Deep Learning researchers have
an unfussy, general purpose innovation for that: MORE LAYERS![^2]

![](images/more_layers.png){fig-alt="Wow such estimator, very deep"}

# What is a normalizing flow?

A normalizing flow transforms a simple base density into a complex one through
a sequence of invertible transformations. By stacking more and more of these
invertible transformations (having the density "flow" through them), we can create
arbitrarily complex distributions that remain valid probability distributions. Since
it isn't universal in the flows literature, let me be explicit that I'll consider
"forward" to be the direction flowing from the base density to the posterior, and
the "backward" or "normalizing" direction as towards the base density.

![Image Credit to [Simon Boehm](https://siboehm.com/articles/19/normalizing-flow-network) here](images/normalizing-flow.png)

If we have a random variable $x$, with distribution $q(x)$, some function $f$ with an inverse
$f^{-1} = g, g \circ f(x) = x$, then the distribution of the result of
one iteration of x through, $q^\prime(x)$ is:

$$
q\prime(z) = q(x) \lvert det \frac{\partial f^{-1}}{\partial x^\prime} \rvert = q(x) \lvert \frac{\partial f}{\partial x} \rvert^{-1}
$$
I won't derive this identity[^3], but it follows from the chain rule and the
 properties of Jacobians of invertible functions.

The real power comes in here when we see that these transformations stack. If
we've got a chain of transformations (eg $f_K(...(f_2f_1(x))$:

$$
x_K = f(x) \circ ... \circ f_2 \circ f_1(x_0)
$$

then the resulting density $q_K(x)$ looks like:

$$
ln q_K (x_K) = lnq_0(x_0) - \sum \limits_{K = 1}\limits^{K} ln  \lvert \frac{\partial f_k}{\partial x_{k-1}} \rvert^{-1}
$$

Neat, and surprisingly simple! If the terms above are all easy to calculate,
we can very efficiently stack a bunch of these transformations and make
an expressive model.

## Normalizing Flows for variational inference versus other applications

One source of confusion when I was learning about normalizing flows for
variational inference was that variational inference makes up a fairly
small proportion of the use cases for normalizing flows, and thus the academic
literature and online discussion. More common applications include density estimation, image generation,
representation learning, and reinforcement learning. In addition to making specifically applicable
discussions harder to find, often resources will make strong claims about properties of a given
flow structure, that really only holding in some subset of the above applications[^4].

By taking a second to explain this crisply and compare different application's needs,
hopefully I can save you some confusion and make engaging with the broader literature easier.

To start, consider the relevant operations we've introduced so far:

1. computing $f$, that is pushing a sample through the transformations
2. computing $g$, $f$'s inverse which undoes the manipulations
3. computing the (log) determinant of the Jacobian
 
1 and 3 definitely need to be efficient for our use case, since we need to be
able to sample and push through using the formula above efficiently to calculate
an ELBO and train our model. 2 is where things get
more subtle: we definitely need $f$ to be invertible, since our formulas above
are dependent on a property of Jacobians of invertible functions. But we don't
actually really need to explicitly compute $g$ for variational inference. Even knowing the inverse
exists but not having a formula might be fine for us!

Contrast
this with density estimation, where the goal would not to sample from the distribution,
but instead to estimate the density. In this case, most of the time would be
spent going in the opposite direction, so that they can evaluate the log-likliehood
of the data, and maximize it to improve the model[^5]. The need for an expressive
transformation of densities unite these two cases, but the goal is quite different!

This level of goal disagreement also shows it face in what direction papers
choose to call forward: Most papers outside of variational inference applications consider forward to be the opposite of what I do here, the direction towards
the base density, the "normalizing" direction. 

For our use, hopefully this short digression has clarified which operations we need to be
fast versus just exist. If you dive deeper into
further work on normalizing flows, hopefully recognizing there are two
different ways to point this thing helps you more quickly orient yourself
to how other work describe flows.

# How to train your neural net

Now, let's turn to how we actually fit a normalizing flow. Since this would be a bit
hard to grok a code presentation if I took advantage of the full flexibility and abstraction that
something like [`vistan`](https://github.com/abhiagwl/vistan/tree/master) provides, before
heading into general purpose tools I'll talk through a bit more explicit implementation
of a simpler flow called a planar flow `PyTorch` for illustration. Rather than
reinventing the wheel, I'll leverage Edvard Hulten's implementation [here](https://github.com/e-hulten/planar-flows).

In this section,
I'll define conceptually how we're fitting the model, and build out a fun
target distribution and loss`- since I expect many people reading
this may moderately new to PyTorch, I'll explain in detail
than normal what each operation is doing and why we need it.


In [None]:
#| echo: false
import torch
import numpy as np
import torch.nn as nn
from torch.distributions import Uniform
from torch.distributions import MultivariateNormal
from torch import Tensor
from PIL import Image
from typing import Tuple
from typing import Callable
import matplotlib.pyplot as plt

torch.set_default_device('cuda')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Let's first make a fun target posterior distribution from an image to model. I
think it'd be a fun preview gif for the post to watch the model say Hi:

![](images/hard_to_draw_posterior.png){fig-alt="Wow such estimator, very deep"}

It's quick to turn the 300x300 pixel image above into a 300x300 PyTorch tensor.
To represent this as a 2-D density we can fit models against, we'll read in
the image, collapse along the color dimension, and transform it into a torch
tensor:


In [None]:
raw_img = Image.open("images/hard_to_draw_posterior.png")

# Sum the 3 color channels
greyscale_img = np.array(raw_img).sum(axis = 2)

# Replace white values (1020), with 0, so density is all at letters
indices = greyscale_img == 1020

# Replace the selected values with 0
greyscale_img[indices] = 50


# Normalize values to help with fitting
normalized_image  = (greyscale_img - greyscale_img.min()) / (10000000)
 
# Make a torch tensor of the target to use
torch_posterior = torch.from_numpy(normalized_image).to(device)

#300 x 300 px, normalized grayscale representation of the above image
torch_posterior.sum()

Now let's define our loss for training, which will just be a slight
reformulation of our ELBO:

$$
 \mathbb{E}[logp(z,x)] - \mathbb{E}[logq(z)]
$$

To do this, we'll define a class for the loss.

First, we pick a simple base distribution to push through our flow, here a 
2-D Normal distribution called `base_distr`. We'll also include the interesting
target we just made above, `distr`.

Next, the forward pass structure. The `forward` method is the is the core of the computational graph structure in PyTorch. It defines operations that are applied to the input tensors to compute the output, and 
gives PyTorch the needed information for automatic differentiation, which allows smooth calculation
and backpropogation of loss through the model to train it. This `VariationalLoss`
module will run at the end of the forward pass to calculate the loss and allow us
to pass it back through the graph for training.

Keeping with the structure above of numbering successive stages of the flow,
`z0` here is our base distribution, and `z` will be the learned approximation
to the target. In addition to the terms you'd expect in the ELBO, we're also
tracking and making use of the sum of the log determinant of the Jacobians to
a handle on the distortion of the base density the flows apply.


In [None]:
#| echo: False


# helper function to get a value for continuous density of my pixelated image
# There's definitely a better way to do this, this is just quick
def interpolate_tensor(tensor, z):
  # Get the dimensions of the tensor
  height, width = z.shape

  # Separate x and y coordinates from the input tensor
  x = z[:, 0]
  y = z[:, 1]

  # Calculate the indices of the four surrounding elements
  x1 = x.floor().clamp(max=width - 1).long()
  x2 = x1 + 1
  y1 = y.floor().clamp(max=height - 1).long()
  y2 = y1 + 1

  # Calculate the weight for interpolation
  weight_x2 = x - x1.float()
  weight_x1 = 1 - weight_x2
  weight_y2 = y - y1.float()
  weight_y1 = 1 - weight_y2

  # Perform interpolation
  value = (
      tensor[y1.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y1 +
      tensor[y1.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y1 +
      tensor[y2.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y2 +
      tensor[y2.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y2
  )

  return value

class TargetDistribution:
    def __init__(self, name: str):
        """Define target distribution. 

        Args:
            name: The name of the target density to use. 
                  Valid choices: ["U_1", "U_2", "U_3", "U_4", "ring"].
        """
        self.func = self.get_target_distribution(name)

    def __call__(self, z: Tensor) -> Tensor:
        return self.func(z)

    @staticmethod
    def get_target_distribution(name: str) -> Callable[[Tensor], Tensor]:
        w1 = lambda z: torch.sin(2 * np.pi * z[:, 0] / 4)
        w2 = lambda z: 3 * torch.exp(-0.5 * ((z[:, 0] - 1) / 0.6) ** 2)
        w3 = lambda z: 3 * torch.sigmoid((z[:, 0] - 1) / 0.3)

        if name == "U_1":

            def U_1(z):
                u = 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
                u = u - torch.log(
                    torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
                    + torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
                )
                return u

            return U_1
        elif name == "U_2":

            def U_2(z):
                u = 0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2
                return u

            return U_2
        elif name == "U_3":

            def U_3(z):
                u = -torch.log(
                    torch.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.35) ** 2)
                    + torch.exp(-0.5 * ((z[:, 1] - w1(z) + w2(z)) / 0.35) ** 2)
                    + 1e-6
                )
                return u

            return U_3
        elif name == "U_4":

            def U_4(z):
                u = -torch.log(
                    torch.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2)
                    + torch.exp(-0.5 * ((z[:, 1] - w1(z) + w3(z)) / 0.35) ** 2)
                    + 1e-6
                )
                return u

            return U_4
        elif name == "ring":

            def ring_density(z):
                exp1 = torch.exp(-0.5 * ((z[:, 0] - 2) / 0.8) ** 2)
                exp2 = torch.exp(-0.5 * ((z[:, 0] + 2) / 0.8) ** 2)
                u = 0.5 * ((torch.norm(z, 2, dim=1) - 4) / 0.4) ** 2
                u = u - torch.log(exp1 + exp2 + 1e-6)
                return u

            return ring_density
          
        elif name == "hi":

          def hi_density(z):
              return interpolate_tensor(torch_posterior,z)
          
          return hi_density
        
        elif name == "moons":
          
          def two_moons_density(z):
            x = z[:, 0]
            y = z[:, 1]
            d = torch.sqrt(x**2 + y**2)
            density = torch.exp(-0.2 * d) * torch.cos(4 * np.pi * d)
            return density
          
          return two_moons_density

In [None]:
# https://github.com/e-hulten/planar-flows/blob/master/loss.py
class VariationalLoss(nn.Module):
  def __init__(self,distribution):
      super().__init__()
      self.distr = distribution
      self.base_distr = MultivariateNormal(torch.zeros(2), torch.eye(2))

  def forward(self, z0: Tensor, z: Tensor, sum_log_det_J: float) -> float:
      base_log_prob = self.base_distr.log_prob(z0)
      target_density_log_prob = -self.distr(z)
      return (base_log_prob - target_density_log_prob - sum_log_det_J).mean()

# A basic flow

Next, let's define the structure of the actual flow. To do this, we'll first
describe a single layer of the flow, then we'll show structure to stack
the flow in layers.

Our first flow we look at will be the **planar flow** from the original
Normalizing Flows for variational Inference paper mentioned above. The name
comes from how the function defines a (hyper)plane, and compress or expand
the density around it:

$$
f(x) = x + u*tanh(w^Tx + b), w, u \in 	\mathbb{R}^d, b \in	\mathbb{R} 
$$

$w$ and $b$ define the hyperplane and u specifies the direction and strength
of the expansion. I'll show a visualization of just one layer of that below.

If you're more used to working with neural nets, you might wonder why we
choose the non-linearity $tanh$ here, which generally isn't as popular as something
like $relu$ or it's variants in more recent years due to it's more unstable
gradient flows. As the authors show in appendix $A.1$, functions like the
above aren't actually always invertible, and choosing $tanh$ allows them
to impose some constraints that make things reliably invertible. See the Appendix
for more details about how that works, or take a careful look at Edvard's
implementation of the single function below:


In [None]:
# From https://github.com/e-hulten/planar-flows/blob/master/planar_transform.py

class PlanarTransform(nn.Module):
  """Implementation of the invertible transformation used in planar flow:
      f(z) = z + u * h(dot(w.T, z) + b)
  See Section 4.1 in https://arxiv.org/pdf/1505.05770.pdf. 
  """

  def __init__(self, dim: int = 2):
      """Initialise weights and bias.
      
      Args:
          dim: Dimensionality of the distribution to be estimated.
      """
      super().__init__()
      self.w = nn.Parameter(torch.randn(1, dim).normal_(0, 0.1))
      self.b = nn.Parameter(torch.randn(1).normal_(0, 0.1))
      self.u = nn.Parameter(torch.randn(1, dim).normal_(0, 0.1))

  def forward(self, z: Tensor) -> Tensor:
      if torch.mm(self.u, self.w.T) < -1:
          self.get_u_hat()

      return z + self.u * nn.Tanh()(torch.mm(z, self.w.T) + self.b)

  def log_det_J(self, z: Tensor) -> Tensor:
      if torch.mm(self.u, self.w.T) < -1:
          self.get_u_hat()
      a = torch.mm(z, self.w.T) + self.b
      psi = (1 - nn.Tanh()(a) ** 2) * self.w
      abs_det = (1 + torch.mm(self.u, psi.T)).abs()
      log_det = torch.log(1e-4 + abs_det)

      return log_det

  def get_u_hat(self) -> None:
      """Enforce w^T u >= -1. When using h(.) = tanh(.), this is a sufficient condition 
      for invertibility of the transformation f(z). See Appendix A.1.
      """
      wtu = torch.mm(self.u, self.w.T)
      m_wtu = -1 + torch.log(1 + torch.exp(wtu))
      self.u.data = (
          self.u + (m_wtu - wtu) * self.w / torch.norm(self.w, p=2, dim=1) ** 2
      )

Where things will start to get exciting is multiple layers of the flow; here's
how we can make an abstraction that allows us to stack up $K$ layers
of the flow to control the flexibility of our approximation.


In [None]:
class PlanarFlow(nn.Module):
    def __init__(self, dim: int = 2, K: int = 6):
        """Make a planar flow by stacking planar transformations in sequence.

        Args:
            dim: Dimensionality of the distribution to be estimated.
            K: Number of transformations in the flow. 
        """
        super().__init__()
        self.layers = [PlanarTransform(dim) for _ in range(K)]
        self.model = nn.Sequential(*self.layers)

    def forward(self, z: Tensor) -> Tuple[Tensor, float]:
        log_det_J = 0

        for layer in self.layers:
            log_det_J += layer.log_det_J(z)
            z = layer(z)

        return z, log_det_J

Let's run this for a single layer to introduce the training loop, and build some
intuition on the planar flow. Note that I'm hiding setting up the plot code.


In [None]:
#| echo: False

# https://github.com/e-hulten/planar-flows/blob/master/utils/plot.py
def plot_density(density, xlim=4, ylim=4, ax=None, cmap="Blues"):
    x = y = np.linspace(-xlim, xlim, 300)
    X, Y = np.meshgrid(x, y)
    shape = X.shape
    X_flatten, Y_flatten = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
    Z = torch.from_numpy(np.concatenate([X_flatten, Y_flatten], 1))
    U = torch.exp(-density(Z))
    U = U.reshape(shape)
    if ax is None:
        fig = plt.figure(figsize=(7, 7))
        ax = fig.add_subplot(111)

    ax.set_xlim(-xlim, xlim)
    ax.set_ylim(-xlim, xlim)
    ax.set_aspect(1)

    ax.pcolormesh(X, Y, U, cmap=cmap, rasterized=True)
    ax.tick_params(
        axis="both",
        left=False,
        top=False,
        right=False,
        bottom=False,
        labelleft=False,
        labeltop=False,
        labelright=False,
        labelbottom=False,
    )
    return ax


def plot_samples(z):
    nbins = 250
    lim = 4
    # z = np.exp(-z)
    k = gaussian_kde([z[:, 0], z[:, 1]])
    xi, yi = np.mgrid[-lim : lim : nbins * 1j, -lim : lim : nbins * 1j]
    zi = k(np.vstack([xi.flatten(), yi.flatten()]))

    fig = plt.figure(figsize=[7, 7])
    ax = fig.add_subplot(111)
    ax.set_xlim(-5, 5)
    ax.set_aspect(1)
    plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap="Purples", rasterized=True)
    return ax


def plot_transformation(model, n=500, xlim=4, ylim=4, ax=None, cmap="Purples"):
    base_distr = torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
    x = torch.linspace(-xlim, xlim, n)
    xx, yy = torch.meshgrid(x, x)
    zz = torch.stack((xx.flatten(), yy.flatten()), dim=-1).squeeze()

    zk, sum_log_jacobians = model(zz)
    

    base_log_prob = base_distr.log_prob(zz)
    final_log_prob = base_log_prob - sum_log_jacobians
    qk = torch.exp(final_log_prob)
    
    if ax is None:
        fig = plt.figure(figsize=[7, 7])
        ax = fig.add_subplot(111)
    ax.set_xlim(-xlim, xlim)
    ax.set_ylim(-ylim, ylim)
    ax.set_aspect(1)
    
    ax.pcolormesh(
        zk[:, 0].detach().cpu().data.reshape(n, n),
        zk[:, 1].detach().cpu().data.reshape(n, n),
        qk.detach().cpu().data.reshape(n, n),
        cmap=cmap,
        rasterized=True,
    )
    

    plt.tick_params(
        axis="both",
        left=False,
        top=False,
        right=False,
        bottom=False,
        labelleft=False,
        labeltop=False,
        labelright=False,
        labelbottom=False,
    )
    if cmap == "Purples":
        ax.set_facecolor(plt.cm.Purples(0.0))
    elif cmap == "Reds":
        ax.set_facecolor(plt.cm.Reds(0.0))

    return ax


def plot_training(model, flow_length, batch_num, lr, axlim):
    ax = plot_transformation(model, xlim=axlim, ylim=axlim)
    ax.text(
        0,
        axlim - 2,
        "Flow length: {}\nDensity of one batch, iteration #{:06d}\nLearning rate: {}".format(
            flow_length, batch_num, lr
        ),
        horizontalalignment="center",
    )
    plt.savefig(
        f"training_plots/iteration_{batch_num:06d}.png",
        bbox_inches="tight",
        pad_inches=0.5,
    )
    plt.close()


def plot_comparison(model, target_distr, flow_length, dpi=400):
    xlim = ylim = 7 if target_distr == "ring" else 5
    fig, axes = plt.subplots(
        ncols=2, nrows=1, sharex=True, sharey=True, figsize=[10, 5], dpi=dpi
    )
    axes[0].tick_params(
        axis="both",
        left=False,
        top=False,
        right=False,
        bottom=False,
        labelleft=False,
        labeltop=False,
        labelright=False,
        labelbottom=False,
    )
    # Plot true density.
    density = TargetDistribution(target_distr)
    plot_density(density, xlim=xlim, ylim=ylim, ax=axes[0])
    axes[0].text(
        0,
        ylim - 1,
        "True density $\exp(-{})$".format(target_distr),
        size=14,
        horizontalalignment="center",
    )

    # Plot estimated density.
    batch = torch.zeros(500, 2).normal_(mean=0, std=1)
    z = model(batch)[0].detach().numpy()
    axes[1] = plot_transformation(model, xlim=xlim, ylim=ylim, ax=axes[1], cmap="Reds")
    axes[1].text(
        0,
        ylim - 1,
        "Estimated density $\exp(-{})$".format(target_distr),
        size=14,
        horizontalalignment="center",
    )
    fig.savefig(
        "results/" + target_distr + "_K" + str(flow_length) + "_comparison.pdf",
        bbox_inches="tight",
        pad_inches=0.1,
    )


def plot_available_distributions():
    target_distributions = ["U_1", "U_2", "U_3", "U_4", "ring"]
    cmaps = ["Reds", "Purples", "Oranges", "Greens", "Blues"]
    fig, axes = plt.subplots(1, len(target_distributions), figsize=(20, 5))
    for i, distr in enumerate(target_distributions):
        axlim = 7 if distr == "ring" else 5
        density = TargetDistribution(distr)
        plot_density(density, xlim=axlim, ylim=axlim, ax=axes[i], cmap=cmaps[i])
        axes[i].set_title(f"Name: '{distr}'", size=16)
        plt.setp(axes, xticks=[], yticks=[])
    plt.show()

In [None]:
#From https://github.com/e-hulten/planar-flows/blob/master/train.py
target_distr = "ring"  # U_1, U_2, U_3, U_4, ring
flow_length = 32
dim = 2
num_batches = 20000
batch_size = 128
lr = 6e-4
axlim = xlim = ylim = 5  # 5 for U_1 to U_4, 7 for ring
# ------------------------------------

density = TargetDistribution(target_distr)
model = PlanarFlow(dim, K=flow_length)
bound = VariationalLoss(density)
optimiser = torch.optim.Adam(model.parameters(), lr=lr)

# Train model.
for batch_num in range(1, num_batches + 1):
    # Get batch from N(0,I).
    batch = torch.zeros(size=(batch_size, 2)).normal_(mean=0, std=1)
    # Pass batch through flow.
    zk, log_jacobians = model(batch)
    
    # Compute loss under target distribution.
    loss = bound(batch, zk, log_jacobians)

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

    if batch_num % 100 == 0:
        print(f"(batch_num {batch_num:05d}/{num_batches}) loss: {loss}")
        #print(log_jacobians)

    if batch_num == 1 or batch_num % 100 == 0:
        # Save plots during training. Plots are saved to the 'train_plots' folder.
        plot_training(model, flow_length, batch_num, lr, axlim)

# What more complicated NNs look like

# Conclusion 

[^1]: It also almost has a bit of "no brain no pain" ML guy energy, in the sense
that we're really pulling out the biggest algorithm possible. It really is a funny
trajectory to me to go from "I'd like to still be Bayesian, but avoid MCMC because it's slow"
to "screw subtle design, let's throw a NN at it".
[^2]: This is mostly a joke, but it really is a tremendous convenience that
there's such a straight forward knob to turn for "expressivity" in this context.
We'll get into the ways that isn't completely true soon, but NNs provide fantastic
convenience in terms of workflow for improving model flexibility.
[^3]: You can see it in the original Normalizing Flows paper linked above, or
combined with a nice matrix calc review by [Lilian Weng](https://lilianweng.github.io/posts/2018-10-13-flow-models/). As a more general note, since this is a common topic on a few different talented
people's blogs, I'll try to focus on covering material I think I can provide
more intuition for, or that are most relevant for variational inference.
[^4]: A great example of this is Lilian Weng's [NF walkthrough](https://lilianweng.github.io/posts/2018-10-13-flow-models/) which
I reccomended above- It
has a fantastic review of the needed linear algebra and covers a lot of different
flow types, but is a bit overly general about what properties are most desirable
in a flow, and therefore initially a bit fuzzy on the value different flows
have.
[^5]: Deriving precisely how this works would take us too far afield, but see [Kobyzev et
al. (2020)](https://arxiv.org/abs/1908.09257) if you're interested. It's a great review paper that does a lot of work to recognize there are multiple different possible applications of normalizing flows, and thus
different notations and framings that they very successfully bridge.
are many different implicit and explicit objectives and 