<a href="https://colab.research.google.com/github/anhquan-truong/PM520/blob/main/HW/PM520_HW1_AnhQuanTRUONG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 1. Linear regression and normal equations

In [21]:
import jax
import jax.numpy as jnp
import jax.random as rdm
import jax.numpy.linalg as jnpla
import matplotlib.pyplot as plt


# 1. Linear model simulation
In class we defined a Python function that simulates $N$ $P\times 1$ variables $X$ (i.e. an $N \times P$ matrix $X$) and outcome $y$ as a linear function of $X$. Please include its definition here and use for problem 2.

Given $N \times P$ matrix $X$, a $P \times 1$ vector $\beta$, and $N \times 1$ outcome vector $y$, and a random variabl $\epsilon$, with $\mathbb{E}[\epsilon]=0$ and $Var(\epsilon) = \sigma^2$. Then, we can describe the $y$ as a linear function of $X$ as

$$
y = X\times \beta + \epsilon\\
$$

In [49]:
def sim_linear_reg(key, N, P, r2=0.5):
  key, b_key = rdm.split(key)
  b = rdm.normal(b_key, shape=(P,))

  key, X_key = rdm.split(key)
  X = rdm.normal(X_key, shape = (N, P))

  y_hat = X @ b # this is the predicted y without error eps
  s2pred = jnp.var(y_hat)
  s2tot = s2pred/r2 # s2tot = s2pred + s2e
  s2e =( s2tot - s2pred)

  key, e_key = rdm.split(key)
  eps = rdm.normal(e_key, shape = (N,)) * jnp.sqrt(s2e)
  y = y_hat + eps
  return y, y_hat, b, X, eps

seed = 912
key = rdm.PRNGKey(seed) # creating key from seed

N, P = 1000, 150 # a matrix 1000 x 150

y, y_hat, b, X, eps = sim_linear_reg(key, N, P, r2=0.5)

#Double check eps
print(f"Mean of epsilon: {jnp.mean(eps)}\n Var of epsilon: {jnp.var(eps)}")


Mean of epsilon: 0.9396470785140991
 Var of epsilon: 158.33473205566406


# 2. Just-in time decorator and ordinary least squares
Complete the definition of `ordinary_least_squares` below, that estimates the effect and its standard error. `@jit` wraps a function to perform just-in-time compilation, which boosts computational performance/speed.

