<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab5_Optimization_PtII.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install lineax

## Newton's Method for Optimization

## Quasi-Newton Methods for Optimization

## Poisson Regression

$$y_i | x_i \sim \text{Poi}(\lambda_i)$$ where $\lambda_i := \exp(x_i^T \beta)$. We can fit via iteratively reweighted least-squares.

$$\log \ell(\beta) = \sum_i \log \text{Poi}(y_i | exp(x_i^T \beta))$$
$$ \nabla_\beta \ell = X^T(y - \lambda)$$
$$ \nabla^2_{\beta \beta} \ell = -X^T \Lambda X $$

Suggests IRWLS =>
$$ \beta(t+1) = \beta(t) + (X^T \Lambda(t) X)^{-1} X^T (y - \lambda) ⇒ $$
$$ = (X^T \Lambda(t) X)^{-1} X^T \Lambda(t) (\Lambda(t)^{-1}y + X\beta(t) - 1)$$
where $\Lambda(t) := \text{diag}(\lambda_1, \dotsc, \lambda_n)$.

In [None]:
@jax.jit
def loglikelihood_eta(eta, y, X):
  """
  Our loglikelihood function for $y_i | x_i ~ \text{Poi}(\exp(eta_i))$.

  eta: X @ beta; the linear component for each observation
  y: poisson-distributed observations
  X: our design matrix

  returns: sum of the logliklihoods of each sample
  """
  mean_lambda = jnp.exp(eta)
  return jnp.sum(stats.poisson.logpmf(y, mean_lambda))


@jax.jit
def irwls_fit(eta, y, X):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  eta: X @ beta; the linear component for each observation
  y: poisson-distributed observations
  X: our design matrix

  returns: updated estimate of $\beta$
  """
  # compute lambda_i := exp(x_i @ beta)
  d_i = jnp.exp(eta)
  d_sqrt = jnp.sqrt(d_i)

  # compute z_i := Lambda^{1/2}(Lambda^-1 y + X @beta - 1)
  z = (y / d_i + eta - 1) * d_sqrt

  # X* := Lambda^{1/2} X
  # we use linear operators to postpone any computation
  X_star = lx.DiagonalLinearOperator(d_sqrt) @ X

  # lineax can solve normal equations iteratively as (t(X*) @ (X* @ guess) - z)
  solution = lx.linear_solve(X_star, z, solver=lx.NormalCG(atol=1e-4, rtol=1e-3))
  beta = solution.value
  return beta


def poiss_reg_irwls(y, X, max_iter=100, tol=1e-3):
  """
  Perform MLE estimation for $\beta$ under the model
     $y_i | x_i ~ \text{Poi}(\exp(x_i^T \beta))$.

  y: poisson-distributed observations
  X: our design matrix
  max_iter: the maximum number of iterations to perform optimization
  tol:

  returns: updated estimate of $\beta$
  """
  # intialize eta := X @ beta
  eta = jnp.log((y + jnp.mean(y))/2)

  # fake bookkeeping
  loglike = -100000
  delta = 10000

  # convert to a linear operator for lineax
  X = lx.MatrixLinearOperator(X)
  for epoch in range(max_iter):

    # fit using irwls
    beta = irwls_fit(eta, y, X)

    # update eta
    eta = X.mv(beta)

    # evaluate log likelihood
    newll = loglikelihood_eta(eta, y, X)

    # take delta and check if we can stop
    delta = jnp.fabs(newll - loglike)
    print(f"Epoch[{epoch}] = {newll}")
    if delta < tol:
      break

    # replace old value
    loglike = newll

  return beta

In [None]:
# Let's simulate a poisson regression model with N samples and P variables
N = 100
P = 5

# split key for each random call
key, y_key, x_key, b_key = rdm.split(key, 4)
X = rdm.normal(x_key, shape=(N, P))
beta = rdm.normal(b_key, shape=(P,))

# compute lambda_i
mean_lambda = jnp.exp(X @ beta)

# sample y from Poi(lambda_i)
y = rdm.poisson(y_key, mean_lambda)

# estimate beta using our irwls function
beta_hat = poiss_reg_irwls(y, X)
print(f"beta = {beta}")
print(f"hat(beta) = {beta_hat}")

## Automatic differentiation