<a href="https://colab.research.google.com/github/anhquan-truong/PM520/blob/main/Lab_3_Optimization_PtI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Move on Up, or: Maximum likelihood Estimation & Optimization Pt I

TBD: move notes from slides to here.


In [1]:
import jax
import jax.numpy as jnp
import jax.random as rdm

## MLE for iid Normal data
Let $x_1, \dotsc, x_n \overset{\mathrm{iid}}{\sim} \mathcal{N}(\mu, \sigma^2)$ where $\mathcal{N}(\mu, \sigma^2)$ refers to the [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution) with mean parameter $\mu$ and variance parameter $\sigma^2$. The likelihood of our data is given by,
$$\begin{align*}
\mathcal{L}(\mu, \sigma^2 | x_1, \dots, x_n) &=
  \prod_{i=1}^n \mathcal{N}(x_i | \mu, \sigma^2) \\
  &= \prod_{i=1}^n \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)\\
  &= \left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)^n \exp\left(-\frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2\right).
\end{align*}\\
$$
Thus, our _log_-likelihood is given by,
The likelihood of our data is given by,
(maximize the log mean maximize the LLH function because of monotone
$$\begin{align*}
\ell(\mu, \sigma^2 | x_1, \dots, x_n) &=
  \log \left[\left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)^n \exp\left(-\frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2\right)\right]\\
  &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2.
\end{align*}\\
$$

In [14]:
def norm_rv(key, n: int, mu: float, sigma_sq: float):
  r"""
  Samples $n$ observations from $x_i \sim N(\mu, \sigma^2)$

  n: the number of observations
  mu: the mean parameter
  sigma_sq: the variance parameter

  returns: x, Array of observations
  """
  x = mu + jnp.sqrt(sigma_sq) * rdm.normal(key, shape=(n,))
  return x


def norm_mle(x):
  r"""
  Computes $\hat{\mu}_{MLE}$ and $\hat{\sigma^2}_{MLE}$.

  x: Array of observations

  returns:  Tuple of $\hat{\mu}_{MLE}$ and $\hat{\sigma^2}_{MLE}$.
  """
  mu_hat = jnp.mean(x)
  # mu_hat = jnp.sum(x) / len(x)
  #ssq_hat = jnp.mean(jnp.sum(x - mu_hat)**2)
  ssq_hat = jnp.var(x)

  return mu_hat, ssq_hat

seed = 0
key = rdm.PRNGKey(seed)
key, x_key = rdm.split(key)

N = 1000 # how many obs we want

mu = 58.
sigma_sq = 100.
x = norm_rv(x_key, N, mu, sigma_sq)
#print(f"x = {x}")
mu_hat, ssq_hat = norm_mle(x)
print(fr"MLE[\mu, \sigma^2] = {mu_hat}, {ssq_hat}")

MLE[\mu, \sigma^2] = 57.703948974609375, 110.18893432617188


In [19]:
def sq_diff(param, estimate): #square difference - l2norm
  return (param - estimate) ** 2

mu = 58.
sigma_sq = 10.
for N in [50, 100, 1000, 10000]:
  key, x_key = rdm.split(key)
  # generate N observations
  x_n = norm_rv(x_key, N, mu, sigma_sq)
  # estimate mu, and sigma_sq
  mu_hat, ssq_hat = norm_mle(x_n)
  # compute the sq-diff for both and report
  mu_err = sq_diff(mu, mu_hat)
  ssq_err = sq_diff(sigma_sq, ssq_hat)
  print(f"MSE[{N} | mu, sigma^2] = {mu_err}, {ssq_err}")

  # the error is descreasing as a function of N

MSE[50 | mu, sigma^2] = 0.5217084288597107, 3.3595786094665527
MSE[100 | mu, sigma^2] = 0.1484147012233734, 1.515868067741394
MSE[1000 | mu, sigma^2] = 0.0020461762323975563, 0.027577972039580345
MSE[10000 | mu, sigma^2] = 0.000588434049859643, 0.003656691173091531


## MLE for iid Exponential data
TBD: Add notes for Exponential PDF and MLE estimator

In [20]:
def exp_rv(key, n: int, rate: float):
  """
  Samples $n$ observations from $x_i \sim Exp(\lambda)$

  n: the number of observations
  rate: the $\lambda$ parameter

  returns: x, Array of observations
  """
  mean = 1 / rate
  x = mean * rdm.exponential(key, shape=(n,))
  return x


def exp_mle(x):
  """
  Computes $\hat{\lambda}_{MLE}$.

  x: Array of observations

  returns: $\hat{\lambda}_{MLE}$.
  """
  rate_hat = 1. / jnp.mean(x)
  return rate_hat

key, x_key = rdm.split(key)
N = 100
rate = 1 / 500.
x = exp_rv(x_key, N, rate)
print(f"x = {x}")
rate_hat = exp_mle(x)
print(f"MLE[\lambda = {rate}] = {rate_hat}")

  Samples $n$ observations from $x_i \sim Exp(\lambda)$
  Computes $\hat{\lambda}_{MLE}$.
  print(f"MLE[\lambda = {rate}] = {rate_hat}")


