<a href="https://colab.research.google.com/github/anhquan-truong/PM520/blob/main/HW/PM520_HW2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 2. Maximum likelihood & Optimization Crash Course

In [None]:
!pip install lineax

In [None]:
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import jax.scipy as jsp
import jax.scipy.linalg as jspla

## 1. Ordinary least squares (i.e. OLS)
OLS is an approach to fit a linear regression model $$y = X \beta + ɛ,$$
such that $\mathbb{E}[ɛ'ɛ]$ is minimized, where $\mathbb{E}[ɛ_i]=0$ and
$\mathbb{V}[ɛ_i] = \sigma^2$, for each $i=1,\dotsc,n$.

**1.1 Derive the OLS solution $\hat{\beta}$ under the above objective. Show step by step.**

We have the residual sum of square:
$$
\begin{align*}
RSS(\hat{\beta}) &= ɛ'ɛ \\
&= (y - X \hat{\beta})^T (y - X \hat{\beta})\\
&= (y^T - \hat{\beta}^TX^T)(y - X \hat{\beta}) \\
&= (y^Ty - y^T X \hat{\beta} - \hat{\beta}^TX^Ty +\hat{\beta}^TX^TX\hat{\beta})\\
&= (y^Ty - 2 \hat{\beta}^TX^Ty + \hat{\beta}^TX^TX\hat{\beta})
\end{align*}
$$

$\quad$

To find the OLS solution, we find$\hat{\beta}$ minizing the loss function, or the RSS.

This equals to find the stationary points at which the derivation with respect to $\hat{\beta}$, or the gradient, is zero.

We first express the RSS in terms of its cordinates
$$RSS(\hat{\beta}) = \sum(y_i^2-2\hat{\beta}_iX_{ji}y_i+\hat{\beta}_i^2X_{ji}X_{ij})$$

Then we take the derivatives with respect to each $\hat{\beta}_k$
$$
\begin{align*}
\frac{\partial RSS(\hat{\beta})} {\partial \hat{\beta}_k} &= 0-2\sum(X_{jk} y_{k}) + 2\sum(\hat{\beta}_k X_{jk}X_{kj}) \\
&= -2\sum(X_{jk} y_{k} - \hat{\beta}_k X_{jk}X_{kj}) \\
\end{align*}
$$

Hence, the gradient with repect to $\hat{\beta}$ is
$$ \frac{\partial RSS(\hat{\beta})} {\partial \hat{\beta}} = \nabla_{\hat{\beta}}RSS(\hat{\beta}) = -2(X^Ty - X^TX\hat{\beta})$$

Set the gradient equal to 0, we have

$$\hat{\beta}=(X^TX)^{-1}X^Ty$$

This is the solution of the OLS

**1.2 Re-write the objective using a likelihood formulation assuming $ɛ_i \sim N(0, \sigma^2)$, for each $i=1,\dotsc,n$.**

For each $i$ observation, we have
$$y_i = x^T_i \beta + ɛ_i$$

The expectation and variance of $y_i$ given $x_i$ is
$$
\begin{align*}
\mathbb{E[y_i|x_i]} & =\mathbb{E[x^T_i \beta + ɛ_i]}\\
&=\mathbb{E[x^T_i \beta]} + \mathbb{E[ɛ_i]}\\
&=x^T_i \beta + 0\\
&=x^T_i \beta\\
\end{align*}
$$

$$
\begin{align*}
\mathbb{Var[y_i|x_i]} & =\mathbb{Var[x^T_i \beta + ɛ_i]}\\
&=\mathbb{Var[x^T_i \beta]} + \mathbb{Var[ɛ_i]}\\
&=0 + \sigma^2\\
&=\sigma^2\\
\end{align*}
$$

Then we have $$y_i|x_i \sim \mathcal{N}(x_i^T\beta, \sigma^2)$$

The likelihood of the data given x is,

$$
\begin{align*}
L(\beta, \sigma^2 |y_1,…,y_i) &=
  \prod_{i=1}^n N(y_i | x^T\beta, \sigma^2) \\
  &= \prod_{i=1}^n \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left(-\frac{(y_i - x_i^T\beta)^2}{2\sigma^2}\right)\\
  &= \left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)^n \exp\left(-\frac{1}{2\sigma^2} \sum_{i=1}^n (y_i - x_i^T\beta)^2\right).