Compare the times of with and without JIT
Hint: use [`block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html) to get correct timing estimates.

In [50]:
import jax

from jax import jit


def ordinary_least_squares(X, y):
  N, P = X.shape
  XtX = X.T @ X
  b_hat = jnpla.inv(XtX) @ (X.T@y)

  s2_e = jnp.sum(jnp.square(y - X @ b_hat)) / (N - P)
  # Corrected se_b_hat calculation: sqrt of diagonal of (s2_e * (X'X)^-1)
  se_b_hat = jnp.sqrt(jnp.diag(s2_e * jnpla.inv(X.T @ X)))
  return b_hat, se_b_hat

jit_ordinary_least_squares = jit(ordinary_least_squares)

N, P = 1000, 150 # a matrix 1000 x 150
y, y_hat, b, X, eps = sim_linear_reg(key, N, P, r2=0.5)

b_hat, se_b_hat = jit_ordinary_least_squares(X,y)
print(b_hat)
print(se_b_hat)

%timeit ordinary_least_squares(X,y)[0].block_until_ready()
%timeit jit_ordinary_least_squares(X,y)[0].block_until_ready()

[ 4.50251251e-01 -1.52883574e-01  3.14414471e-01  6.80672228e-02
  1.28357494e+00  1.76152444e+00 -4.74031895e-01 -1.39597678e+00
 -7.98884481e-02 -8.96839857e-01  3.49870294e-01 -2.20914531e+00
  1.69369769e+00 -2.27455348e-01  1.13967729e+00 -1.47309911e+00
  1.51594985e+00 -7.81703591e-01  1.00653470e+00 -8.47649097e-01
  1.00606728e+00  7.41412103e-01 -2.59430408e-01 -1.71335351e+00
 -8.58722329e-01  5.26010871e-01  2.40069509e+00 -1.58846056e+00
  1.67933500e+00  1.24910533e+00 -5.66803336e-01  1.70093790e-01
  3.76371473e-01  1.71797872e+00  1.33770967e+00  5.92871234e-02
  3.70338678e-01  8.80754888e-01 -6.68964863e-01  1.21331072e+00
 -3.95349264e-01 -1.06813574e+00 -3.53555024e-01 -8.05687383e-02
  8.37355196e-01 -1.10281193e+00 -1.52880931e+00 -1.47320771e+00
 -3.56925368e-01 -1.51822805e+00 -8.36969793e-01 -1.57295883e+00
  1.04633021e+00  1.55464292e+00 -1.33921242e+00  1.22519922e+00
 -1.79591924e-01  4.88310516e-01  1.62782475e-01  5.51790930e-02
  1.22392941e+00  1.92954

# 3. OLS derivation
Assume that $y = X \beta + \epsilon$ where $y$ is $N \times 1$ vector, $X$ is an $N \times P$ matrix where $P < N$ and $\epsilon$ is a random variable such that $\mathbb{E}[\epsilon_i] = 0$ and $\mathbb{V}[\epsilon_i] = \sigma^2$ for all $i = 1 \dots n$. Derive the OLS "normal equations".

The goal is to find $\beta$ such that the residual sum of square is minimal (RSS), or the sun of square of $\epsilon$ is minimal. The $RSS(\beta)$ is

$$RSS(\beta)=\sum_{i=1}^n (y_i - x_i^T\beta)^2$$

We want to find

$$\beta^*=argmin \ RSS(\beta)$$

**Approach**: We find stationary point, i.e. the point with zero gradients. We take the derivative of $RSS(\beta)$ with respect to $\beta$

$$
\begin{align*}
\frac{\partial RSS(\beta)}{\partial \beta} &= 2\sum_{i=1}^n (x_i^T\beta-y_i)x_i \\
&=2\sum_{i=1}^n (x_i^T x_i \beta - x_i y_i) \\
&=2\sum_{i=1}^n (x_i^T \beta x_i - x_i y_i) \\
&=2\sum_{i=1}^n (x_i x_i^T) \beta - 2\sum_{i=1}^n (x_i y_i) \\
&=2\sum_{i=1}^n (x_i x_i^T) \beta - 2\sum_{i=1}^n (x_i y_i) \\
&=2[(X^T X) \beta - (X^T Y)] \\
\end{align*}
$$

From that we have

$$
\begin{align*}
\nabla RSS(\beta) &= 0 \iff 2[(X^T X) \beta - (X^T Y)]  = 0 \iff \beta = (X^T X)^{-1} (X^T Y)\\
\end{align*}
$$

assuming $(X^T X)$ is invertible.

Now we calculate the standard error of $\beta$

From above, we have

$$
\begin{align*}
\hat{\beta} &= (X^T X)^{-1} (X^T Y) \\
&= (X^T X)^{-1} X^T (X \beta + \epsilon)\\
&= (X^T X)^{-1} X^T X \beta + (X^T X)^{-1} X^T \epsilon\\
&= In \beta + (X^T X)^{-1} X^T \epsilon\\
\implies \beta - \hat{\beta} &= (X^T X)^{-1} X^T \epsilon\\
\implies Var(\hat{\beta}) &= Var[(X^T X)^{-1} X^T \epsilon]\\
\end{align*}
$$

Apply variances rule for a matrix-vector mulilplication $Var(Ax) = AVar(X)A^T$, we have

$$
\begin{align*}
Var(\hat{\beta}) &= [(X^T X)^{-1} X^T]\ Var(\epsilon)\ [(X^T X)^{-1} X^T]^T\\
&= [(X^T X)^{-1} X^T]\  Var(\epsilon)\ X [(X^T X)^{-1}]^T\\
&= [(X^T X)^{-1} X^T]\ Var(\epsilon)\ X [(X^T X)^T]^{-1}\\
\implies Var(\hat{\beta}) &= [(X^T X)^{-1} X^T]\ Var(\epsilon)\ X (X^T X)^{-1}\\
\end{align*}
$$
