## Variational inference: the core problem

We consider a latent variable model with observations $x$ and latent variables $z$.

The posterior
$$
p_\theta(z \mid x)
$$
is typically **intractable**.


### Variational idea

Introduce a tractable approximation
$$
q_\phi(z \mid x) \approx p_\theta(z \mid x),
$$
and optimize the evidence lower bound (ELBO):
$$
\log p_\theta(x)
\;\ge\;
\mathbb{E}_{q_\phi(z \mid x)}\!\left[\log p_\theta(x \mid z)\right]
-
\mathrm{D_{KL}}\!\left(q_\phi(z \mid x)\,\|\,p(z)\right).
$$



## Why amortized variational inference?

Naively, variational inference requires **separate variational parameters for each datapoint**.

This does not scale.

<br> 


### Amortized inference

Use an inference (recognition) network:
$$
x \;\longmapsto\; q_\phi(z \mid x),
$$
with **global parameters** $\phi$ shared across all datapoints.

<br>

### Benefits

- Scales to large datasets
- Fast inference at test time
- Compatible with minibatch SGD

## The problem with standard amortized VI

In practice, $q_\phi(z \mid x)$ is chosen to be **simple**:

$$
q_\phi(z \mid x)
=
\mathcal{N}\!\big(\mu_\phi(x), \mathrm{diag}(\sigma_\phi^2(x))\big).
$$



### Consequences

- Mean-field independence assumptions
- Cannot represent:
  - multimodality
  - strong correlations
  - complex posterior geometry



### Empirical issues (well documented)

- Underestimation of posterior variance
- Biased parameter estimates
- Poor uncertainty quantification


## Change of variables

By the change-of-variables formula:
$$
\log q_K(z_K \mid x)
=
\log q_0(z_0 \mid x)
-
\sum_{k=1}^K
\log\left|
\det
\frac{\partial f_k}{\partial z_{k-1}}
\right|.
$$



### Interpretation

- Each transformation **warps** the density
- Stacking transformations yields rich posteriors
- Complexity controlled by flow length $K$

### Why normalizing flows help

- Start from a simple, reparameterizable base distribution
- Add flexibility through invertible transformations
- Preserve tractable likelihoods and gradients

## Finite normalizing flows

Rather than a single transformation, we compose many:

$$
z_K = f_K \circ f_{K-1} \circ \cdots \circ f_1(z_0),
$$

where
$$
z_0 \sim q_0(z_0).
$$


The resulting density is
$$
\log q_K(z_K)
=
\log q_0(z_0)
-
\sum_{k=1}^K
\log\left|
\det\frac{\partial f_k}{\partial z_{k-1}}
\right|.
$$

## Expectations under a flow (LOTUS)

A key property:

For any function $h(z)$,
$$
\mathbb{E}_{q_K}[h(z)]
=
\mathbb{E}_{q_0}\!\left[
h\big(f_K \circ \cdots \circ f_1(z_0)\big)
\right].
$$

<br>

### Consequence

- We can compute expectations **without explicitly evaluating** $q_K$
- Jacobian terms only matter when evaluating $\log q_K$

## Geometric intuition (visual)

*Expansions spread mass; contractions concentrate mass.*

<div id="flow-demo" style="width: 100%; height: 420px;"></div>


```{js}
//| echo: false
//| output: true

// ---- D3 must be available. If it's not already in your deck,
// add a header include loading d3 (see notes below).

const el = document.getElementById("flow-demo");
el.innerHTML = "";  // reset if slide revisited

const W = el.clientWidth || 900;
const H = 420;

const svg = d3.select(el).append("svg")
  .attr("width", W)
  .attr("height", H)
  .style("border", "1px solid rgba(255,255,255,0.15)")
  .style("border-radius", "12px");

const pad = 30;
const cx = W/2, cy = H/2;

// coordinate system: map [-2,2]x[-1.4,1.4] to screen
const xScale = d3.scaleLinear().domain([-2, 2]).range([pad, W-pad]);
const yScale = d3.scaleLinear().domain([-1.4, 1.4]).range([H-pad, pad]);

// build a grid of points (acts like "mass")
const pts = [];
for (let x=-2; x<=2.001; x+=0.18) {
  for (let y=-1.4; y<=1.4001; y+=0.18) {
    pts.push({x, y, x0:x, y0:y});
  }
}

// two warps: contraction + expansion around a curve
function warp(z, mode){
  // z = {x,y}
  // define a "potential" around a vertical ridge near x = 0
  const ridge = 0.0;
  const dx = z.x - ridge;
  const strength = mode === "expand" ? +1.0 : -1.0; // expand vs contract

  // radial-ish push/pull centered on (ridge, 0)
  const r2 = dx*dx + z.y*z.y + 0.15;
  const s = strength * 0.55 / r2;

  return { x: z.x + s*dx, y: z.y + s*z.y };
}

// draw points
const g = svg.append("g");
const dots = g.selectAll("circle")
  .data(pts)
  .enter()
  .append("circle")
  .attr("cx", d => xScale(d.x0))
  .attr("cy", d => yScale(d.y0))
  .attr("r", 2.0)
  .attr("opacity", 0.7);

const label = svg.append("text")
  .attr("x", 14)
  .attr("y", 24)
  .attr("font-size", 18)
  .attr("font-weight", 600)
  .text("Base density (points) → warp");

// animate between base and warped states
let mode = "contract";
let t = 0;            // 0..1 interpolation
let dir = +1;

function step(){
  t += dir * 0.02;
  if (t >= 1){ t = 1; dir = -1; }
  if (t <= 0){
    t = 0; dir = +1;
    mode = (mode === "contract") ? "expand" : "contract";
  }

  label.text(mode === "contract"
    ? "Contraction: mass concentrates"
    : "Expansion: mass spreads"
  );

  dots
    .attr("cx", d => {
      const w = warp({x:d.x0, y:d.y0}, mode);
      const x = (1-t)*d.x0 + t*w.x;
      return xScale(x);
    })
    .attr("cy", d => {
      const w = warp({x:d.x0, y:d.y0}, mode);
      const y = (1-t)*d.y0 + t*w.y;
      return yScale(y);
    });

  requestAnimationFrame(step);
}
requestAnimationFrame(step);

``` 