<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_3_Optimization_PtI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Move on Up, or: Maximum likelihood Estimation & Optimization Pt I

TBD: move notes from slides to here.


In [None]:
import jax
import jax.numpy as jnp
import jax.random as rdm

## MLE for iid Normal data
Let $x_1, \dotsc, x_n \overset{\mathrm{iid}}{\sim} \mathcal{N}(\mu, \sigma^2)$ where $\mathcal{N}(\mu, \sigma^2)$ refers to the [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution) with mean parameter $\mu$ and variance parameter $\sigma^2$. The likelihood of our data is given by,
$$\begin{align*}
\mathcal{L}(\mu, \sigma^2 | x_1, \dots, x_n) &=
  \prod_{i=1}^n \mathcal{N}(x_i | \mu, \sigma^2) \\
  &= \prod_{i=1}^n \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)\\
  &= \left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)^n \exp\left(-\frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2\right).
\end{align*}\\
$$
Thus, our _log_-likelihood is given by,
The likelihood of our data is given by,
$$\begin{align*}
\ell(\mu, \sigma^2 | x_1, \dots, x_n) &=
  \log \left[\left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)^n \exp\left(-\frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2\right)\right]\\
  &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2.
\end{align*}\\
$$

In [None]:
def norm_rv(key, n: int, mu: float, sigma_sq: float):
  """
  Samples $n$ observations from $x_i \sim N(\mu, \sigma^2)$

  n: the number of observations
  mu: the mean parameter
  sigma_sq: the variance parameter

  returns: x, Array of observations
  """
  x = mu + jnp.sqrt(sigma_sq) * rdm.normal(key, shape=(n,))
  return x


def norm_mle(x):
  """
  Computes $\hat{\mu}_{MLE}$ and $\hat{\sigma^2}_{MLE}$.

  x: Array of observations

  returns: Tuple of $\hat{\mu}_{MLE}$ and $\hat{\sigma^2}_{MLE}$.
  """
  mu_hat = pass
  ssq_hat = pass

  return mu_hat, ssq_hat

seed = 0
key = rdm.PRNGKey(seed)
key, x_key = rdm.split(key)

N = 500

mu = 58.
sigma_sq = 100.
x = norm_rv(x_key, N, mu, sigma_sq)
#print(f"x = {x}")
mu_hat, ssq_hat = norm_mle(x)
print(f"MLE[\mu, \sigma^2] = {mu_hat}, {ssq_hat}")

MLE[\mu, \sigma^2] = 58.59890365600586, 98.35617065429688


In [None]:
def sq_diff(param, estimate):
  return (param - estimate) ** 2

mu = 58.
sigma_sq = 10.
for N in [50, 100, 1000, 10000]:
  key, x_key = rdm.split(key)
  # generate N observations
  x_n = norm_rv(x_key, N, mu, sigma_sq)
  # estimate mu, and sigma_sq
  mu_hat, ssq_hat = norm_mle(x_n)
  # compute the sq-diff for both and report
  mu_err = sq_diff(mu, mu_hat)
  ssq_err = sq_diff(sigma_sq, ssq_hat)
  print(f"MSE[{N} | mu, sigma^2] = {mu_err}, {ssq_err}")

MSE[50 | mu, sigma^2] = 1.0336028337478638, 6.931415557861328
MSE[100 | mu, sigma^2] = 0.11436349898576736, 3.418867826461792
MSE[1000 | mu, sigma^2] = 1.0024814400821924e-05, 0.01468171738088131
MSE[10000 | mu, sigma^2] = 0.004498897586017847, 0.015031831339001656


## MLE for iid Exponential data
TBD: Add notes for Exponential PDF and MLE estimator

In [None]:
def exp_rv(key, n: int, rate: float):
  """
  Samples $n$ observations from $x_i \sim Exp(\lambda)$

  n: the number of observations
  rate: the $\lambda$ parameter

  returns: x, Array of observations
  """
  mean = 1 / rate
  x = mean * rdm.exponential(key, shape=(n,))
  return x


def exp_mle(x):
  """
  Computes $\hat{\lambda}_{MLE}$.

  x: Array of observations

  returns: $\hat{\lambda}_{MLE}$.
  """
  rate_hat = pass
  return rate_hat

