<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_3_Optimization_PtI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Move on Up, or: Maximum likelihood Estimation & Optimization Pt I


In [None]:
import jax
import jax.numpy as jnp
import jax.random as rdm

## MLE for iid Normal data
Let $x_1, \dotsc, x_n \overset{\mathrm{iid}}{\sim} N(\mu, \sigma^2)$ where $N(\mu, \sigma^2)$ refers to the [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution) with mean parameter $\mu$ and variance parameter $\sigma^2$. The likelihood of our data is given by,
$$\begin{align*}
L(\mu, \sigma^2 | x_1, \dots, x_n) &=
  \prod_{i=1}^n N(x_i | \mu, \sigma^2) \\
  &= \prod_{i=1}^n \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)\\
  &= \left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)^n \exp\left(-\frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2\right).
\end{align*}\\
$$
Thus, our _log_-likelihood is given by,
The likelihood of our data is given by,
$$\begin{align*}
\ell(\mu, \sigma^2 | x_1, \dots, x_n) &=
  \log \left[\left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)^n \exp\left(-\frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2\right)\right]\\
  &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2.
\end{align*}\\
$$

In [None]:
def norm_rv(key, n: int, mu: float, sigma_sq: float):
  """
  Samples $n$ observations from $x_i \sim N(\mu, \sigma^2)$

  n: the number of observations
  mu: the mean parameter
  sigma_sq: the variance parameter

  returns: x, Array of observations
  """
  x = mu + jnp.sqrt(sigma_sq) * rdm.normal(key, shape=(n,))
  return x


def norm_mle(x):
  """
  Computes $\hat{\mu}_{MLE}$ and $\hat{\sigma^2}_{MLE}$.

  x: Array of observations

  returns: Tuple of $\hat{\mu}_{MLE}$ and $\hat{\sigma^2}_{MLE}$.
  """
  mu_hat = None
  ssq_hat = None

  return mu_hat, ssq_hat

seed = 0
key = rdm.PRNGKey(seed)
key, x_key = rdm.split(key)

N = 500

mu = 58.
sigma_sq = 100.
x = norm_rv(x_key, N, mu, sigma_sq)
#print(f"x = {x}")
mu_hat, ssq_hat = norm_mle(x)
print(f"MLE[\mu, \sigma^2] = {mu_hat}, {ssq_hat}")

In [None]:
def sq_diff(param, estimate):
  return (param - estimate) ** 2

mu = 58.
sigma_sq = 100.
for N in [50, 100, 1000, 10000]:
  key, x_key = rdm.split(key)
  # generate N observations
  x_n = norm_rv(x_key, N, mu, sigma_sq)
  # estimate mu, and sigma_sq
  mu_hat, ssq_hat = norm_mle(x_n)
  # compute the sq-diff for both and report
  mu_err = sq_diff(mu, mu_hat)
  ssq_err = sq_diff(sigma_sq, ssq_hat)
  print(f"MSE[{N} | mu, sigma^2] = {mu_err}, {ssq_err}")

## MLE for iid Exponential data
Let $x_1, \dotsc, x_n \overset{\mathrm{iid}}{\sim} Exp(\lambda)$ where $Exp(\lambda)$ refers to the [Exponential distribution](https://en.wikipedia.org/wiki/Exponential_distribution) with rate parameter $\lambda$. The likelihood of our data is given by,
$$\begin{align*}
L(\lambda | x_1, \dots, x_n) &=
  \prod_{i=1}^n Exp(x_i | \lambda) \\
  &= \prod_{i=1}^n \lambda \exp(-\lambda x_i).
\end{align*}\\
$$

In [None]:
def exp_rv(key, n: int, rate: float):
  """
  Samples $n$ observations from $x_i \sim Exp(\lambda)$

  n: the number of observations
  rate: the $\lambda$ parameter

  returns: x, Array of observations
  """
  mean = 1 / rate
  x = mean * rdm.exponential(key, shape=(n,))
  return x


def exp_mle(x):
  """
  Computes $\hat{\lambda}_{MLE}$.

  x: Array of observations

  returns: $\hat{\lambda}_{MLE}$.
  """
  rate_hat = pass
  return rate_hat

key, x_key = rdm.split(key)
N = 100
rate = 1 / 500.
x = exp_rv(x_key, N, rate)
#print(f"x = {x}")
rate_hat = exp_mle(x)
print(f"MLE[\lambda = {rate}] = {rate_hat}")

In [None]:
rate = 1 / 50.
for N in [50, 100, 1000, 10000]:
  key, x_key = rdm.split(key)
  # generate N observations
  x_n = exp_rv(x_key, N, rate)
  # estimate rate
  rate_hat = exp_mle(x_n)
  # compute the sq-diff for rate
  rate_err = sq_diff(rate, rate_hat)
  print(f"MSE[{N} | \lambda = {rate}] = {rate_err}")

# Gradient descent
[Gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) seeks to iteratively optimize a function $f(\beta)$ by taking steps in the steepest direction,
$$ \hat{\beta} = \beta_t - \rho_t \nabla f(\beta_t),$$
where that direction is provided by the [gradient](https://en.wikipedia.org/wiki/Gradient) of (f).

A helpful way to recast gradient descent is that we seek to perform a series of _local_ optimizations,

$$\hat{\beta} = \min_\beta \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t}\|\beta - \beta_t\|_2^2.$$

To see how these are equivalent let's solve the local problem. but using inner product notation,
$$m(\beta) = \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t} (\beta - \beta_t)^T(\beta - \beta_t).$$
Now, using calculus again,
$$\begin{align*}
\nabla m(\beta) &= \nabla [ \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t} (\beta - \beta_t)^T(\beta - \beta_t)] \\
&= \nabla [\nabla f(\beta_t)^T \beta] + \frac{1}{2\rho_t} \nabla [(\beta - \beta_t)^T(\beta - \beta_t)] \\
&= \nabla f(\beta_t) + \frac{1}{\rho_t}(\beta - \beta_t) \Rightarrow \\
\hat{\beta} &= \beta_t - \rho_t \nabla f(\beta_t).
\end{align*}
$$

