<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_2_LinearSolve.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import jax.random as rdm
import jax.scipy.linalg as jspla

## Sum(mer) Madness, or Solving Linear Equations
Or how I learned not to take the inverse and perform [matrix-vector products](https://en.wikipedia.org/wiki/Operator_(mathematics)).

Given $n \times n$ non-singular (i.e. "nice") matrix $A$, $n \times 1$ column-vector $b$, we can describe a [linear system of equations](https://en.wikipedia.org/wiki/System_of_linear_equations#General_form) related $A$, $b$ as $$Ax=b.$$
Algebraically, we can solve for $x$ as, $x = A^{-1}b$, but how should we go about this [_numerically_](https://en.wikipedia.org/wiki/Numerical_linear_algebra)?

In [6]:
def sim_linear_system(key, n: int):
  key, x_key = rdm.split(key)
  A = rdm.normal(key, shape=(n,n))
  x = rdm.normal(x_key, shape=(n,))
  b = A @ x
  return A, x, b

seed = 0
N = 100

key = rdm.PRNGKey(seed)
key, sim_key = rdm.split(key)
A, x, b = sim_linear_system(sim_key, N)

# solve using algebraic approach
Ainv = jnpla.inv(A)
x_hat_direct = Ainv @ b

# solve using 'blackbox(!?)' solver
x_hat = jnpla.solve(A, b)

# measure distance from truth using 2norm
direct_dist = jnpla.norm(x - x_hat_direct)
solve_dist = jnpla.norm(x - x_hat)

print(f"direct dist = {direct_dist} | solve_dist = {solve_dist}")

direct dist = 0.0003477725840639323 | solve_dist = 0.0002090586203848943


We see less [numerical error](https://en.wikipedia.org/wiki/Numerical_analysis#Numerical_stability_and_well-posed_problems) in the solution using the 'solve' approach compared to our attempt at using the 'algebraic' solution. What happens as `N` increases?

In [3]:
for N in [100, 1000, 5000]:

  key, sim_key = rdm.split(key)
  A, x, b = sim_linear_system(sim_key, N)

  # solve using algebraic approach
  Ainv = jnpla.inv(A)
  x_hat_direct = Ainv @ b

  # solve using 'blackbox(!?)' solver
  x_hat = jnpla.solve(A, b)

  # measure distance from truth using 2norm
  direct_dist = jnpla.norm(x - x_hat_direct)
  solve_dist = jnpla.norm(x - x_hat)

  print(f"N = {N} | direct dist = {direct_dist} | solve_dist = {solve_dist}")

N = 100 | direct dist = 0.0002750760759226978 | solve_dist = 4.570868986775167e-05
N = 1000 | direct dist = 0.02133580483496189 | solve_dist = 0.003495120210573077
N = 5000 | direct dist = 0.3149307668209076 | solve_dist = 0.04819103330373764


### How is `jnpla.solve` _solving_ the system?

Numerical solvers typically perform some _decomopsition_ of $A$ into a product of structured matrices (e.g., orthogonal, lower/upper diagonal, permutations, etc.) which permit direct solutions through forward/backward solvers. A couple prominant examples are given below.

[QR decomposition](https://en.wikipedia.org/wiki/QR_decomposition#Using_for_solution_to_linear_inverse_problems)


$A = QR$ where $Q$ is $n \times n$ [orthonormal matrix](https://en.wikipedia.org/wiki/Orthogonal_matrix) (i.e. $Q Q' = Q'Q = I$, thus $Q^{-1} = Q'$) and $R$ is $n \times n$ upper [triangular matrix](https://en.wikipedia.org/wiki/Triangular_matrix) (i.e. $R_{ij} = 0, i < j$).

Solving $Ax = b$ using QR amounts to noticing,
$$QRx = b ⇒ Rx = Q^{-1}b = Q'b,$$ which can then be solved using a backwards solve. See, [jax.numpy.linalg.qr](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.qr.html), [jax.scipy.linalg.solve_triangular](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve_triangular.html) for reference.

[LU decomposition](https://en.wikipedia.org/wiki/LU_decomposition#Solving_linear_equations)

$A = PLU$ where $P$ is an $n \times n$ [permutation matrix](https://en.wikipedia.org/wiki/Permutation_matrix), $L$ is $n \times n$ lower triangular matrix and $U$ is an $n \times n$ upper triangular matrix. Solving $Ax = b$ using LU amounts to, $$PLU = b ⇒ LU = P^{-1}b = P'b,$$ which can be solved with one forwards solve $Ly = P'b$ and then a backwards solve $Ux = y$. See, [jax.scipy.linalg.lu_factor](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_factor.html), [jax.scipy.linalg.lu_solve](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html) for reference.

Let's try to assess which is used by `jnpla.solve` by comparing error of the solutions (not ideal, but not worst idea possible; see [jax.numpy.linalg.solve](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.solve.html) for reference).

In [None]:
N = 25
key, sim_key = rdm.split(key)
A, x, b = sim_linear_system(sim_key, N)

# solve using algebraic approach
Ainv = jnpla.inv(A)
x_direct = Ainv @ b

# QR
Q, R = jnpla.qr(A)
x_qr = jspla.solve_triangular(R, Q.T @ b)

# LU
LU, P = jspla.lu_factor(A)
x_lu = jspla.lu_solve((LU, P), b)

# direct solve for baseline
x_solve = jnpla.solve(A, b)

direct_dist = jnpla.norm(x - x_direct)
qr_dist = jnpla.norm(x - x_qr)
lu_dist = jnpla.norm(x - x_lu)
solve_dist = jnpla.norm(x - x_solve)


print(f"direct dist = {direct_dist} | qr dist = {qr_dist} | lu dist = {lu_dist} | solve_dist = {solve_dist}")

direct dist = 2.836631210811902e-05 | qr dist = 2.3936559955473058e-05 | lu dist = 4.185290890745819e-06 | solve_dist = 4.185290890745819e-06


### How much time does it take to solve a system of linear equations?
Complexity is $O(n^3)$ when $A$ is $n \times n$ (under exact numerical precision, i.e. ignoring bit complexity of numerical results), but can we do better in terms of constants? Pls no [mention](https://en.wikipedia.org/wiki/Computational_complexity_of_matrix_multiplication#Matrix_multiplication_exponent) of $O(n^\omega)$.

What if we know that our matrix $A$ has additional structure from the start?

[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition)

When $A$ is an $n \times n$ [_positive definite_](https://en.wikipedia.org/wiki/Definite_matrix) (i.e. $x'Ax \geq 0, ∀ x \in R^n$), the Cholesky decomposition gives $A = LL'$ where $L$ is an $n \times n$ lower triangular matrix.

Solving $Ax=b$ using Cholesky decomposition amounts to, $$Ax=b ⇒LL'x = b,$$
which  can be solved with one forwards solve $Ly=b$ and then a backwards solve $L'x=y$. This will be roughly twice as fast compared with LU decomposition approaches. See [jax.scipy.linalg.cho_factor](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_factor.html) and [jax.scipy.linalg.cho_solve](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_solve.html) for reference.

Let's compare the error of the solve solutions.

In [7]:
def sim_sym_linear_system(key, n: int):
  key, x_key = rdm.split(key, 2)
  # sample larger row to improve conditioning
  A = rdm.normal(key, shape=(2*n,n))
  # rescale back to n x n
  A = A.T @ A
  x = rdm.normal(x_key, shape=(n,))
  b = A @ x
  return A, x, b

key, sim_key = rdm.split(key)
N = 500

A, x, b = sim_sym_linear_system(sim_key, N)

# solve using algebraic approach
Ainv = jnpla.inv(A)
x_direct = Ainv @ b

# QR
Q, R = jnpla.qr(A)
x_qr = jspla.solve_triangular(R, Q.T @ b)

# Cholesky
L, lower = jspla.cho_factor(A)
x_cho = jspla.cho_solve((L, lower), b)

# direct solve for baseline
x_solve = jnpla.solve(A, b)

direct_dist = jnpla.norm(x - x_direct)
qr_dist = jnpla.norm(x - x_qr)
cho_dist = jnpla.norm(x - x_cho)
solve_dist = jnpla.norm(x - x_solve)

print(f"direct dist = {direct_dist} | qr dist = {qr_dist} | cho dist = {cho_dist} | solve_dist = {solve_dist}")

direct dist = 2.7083056920673698e-05 | qr dist = 2.6027397325378843e-05 | cho dist = 1.7303784261457622e-05 | solve_dist = 1.732515011099167e-05


### LAB portion

What is the average runtime for each of these operations?
Hint: use the `%timeit` magic command and [`block_until_ready`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html) function.

In [None]:
N = 1000
key, sim_key = rdm.split(key)
A, x, b = sim_sym_linear_system(sim_key, N)

# solve using algebraic approach
def _direct(A, b):
  pass

# QR
def _qr(A, b):
  pass

# Cholesky
def _cholesky(A, b):
  pass

# general solver (uses LU)
def _solve(A, b):
  return jnpla.solve(A, b)

%timeit x_solve = _solve(A, b).block_until_ready()
%timeit x_direct = _direct(A, b).block_until_ready()
%timeit x_qr = _qr(A, b).block_until_ready()
%timeit x_cho = _cholesky(A, b).block_until_ready()

48 ms ± 15.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


Great! Cholesky is the fastest and has error on par with the general `solve` command. What can we gain in terms of speed by using [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html)?

In [None]:
N = 1000
key, sim_key = rdm.split(key)
A, x, b = sim_sym_linear_system(sim_key, N)

jit_direct = jax.jit(_direct)
jit_qr = jax.jit(_qr)
jit_cholesky = jax.jit(_cholesky)
jit_solve = jax.jit(jnpla.solve)

%timeit x_direct = jit_direct(A, b).block_until_ready()
%timeit x_qr = jit_qr(A, b).block_until_ready()
%timeit x_cho = jit_cholesky(A, b).block_until_ready()
%timeit x_solve = jit_solve(A, b).block_until_ready()

Not much improvement! What gives?

Turns out, JAX is already JIT compiling these lower level functions (e.g., solve, decompositions) and our code is not much larger in scope, thus benefit not clearly demonstrated. The takeaway here is not that JIT doesn't help, but rather, to get the most out of JIT'd code, we should apply it to larger programs/functions, which can better optimize the computational graph.



## Conjugate Gradient (CG)
TBD

## And now for something completely different... [Lineax](https://docs.kidger.site/lineax/)
Library for solving linear systems and related linear computational tasks.

In [1]:
!pip install lineax



In [16]:
import lineax as lx

N = 1000
key, sim_key = rdm.split(key)
A, x, b = sim_sym_linear_system(sim_key, N)

A_op = lx.MatrixLinearOperator(A)
sol = lx.linear_solve(A_op, b)
x_sol = sol.value

A_psd_op = lx.MatrixLinearOperator(A, lx.positive_semidefinite_tag)
psd_sol = lx.linear_solve(A_psd_op, b)
x_psd_sol = psd_sol.value

# initialize our CG solver
cg_solver = lx.CG(atol=1e-5, rtol=1e-4)
y0 = b / jnp.diag(A)
cg_sol = lx.linear_solve(A_psd_op, b, solver=cg_solver, options={"y0":y0})
x_cg_sol = cg_sol.value

x_solve = jnpla.solve(A, b)

dist_sol = jnpla.norm(x - x_sol)
dist_psd = jnpla.norm(x - x_psd_sol)
dist_cg = jnpla.norm(x - x_cg_sol)
dist_solve = jnpla.norm(x - x_solve)

print(f"dist lineax solve = {dist_sol}")
print(f"psd lineax solve = {dist_psd}")
print(f"cg lineax solve = {dist_cg}")
print(f"jax solve = {dist_solve}")

dist lineax solve = 2.793022031255532e-05
psd lineax solve = 3.014422873093281e-05
cg lineax solve = 3.1980012863641605e-05
jax solve = 2.793022031255532e-05


Let's look at slightly more sophisticated linear operator. Let $A = B B' + I_n$ we want to solve $Ax = b$, where $B$ is shape $n \times k$ for $k < n$.

In [23]:
N, K = 1000, 50
key, b_key, x_key = rdm.split(key, 3)

B = rdm.normal(b_key, shape=(N, K))
x = rdm.normal(x_key, shape=(N,))
b = B @ (B.T @ x) + x

B_op = lx.MatrixLinearOperator(B)
I_op = lx.IdentityLinearOperator(jax.eval_shape(lambda: x))
A_op = lx.TaggedLinearOperator(B_op @ B_op.T + I_op, lx.positive_semidefinite_tag)

sol = lx.linear_solve(A_op, b)
cg_sol = lx.linear_solve(A_op, b, solver=cg_solver)

dist_sol = jnpla.norm(x - sol.value)
dist_cg = jnpla.norm(x - cg_sol.value)
print(f"dist lineax solve = {dist_sol}")
print(f"cg lineax solve = {dist_cg}")

dist lineax solve = 0.003705778392031789
cg lineax solve = 0.0006713042384944856


In [24]:
%timeit lx.linear_solve(A_op, b).value.block_until_ready()
%timeit lx.linear_solve(A_op, b, solver=cg_solver).value.block_until_ready()

22.2 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.99 ms ± 48.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
def sim_linear_reg(key, N, P, r2=0.5):

  key, x_key = rdm.split(key)
  X = rdm.normal(x_key, shape=(N, P))

  key, b_key = rdm.split(key)
  beta = rdm.normal(b_key, shape=(P,))

  # g = jnp.dot(X, beta)
  g = X @ beta
  s2g = jnp.var(g)

  # back out what s2e is, such that s2g / (s2g + s2e) == h2
  s2e = (1 - r2) / r2 * s2g
  key, y_key = rdm.split(key)

  # add env noise to g, but scale such that var(e) == s2e
  y = g + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
  return y, X, beta


N, P = 1000, 150
key, sim_key = rdm.split(key)
y, X, beta = sim_linear_reg(sim_key, N, P)

XtX = X.T @ X
Xty = X.T @ y

# dist func
def _dist(sol):
  return jnpla.norm(beta - sol.value)

X_op = lx.MatrixLinearOperator(X)
XtX_op = lx.MatrixLinearOperator(XtX, lx.positive_semidefinite_tag)

# 1st: linear operator over `XtX` and solve against Xty
sol_1 = lx.linear_solve(XtX_op, Xty)
sol_1_dist = _dist(sol_1)

# 2nd: linear operator over `X` and solve against `y`
sol_2 = lx.linear_solve(X_op, y, solver=lx.AutoLinearSolver(well_posed=False))
sol_2_dist = _dist(sol_2)

# 3rd: linear operator over `X`, then compose with X.T (ie Xop.T @ Xop) solve against Xty
sol_3 = lx.linear_solve(X_op.T @ X_op, Xty)
sol_3_dist = _dist(sol_3)

# 4th: linear operator over `X` then solve against `y` using NormalCG
solver = lx.NormalCG(atol=1e-4, rtol=1e-4)
sol_4 = lx.linear_solve(X_op, y, solver=solver)
sol_4_dist = _dist(sol_4)
print(f" XtX Xty = {sol_1_dist} | X y = {sol_2_dist} | XtX Xty #2 = {sol_3_dist} | XtX Xty CG = {sol_4_dist}")

 XtX Xty = 4.768477916717529 | X y = 4.768476486206055 | XtX Xty #2 = 4.768477916717529 | XtX Xty CG = 4.768479347229004
