<a href="https://colab.research.google.com/github/anhquan-truong/PM520/blob/main/Lab_4_Optimization_PtII_AnhQuanTruong.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ain't no mountain high enough, or: Optimization Pt II
Outline for today:
1. Newton's Method & Quasi-Newton Methods
2. Poisson Regression Lab
3. Automatic differentiation

Before we _climb_ into second-order methods, keep the inferential target in view. We assume some parametric model
$
x_1,\dots,x_n \sim p(\cdot \mid \theta),
$
and estimate $\theta$ via [maximum likelihood](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation):
$$
\hat{\theta}_{\mathrm{MLE}} \in \arg\max_{\theta\in\Theta} \ell(\theta \mid x_{1:n}),
$$
where $\ell(\theta \mid x_{1:n})$ is the log-likelihood of the data. Last lecture, when closed-form solutions don't exist, we used [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent), an iterative procedure which only uses first-order (i.e. the gradient) information to improve upon our initial guess for $\theta$.

Today we'll cover a class of approaches that move beyond only first-order information to improve the [convergence rate](https://en.wikipedia.org/wiki/Rate_of_convergence) (roughly can think of this as the number of iterations needed to stop inference).

## Newton's Method for Optimization
Let $f(\beta)$ be the function we wish to optimize (e.g., log likelihood, a loss function, etc). Can we do better than gradient descent, by considering higher-order information (ie geometry) of the function $f$? Here, "better" is wrt convergence rate.

Let's consider a 2nd-order [Taylor-series approximation](https://en.wikipedia.org/wiki/Taylor_series) to $f$ around $\beta_t$ as,

$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)^T (\beta - \beta_t) + \frac{1}{2} (\beta - \beta_t)^T H(\beta_t)(\beta - \beta_t),$$ where $H(\beta_t) = \nabla^2 f(\beta_t)$ (i.e. the [Hessian](https://en.wikipedia.org/wiki/Hessian_matrix) of $f$ at $\beta_t$). If we minimize this _local_ approximation, we see

$\nabla_\beta f(\beta) \approx \nabla f(\beta_t) + H(\beta_t)(\beta - \beta_t) = \nabla f(\beta_t) + H(\beta_t)\beta - H(\beta_t)\beta_t ⇒$
$$ H(\beta_t)\beta = H(\beta_t)\beta_t - \nabla f(\beta_t).$$

We can recognize that this is a [system of linear equations](https://en.wikipedia.org/wiki/System_of_linear_equations) $A x = b$ where $A = H(\beta_t)$, $x = \beta$, and $b = H(\beta_t)\beta_t - \nabla f(\beta_t)$. The solution is given by, $\hat{x} = A^{-1}b$, which in this case implies,
$$ \beta = H(\beta_t)^{-1}\left(H(\beta_t)\beta_t - \nabla f(\beta_t)\right) = \beta_t - H(\beta_t)^{-1}\nabla f(\beta_t).$$


### Caveats
[Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method_in_optimization) is only guaranteed to converge _locally_ for convex/concave functions, and can diverge for some [convex functions](https://en.wikipedia.org/wiki/Convex_function) (e.g., $f(\beta) = \sqrt{\beta^2 + 1}$). To address this limitation, we can add a damping parameter, $\rho_t \in (0,1]$, which gives us,
$$ \beta_{t+1} = \beta_t - \rho_t H(\beta_t)^{-1}\nabla f(\beta_t).$$

## Quasi-Newton Methods for Optimization
What if computing $H(\beta_t)$ is prohibitive or too costly? Do we need _exact_ second order information to improve on gradient descent's convergence? Given an approximation of $H$, called $B$, i.e. $B(\beta_t) \approx H(\beta_t)$, [_quasi_-Newton methods](https://en.wikipedia.org/wiki/Quasi-Newton_method) optimize for the form
$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)^T (\beta - \beta_t) + \frac{1}{2} (\beta - \beta_t)^T B(\beta_t)(\beta - \beta_t),$$ where $B(\beta_t) \approx H(\beta_t)$. Optimizing this statement gives us our update rule,
$$ \beta_{t+1} = \beta_t - \rho_t B(\beta_t)^{-1}\nabla f(\beta_t).$$

## Poisson Regression
Assume $y_i | x_i \sim \text{Poi}(\lambda_i)$ where $\lambda_i := \exp(x_i^T \beta)$, and $\text{Poi}(k | \lambda) := \frac{\lambda^k \exp(-\lambda)}{k!}$ is the [PMF](https://en.wikipedia.org/wiki/Probability_mass_function) of the [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution). Given $\{(y_i, x_i)\}_{i=1}^n$, we would like to identify the [maximum likelihood parameter estimate](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation) for $\beta$. In other words, we would to find a value for $\beta$ such that we maximize the log-likelihood given by,
$$\begin{align*}
\ell(\beta) &= \sum_i \log \text{Poi}(y_i | \exp(x_i^T \beta)) \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta) \exp(-\exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))\right] - \log(y_i!) \\
&= \sum_i \left[y_i \cdot x_i^T \beta - \exp(x_i^T \beta) - \log(y_i!)\right] \\
&= y^T X\beta - \exp(X\beta)^T 1_n - O(1) \\
&= y^T X\beta - \lambda^T 1_n - O(1),
\end{align*}$$
where $\lambda = \{\lambda_1, \dotsc, \lambda_n\}.$


$$ \begin{align*}
\nabla_\beta \ell &= \nabla_\beta \left[ y^T X\beta - \lambda^T 1_n \right] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\lambda^T 1_n] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\exp(X\beta)^T 1_n] \\
&= X^T y - X^T \exp(X\beta)  \\
&= X^T y - X^T \lambda  \\
&= X^T(y - \lambda) \\
\nabla^2_{\beta \beta} \ell &= \nabla_{\beta} X^T(y - \lambda) \\
&= \nabla_{\beta} \left[X^T y - X^T \lambda \right] \\
&= - X^T \nabla_{\beta}  \lambda \\
&= -X^T \nabla_{\beta}  \exp(X\beta) \\
&= -X^T \Lambda X,
\end{align*}$$
where $\Lambda = \text{diag}(\lambda)$, i.e. $\Lambda_{ii} = \lambda_i$ and $\Lambda_{ij} = 0$ for $i \neq j$.