key, x_key = rdm.split(key)
N = 100
rate = 1 / 500.
x = exp_rv(x_key, N, rate)
print(f"x = {x}")
rate_hat = exp_mle(x)
print(f"MLE[\lambda = {rate}] = {rate_hat}")

x = [4.35397430e+02 4.72187927e+02 1.38809436e+03 2.51057831e+02
 7.86665115e+01 5.10070496e+02 6.66986755e+02 5.30637817e+02
 6.95734985e+02 3.18214294e+02 5.19751167e+01 2.74846916e+01
 3.22565138e-01 4.53188721e+02 6.49053101e+02 6.87411118e+01
 3.77136902e+02 4.52345795e+02 8.62563843e+02 4.71622925e+02
 1.00533356e+03 6.46890991e+02 2.95587494e+02 1.70733911e+03
 8.86031799e+01 3.67736015e+01 1.70877844e+03 1.35494446e+02
 6.13494530e+01 6.64146118e+02 2.19395728e+03 1.05528516e+03
 9.19410034e+02 3.70226440e+01 2.18483673e+02 5.09372223e+02
 2.02266724e+02 4.67761505e+02 1.08360504e+02 2.41796265e+03
 4.80211212e+02 9.36368164e+02 5.30847473e+02 2.47510696e+02
 3.98499237e+02 2.79448181e+02 1.04075806e+03 9.33896729e+02
 3.96697357e+02 4.65119263e+02 7.25767761e+02 9.85023422e+01
 5.21523010e+02 3.69823975e+02 1.16978165e+02 2.81428162e+02
 1.20678604e+02 1.87231949e+02 1.82771103e+02 9.60759338e+02
 7.37772095e+02 6.62937164e+01 3.19520782e+02 2.58933685e+02
 1.22909531e+02 1.70

In [None]:
rate = 1 / 50.
for N in [50, 100, 1000, 10000]:
  key, x_key = rdm.split(key)
  # generate N observations
  x_n = exp_rv(x_key, N, rate)
  # estimate rate
  rate_hat = exp_mle(x_n)
  # compute the sq-diff for rate
  rate_err = sq_diff(rate, rate_hat)
  print(f"MSE[{N} | \lambda = {rate}] = {rate_err}")

MSE[50 | \lambda = 0.02] = 9.133363164437469e-06
MSE[100 | \lambda = 0.02] = 6.241853043320589e-07
MSE[1000 | \lambda = 0.02] = 2.641813523496239e-07
MSE[10000 | \lambda = 0.02] = 4.547181120528876e-08


# Gradient descent
[Gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) seeks to iteratively optimize a function $f(\beta)$ by taking steps in the steepest direction,
$$ \hat{\beta} = \beta_t - \rho_t \nabla f(\beta_t),$$
where that direction is provided by the [gradient](https://en.wikipedia.org/wiki/Gradient) of (f).

A helpful way to recast gradient descent is that we seek to perform a series of _local_ optimizations,

$$\hat{\beta} = \min_\beta \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t}\|\beta - \beta_t\|_2^2.$$

To see how these are equivalent let's solve the local problem. but using inner product notation,
$$m(\beta) = \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t} (\beta - \beta_t)^T(\beta - \beta_t).$$
Now, using calculus again,
$$\begin{align*}
\nabla m(\beta) &= \nabla [ \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t} (\beta - \beta_t)^T(\beta - \beta_t)] \\
&= \nabla [\nabla f(\beta_t)^T \beta] + \frac{1}{2\rho_t} \nabla [(\beta - \beta_t)^T(\beta - \beta_t)] \\
&= \nabla f(\beta_t) + \frac{1}{\rho_t}(\beta - \beta_t) \Rightarrow \\
\hat{\beta} &= \beta_t - \rho_t \nabla f(\beta_t).
\end{align*}
$$

Neat! However, notice that the original local objective can be thought of as minimizing the directional derivative, but with a distance penalty, where that distance is defined by the geometry of the parameter space.

$$\hat{\beta} = \min_\beta \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t}\text{dist}(\beta, \beta_t).$$

When the natural geometry is $\mathbb{R}^p$ then $\text{dist}(\cdot) = \| \cdot \|_2^2$, however there are many  geometries that can describe the natural parameter space (for future class 😉)

