<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_4_Optimization_PtII.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ain't no mountain high enough, or: Optimization Pt II
Outline for today:
1. Newton's Method & Quasi-Newton Methods
2. Poisson Regression Lab
3. Automatic differentiation

## Newton's Method for Optimization
Can we do better, by considering higher-order information (ie geometry) of
the function $f$?

Let's consider a 2nd-order [Taylor-series approximation](https://en.wikipedia.org/wiki/Taylor_series) to $f$ around $\beta_t$ as,

$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)^T (\beta - \beta_t) + \frac{1}{2} (\beta - \beta_t)^T H(\beta_t)(\beta - \beta_t),$$ where $H(\beta_t) = \nabla^2 f(\beta_t)$ (i.e. the [Hessian](https://en.wikipedia.org/wiki/Hessian_matrix) of $f$ at $\beta_t$). If we minimize this _local_ approximation, we see

$\nabla_\beta f(\beta) \approx \nabla f(\beta_t) + H(\beta_t)(\beta - \beta_t) = \nabla f(\beta_t) + H(\beta_t)\beta - H(\beta_t)\beta_t ⇒$
$$ H(\beta_t)\beta = H(\beta_t)\beta_t - \nabla f(\beta_t).$$

We can recognize that this is a [system of linear equations](https://en.wikipedia.org/wiki/System_of_linear_equations) $A x = b$ where $A = H(\beta_t)$, $x = \beta$, and $b = H(\beta_t)\beta_t - \nabla f(\beta_t)$. The solution is given by, $\hat{x} = A^{-1}b$, which in this case implies,
$$ \hat{\beta} = H(\beta_t)^{-1}\left(H(\beta_t)\beta_t - \nabla f(\beta_t)\right) = \beta_t - H(\beta_t)^{-1}\nabla f(\beta_t).$$



[Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method_in_optimization) is only guaranteed to converge _locally_, and can diverge even for _strongly_ [convex functions](https://en.wikipedia.org/wiki/Convex_function) (e.g., $f(\beta) = \sqrt{\beta^2 + 1}$). To address this limitation, we can add a dampening parameter, $\rho_t$, which gives us our final update form,
$$ \hat{\beta} = H(\beta_t)^{-1}(H(\beta_t)\beta_t - \nabla f(\beta_t)) = \beta_t - \rho_t H(\beta_t)^{-1}\nabla f(\beta_t).$$

## Quasi-Newton Methods for Optimization
What if computing $H(\beta_t)$ is prohibitive or too costly? Do we need _exact_ second order information to improve on gradient descent's convergence? Given an approximation of $H$, called $B$, i.e. $B(\beta_t) \approx H(\beta_t)$, [_quasi_-Newton methods](https://en.wikipedia.org/wiki/Quasi-Newton_method) optimize for the form
$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)^T (\beta - \beta_t) + \frac{1}{2} (\beta - \beta_t)^T B(\beta_t)(\beta - \beta_t),$$ where $B(\beta_t) \approx H(\beta_t)$. Optimizing this statement gives us our update rule,
$$ \hat{\beta} = \beta_t - \rho_t B(\beta_t)^{-1}\nabla f(\beta_t).$$

## Poisson Regression

$$y_i | x_i \sim \text{Poi}(\lambda_i)$$ where $\lambda_i := \exp(x_i^T \beta)$, and $\text{Poi}(k | \lambda) := \frac{\lambda^k \exp(-\lambda)}{k!}$ is the [PMF](https://en.wikipedia.org/wiki/Probability_mass_function) of the [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution). Given $\{(y_i, x_i)\}_{i=1}^n$, we would like to identify the [maximum likelihood parameter estimate](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation) for $\beta$. In other words, we would to find a value for $\beta$ such that we maximize the log-likelihood given by,
$$\begin{align*}
\log \ell(\beta) &= \sum_i \log \text{Poi}(y_i | \exp(x_i^T \beta)) \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta) \exp(-\exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))\right] - \log(y_i!) \\
&= \sum_i \left[y_i \cdot x_i^T \beta - \exp(x_i^T \beta) - \log(y_i!)\right] \\
&= y^T X\beta - \exp(X\beta)^T 1_n - O(1) \\
&= y^T X\beta - \lambda^T 1_n - O(1),
\end{align*}$$
where $\lambda = \{\lambda_1, \dotsc, \lambda_n\}.$


$$ \begin{align*}
\nabla_\beta \ell &= \nabla_\beta \left[ y^T X\beta - \lambda^T 1_n \right] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\lambda^T 1_n] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\exp(X\beta)^T 1_n] \\
&= X^T y - X^T \exp(X\beta)  \\
&= X^T y - X^T \lambda  \\
&= X^T(y - \lambda) \\
\nabla^2_{\beta \beta} \ell &= \nabla_{\beta} X^T(y - \lambda) \\
&= \nabla_{\beta} \left[X^T y - X^T \lambda \right] \\
&= - X^T \nabla_{\beta}  \lambda \\
&= -X^T \nabla_{\beta}  \exp(X\beta) \\
&= -X^T \Lambda X,
\end{align*}$$
where $\Lambda = \text{diag}(\lambda)$, i.e. $\Lambda_{ii} = \lambda_i$ and $\Lambda_{ij} = 0$ for $i \neq j$.