To illustrate how $\nabla_{\beta}  \exp(X\beta) = \Lambda X$ (i.e. last step in Hessian calculation), recall that the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of a function $f : \mathbb{R}^n → \mathbb{R}^m$ is the $m \times n$ matrix $J$ such that $J_{ij} = \frac{∂f_i}{∂j}$. In this case we are computing the Jacobian for $\exp(X\beta)$, which is $\mathbb{R}^p → \mathbb{R}^n$, so our final Jacobian for $\exp(X\beta)$ should have shape $n \times p$. Notice that $J_{i,j} = \frac{\partial}{\partial \beta_j} \exp(x_i^T \beta) = x_{ij}\exp(x_i^T \beta)$, thus $J_{i, .} = \exp(x_i^T \beta) x_i^T$. Repeating this for each $i$ we have $$∇_\beta \exp(X \beta) = J(\exp(X \beta)) = \begin{bmatrix} J_{1,.} \\ ⋮ \\ J_{n,.} \end{bmatrix} =
\begin{bmatrix} \exp(x_1^T \beta) x_1^T \\ ⋮ \\ \exp(x_n^T \beta) x_n^T \end{bmatrix}  =
\begin{bmatrix} \lambda_1 x_1^T \\ ⋮ \\ \lambda_n x_n^T\end{bmatrix} = \Lambda X.$$

We can fit using Newton's method. =>
$$\begin{align*}
\beta_{t+1} &= \beta_t - H(\beta_t)^{-1}\nabla \ell(\beta_t) \\
&= \beta_t + (X^T \Lambda_t X)^{-1} X^T (y - \lambda_t) ⇒ \\
&= (X^T \Lambda_t X)^{-1} X^T \Lambda_t X\beta_t + (X^T \Lambda_t X)^{-1} X^T (y - \lambda_t)\\
&= (X^T \Lambda_t X)^{-1} X^T \Lambda_t X\beta_t + (X^T \Lambda_t X)^{-1} X^T \Lambda_t\Lambda_t^{-1}(y - \lambda_t)\\
&= (X^T \Lambda_t X)^{-1} X^T \Lambda_t X\beta_t + (X^T \Lambda_t X)^{-1} X^T \Lambda_t(\Lambda_t^{-1}y - \Lambda_t^{-1}\lambda_t)\\
&= (X^T \Lambda_t X)^{-1} X^T \Lambda_t X\beta_t + (X^T \Lambda_t X)^{-1} X^T \Lambda_t(\Lambda_t^{-1}y - 1_n)\\
&= (X^T \Lambda_t X)^{-1} X^T \Lambda_t (\Lambda_t^{-1}y + X\beta_t - 1_n)
\end{align*}$$
where $\Lambda_t := \text{diag}(\lambda_1, \dotsc, \lambda_n)$.

