<a href="https://colab.research.google.com/github/anhquan-truong/PM520/blob/main/HW/PM520_HW1_AnhQuanTRUONG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 1. Linear regression and normal equations

In [2]:
import jax
import jax.numpy as jnp
import jax.random as rdm
import jax.numpy.linalg as jnpla
import matplotlib.pyplot as plt


# 1. Linear model simulation
In class we defined a Python function that simulates $N$ $P\times 1$ variables $X$ (i.e. an $N \times P$ matrix $X$) and outcome $y$ as a linear function of $X$. Please include its definition here and use for problem 2.

Given $N \times P$ matrix $X$, a $P \times 1$ vector $\beta$, and $N \times 1$ outcome vector $y$, and a random variabl $\epsilon$, with $\mathbb{E}[\epsilon]=0$ and $Var(\epsilon) = \sigma^2$. Then, we can describe the $y$ as a linear function of $X$ as

$$
y = X\times \beta + \epsilon\\
$$

In [20]:
def sim_linear_reg(key, N, P, r2=0.5):
  key, b_key = rdm.split(key)
  b = rdm.normal(b_key, shape=(P,))

  key, X_key = rdm.split(key)
  X = rdm.normal(X_key, shape = (N, P))

  y_hat = X @ b # this is the predicted y without error eps
  s2pred = jnp.var(y_hat)
  s2tot = s2pred/r2 # s2tot = s2pred + s2e
  s2e =( s2tot - s2pred)

  key, e_key = rdm.split(key)
  eps = rdm.normal(e_key, shape = (N,)) * jnp.sqrt(s2e)
  y = y_hat + eps
  return y, y_hat, b, X, eps

seed = 91227102000
key = rdm.PRNGKey(seed) # creating key from seed

N, P = 1000, 150 # a matrix 1000 x 150

y, y_hat, b, X, eps = sim_linear_reg(key, N, P, r2=0.5)

#Double check eps
print(f"Mean of epsilon: {jnp.mean(eps)}\n Var of epsilon: {jnp.var(eps)}")


Mean of epsilon: -0.3946656584739685
 Var of epsilon: 124.35021209716797


# 2. Just-in time decorator and ordinary least squares
Complete the definition of `ordinary_least_squares` below, that estimates the effect and its standard error. `@jit` wraps a function to perform just-in-time compilation, which boosts computational performance/speed.

Compare the times of with and without JIT
Hint: use [`block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html) to get correct timing estimates.

In [None]:
import jax

from jax import jit


def ordinary_least_squares(X, y):
  """
  computes the OLS solution to linear system y ~ X.
  Returns a tuple of $\hat{beta}$ and $\text{se}(\hat{beta})$.
  """
  b_hat = jnpla.inv(X.T @ X) @ (X.T@y)
  return b_hat

jit_ordinary_least_squares = jit(ordinary_least_squares)

b_hat = jit_ordinary_least_squares(X,y)
print(b_hat)
print(b)


%timeit ordinary_least_squares(X,y).block_until_ready()
%timeit jit_ordinary_least_squares(X,y).block_until_ready()

  """


[-0.66653067  1.2053611  -0.00617148  2.3890886   0.07242878  1.4051487
 -1.3037715  -0.69175434 -0.6884035   0.88761115]
[-1.4279191   0.85496247 -0.6770563   2.2756033  -0.29681712  0.44081363
 -1.2529446  -0.41230848 -0.6117764   1.1318142 ]
130 μs ± 43.7 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
50.7 μs ± 15.4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
13.3 μs ± 2.74 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


# 3. OLS derivation
Assume that $y = X \beta + \epsilon$ where $y$ is $N \times 1$ vector, $X$ is an $N \times P$ matrix where $P < N$ and $\epsilon$ is a random variable such that $\mathbb{E}[\epsilon_i] = 0$ and $\mathbb{V}[\epsilon_i] = \sigma^2$ for all $i = 1 \dots n$. Derive the OLS "normal equations".

The goal is to find $\beta$ such that the residual sum of square is minimal (RSS), or the sun of square of $\epsilon$ is minimal. The $RSS(\beta)$ is

$$RSS(\beta)=\sum_{i=1}^n (y_i - x_i^T\beta)^2$$

We want to find

$$\beta^*=argmin \ RSS(\beta)$$

**Approach**: We find stationary point, i.e. the point with zero gradients. We take the derivative of $RSS(\beta)$ with respect to $\beta$

$$
\begin{align*}
\frac{\partial RSS(\beta)}{\partial \beta} &= 2\sum_{i=1}^n (x_i^T\beta-y_i)x_i \\
&=2\sum_{i=1}^n (x_i^T x_i \beta - x_i y_i) \\
&=2\sum_{i=1}^n (x_i^T \beta x_i - x_i y_i) \\
&=2\sum_{i=1}^n (x_i x_i^T) \beta - 2\sum_{i=1}^n (x_i y_i) \\
&=2\sum_{i=1}^n (x_i x_i^T) \beta - 2\sum_{i=1}^n (x_i y_i) \\
&=2[(X^T X) \beta - (X^T Y)] \\
\end{align*}
$$

From that we have
$$
\begin{align*}
\nabla RSS(\beta) &= 0 \iff 2[(X^T X) \beta - (X^T Y)]  = 0 \iff \beta = (X^T X)^{-1} (X^T Y)\\
\end{align*}
$$

assuming $(X^T X)$ is invertible.

Now we calculate the standard error of $\beta$

From above, we have
$$
\begin{align*}
\hat{\beta} &= (X^T X)^{-1} (X^T Y) \\
&= (X^T X)^{-1} X^T (X \beta + \epsilon)\\
&= (X^T X)^{-1} X^T X \beta + (X^T X)^{-1} X^T \epsilon\\
&= In \beta + (X^T X)^{-1} X^T \epsilon\\
\implies \beta - \hat{\beta} &= (X^T X)^{-1} X^T \epsilon\\
\implies Var(\hat{\beta}) &= Var[(X^T X)^{-1} X^T \epsilon]\\
\end{align*}
$$

Apply variances rule for a matrix-vector mulilplication $Var(Ax) = AVar(X)A^T$, we have
$$
\begin{align*}
Var(\hat{\beta}) &= [(X^T X)^{-1} X^T] Var(\epsilon) [(X^T X)^{-1} X^T]^T\\
&= [(X^T X)^{-1} X^T] Var(\epsilon) X [(X^T X)^{-1}]^T\\
\end{align*}
$$
