<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_6_ExpFam_Divergences.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# It's a Family Affair, or: Exponential Families, Natural Gradient Descent, & Statistical Divergences

[Exponential Families](https://en.wikipedia.org/wiki/Exponential_family) (sometimes abbreviated as ExpFam) provide a [succinct characterization](https://en.wikipedia.org/wiki/Exponential_family#Table_of_distributions) of many distributions (e.g., [Normal](https://en.wikipedia.org/wiki/Normal_distribution), [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution), [Poisson](https://en.wikipedia.org/wiki/Poisson_distribution), [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution), [Wishart](https://en.wikipedia.org/wiki/Wishart_distribution), etc.). We'll take an informal look at their properties and how to perform inference.

We'll also take an informal review of [statistical divergences](https://en.wikipedia.org/wiki/Divergence_(statistics)), which reflect notions of a "distance" between parametric distributions.

Lastly, we look at another means to perform optimization, but different from previous approaches, we show how to generalize notions of "steepest" to consider the underlying geometry in a distributional sense through *natural* gradient descent.

## Exponential Families
Let $\eta = [\eta_1, \dotsc, \eta_k]$ be a $k$-vector of parameters, and $x$ be an observation such that $x \sim f(\eta)$. We can define its [PDF](https://en.wikipedia.org/wiki/Probability_density_function) (or [PMF](https://en.wikipedia.org/wiki/Probability_mass_function) in case of discrete $x$) as $$f(x | \eta) = h(x)\exp(\eta \cdot T(x) - A(\eta)),$$ where $h(x)$ is a *base measure*, $\eta$ are the *natural parameters*, $T(x)$ are the [*sufficient statistics*](https://en.wikipedia.org/wiki/Sufficient_statistic), and $A(\eta)$ is the [*log-partition function*](https://en.wikipedia.org/wiki/Partition_function_(mathematics)). If $\eta$ is *finite*, and the [*support*](https://en.wikipedia.org/wiki/Support_(mathematics)#In_probability_and_measure_theory) of $f$ does not depend on the value of $\eta$, then $f$ can be said to be a member of the [Exponential Families](https://en.wikipedia.org/wiki/Exponential_family).

### Example: Normal Distribution
Recall if $x \sim N(\mu, \sigma^2)$, then the PDF of $x$ is given by,
$$f(x | \mu, \sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2 \sigma^2}\right).$$ To see that the two-parameter Normal distribution is a member of the Exponential Families, define $\eta = [\frac{\mu}{\sigma^2}, -\frac{1}{2\sigma^2}]$, $h(x) = \frac{1}{\sqrt{2\pi}}$, $T(x) = [x, x^2]^T$, and $A(\eta) = \frac{\mu^2}{2\sigma^2} + \log |\sigma| = -\frac{\eta_1^2}{4\eta_2} + \frac{1}{2}\log|\frac{1}{\eta_2}|$. Placing this all together we have,
$$\begin{align*}
f(x | \eta) &= h(x)\exp(\eta \cdot T(x) - A(\eta)) \\
  &= \frac{1}{\sqrt{2\pi}} \exp(\eta \cdot T(x) - A(\eta)) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left([\frac{\mu}{\sigma^2}, -\frac{1}{2\sigma^2}] \cdot T(x) - A(\eta)\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left([\frac{\mu}{\sigma^2}, -\frac{1}{2\sigma^2}] \cdot [x, x^2]^T - A(\eta)\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left(\frac{\mu x}{\sigma^2} -\frac{x^2}{2\sigma^2} - A(\eta)\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left(\frac{\mu x}{\sigma^2} -\frac{x^2}{2\sigma^2} + \frac{\eta_1^2}{4\eta_2} - \frac{1}{2}\log\left|\frac{1}{\eta_2}\right|\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left(\frac{\mu x}{\sigma^2} -\frac{x^2}{2\sigma^2} - \frac{\mu^2}{2\sigma^2} - \log |\sigma|\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2} - \log |\sigma|\right) \\
  &= \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right) \\
  &= \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right) \\
\end{align*}$$

## Expectations
A key property of ExpFam distributions is that one can define the moments of $T(x)$ from the log-partition function $A(\eta)$. Two key moments to remember are, $\mathbb{E}[T(x)] = \frac{\partial}{\partial \eta} A(\eta)$ and $\mathbb{V}[T(x)] = \frac{\partial^2}{\partial \eta \partial \eta^T} A(\eta)$.

In the case of scalar Normal variables, we have $\mathbb{E}[T(x)] = \mathbb{E}\left[[x, x^2]\right] = [\mu, \sigma^2 + \mu^2]$ and $\mathbb{V}[T(x)] = [\sigma^2, \sigma^4]$.
$$\begin{align*}
\mathbb{E}[T(x)] &= \frac{\partial}{\partial \eta} A(\eta) \\
  &= \frac{\partial}{\partial \eta} \left[-\frac{\eta_1^2}{4\eta_2} + \frac{1}{2}\log|\frac{1}{\eta_2}|\right] \\
  &= -\frac{\partial}{\partial \eta} \frac{\eta_1^2}{4\eta_2} + \frac{\partial}{\partial \eta}\frac{1}{2}\log|\frac{1}{\eta_2}| \\
  &= -\frac{\partial}{\partial \eta} \frac{\eta_1^2}{4\eta_2} - \frac{\partial}{\partial \eta}\frac{1}{2}\log|\eta_2| \\
\mathbb{E}[T(x)]_1 &= -\frac{\partial}{\partial \eta_1} \frac{\eta_1^2}{4\eta_2} - \frac{\partial}{\partial \eta_1}\frac{1}{2}\log|\eta_2| \\
  &= - \frac{2 \eta_1}{4 \eta_2} = -\frac{\eta_1}{2 \eta_2} = -\frac{\mu / \sigma^2}{-2 / (2\sigma^2)} = \mu \\
\mathbb{E}[T(x)]_2 &= -\frac{\partial}{\partial \eta_2} \frac{\eta_1^2}{4\eta_2} - \frac{\partial}{\partial \eta_2}\frac{1}{2}\log|\eta_2| \\
  &= \frac{\eta^2_1}{\eta_2^2} - \frac{1}{2\eta_2} = \frac{\mu^2/\sigma^4}{1/\sigma^4} + \frac{1}{1 / \sigma^2} = \mu^2 + \sigma^2\\
\mathbb{V}[T(x)] &= \frac{\partial^2}{\partial \eta \partial \eta^T} A(\eta) \\
  &= \text{Ha, ha! For the HW} .
\end{align*}$$

## Inference
Given $x_1, \dotsc, x_n$ we assume that $x_i \sim f(\eta)$ independently and identically distributed under some ExpFam dist $f$ with natural parameters $\eta$.

Our log-likelihood is then given by,
$$ \ell(\eta) = \sum_{i=1}^n \log f(x_i | \eta).$$ To identify the
maximum likelihood estimates we must first compute the gradient wrt $\eta$, which is given by
$$\begin{align*}
\nabla \ell(\eta) &= \sum_{i=1}^n \nabla \log f(x_i | \eta) \\
  &= \sum_{i=1}^n \nabla \log \left[ h(x_i)\exp(\eta \cdot T(x_i) - A(\eta))\right] \\
  &= \sum_{i=1}^n \nabla \left[\log h(x_i) + \log \exp(\eta \cdot T(x_i) - A(\eta))\right] \\
  &= \sum_{i=1}^n \nabla \log h(x_i) + \nabla \log \exp(\eta \cdot T(x_i) - A(\eta)) \\
  &= \sum_{i=1}^n \nabla [\eta \cdot T(x_i)] - \nabla A(\eta) \\
  &= \sum_{i=1}^n T(x_i) - \mathbb{E}[T(x_i)] ⇒ \\
\end{align*}$$
We would like to find values $\eta$ such that $\sum_i T(x_i) = \sum_i \mathbb{E}[T(x_i)] = n \mathbb{E}[T(X)]$.

In [1]:
!pip install equinox

Collecting equinox
  Downloading equinox-0.11.3-py3-none-any.whl (167 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m167.9/167.9 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Collecting jaxtyping>=0.2.20 (from equinox)
  Downloading jaxtyping-0.2.25-py3-none-any.whl (39 kB)
Collecting typeguard<3,>=2.13.3 (from jaxtyping>=0.2.20->equinox)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping, equinox
Successfully installed equinox-0.11.3 jaxtyping-0.2.25 typeguard-2.13.3


In [None]:
from abc import abstractmethod

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike


class ExpFam(eqx.Module):
  x: Array

  @abstractmethod
  def base_measure(self) -> Array:
    ...

  @abstractmethod
  def sufficient_statistics(self) -> Array:
    ...

  @abstractmethod
  def log_partition(self, eta: ArrayLike) -> Array:
    ...

  def loglikelihood(self, eta: ArrayLike) -> Array:
    t_x = self.sufficient_statistics(self.x)
    log_h_x = jnp.log(self.base_measure(self.x))
    log_eta = self.log_partition(eta)
    return t_x @ eta - log_eta + log_h_x

  def __call__(self, eta: ArrayLike) -> Array:
    return self.loglikelihood(eta)


class Normal(ExpFam):

  def base_measure(self) -> Array:
    return 1. / jnp.sqrt(2 * jnp.pi)


## Statistical Divergences
[Divergences](https://en.wikipedia.org/wiki/Divergence_(statistics)) capture a notion of "[statistical distance](https://en.wikipedia.org/wiki/Statistical_distance)" between parameterized distribution functions. Their full definition is beyond the scope of this course, but key properties to recall are, given distributions $q$ and $p$ over the same [sample space](https://en.wikipedia.org/wiki/Sample_space), a divergence statisfies $D(p || q) \geq 0$ and $D(p || q) = 0$ iff $p = q$. There are many [different kinds](https://en.wikipedia.org/wiki/Divergence_(statistics)#Examples) of divergences, but a very commonly used divergence is the [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence), or KL divergence.

For discrete $x$ with sample space $\mathcal{X}$, we have, $$D_{KL}(p || q) = \sum_{x \in \mathcal{X}} p(x) \log \frac{p(x)}{q(x)} = - \sum_{x \in \mathcal{X}} p(x) \log \frac{q(x)}{p(x)}.$$

Fo continuous $x \in \mathbb{R}$, we have, $$D_{KL}(p || q) = \int_{-\infty}^\infty p(x) \log \frac{p(x)}{q(x)}dx = -\int_{-\infty}^\infty p(x) \log \frac{q(x)}{p(x)}dx.$$

## Example: Normal Distribution
Let $p := N(\mu_p, \sigma^2_p)$ and $q := N(\mu_q, \sigma^2_q).$ The KL divergence between $p, q$ is given by,
$$D_{KL}(p || q) = \frac{(\mu_p - \mu_q)^2}{2 \sigma^2_q} + \frac{1}{2}\left(\frac{\sigma^2_p}{\sigma^2_q} - 1 - \ln \frac{\sigma^2_p}{\sigma^2_q} \right).$$

## Natural Gradient Descent
Recall under [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) we can iteratively optimize a function $f(\beta)$ by taking steps in the steepest direction,
$$ \theta_{t+1} = \theta_t - \rho_t \nabla f(\theta_t).$$ We can re-write this update as,
$$ \theta_{t+1} = {\arg \min}_{\theta} \ \ \theta \cdot \nabla f(\theta_t) - \frac{1}{\rho_t}||\theta - \theta_t||^2.$$ To see their equivalence we have,
$$\begin{align*}
TBD
\end{align*}$$

What if rather than considering *Euclidean distance* we used a notion of *statistical distance* or *divergence*?