To illustrate how $\nabla_{\beta}  \exp(X\beta) = \Lambda X$ (i.e. last step in Hessian calculation), recall that the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of a function $f : \mathbb{R}^n → \mathbb{R}^m$ is the $m \times n$ matrix $J$ such that $J_{ij} = \frac{∂f_i}{∂j}$. In this case we are computing the Jacobian for $\exp(X\beta)$, which is $\mathbb{R}^p → \mathbb{R}^n$, so our final Jacobian for $\exp(X\beta)$ should have shape $n \times p$. Notice that $J_{i,j} = \frac{\partial}{\partial \beta_j} \exp(x_i^T \beta) = x_{ij}\exp(x_i^T \beta)$, thus $J_{i, .} = \exp(x_i^T \beta) x_i^T$. Repeating this for each $i$ we have $$∇_\beta \exp(X \beta) = J(\exp(X \beta)) = \begin{bmatrix} J_{1,.} \\ ⋮ \\ J_{n,.} \end{bmatrix} =
\begin{bmatrix} \exp(x_1^T \beta) x_1^T \\ ⋮ \\ \exp(x_n^T \beta) x_n^T \end{bmatrix}  =
\begin{bmatrix} \lambda_1 x_1^T \\ ⋮ \\ \lambda_n x_n^T\end{bmatrix} = \Lambda X.$$

We can fit using Newton's method. =>
$$\begin{align*}
\beta(t+1) &= \beta(t) - H(\beta(t))^{-1}\nabla \ell(\beta_t) \\
&= \beta(t) + (X^T \Lambda(t) X)^{-1} X^T (y - \lambda) ⇒ \\
&= (X^T \Lambda(t) X)^{-1} X^T \Lambda(t) (\Lambda(t)^{-1}y + X\beta(t) - 1)
\end{align*}$$
where $\Lambda(t) := \text{diag}(\lambda_1, \dotsc, \lambda_n)$.

In [1]:
!pip install lineax



In [2]:
import jax
import jax.numpy as jnp
import jax.random as rdm
import jax.scipy.stats as stats

import lineax as lx

@jax.jit
def loglikelihood(beta, y, X):
  """
  Our loglikelihood function for $y_i | x_i ~ \text{Poi}(\exp(eta_i))$.

  beta: beta
  y: poisson-distributed observations
  X: our design matrix as lx.AbstractLinearOperator

  returns: sum of the logliklihoods of each sample
  """
  pass



@jax.jit
def irwls_fit(beta, y, X, step_size):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  beta: beta
  y: poisson-distributed observations
  X: our design matrix

  returns: updated estimate of $\beta$
  """
  # compute lambda_i := exp(x_i @ beta)
  eta = X.mv(beta)
  d_i = jnp.exp(eta)
  d_sqrt = jnp.sqrt(d_i)

  # compute z_i := Lambda^{1/2}(Lambda^-1 y + X @beta - 1)
  z = (y / d_i + eta - 1) * d_sqrt

  # X* := Lambda^{1/2} X
  # we use linear operators to postpone any computation
  X_star = lx.DiagonalLinearOperator(d_sqrt) @ X

  # lineax can solve normal equations iteratively as (t(X*) @ (X* @ guess)) - z
  solution = lx.linear_solve(X_star, z, solver=lx.NormalCG(atol=1e-4, rtol=1e-3))
  beta = solution.value

  return beta


def poiss_reg(y, X, fit_func, step_size = 1.0, max_iter=100, tol=1e-3):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  y: poisson-distributed observations
  X: our design matrix
  max_iter: the maximum number of iterations to perform optimization
  tol:

  returns: updated estimate of $\beta$
  """
  # intialize eta := X @ beta
  n, p = X.shape

  # fake bookkeeping
  loglike = -100000
  delta = 10000

  # convert to a linear operator for lineax
  X = lx.MatrixLinearOperator(X)

  # initialize using OLS estimate and normalizing for downstream stability
  sol = lx.linear_solve(X, (y - jnp.mean(y))/2, solver=lx.NormalCG(atol=1e-4, rtol=1e-3))
  beta = sol.value
  beta = beta / jnp.linalg.norm(beta)

  for epoch in range(max_iter):

    # fit using our function
    beta = fit_func(beta, y, X, step_size)

    # evaluate log likelihood
    newll = loglikelihood(beta, y, X)

    # take delta and check if we can stop
    delta = jnp.fabs(newll - loglike)
    print(f"Log likelihood[{epoch}] = {newll}")
    if delta < tol:
      break

    # replace old value
    loglike = newll

  return beta

In [3]:
# Let's simulate a poisson regression model with N samples and P variables
# we need X (N,P), beta (P,) and y (N,)
N = 1000
P = 5

# initialize PRNG env
seed = 0
key = rdm.PRNGKey(seed)

# TODO: split key for each random call


# TODO: compute lambda_i = exp(x_i' \beta)