x = [8.5540985e+02 2.5296268e+02 6.9233221e+02 7.4306616e+02 2.0023449e+01
 3.0325012e+02 4.9858099e+02 4.4559082e+02 2.2487761e+02 1.7512365e+01
 1.8983838e+02 7.1858414e+01 5.3631573e+00 1.2744092e+02 4.3240015e+02
 2.5281116e+02 3.0460519e+02 9.8529327e+02 3.5058794e+03 9.7805145e+02
 5.8293109e+02 9.4777515e+02 3.6097495e+02 2.9653098e+02 5.9593036e+02
 2.6948223e+00 1.1319576e+02 1.3489890e+01 2.3469034e+02 8.8526413e+01
 5.4663782e+02 2.1881448e+02 3.5277534e+01 5.4674384e+02 2.1851385e+01
 4.5251007e+02 8.2026758e+02 6.7611023e+01 6.8925110e+01 2.2094575e+03
 7.8212195e+02 2.3713325e+03 1.7880037e+01 3.8670126e+02 1.2078967e+03
 1.9894528e+02 2.5183937e+02 6.4554199e+02 8.7256927e+01 1.3923959e+03
 1.8955074e+02 1.1624373e+03 6.1272675e+02 3.5862617e+01 1.7658134e+02
 4.1262427e+02 4.7592819e+02 5.8816719e+01 2.2342406e+02 2.8205704e+02
 1.3506290e+03 2.4395990e+02 5.0029422e+02 4.5187828e+01 3.4501726e+03
 9.8308083e+01 6.2813635e+02 9.8704968e+02 1.7907946e+03 5.0491989e+02
 6

In [None]:
rate = 1 / 50.
for N in [50, 100, 1000, 10000]:
  key, x_key = rdm.split(key)
  # generate N observations
  x_n = exp_rv(x_key, N, rate)
  # estimate rate
  rate_hat = exp_mle(x_n)
  # compute the sq-diff for rate
  rate_err = sq_diff(rate, rate_hat)
  print(f"MSE[{N} | \lambda = {rate}] = {rate_err}")

MSE[50 | \lambda = 0.02] = 9.133363164437469e-06
MSE[100 | \lambda = 0.02] = 6.241853043320589e-07
MSE[1000 | \lambda = 0.02] = 2.641813523496239e-07
MSE[10000 | \lambda = 0.02] = 4.547181120528876e-08


# Gradient descent
[Gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) seeks to iteratively optimize a function $f(\beta)$ by taking steps in the steepest direction,
$$ \hat{\beta} = \beta_t - \rho_t \nabla f(\beta_t),$$
where that direction is provided by the [gradient](https://en.wikipedia.org/wiki/Gradient) of (f).

A helpful way to recast gradient descent is that we seek to perform a series of _local_ optimizations,

$$\hat{\beta} = \min_\beta \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t}\|\beta - \beta_t\|_2^2.$$

To see how these are equivalent let's solve the local problem. but using inner product notation,
$$m(\beta) = \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t} (\beta - \beta_t)^T(\beta - \beta_t).$$
Now, using calculus again,
$$\begin{align*}
\nabla m(\beta) &= \nabla [ \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t} (\beta - \beta_t)^T(\beta - \beta_t)] \\
&= \nabla [\nabla f(\beta_t)^T \beta] + \frac{1}{2\rho_t} \nabla [(\beta - \beta_t)^T(\beta - \beta_t)] \\
&= \nabla f(\beta_t) + \frac{1}{\rho_t}(\beta - \beta_t) \Rightarrow \\
\hat{\beta} &= \beta_t - \rho_t \nabla f(\beta_t).
\end{align*}
$$

Neat! However, notice that the original local objective can be thought of as minimizing the directional derivative, but with a distance penalty, where that distance is defined by the geometry of the parameter space.

$$\hat{\beta} = \min_\beta \nabla f(\beta_t)^T \beta + \frac{1}{2\rho_t}\text{dist}(\beta, \beta_t).$$

When the natural geometry is $\mathbb{R}^p$ then $\text{dist}(\cdot) = \| \cdot \|_2^2$, however there are many  geometries that can describe the natural parameter space (for future class ðŸ˜‰)

In [58]:
def sim_linear_reg(key, N, P, r2=0.5):
  key, x_key = rdm.split(key)
  X = rdm.normal(x_key, shape=(N, P))

  key, b_key = rdm.split(key)
  beta = rdm.normal(b_key, shape=(P,))

  # g = jnp.dot(X, beta)
  g = X @ beta
  s2g = jnp.var(g)

  # back out what s2e is, such that s2g / (s2g + s2e) == h2
  s2e = (1 - r2) / r2 * s2g
  key, y_key = rdm.split(key)

  # add env noise to g, but scale such that var(e) == s2e
  y = g + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
  return y, X, beta