In [None]:
def sim_linear_reg(key, N, P, r2=0.5):
  key, x_key = rdm.split(key)
  X = rdm.normal(x_key, shape=(N, P))

  key, b_key = rdm.split(key)
  beta = rdm.normal(b_key, shape=(P,))

  # g = jnp.dot(X, beta)
  g = X @ beta
  s2g = jnp.var(g)

  # back out what s2e is, such that s2g / (s2g + s2e) == h2
  s2e = (1 - r2) / r2 * s2g
  key, y_key = rdm.split(key)

  # add env noise to g, but scale such that var(e) == s2e
  y = g + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
  return y, X, beta

key, sim_key = rdm.split(key)

N = 1000
P = 5
y, X, beta = sim_linear_reg(sim_key, N, P)

def linreg_loss(beta_hat, y, X):
  pass

def gradient(beta_hat, y, X):
  pass

step_size = 1 / N
diff = 10.
last_loss = 1000.
idx = 0
beta_hat = jnp.zeros((P,))
# while delta in loss is large, continue
print(f"true beta = {beta}")
while jnp.fabs(diff) > 1e-3:

  # take a step in the direction of the gradient using step_size
  beta_hat = beta_hat - step_size * gradient(beta_hat, y, X)

  # update our current loss and compute delta
  cur_loss = linreg_loss(beta_hat, y, X)
  diff = last_loss - cur_loss
  last_loss = cur_loss

  # wave to the crowd
  print(f"Loss[{idx}]({beta_hat}) = {last_loss}")
  idx += 1

# OLS solution
beta_hat_ols = jnp.linalg.solve(X.T @ X, X.T @ y)
print(f"ols beta = {beta_hat_ols}")

true beta = [-0.7478904   0.07329974  0.05228964 -1.0152605  -0.6137889 ]
Loss[0]([-0.6955502   0.12941222 -0.00297071 -0.99138546 -0.57778835]) = 1002.5853271484375
Loss[1]([-0.7247923   0.04623848  0.05397529 -0.9596198  -0.64419246]) = 994.5390014648438
Loss[2]([-0.7290263   0.05385752  0.0567775  -0.96942943 -0.64901537]) = 994.4423217773438
Loss[3]([-0.7293095   0.05279243  0.05733057 -0.96860015 -0.6497474 ]) = 994.4410400390625
Loss[4]([-0.7293439   0.05291931  0.05735086 -0.96873957 -0.64979035]) = 994.4409790039062
ols beta = [-0.72934663  0.05290483  0.05735677 -0.9687273  -0.64979887]


In [None]:
key, sim_key = rdm.split(key)

N = 1000
P = 5
y, X, beta = sim_linear_reg(sim_key, N, P)

step_size = 1 / N
diff = 10.
last_loss = 1000.
idx = 0
beta_hat = jnp.zeros((P,))
# while delta in loss is large, continue
print("Using JAX to compute gradient")
print(f"true beta = {beta}")
while jnp.fabs(diff) > 1e-3:
  # take a step in the direction of the gradient using step_size

  beta_hat = beta_hat - step_size * jax.grad(linreg_loss)(beta_hat, y, X)

  # update our current loss and compute delta
  cur_loss = linreg_loss(beta_hat, y, X)
  diff = last_loss - cur_loss
  last_loss = cur_loss

  # wave to the crowd
  print(f"Loss[{idx}]({beta_hat}) = {last_loss}")
  idx += 1

# OLS solution
beta_hat_ols = jnp.linalg.solve(X.T @ X, X.T @ y)
print(f"ols beta = {beta_hat_ols}")

Using JAX to compute gradient
true beta = [-0.49276644  0.36893874  0.8708915  -0.22680327  0.31571558]
Loss[0]([-0.5512781   0.26551786  0.9118643  -0.22698347  0.35736504]) = 699.7918701171875
Loss[1]([-0.5188158   0.29075423  0.8899959  -0.22931792  0.3277857 ]) = 698.3126220703125
Loss[2]([-0.5198399   0.29171753  0.8909349  -0.23179689  0.33005244]) = 698.305419921875
Loss[3]([-0.5196871   0.2918358   0.89090306 -0.23202236  0.32988426]) = 698.305419921875
ols beta = [-0.5196884   0.29184714  0.8909022  -0.23206532  0.3298884 ]
