<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_5_ExpFam_Divergences.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# It's a Family Affair, or: Exponential Families

[Exponential Families](https://en.wikipedia.org/wiki/Exponential_family) (sometimes abbreviated as ExpFam) provide a [succinct characterization](https://en.wikipedia.org/wiki/Exponential_family#Table_of_distributions) of many distributions (e.g., [Normal](https://en.wikipedia.org/wiki/Normal_distribution), [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution), [Poisson](https://en.wikipedia.org/wiki/Poisson_distribution), [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution), [Wishart](https://en.wikipedia.org/wiki/Wishart_distribution), etc.). We'll take an informal look at their properties and how to perform inference.

## Exponential Families
Let $\eta = [\eta_1, \dotsc, \eta_k]$ be a $k$-vector of parameters, and $x$ be an observation such that $x \sim f(\eta)$. We can define its [PDF](https://en.wikipedia.org/wiki/Probability_density_function) (or [PMF](https://en.wikipedia.org/wiki/Probability_mass_function) in case of discrete $x$) as $$f(x | \eta) = h(x)\exp(\eta \cdot T(x) - A(\eta)),$$ where $h(x)$ is a *base measure*, $\eta$ are the *natural parameters*, $T(x)$ are the [*sufficient statistics*](https://en.wikipedia.org/wiki/Sufficient_statistic), and $A(\eta)$ is the [*log-partition function*](https://en.wikipedia.org/wiki/Partition_function_%28mathematics%29). If $\eta$ is *finite*, and the [*support*](https://en.wikipedia.org/wiki/Support_%28mathematics%29%23In_probability_and_measure_theory) of $f$ does not depend on the value of $\eta$, then $f$ can be said to be a member of the [Exponential Families](https://en.wikipedia.org/wiki/Exponential_family).

### Example: Normal Distribution
Recall if $x \sim N(\mu, \sigma^2)$, then the PDF of $x$ is given by,
$$f(x | \mu, \sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2 \sigma^2}\right).$$ To see that the two-parameter Normal distribution is a member of the Exponential Families, define $\eta = [\frac{\mu}{\sigma^2}, -\frac{1}{2\sigma^2}]$, $h(x) = \frac{1}{\sqrt{2\pi}}$, $T(x) = [x, x^2]^T$, and $A(\eta) = \frac{\mu^2}{2\sigma^2} + \log |\sigma| = -\frac{\eta_1^2}{4\eta_2} + \frac{1}{2}\log|\frac{1}{2\eta_2}|$. Placing this all together we have,
$$\begin{align*}
f(x | \eta) &= h(x)\exp(\eta \cdot T(x) - A(\eta)) \\
  &= \frac{1}{\sqrt{2\pi}} \exp(\eta \cdot T(x) - A(\eta)) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left([\frac{\mu}{\sigma^2}, -\frac{1}{2\sigma^2}] \cdot T(x) - A(\eta)\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left([\frac{\mu}{\sigma^2}, -\frac{1}{2\sigma^2}] \cdot [x, x^2]^T - A(\eta)\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left(\frac{\mu x}{\sigma^2} -\frac{x^2}{2\sigma^2} - A(\eta)\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left(\frac{\mu x}{\sigma^2} -\frac{x^2}{2\sigma^2} + \frac{\eta_1^2}{4\eta_2} - \frac{1}{2}\log\left|\frac{1}{2\eta_2}\right|\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left(\frac{\mu x}{\sigma^2} -\frac{x^2}{2\sigma^2} - \frac{\mu^2}{2\sigma^2} - \log |\sigma|\right) \\
  &= \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2} - \log |\sigma|\right) \\
  &= \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right) \\
  &= \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right) \\
\end{align*}$$

## Expectations
A key property of ExpFam distributions is that one can define the [moments](https://en.wikipedia.org/wiki/Moment_%28mathematics%29) of $T(x)$ from the log-partition function $A(\eta)$. Two key moments to remember are, $\mathbb{E}[T(x)] = \frac{\partial}{\partial \eta} A(\eta)$ and $\mathbb{V}[T(x)] = \frac{\partial^2}{\partial \eta \partial \eta^T} A(\eta)$.

In the case of scalar Normal variables, we have $\mathbb{E}[T(x)] = \mathbb{E}\left[[x, x^2]\right] = [\mu, \sigma^2 + \mu^2]$ and $\mathbb{V}[T(x)] = \begin{bmatrix} \sigma^2 & 0 \\ 0 & 2\sigma^4 \end{bmatrix}$.
$$\begin{align*}
\mathbb{E}[T(x)] &= \frac{\partial}{\partial \eta} A(\eta) \\
  &= \frac{\partial}{\partial \eta} \left[-\frac{\eta_1^2}{4\eta_2} + \frac{1}{2}\log|\frac{1}{2\eta_2}|\right] \\
  &= -\frac{\partial}{\partial \eta} \frac{\eta_1^2}{4\eta_2} + \frac{\partial}{\partial \eta}\frac{1}{2}\log|\frac{1}{2\eta_2}| \\
  &= -\frac{\partial}{\partial \eta} \frac{\eta_1^2}{4\eta_2} - \frac{\partial}{\partial \eta}\frac{1}{2}\log|2\eta_2| \\
\mathbb{E}[T(x)]_1 &= -\frac{\partial}{\partial \eta_1} \frac{\eta_1^2}{4\eta_2} - \frac{\partial}{\partial \eta_1}\frac{1}{2}\log|2\eta_2| \\
  &= - \frac{2 \eta_1}{4 \eta_2} = -\frac{\eta_1}{2 \eta_2} = -\frac{\mu / \sigma^2}{-2 / (2\sigma^2)} = \mu \\
\mathbb{E}[T(x)]_2 &= -\frac{\partial}{\partial \eta_2} \frac{\eta_1^2}{4\eta_2} - \frac{\partial}{\partial \eta_2}\frac{1}{2}\log|2\eta_2| \\
  &= \frac{\eta^2_1}{\eta_2^2} - \frac{1}{2\eta_2} = \frac{\mu^2/\sigma^4}{1/\sigma^4} + \frac{1}{1 / \sigma^2} = \mu^2 + \sigma^2.
\end{align*}$$

## Example: Poisson

## MLE Inference for ExpFam
Given $x_1, \dotsc, x_n$ we assume that $x_i \sim f(\eta)$ independently and identically distributed under some ExpFam dist $f$ with natural parameters $\eta$. Our log-likelihood is then given by,
$$ \ell(\eta) = \sum_{i=1}^n \log f(x_i | \eta).$$ To identify the
maximum likelihood estimates we first compute the gradient of $\ell$ wrt $\eta$, which is given by
$$\begin{align*}
\nabla \ell(\eta) &= \sum_{i=1}^n \nabla \log f(x_i | \eta) \\
  &= \sum_{i=1}^n \nabla \log \left[ h(x_i)\exp(\eta \cdot T(x_i) - A(\eta))\right] \\
  &= \sum_{i=1}^n \nabla \left[\log h(x_i) + \log \exp(\eta \cdot T(x_i) - A(\eta))\right] \\
  &= \sum_{i=1}^n \nabla \log h(x_i) + \nabla \log \exp(\eta \cdot T(x_i) - A(\eta)) \\
  &= \sum_{i=1}^n \nabla [\eta \cdot T(x_i)] - \nabla A(\eta) \\
  &= \sum_{i=1}^n T(x_i) - \mathbb{E}[T(x_i)],
\end{align*}$$
where we used the fact that  $\nabla A(\eta) = \mathbb{E}[T(x)]$. Setting this to zero and solving implies, we would like to find values $\eta$ such that $\sum_i T(x_i) = \sum_i \mathbb{E}[T(x_i)] = n \mathbb{E}[T(X)] ⇒ \frac{1}{n} \sum_i T(x_i) = \mathbb{E}[T(X)]$! In other words, MLE under ExpFam seeks values of $\eta$ that *match the empirical mean of the observed sufficient statistics to their expectation*!

## It's an Equinox for all seasons
TBD: Information on `equinox`, (functional) objective oriented programming, abstract base classes. More documentation on `Normal` class.

In [None]:
!pip install equinox

In [None]:
from abc import abstractmethod

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike


class ExpFam(eqx.Module):
  """
  Simple base class for Exponential Families. Will provide means to compute
  sufficient statistics and evaluat the loglikelihood of downstream
  implementations.
  """


  @abstractmethod
  def base_measure(self) -> Array:
    """
    Computes the base measure for an implementation of ExpFam.

    Returns:
      The base measure for an implementation of ExpFam.
    """
    ...

  @abstractmethod
  def sufficient_statistics(self, x: ArrayLike) -> Array:
    """
    Computes the sufficient statistics (i.e. $T(x)$) for an implementation
    of ExpFam.

    x: ArrayLike, the observations.

    Returns:
      The sufficient statistics for each observation under an implementation
      of ExpFam.
    """
    ...

  @abstractmethod
  def log_partition(self, eta: ArrayLike) -> Array:
    """
    Computes the log partition function (i.e. $A(\eta)$) for an implementation
    of ExpFam with natural parameters $\eta$.

    eta: ArrayLike, the natural parameters to evaluate under the log partition.

    Returns:
      The value of the log partition function for each observation under an
      implementation of ExpFam.
    """
    ...

  def loglikelihood(self, eta: ArrayLike, x: ArrayLike) -> Array:
    """
    Computes the log likelihood for each observation $x$ under an implementation
    of ExpFam with natural parameters $\eta$.

    eta: ArrayLike, the natural parameters to evaluate under the log likelihood.
    x: ArrayLike, the observations.

    Returns:
      The log likelihood for each observation under an implementation of ExpFam.
    """

    t_x = self.sufficient_statistics(x)
    log_h_x = jnp.log(self.base_measure())
    log_eta = self.log_partition(eta)
    return t_x @ eta - log_eta + log_h_x

  def __call__(self, eta: ArrayLike, x: ArrayLike) -> Array:
    return self.loglikelihood(eta, x: ArrayLike)

  def fit(self, x: ArrayLike) -> Array:
    """
    Maximizes the log likelihood for each observation $x$ under an implementation
    of ExpFam with natural parameters $\eta$.

    eta: ArrayLike, the natural parameters to evaluate under the log likelihood.
    x: ArrayLike, the observations.

    Returns:
      The natural parameters that maximize log likelihood.
    """
    t_x = self.sufficient_statistics(x)
    mle = jnp.mean(t_x, axis=0)
    return mle


class Normal(ExpFam):

  def base_measure(self) -> Array:
    return 1. / jnp.sqrt(2 * jnp.pi)

  def sufficient_statistics(self, x: ArrayLike) -> Array:
    x_sq = x ** 2
    return jnp.concatenate((x[:, jnp.newaxis], x_sq[:, jnp.newaxis]), axis=1)

  def log_partition(self, eta: ArrayLike) -> Array:
    eta_1, eta_2 = eta
    term1 = -(eta_1**2 / (4 * eta_2))
    term2 = -0.5 * jnp.log(- 2 * eta_2)
    return term1 + term2

Great! Having implemented our objects to perform MLE in a general setting, with a specific example of normal distributions, let's perform some sanity checks before proceeding with inference.

In [None]:
import jax.random as rdm

# initialize our PRNG state
seed = 0
key = rdm.PRNGKey(seed)

# define some parameter values for N(mu, sigma2)
mu = 5.0
sigma_sq = 10.0

# generate random data
N = 10
key, x_key = rdm.split(key)
obs = mu + jnp.sqrt(sigma_sq) * rdm.normal(x_key, shape=(N,))

# transform parameters to natural parameters and create some other guess at eta
eta_true = jnp.array([mu / sigma_sq, - 1. / (2 * sigma_sq)])
eta_guess = jnp.array([-9, -100.])

# create an instance of Normal distribution using our implementation above
model = Normal(obs)

# calculate the sum of log likelihood
# equivalent to `jnp.sum(model.loglikelihood(eta_true))`
sum_ll_true = jnp.sum(model(eta_true))
sum_ll_guess = jnp.sum(model(eta_guess))
print(f"true logl $\ell$({eta_true}) = {sum_ll_true} | guess logl $\ell$({eta_guess}) = {sum_ll_guess}")

# sanity check against jax loglikelihood
import jax.scipy.stats as stats
jax_ll_true = jnp.sum(stats.norm.logpdf(obs, mu, jnp.sqrt(sigma_sq)))
print(f"Our $\ell$({eta_true}) = {sum_ll_true} | JAX logl $\ell$({eta_true}) = {jax_ll_true}")

true logl $\ell$([ 0.5  -0.05]) = -26.96800994873047 | guess logl $\ell$([  -9. -100.]) = -17816.912109375
Our $\ell$([ 0.5  -0.05]) = -26.96800994873047 | JAX logl $\ell$([ 0.5  -0.05]) = -26.9680118560791