Neat! However, notice that the original local objective can be thought of as minimizing the directional derivative, but with a distance penalty, where that distance is defined by the geometry of the parameter space.

$$\hat{\beta} = \min_\beta \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t}\text{dist}(\beta, \beta_t).$$

When the natural geometry is $\mathbb{R}^p$ then $\text{dist}(\cdot) = \| \cdot \|_2^2$, however there are many  geometries that can describe the natural parameter space (for future class ðŸ˜‰)

In [None]:
def sim_linear_reg(key, N, P, r2=0.5):
  key, x_key = rdm.split(key)
  X = rdm.normal(x_key, shape=(N, P))

  key, b_key = rdm.split(key)
  beta = rdm.normal(b_key, shape=(P,))

  # g = jnp.dot(X, beta)
  g = X @ beta
  s2g = jnp.var(g)

  # back out what s2e is, such that s2g / (s2g + s2e) == h2
  s2e = (1 - r2) / r2 * s2g
  key, y_key = rdm.split(key)

  # add env noise to g, but scale such that var(e) == s2e
  y = g + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
  return y, X, beta

key, sim_key = rdm.split(key)

N = 1000
P = 5
y, X, beta = sim_linear_reg(sim_key, N, P)

def linreg_loss(beta_hat, y, X):
  pass

def gradient(beta_hat, y, X):
  pass

step_size = 1 / N
diff = 10.
last_loss = 1000.
idx = 0
beta_hat = jnp.zeros((P,))
# while delta in loss is large, continue
print(f"true beta = {beta}")
while jnp.fabs(diff) > 1e-3:

  # take a step in the direction of the gradient using step_size
  beta_hat = beta_hat - step_size * gradient(beta_hat, y, X)

  # update our current loss and compute delta
  cur_loss = linreg_loss(beta_hat, y, X)
  diff = last_loss - cur_loss
  last_loss = cur_loss

  # wave to the crowd
  print(f"Loss[{idx}]({beta_hat}) = {last_loss}")
  idx += 1

# OLS solution
beta_hat_ols = jnp.linalg.solve(X.T @ X, X.T @ y)
print(f"ols beta = {beta_hat_ols}")

In [None]:
key, sim_key = rdm.split(key)

N = 1000
P = 5
y, X, beta = sim_linear_reg(sim_key, N, P)

step_size = 1 / N
diff = 10.
last_loss = 1000.
idx = 0
beta_hat = jnp.zeros((P,))
# while delta in loss is large, continue
print("Using JAX to compute gradient")
print(f"true beta = {beta}")
while jnp.fabs(diff) > 1e-3:
  # take a step in the direction of the gradient using step_size

  beta_hat = beta_hat - step_size * jax.grad(linreg_loss)(beta_hat, y, X)

  # update our current loss and compute delta
  cur_loss = linreg_loss(beta_hat, y, X)
  diff = last_loss - cur_loss
  last_loss = cur_loss

  # wave to the crowd
  print(f"Loss[{idx}]({beta_hat}) = {last_loss}")
  idx += 1

# OLS solution
beta_hat_ols = jnp.linalg.solve(X.T @ X, X.T @ y)
print(f"ols beta = {beta_hat_ols}")