key, sim_key = rdm.split(key)

N = 1000
P = 5
y, X, beta = sim_linear_reg(sim_key, N, P)

def linreg_loss(beta_hat, y, X):
  #loss = jnp.square(jnp.linalg.norm(X @ beta_hat - y))
  loss = jnp.sum((y - X @ beta_hat)**2)
  return loss

def gradient(beta_hat, y, X):
  y_hat = X @ beta_hat
  return -1/2 * X.T @ (y-y_hat)
  #gradient = 1/2 * ((X.T @ X) @ beta_hat - (X.T @ y_hat))
  #return gradient

step_size = 1 / N
diff = 10.
last_loss = 1000.
idx = 0
beta_hat = jnp.zeros((P,))
points = []
# while delta in loss is large, continue
print(f"true beta = {beta}")
while jnp.fabs(diff) > 1e-3:

  # take a step in the direction of the gradient using step_size
  beta_hat = beta_hat - step_size * gradient(beta_hat, y, X)

  # update our current loss and compute delta
  cur_loss = linreg_loss(beta_hat, y, X)
  diff = last_loss - cur_loss
  last_loss = cur_loss

  # wave to the crowd
  print(f"Loss[{idx}]({beta_hat}) = {last_loss}")
  idx += 1

# OLS solution
beta_hat_ols = jnp.linalg.solve(X.T @ X, X.T @ y)
print(f"ols beta = {beta_hat_ols}")

true beta = [ 0.27317098  0.35787642 -0.43727624 -0.11257944  1.0787896 ]
Loss[0]([-0.2426016  -0.27000135  0.39343295  0.06244757 -1.087016  ]) = 3777.68359375
Loss[1]([-0.7305606 -0.7884005  1.1795309  0.1473572 -3.2879128]) = 12789.728515625
Loss[2]([-1.7137765  -1.7815676   2.7505023   0.23630705 -7.7456527 ]) = 49456.31640625
Loss[3]([ -3.6985135  -3.67969     5.890618    0.2505736 -16.777561 ]) = 198885.125
Loss[4]([ -7.712308    -7.2974014   12.168433    -0.05214828 -35.08352   ]) = 808855.125
Loss[5]([-15.844563  -14.171099   24.72171    -1.3284883 -72.19871  ]) = 3302788.5
Loss[6]([ -32.351723  -27.184803   49.828514   -5.24057  -147.47455 ]) = 13515814.0
Loss[7]([ -65.92068   -51.722427  100.05257   -15.820463 -300.19678 ]) = 55405400.0
Loss[8]([-134.3121    -97.76929   200.54146   -42.569073 -610.1455  ]) = 227484320.0
Loss[9]([ -273.9022   -183.7002    401.64087  -107.40582 -1239.386  ]) = 935440512.0
Loss[10]([ -559.322    -343.0067    804.16345  -260.09686 -2517.2383 ]) =

In [59]:
key, sim_key = rdm.split(key)

N = 1000
P = 5
y, X, beta = sim_linear_reg(sim_key, N, P)

step_size = 1 / N
diff = 10.
last_loss = 1000.
idx = 0
beta_hat = jnp.zeros((P,))
# while delta in loss is large, continue
print("Using JAX to compute gradient")
print(f"true beta = {beta}")
while jnp.fabs(diff) > 1e-3:
  # take a step in the direction of the gradient using step_size
  jax_gradient = jax.gradlinreg_loss)
  vandg = jax.value_and_grad(linreg_loss)
  cur_loss, g = vandg(beta_hat, y, X) # like last round loss
  beta_hat = beta_hat - step_size * jax.grad(linreg_loss)(beta_hat, y, X)

  # update our current loss and compute delta
  cur_loss = linreg_loss(beta_hat, y, X)
  diff = -last_loss + cur_loss
  last_loss = cur_loss

  # wave to the crowd
  print(f"Loss[{idx}]({beta_hat}) = {last_loss}")
  idx += 1

# OLS solution
beta_hat_ols = jnp.linalg.solve(X.T @ X, X.T @ y)
print(f"ols beta = {beta_hat_ols}")

Using JAX to compute gradient
true beta = [ 1.3997703   0.54678535  0.6298904  -1.3709987  -1.343823  ]
Loss[0]([ 1.2909186  0.5590918  0.4619145 -1.4868585 -1.2439543]) = 2938.958740234375
Loss[1]([ 1.391792    0.56393296  0.5750283  -1.4181691  -1.2588187 ]) = 2924.93408203125
Loss[2]([ 1.3918428   0.55968356  0.57975376 -1.4267044  -1.2711742 ]) = 2924.802734375
Loss[3]([ 1.3926661   0.55965734  0.5808359  -1.4258059  -1.2710974 ]) = 2924.80126953125
Loss[4]([ 1.3926512  0.5596132  0.5808643 -1.425903  -1.2712198]) = 2924.801513671875
ols beta = [ 1.3926594  0.5596131  0.5808748 -1.4258937 -1.2712185]
