<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/HW/PM520_HW2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 2. Maximum likelihood & Optimization Crash Course

In [None]:
!pip install lineax

In [None]:
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import jax.scipy as jsp
import jax.scipy.linalg as jspla

## 1. Ordinary least squares (i.e. OLS)
OLS is an approach to fit a linear regression model $$y = X \beta + ɛ,$$
such that $\mathbb{E}[ɛ'ɛ]$ is minimized, where $\mathbb{E}[ɛ_i]=0$ and
$\mathbb{V}[ɛ_i] = \sigma^2$, for each $i=1,\dotsc,n$.

1.1 Derive the OLS solution $\hat{\beta}$ under the above objective. Show step by step.

1.2 Re-write the objective using a likelihood formulation assuming $ɛ_i \sim N(0, \sigma^2)$, for each $i=1,\dotsc,n$.

1.3 Derive the OLS solution $\hat{\beta}_{MLE}$ using the above objective. Show step by step.

1.4 Using [lineax](https://docs.kidger.site/lineax/), implement a solver for OLS.

In [None]:
import lineax as lx

from jax import Array
from jax.typing import ArrayLike


def solve_ols(y: ArrayLike, X: ArrayLike) -> Array:
  """
  Solves ordinary least squares using lineax.

  y: ArrayLike of observations
  X: ArrayLike of covariates

  returns: $\hat{\beta}$ for OLS
  """
  pass

## 2. Weighted least squares (i.e. WLS)
WLS is an approach to fit a slightly more general linear model where, $$y = X \beta + ɛ,$$ where $\mathbb{E}[ɛ_i] = 0$ and $\mathbb{V}[ɛ_i] = \sigma^2_i$. We can model all variances jointly as $\mathbb{V}[ɛ] = D$ where $D$ is a diagonal matrix such that $D_{ii} = \sigma^2_i$.

2.1 Write the WLS objective.

2.2. Derive the WLS solution $\hat{\beta}$ under the above objective. Show step by step.

2.3. Re-write the objective using a likelihood formulation assuming $ɛ \sim N(0, D)$.

2.4 Derive the OLS solution $\hat{\beta}_{MLE}$ using the above objective. Show step by step.

2.5 Using [lineax](https://docs.kidger.site/lineax/), implement a solver for WLS.

In [None]:
import lineax as lx

from jax import Array
from jax.typing import ArrayLike


def solve_wls(y: ArrayLike, X: ArrayLike, D: ArrayLike) -> Array:
  """
  Solves weighted least squares using lineax.

  y: ArrayLike of observations
  X: ArrayLike of covariates
  D: ArrayLike of weights per observation

  returns: $\hat{\beta}$ for WLS
  """
  pass

## 3. MLE for scalar Poisson observations
Given $x_1, \dotsc, x_n$, assume that $x_i \sim \text{Poi}(\lambda)$ for $i=1,\dotsc,n$ where $\text{Poi}(\lambda)$ is the Poisson distribution with rate $\lambda$.

3.1 Write a likelihood-based formulation of the objective.

3.2 Derive the MLE for the above objective. Show step by step.

3.3 Implement a function that simulates Poisson distributed data with rate $\lambda$ using JAX.

3.4 Implement a function that computes the MLE $\hat{\lambda}$ given observations $x_1, \dotsc, x_n$.

In [None]:
import lineax as lx
import jax.random as rdm

from jax import Array
from jax.typing import ArrayLike


def simulate_poisson(key, rate: ArrayLike, n: int) -> Array:
  """
  Simulates Poisson distributed data.

  key: PRNGKey to generate
  rate: rate specifying the Poisson distribution; can be either a scalar, or
    ArrayLike (i.e. unique to each observation)
  n: the number of samples to generate

  returns: $x_i \sim \text{Poi}(\lambda_i)$
  """
  pass


def fit_poisson(x: ArrayLike) -> float:
  """
  Fits Poisson distributed data.

  x: ArrayLike observations

  returns: estimate of $\lambda$.
  """
  pass