<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab5_Optimization_PtII.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install lineax

Collecting lineax
  Downloading lineax-0.0.4-py3-none-any.whl (65 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.3/65.3 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting equinox>=0.11.0 (from lineax)
  Downloading equinox-0.11.3-py3-none-any.whl (167 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m167.9/167.9 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting jaxtyping>=0.2.20 (from lineax)
  Downloading jaxtyping-0.2.25-py3-none-any.whl (39 kB)
Collecting typeguard<3,>=2.13.3 (from jaxtyping>=0.2.20->lineax)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping, equinox, lineax
Successfully installed equinox-0.11.3 jaxtyping-0.2.25 lineax-0.0.4 typeguard-2.13.3


## Newton's Method for Optimization
Recall under [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) we approximate our function $f(\beta)$ as,
$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)(\beta - \beta_t),$$ which suggests to locally minimize this approximation we find,
$$\nabla_\beta f(\beta) \approx \nabla f(\beta_t).$$

But, can we do better, by considering higher-order information (ie geometry) of
the function $f$?

Let's consider a 2nd-order [Taylor-series approximation](https://en.wikipedia.org/wiki/Taylor_series) to $f$ around $\beta_t$ as,

$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)^T (\beta - \beta_t) + \frac{1}{2} (\beta - \beta_t)^T H(\beta_t)(\beta - \beta_t),$$ where $H(\beta_t) = \nabla^2 f(\beta_t)$ (i.e. the [Hessian](https://en.wikipedia.org/wiki/Hessian_matrix) of $f$ at $\beta_t$). If we minimize this _local_ approximation, we see

$\nabla_\beta f(\beta) \approx \nabla f(\beta_t) + \nabla^2 f(\beta_t)(\beta - \beta_t) = \nabla f(\beta_t) + H(\beta_t)\beta - H(\beta_t)\beta_t ⇒$
$$ H(\beta_t)\beta = H(\beta_t)\beta_t - \nabla f(\beta_t).$$

We can recognize that this is a [system of linear equations](https://en.wikipedia.org/wiki/System_of_linear_equations) $A x = b$ where $A = H(\beta_t)$, $x = \beta$, and $b = H(\beta_t)\beta_t - \nabla f(\beta_t)$. The solution is given by, $\hat{x} = A^{-1}b$, which in this case implies,
$$ \hat{\beta} = H(\beta_t)^{-1}\left(H(\beta_t)\beta_t - \nabla f(\beta_t)\right) = \beta_t - H(\beta_t)^{-1}\nabla f(\beta_t).$$

Contrast this with [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent), which is given by,
$$ \hat{\beta} = \beta_t - \rho_t \nabla f(\beta_t).$$

[Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method_in_optimization) is only guaranteed to converge _locally_, and can diverge even for _strongly_ [convex functions](https://en.wikipedia.org/wiki/Convex_function) (e.g., $f(\beta) = \sqrt{\beta^2 + 1}$). To address this limitation, we can add a dampening parameter, $\rho_t$, which gives us our final update form,
$$ \hat{\beta} = H(\beta_t)^{-1}(\nabla^2 f(\beta_t)\beta_t - \nabla f(\beta_t)) = \beta_t - \rho_t H(\beta_t)^{-1}\nabla f(\beta_t).$$

## Quasi-Newton Methods for Optimization
What if computing $H(\beta_t)$ is prohibitive or too costly? Do we need _exact_ second order information to improve on gradient descent's convergence? Given an approximation of $H$, called $B$, i.e. $B(\beta_t) \approx H(\beta_t)$, [_quasi_-Newton methods](https://en.wikipedia.org/wiki/Quasi-Newton_method) optimize for the form
$$f(\beta) \approx f(\beta_t) + \nabla f(\beta_t)^T (\beta - \beta_t) + \frac{1}{2} (\beta - \beta_t)^T B(\beta_t)(\beta - \beta_t),$$ where $B(\beta_t) \approx H(\beta_t)$. Optimizing this statement gives us our update rule,
$$ \hat{\beta} = \beta_t - \rho_t B(\beta_t)^{-1}\nabla f(\beta_t).$$

## Poisson Regression

$$y_i | x_i \sim \text{Poi}(\lambda_i)$$ where $\lambda_i := \exp(x_i^T \beta)$, and $\text{Poi}(k | \lambda) := \frac{\lambda^k \exp(-\lambda)}{k!}$. Given $\{(y_i, x_i)\}_{i=1}^n$, we would like to identify the maximum likelihood parameter estimate for $\beta$. In other words, we would to find a value for $\beta$ such that we maximize the log-likelihood given by,
$$\begin{align*}
\log \ell(\beta) &= \sum_i \log \text{Poi}(y_i | \exp(x_i^T \beta)) \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta) \exp(-\exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[ \frac{\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))}{y_i!} \right] \\
&= \sum_i \log \left[\exp(y_i \cdot x_i^T \beta - \exp(x_i^T \beta))\right] - \log(y_i!) \\
&= \sum_i \left[y_i \cdot x_i^T \beta - \exp(x_i^T \beta) - \log(y_i!)\right] \\
&= y^T X\beta - \exp(X\beta)^T 1_n - O(1) \\
&= y^T X\beta - \lambda^T 1_n - O(1),
\end{align*}$$
where $\lambda = \{\lambda_1, \dotsc, \lambda_n\}.$


$$ \begin{align*}
\nabla_\beta \ell &= \nabla_\beta \left[ y^T X\beta - \lambda^T 1_n \right] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\lambda^T 1_n] \\
&= \nabla_\beta [ y^T X\beta ] - \nabla_\beta [\exp(X\beta)^T 1_n] \\
&= y^T X - \exp(X\beta)^T X  \\
&= y^T X - \lambda^T X  \\
&= X^T(y - \lambda) \\
\nabla^2_{\beta \beta} \ell &= \nabla_{\beta} X^T(y - \lambda) \\
&= \nabla_{\beta} \left[X^T y - X^T \lambda \right] \\
&= - X^T \nabla_{\beta}  \lambda \\
&= -X^T \nabla_{\beta}  \exp(X\beta) \\
&= -X^T \Lambda X,
\end{align*}$$
where $\Lambda = \text{diag}(\lambda)$, i.e. $\Lambda_{ii} = \lambda_i$ and $\Lambda_{ij} = 0$ for $i \neq j$.

We can fit using Newton's method. =>
$$\begin{align*}
\beta(t+1) &= \beta(t) - H(\beta(t))^{-1}\nabla \ell(\beta_t) \\
&= \beta(t) + (X^T \Lambda(t) X)^{-1} X^T (y - \lambda) ⇒ \\
&= (X^T \Lambda(t) X)^{-1} X^T \Lambda(t) (\Lambda(t)^{-1}y + X\beta(t) - 1)
\end{align*}$$
where $\Lambda(t) := \text{diag}(\lambda_1, \dotsc, \lambda_n)$.

In [7]:
import jax
import jax.numpy as jnp
import jax.random as rdm
import jax.scipy.stats as stats

import lineax as lx

@jax.jit
def loglikelihood_eta(eta, y, X):
  """
  Our loglikelihood function for $y_i | x_i ~ \text{Poi}(\exp(eta_i))$.

  eta: X @ beta; the linear component for each observation
  y: poisson-distributed observations
  X: our design matrix

  returns: sum of the logliklihoods of each sample
  """
  mean_lambda = jnp.exp(eta)
  return jnp.sum(stats.poisson.logpmf(y, mean_lambda))


@jax.jit
def irwls_fit(eta, y, X, step_size):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  eta: X @ beta; the linear component for each observation
  y: poisson-distributed observations
  X: our design matrix

  returns: updated estimate of $\beta$
  """
  # compute lambda_i := exp(x_i @ beta)
  d_i = jnp.exp(eta)
  d_sqrt = jnp.sqrt(d_i)

  # compute z_i := Lambda^{1/2}(Lambda^-1 y + X @beta - 1)
  z = (y / d_i + eta - 1) * d_sqrt

  # X* := Lambda^{1/2} X
  # we use linear operators to postpone any computation
  X_star = lx.DiagonalLinearOperator(d_sqrt) @ X

  # lineax can solve normal equations iteratively as (t(X*) @ (X* @ guess)) - z
  solution = lx.linear_solve(X_star, z, solver=lx.NormalCG(atol=1e-4, rtol=1e-3))
  beta = solution.value
  return beta


def poiss_reg(y, X, fit_func, step_size = 1.0, max_iter=100, tol=1e-3):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  y: poisson-distributed observations
  X: our design matrix
  max_iter: the maximum number of iterations to perform optimization
  tol:

  returns: updated estimate of $\beta$
  """
  # intialize eta := X @ beta
  eta = jnp.log((y + jnp.mean(y))/2)

  # fake bookkeeping
  loglike = -100000
  delta = 10000

  # convert to a linear operator for lineax
  X = lx.MatrixLinearOperator(X)
  for epoch in range(max_iter):

    # fit using our function
    beta = fit_func(eta, y, X, step_size)

    # update eta
    eta = X.mv(beta)

    # evaluate log likelihood
    newll = loglikelihood_eta(eta, y, X)

    # take delta and check if we can stop
    delta = jnp.fabs(newll - loglike)
    print(f"Epoch[{epoch}] = {newll}")
    if delta < tol:
      break

    # replace old value
    loglike = newll

  return beta

In [8]:
# Let's simulate a poisson regression model with N samples and P variables
N = 1000
P = 5

# initialize PRNG env
seed = 0
key = rdm.PRNGKey(seed)

# split key for each random call
key, y_key, x_key, b_key = rdm.split(key, 4)
X = rdm.normal(x_key, shape=(N, P))
beta = rdm.normal(b_key, shape=(P,))

# compute lambda_i
mean_lambda = jnp.exp(X @ beta)

# sample y from Poi(lambda_i)
y = rdm.poisson(y_key, mean_lambda)

# estimate beta using our irwls function
# fit_func has signature (eta, y, X, step_size)
beta_hat = poiss_reg(y, X, irwls_fit)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

Epoch[0] = -1393.6072998046875
Epoch[1] = -1374.225341796875
Epoch[2] = -1374.12451171875
Epoch[3] = -1374.12451171875
beta = [ 0.85658115  0.36212853  1.1522139   0.15692838 -0.35338545]
hat(beta) = [ 0.8356752   0.3820638   1.1501597   0.15936355 -0.35638028]


In [9]:
# let's implement poisson regression using _only_ gradient informatino to perform inference
# and measure how quickly it converges compared with the Newton method
def grad_fit(eta, y, X, step_size):
  beta = ...
  return beta

## Automatic differentiation