In [71]:
!pip install lineax



In [70]:
import jax
import jax.numpy as jnp
import jax.random as rdm
import jax.scipy.stats as stats

import lineax as lx

@jax.jit
def loglikelihood(beta, y, X):
  """
  Our loglikelihood function for $y_i | x_i ~ \text{Poi}(\exp(eta_i))$.

  beta: beta
  y: poisson-distributed observations
  X: our design matrix as lx.MatrixLinearOperator

  returns: sum of the logliklihoods of each sample
  """
  rate = jnp.exp(X.mv(beta))
  return jnp.sum(stats.poisson.logpmf(y, rate))



@jax.jit
def irwls_fit(beta, y, X, step_size):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  beta: beta
  y: poisson-distributed observations
  X: our design matrix as lx.MatrixLinearOperator

  returns: updated estimate of $\beta$
  """
  # compute lambda_i := exp(x_i @ beta)
  eta = X.mv(beta)
  d_i = jnp.exp(eta)
  d_sqrt = jnp.sqrt(d_i)

  # compute z_i := Lambda^{1/2}(Lambda^-1 y + X @beta - 1)
  z = (y / d_i + eta - 1) * d_sqrt

  # X* := Lambda^{1/2} X
  # we use linear operators to postpone any computation
  X_star = lx.DiagonalLinearOperator(d_sqrt) @ X

  # lineax can solve normal equations iteratively as (t(X*) @ (X* @ guess)) - z
  solution = lx.linear_solve(X_star, z, solver=lx.Normal(lx.CG(atol=1e-4, rtol=1e-3)))
  beta = solution.value

  return beta


def poiss_reg(y, X, fit_func, step_size = 1.0, max_iter=100, tol=1e-3):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  y: poisson-distributed observations
  X: our design matrix
  max_iter: the maximum number of iterations to perform optimization
  tol:

  returns: updated estimate of $\beta$
  """
  # intialize eta := X @ beta
  n, p = X.shape

  # fake bookkeeping
  loglike = -100000
  delta = 10000

  # convert to a linear operator for lineax
  X = lx.MatrixLinearOperator(X)

  # initialize using OLS estimate and normalizing for downstream stability
  sol = lx.linear_solve(X, (y - jnp.mean(y))/2, solver=lx.Normal(lx.CG(atol=1e-4, rtol=1e-3)))
  beta = sol.value
  beta = beta / jnp.linalg.norm(beta)

  for epoch in range(max_iter):

    # fit using our function
    beta = fit_func(beta, y, X, step_size)

    # evaluate log likelihood
    newll = loglikelihood(beta, y, X)

    # take delta and check if we can stop
    delta = jnp.fabs(newll - loglike)
    print(f"Log likelihood[{epoch}] = {newll}")
    if delta < tol:
      break

    # replace old value
    loglike = newll

  return beta

  Our loglikelihood function for $y_i | x_i ~ \text{Poi}(\exp(eta_i))$.
  $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.
  $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.


In [73]:
# Let's simulate a poisson regression model with N samples and P variables
# we need X (N,P), beta (P,) and y (N,)
N = 1000
P = 5

# initialize PRNG env
seed = 0
key = rdm.PRNGKey(seed)

# TODO: split key for each random call
key, x_key, b_key, y_key = rdm.split(key, 4)
X = rdm.normal(x_key, shape=(N, P))
beta = rdm.normal(b_key, shape=(P,))


# TODO: compute lambda_i = exp(x_i' \beta)
lam = jnp.exp(X @ beta)


# TODO: sample y from Poi(lambda_i)
y = rdm.poisson(y_key, lam)

# estimate beta using our irwls function
# fit_func has signature (eta, y, X, step_size)
beta_hat = poiss_reg(y, X, irwls_fit) # 2nd order methods
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