\end{align*}
$$

Thus, our _log_-likelihood is given by,
$$\begin{align*}
\ell(\beta, \sigma^2 |y_1,…,y_i) &=
  \log \left[\left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)^n \exp\left(-\frac{1}{2\sigma^2} \sum_{i=1}^n (y_i - x_i^T\beta)^2\right)\right]\\
  &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (y_i - x_i^T\beta)^2.
\end{align*}
$$

**1.3 Derive the OLS solution $\hat{\beta}_{MLE}$ using the above objective. Show step by step.**

We maximize the _log_-likelihood to find $\hat{\beta}_{MLE}$ such that the error is minimized

First, we take the derivative of the _log_-likelihood with respect to $\beta$

$$
\frac{\partial}{\partial \beta}\ell(\beta, \sigma^2|y_1,…,y_i) =
  -\frac{1}{2\sigma^2}\sum_{i=1}^n 2 x_i (y_i - x_i^T\beta) =-\frac{1}{\sigma^2} (X^Ty - X^TX\beta)
$$

Set the gradient to zero gives

$$
\hat{\beta}_{MLE} = (X^TX)^{-1}X^Ty\\
$$
**1.4 Using [lineax](https://docs.kidger.site/lineax/), implement a solver for OLS.**

In [None]:
import lineax as lx

from jax import Array
from jax.typing import ArrayLike
import jax.random as rdm
import matplotlib.pyplot as plt

seed = 27102000
key = rdm.PRNGKey(seed)

key, y_key, x_key, b_key = rdm.split(key, num = 4)
N = 1000
D = 5

X = rdm.normal(x_key, shape = (N,D))
beta = rdm.normal(b_key, shape = (D,))
s2e = 0.04
y = X @ beta + jnp.sqrt(s2e) * rdm.normal(y_key, shape = (N,))

def solve_ols(y: ArrayLike, X: ArrayLike) -> Array:
  r"""
  Solves ordinary least squares using lineax.

  y: ArrayLike of observations
  X: ArrayLike of covariates

  returns: $\hat{\beta}$ for OLS
  """
  X_star = lx.MatrixLinearOperator(X)
  solution = lx.linear_solve(X_star, y, solver=lx.Normal(lx.CG(atol=1e-6, rtol=1e-6)))
  beta_hat = solution.value
  return solution.value

solve_ols(y, X)
print(f"beta hat: {solve_ols(y, X)}")
print(f"True beta: {beta}")


beta hat: [-2.4730766   0.47487995  1.2987322   1.4924115   0.7521332 ]
True beta: [-2.4790854   0.46948895  1.2997216   1.5003587   0.7489997 ]


## 2. Weighted least squares (i.e. WLS)
WLS is an approach to fit a slightly more general linear model where, $$y = X \beta + ɛ,$$ where $\mathbb{E}[ɛ_i] = 0$ and $\mathbb{V}[ɛ_i] = \sigma^2_i$. We can model all variances jointly as $\mathbb{V}[ɛ] = D$ where $D$ is a diagonal matrix such that $D_{ii} = \sigma^2_i$.

2.1 Write the WLS objective.

In weighted linear regresion, each residual is weighted against it variance. We have $D$ is a diagonal matrix of the variance, then residual is weighted with a matrix $W = D^{-1}$.

The objective of WLS is to minimize the weighted residual sum of square (WRSS). WRSS given by

$$
\begin{align*}
WRSS(\beta) &= (y-X\beta)^TW(y-X\beta) \\
&= (y^TW -\beta^TX^TW)(y-X\beta) \\
&= (y^TWy -y^TWX\beta - \beta^TX^TWy + \beta^TX^TWX\beta) \\
&= (y^TWy -2\beta^TX^TWy + (X\beta)^TWX\beta)
\end{align*}
$$

The WLS objective is

$$
\begin{align*}
\hat{\beta} = \arg\min_\beta\;WRSS(\beta)
\end{align*}
$$


**2.2. Derive the WLS solution $\hat{\beta}$ under the above objective. Show step by step.**

To find $\hat{\beta}$, we need to differentiate the WRSS and set it to zero.

By applying calculus

$$
\begin{align*}
&\frac{\partial}{\partial w} (w^TX) = w \quad \text{(derivative of a linear function is its slope}) \\
&\frac{\partial}{\partial w} (w^TXw) = (X + X^T)w \quad \text{(derivative of a quadratic function}) \\
\end{align*}
$$

We can show that
$$
\begin{align*}
\nabla_\beta\;WRSS(\beta) &= \frac{\partial}{\partial \beta} \left(y^TWy -2\beta^TX^TWy + \beta^TX^TWX\beta \right) \\
& = -2X^TWy +2 X^TWX\beta  
\end{align*}
$$

Hence,

$$\hat{\beta} = (X^TWX)^{-1}X^TWy$$


**2.3. Re-write the objective using a likelihood formulation assuming $ɛ \sim N(0, D)$.**

$$
\begin{align*}
L(\beta, D|y_1,…,y_i) &=
  \prod_{i=1}^n N(y_i | x^T\beta, \sigma^2_i) \\
  &= \prod_{i=1}^n \frac{1}{\sqrt{2\pi \sigma^2_i}} \exp\left(-\frac{(y_i - x_i^T\beta)^2}{2\sigma^2_i}\right)\\
  &= \left(\frac{1}{\sqrt{2\pi\sigma^2_i}}\right)^n \exp\left(-\frac{1}{2\sigma^2_i} \sum_{i=1}^n (y_i - x_i^T\beta)^2\right)\\
  &= \left(\frac{1}{\sqrt{2\pi D_{ii}}}\right)^n \exp\left(-\frac{1}{2 D_{ii}} \sum_{i=1}^n (y_i - x_i^T\beta)^2\right).
\end{align*}
$$

2.4 Derive the OLS solution $\hat{\beta}_{MLE}$ using the above objective. Show step by step.

2.5 Using [lineax](https://docs.kidger.site/lineax/), implement a solver for WLS.

In [None]:
import lineax as lx

from jax import Array
from jax.typing import ArrayLike


def solve_wls(y: ArrayLike, X: ArrayLike, D: ArrayLike) -> Array:
  """
  Solves weighted least squares using lineax.

  y: ArrayLike of observations
  X: ArrayLike of covariates
  D: ArrayLike of weights per observation

  returns: $\hat{\beta}$ for WLS
  """
  pass

## 3. MLE for scalar Poisson observations
Given $x_1, \dotsc, x_n$, assume that $x_i \sim \text{Poi}(\lambda)$ for $i=1,\dotsc,n$ where $\text{Poi}(\lambda)$ is the Poisson distribution with rate $\lambda$.

3.1 Write a likelihood-based formulation of the objective.

3.2 Derive the MLE for the above objective. Show step by step.

3.3 Implement a function that simulates Poisson distributed data with rate $\lambda$ using JAX.

3.4 Implement a function that computes the MLE $\hat{\lambda}$ given observations $x_1, \dotsc, x_n$.

In [None]:
import lineax as lx
import jax.random as rdm

from jax import Array
from jax.typing import ArrayLike


def simulate_poisson(key, rate: ArrayLike, n: int) -> Array:
  """
  Simulates Poisson distributed data.

  key: PRNGKey to generate
  rate: rate specifying the Poisson distribution; can be either a scalar, or
    ArrayLike (i.e. unique to each observation)
  n: the number of samples to generate

  returns: $x_i \sim \text{Poi}(\lambda_i)$
  """
  pass


def fit_poisson(x: ArrayLike) -> float:
  """
  Fits Poisson distributed data.

  x: ArrayLike observations

  returns: estimate of $\lambda$.
  """
  pass