# TODO: sample y from Poi(lambda_i)


# estimate beta using our irwls function
# fit_func has signature (eta, y, X, step_size)
beta_hat = poiss_reg(y, X, irwls_fit)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

Log likelihood[0] = -39143916.0
Log likelihood[1] = -14508134.0
Log likelihood[2] = -5372775.0
Log likelihood[3] = -1977102.625
Log likelihood[4] = -711654.75
Log likelihood[5] = -239464.78125
Log likelihood[6] = -64012.875
Log likelihood[7] = -523.58203125
Log likelihood[8] = 20885.935546875
Log likelihood[9] = 26847.2421875
Log likelihood[10] = 27922.48828125
Log likelihood[11] = 27996.65625
Log likelihood[12] = 27997.24609375
Log likelihood[13] = 27997.25
Log likelihood[14] = 27997.24609375
Log likelihood[15] = 27997.251953125
Log likelihood[16] = 27997.248046875
Log likelihood[17] = 27997.24609375
Log likelihood[18] = 27997.24609375
beta = [ 1.2956359   1.3550105  -0.40960556 -0.77188545  0.38094172]
hat(beta) = [ 1.2962791   1.3400909  -0.40938488 -0.77928865  0.3834757 ]


In [9]:
# let's implement poisson regression using _only_ gradient information to perform inference
# and measure how quickly it converges compared with the Newton method
def grad_fit(beta, y, X, step_size):
  pass

# NB: we can transpose a lx.MatrixLinearOperator (say X) as X.transpose()
# NB: we compute matrix-vector produces using a lx.MatrixLinearOperator as X.mv(v)
step_size = 1e-5
beta_hat = poiss_reg(y, X, grad_fit, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

Log likelihood[0] = 18724.921875
Log likelihood[1] = 20996.13671875
Log likelihood[2] = 22946.0234375
Log likelihood[3] = 24516.162109375
Log likelihood[4] = 25680.60546875
Log likelihood[5] = 26468.087890625
Log likelihood[6] = 26960.0546875
Log likelihood[7] = 27257.806640625
Log likelihood[8] = 27444.361328125
Log likelihood[9] = 27570.61328125
Log likelihood[10] = 27662.41796875
Log likelihood[11] = 27732.294921875
Log likelihood[12] = 27786.658203125
Log likelihood[13] = 27829.423828125
Log likelihood[14] = 27863.21484375
Log likelihood[15] = 27890.0390625
Log likelihood[16] = 27911.33984375
Log likelihood[17] = 27928.25
Log likelihood[18] = 27941.796875
Log likelihood[19] = 27952.58984375
Log likelihood[20] = 27961.2109375
Log likelihood[21] = 27968.14453125
Log likelihood[22] = 27973.69921875
Log likelihood[23] = 27978.1796875
Log likelihood[24] = 27981.78515625
Log likelihood[25] = 27984.6796875
Log likelihood[26] = 27987.052734375
Log likelihood[27] = 27988.9296875
Log likelih

## Automatic differentiation
Chain rules, okay! Notes TBD

In [10]:
# let's not worry and use autodiff
auto_grad_ll = jax.grad(loglikelihood)

def jax_grad_step(beta, y, X, step_size):
  pass

# NB: we can transpose a lx.MatrixLinearOperator (say X) as X.transpose()
# NB: we compute matrix-vector produces using a lx.MatrixLinearOperator as X.mv(v)
step_size = 1e-6
beta_hat = poiss_reg(y, X, jax_grad_step, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

Log likelihood[0] = 16456.615234375
Log likelihood[1] = 16712.373046875
Log likelihood[2] = 16965.998046875
Log likelihood[3] = 17217.419921875
Log likelihood[4] = 17466.59765625
Log likelihood[5] = 17713.4609375
Log likelihood[6] = 17957.978515625
Log likelihood[7] = 18200.03125
Log likelihood[8] = 18439.634765625
Log likelihood[9] = 18676.671875
Log likelihood[10] = 18911.099609375
Log likelihood[11] = 19142.8671875
Log likelihood[12] = 19371.900390625
Log likelihood[13] = 19598.134765625
Log likelihood[14] = 19821.53515625
Log likelihood[15] = 20042.001953125
Log likelihood[16] = 20259.4921875
Log likelihood[17] = 20473.96484375
Log likelihood[18] = 20685.294921875
Log likelihood[19] = 20893.521484375
Log likelihood[20] = 21098.5
Log likelihood[21] = 21300.220703125
Log likelihood[22] = 21498.61328125
Log likelihood[23] = 21693.6328125
Log likelihood[24] = 21885.2109375
Log likelihood[25] = 22073.306640625
Log likelihood[26] = 22257.884765625
Log likelihood[27] = 22438.890625
Log li

In [None]:
import jax.scipy.linalg as spla

# Great! But can we use 2nd order information?
auto_hess_ll = jax.hessian(loglikelihood)

def jax_newton_step(beta, y, X, step_size):
  pass

step_size = 1.
beta_hat = poiss_reg(y, X, jax_newton_step, step_size, max_iter=1000)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")