Log likelihood[0] = -39173304.0
Log likelihood[1] = -14537620.0
Log likelihood[2] = -5402229.5
Log likelihood[3] = -2006548.625
Log likelihood[4] = -741087.25
Log likelihood[5] = -268893.0
Log likelihood[6] = -93487.3125
Log likelihood[7] = -29938.27734375
Log likelihood[8] = -8535.572265625
Log likelihood[9] = -2576.5654296875
Log likelihood[10] = -1502.05126953125
Log likelihood[11] = -1427.998779296875
Log likelihood[12] = -1427.3929443359375
Log likelihood[13] = -1427.393310546875
beta = [ 1.2956359   1.3550105  -0.40960556 -0.77188545  0.38094172]
hat(beta) = [ 1.2962788   1.3400885  -0.40938386 -0.77929157  0.38347557]


In [77]:
# let's implement poisson regression using _only_ gradient information to perform inference
# and measure how quickly it converges compared with the Newton method
def grad_fit(beta, y, X, step_size):
  """
  Update MLE estimate $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  Should perform a gradient update step.

  beta: beta
  y: poisson-distributed observations
  X: our design matrix as lx.MatrixLinearOperator

  returns: updated estimate of $\beta$
  """
  pass
  eta = X.mv(beta)
  d_i = jnp.exp(eta)
  return beta + step_size * X.transpose().mv(y - d_i)
# this is gradient asse


# NB: we can transpose a lx.MatrixLinearOperator (say X) as X.transpose()
# NB: we compute matrix-vector produces using a lx.MatrixLinearOperator as X.mv(v)
step_size = 1e-5
beta_hat = poiss_reg(y, X, grad_fit, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

Log likelihood[0] = -10699.7275390625
Log likelihood[1] = -8428.5107421875
Log likelihood[2] = -6478.630859375
Log likelihood[3] = -4908.4990234375
Log likelihood[4] = -3744.06103515625
Log likelihood[5] = -2956.5849609375
Log likelihood[6] = -2464.587646484375
Log likelihood[7] = -2166.8525390625
Log likelihood[8] = -1980.2850341796875
Log likelihood[9] = -1854.0355224609375
Log likelihood[10] = -1762.221435546875
Log likelihood[11] = -1692.367919921875
Log likelihood[12] = -1637.998291015625
Log likelihood[13] = -1595.2332763671875
Log likelihood[14] = -1561.4241943359375
Log likelihood[15] = -1534.61962890625
Log likelihood[16] = -1513.324462890625
Log likelihood[17] = -1496.37548828125
Log likelihood[18] = -1482.861083984375
Log likelihood[19] = -1472.06689453125
Log likelihood[20] = -1463.42919921875
Log likelihood[21] = -1456.504638671875
Log likelihood[22] = -1450.94384765625
Log likelihood[23] = -1446.469970703125
Log likelihood[24] = -1442.864990234375
Log likelihood[25] = -14

  $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.


## Automatic differentiation
Automatic differentiation (AD) applies the chain rule to the exact computational graph of a function, so we get derivatives of implemented code without hand-deriving every intermediate expression. For the Poisson objective,
$$
\ell(\beta)=\sum_{i=1}^n \left[y_i x_i^T\beta-\exp(x_i^T\beta)-\log(y_i!)\right],
$$
AD gives gradient and Hessian operators directly:
$$
\nabla_\beta \ell(\beta)=\mathrm{AD}(\ell)(\beta), \qquad \nabla_{\beta\beta}^2 \ell(\beta)=\nabla_\beta[\mathrm{AD}(\ell)](\beta).
$$
In JAX this corresponds to `jax.grad` for first derivatives and `jax.hessian` (or nested `jax.grad`) for second-order structure.

In [None]:
# let's not worry and use autodiff
auto_grad_ll = jax.grad(loglikelihood)

def jax_grad_step(beta, y, X, step_size):
  pass

# NB: we can transpose a lx.MatrixLinearOperator (say X) as X.transpose()
# NB: we compute matrix-vector produces using a lx.MatrixLinearOperator as X.mv(v)
step_size = 1e-6
beta_hat = poiss_reg(y, X, jax_grad_step, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

In [None]:
import jax.scipy.linalg as spla

# Great! But can we use 2nd order information?
auto_hess_ll = jax.hessian(loglikelihood)

def jax_newton_step(beta, y, X, step_size):
  pass

step_size = 1.
beta_hat = poiss_reg(y, X, jax_newton_step, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")