<a href="https://colab.research.google.com/github/YichenGuo82/Linear-Algebra-and-Learning-from-Data/blob/master/HW/PM520_HW2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 2. Maximum likelihood & Optimization Crash Course

In [1]:
!pip install lineax

Collecting lineax
  Downloading lineax-0.1.0-py3-none-any.whl.metadata (18 kB)
Collecting equinox>=0.11.10 (from lineax)
  Downloading equinox-0.13.4-py3-none-any.whl.metadata (19 kB)
Collecting jaxtyping>=0.2.24 (from lineax)
  Downloading jaxtyping-0.3.7-py3-none-any.whl.metadata (7.3 kB)
Collecting wadler-lindig>=0.1.0 (from equinox>=0.11.10->lineax)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading lineax-0.1.0-py3-none-any.whl (74 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.6/74.6 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading equinox-0.13.4-py3-none-any.whl (181 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m181.2/181.2 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.3.7-py3-none-any.whl (56 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.1/56.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.7-py3-none-any.w

In [2]:
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import jax.scipy as jsp
import jax.scipy.linalg as jspla

## 1. Ordinary least squares (i.e. OLS)
OLS is an approach to fit a linear regression model $$y = X \beta + ɛ,$$
such that $\mathbb{E}[ɛ'ɛ]$ is minimized, where $\mathbb{E}[ɛ_i]=0$ and
$\mathbb{V}[ɛ_i] = \sigma^2$, for each $i=1,\dotsc,n$.

1.1 Derive the OLS solution $\hat{\beta}$ under the above objective. Show step by step.

1.2 Re-write the objective using a likelihood formulation assuming $ɛ_i \sim N(0, \sigma^2)$, for each $i=1,\dotsc,n$.

1.3 Derive the OLS solution $\hat{\beta}_{MLE}$ using the above objective. Show step by step.

1.4 Using [lineax](https://docs.kidger.site/lineax/), implement a solver for OLS.

### 1.1
Define the least squares objective
$$
S(\beta) := \|y - X\beta\|_2^2 = (y - X\beta)^\top (y - X\beta).
$$

Expand:
$$
\begin{align*}
S(\beta)
&= (y - X\beta)^\top (y - X\beta) \\
&= y^\top y - 2\beta^\top X^\top y + \beta^\top X^\top X \beta.
\end{align*}
$$

Differentiate w.r.t. $\beta$ and set to zero:
$$
\begin{align*}
\nabla_\beta S(\beta)
&= -2X^\top y + 2X^\top X\beta \\
&= 0
\quad\Longrightarrow\quad
X^\top X \hat\beta = X^\top y.
\end{align*}
$$

$$ \begin{align*}
\hat{\beta} = (X^\top X)^{-1} X^\top y
\end{align*}
$$

### 1.2
Assuming $\varepsilon_i \sim N(0, \sigma^2)$, the model $y = X\beta + \varepsilon$ implies $y \sim N(X\beta, \sigma^2 I)$. The log-likelihood function $\ell(\beta, \sigma^2)$ is:

$$\ell(\beta, \sigma^2) = -\frac{n}{2} \ln(2\pi\sigma^2) - \frac{1}{2\sigma^2} (y - X\beta)^\top (y - X\beta)$$

### 1.3 MLE Solution
To find the Maximum Likelihood Estimator $\hat{\beta}_{MLE}$, we maximize $\ell(\beta, \sigma^2)$ with respect to $\beta$.

1. Maximizing $\ell$ is equivalent to minimizing the negative term:
   $$\hat{\beta}_{MLE} = \arg\min_\beta (y - X\beta)^\top (y - X\beta)$$
2. Taking the derivative with respect to $\beta$ yields:
   $$\nabla_\beta \ell = \frac{1}{\sigma^2}(X^\top y - X^\top X\beta)$$
3. Setting to zero results in:
   $$X^\top X \hat{\beta}_{MLE} = X^\top y$$
   $$\hat{\beta}_{MLE} = (X^\top X)^{-1} X^\top y$$

In [4]:
import lineax as lx

from jax import Array
from jax.typing import ArrayLike


def solve_ols(y: ArrayLike, X: ArrayLike) -> Array:
  """
  Solves ordinary least squares using lineax.

  y: ArrayLike of observations
  X: ArrayLike of covariates

  returns: $hat{\beta}$ for OLS
  """

  Xop = lx.MatrixLinearOperator(X)

  # minimize ||y - X beta||^2
  sol = lx.linear_solve(
      Xop,
      y,
      solver=lx.Normal(lx.CG(atol=1e-6, rtol=1e-6))
  )
  return sol.value

## 2. Weighted least squares (i.e. WLS)
WLS is an approach to fit a slightly more general linear model where, $$y = X \beta + ɛ,$$ where $\mathbb{E}[ɛ_i] = 0$ and $\mathbb{V}[ɛ_i] = \sigma^2_i$. We can model all variances jointly as $\mathbb{V}[ɛ] = D$ where $D$ is a diagonal matrix such that $D_{ii} = \sigma^2_i$.

2.1 Write the WLS objective.

2.2. Derive the WLS solution $\hat{\beta}$ under the above objective. Show step by step.

2.3. Re-write the objective using a likelihood formulation assuming $ɛ \sim N(0, D)$.

2.4 Derive the OLS solution $\hat{\beta}_{MLE}$ using the above objective. Show step by step.

2.5 Using [lineax](https://docs.kidger.site/lineax/), implement a solver for WLS.

In [None]:
import lineax as lx

from jax import Array
from jax.typing import ArrayLike


def solve_wls(y: ArrayLike, X: ArrayLike, D: ArrayLike) -> Array:
  """
  Solves weighted least squares using lineax.

  y: ArrayLike of observations
  X: ArrayLike of covariates
  D: ArrayLike of weights per observation

  returns: $\hat{\beta}$ for WLS
  """
  w = 1.0 / jnp.sqrt(D)

  Xop = lx.MatrixLinearOperator(X)
  Wop = lx.DiagonalLinearOperator(w)

  # Whitened system: y* = W y, X* = W X
  y_star = Wop.mv(y)
  X_star = Wop @ Xop

  sol = lx.linear_solve(
      X_star,
      y_star,
      solver=lx.Normal(lx.CG(atol=1e-6, rtol=1e-6))
  )
  return sol.value

## 3. MLE for scalar Poisson observations
Given $x_1, \dotsc, x_n$, assume that $x_i \sim \text{Poi}(\lambda)$ for $i=1,\dotsc,n$ where $\text{Poi}(\lambda)$ is the Poisson distribution with rate $\lambda$.

3.1 Write a likelihood-based formulation of the objective.

3.2 Derive the MLE for the above objective. Show step by step.

3.3 Implement a function that simulates Poisson distributed data with rate $\lambda$ using JAX.

3.4 Implement a function that computes the MLE $\hat{\lambda}$ given observations $x_1, \dotsc, x_n$.

In [None]:
import lineax as lx
import jax.random as rdm

from jax import Array
from jax.typing import ArrayLike


def simulate_poisson(key, rate: ArrayLike, n: int) -> Array:
  """
  Simulates Poisson distributed data.

  key: PRNGKey to generate
  rate: rate specifying the Poisson distribution; can be either a scalar, or
    ArrayLike (i.e. unique to each observation)
  n: the number of samples to generate

  returns: $x_i \sim \text{Poi}(\lambda_i)$
  """
  pass


def fit_poisson(x: ArrayLike) -> float:
  """
  Fits Poisson distributed data.

  x: ArrayLike observations

  returns: estimate of $\lambda$.
  """
  pass