<a href="https://colab.research.google.com/github/anhquan-truong/PM520/blob/main/HW/PM520_HW1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 1. Linear regression and normal equations

In [3]:
import jax
import jax.numpy as jnp
import jax.random as rdm
import jax.numpy.linalg as jnpla


# 1. Linear model simulation
In class we defined a Python function that simulates $N$ $P\times 1$ variables $X$ (i.e. an $N \times P$ matrix $X$) and outcome $y$ as a linear function of $X$. Please include its definition here and use for problem 2.

Given $P \times N$, non-singular matrix $X$, and $P \times 1$ vector a with outcome $y$, we can describe the $y$ as a linear function of $X$ as $$a \times X=y.$$
Because A is not a square matrix, we use Moore-Penrose pseudo inverse to solve for X. $$AX=y \iff A^TAX=A^Ty  \iff x = (A^TA)^{-1}A^Ty$$

In [22]:
def sim_linear_reg(key, N, P, r2=0.5):
  key, x_key = rdm.split(key)
  a = rdm.normal(key, shape=(P,))
  X = rdm.normal(x_key, shape = (N, P))
  y = A @ X
  return A, X, y

seed = 912
key = rdm.PRNGKey(seed) # creating key from seed
print(key) # Print only the single key

N = 100 # variables
P = 10 # dimensions - or number of features

A, X, y = sim_linear_reg(key, N, P, r2=0.5) # Pass the single 'key' instead of 'sim_key'

# Solving using algebraic approach
Apinv = jnpla.pinv(A) #Using Moore-Penrose pseudo-inverse because of non-square matrix
is_same  = jnp.allclose(jnp.eye(P,P), A@Apinv, atol = 10e-6)
print(f"Identity check? {is_same}")

X_hat_direct = Apinv @ y
print(f"X_hat_direct: {X_hat_direct}")

# Solving using solver
X_hat = jnpla.solve(A,y)
print(f"X_hat: {X_hat}")

[  0 912]
Identity check? True
X_hat_direct: [[-4.77337688e-01 -3.40503812e-01  2.55770326e-01  2.16860667e-01
  -1.01972260e-01  5.37380278e-02 -1.91086307e-02  3.96180660e-01
   9.53159761e-03 -1.52862698e-01]
 [-1.12043977e-01  2.90709026e-02 -2.84572504e-02 -4.46278930e-01
   3.77156228e-01  1.02493532e-01  6.82589352e-01 -3.87124941e-02
  -5.38766459e-02 -2.47395977e-01]
 [ 3.85292321e-02  2.66632318e-01 -2.58134276e-01 -4.29333866e-01
   8.99916366e-02 -2.34036624e-01  1.68734282e-01 -5.45658350e-01
   1.00751489e-01  1.02821082e-01]
 [-3.76728535e-01 -1.25832409e-01  2.01245725e-01 -8.39235187e-02
   5.88630736e-01  2.87184596e-01 -5.98014474e-01  2.04491064e-01
  -5.54825187e-01 -2.19217092e-01]
 [ 1.45046003e-02 -2.64881223e-01  3.84873189e-02  2.18920901e-01
  -3.08041662e-01 -1.78456098e-01 -4.48213935e-01 -5.55475848e-03
  -1.22285306e-01  3.28609526e-01]
 [ 6.44471824e-01 -3.62663895e-01  4.50628966e-01  5.51274657e-01
  -1.88016102e-01 -3.36631656e-01  1.58305988e-01  3.3

# 2. Just-in time decorator and ordinary least squares
Complete the definition of `ordinary_least_squares` below, that estimates the effect and its standard error. `@jit` wraps a function to perform just-in-time compilation, which boosts computational performance/speed.

Compare the times of with and without JIT
Hint: use [`block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html) to get correct timing estimates.

In [None]:
import jax

from jax import jit


def ordinary_least_squares(X, y):
  """
  computes the OLS solution to linear system y ~ X.
  Returns a tuple of $\hat{beta}$ and $\text{se}(\hat{beta})$.
  """
  pass

jit_ordinary_least_squares = jit(ordinary_least_squares)

# 3. OLS derivation
Assume that $y = X \beta + \epsilon$ where $y$ is $N \times 1$ vector, $X$ is an $N \times P$ matrix where $P < N$ and $\epsilon$ is a random variable such that $\mathbb{E}[\epsilon_i] = 0$ and $\mathbb{V}[\epsilon_i] = \sigma^2$ for all $i = 1 \dots n$. Derive the OLS "normal equations".