In [2]:
import jax
import numpy as np
import matplotlib.pyplot as plt

In [5]:
def matrix_inverse(M, y):
    """Cholesky factorisation based matrix inversion.
    Takes advantage of positive semi-definiteness. If AM=y, solves for A
    """
    L = jax.scipy.linalg.cholesky(M, lower=True)

    # First solve Lz = y for z
    z = jax.scipy.linalg.solve_triangular(L, y, lower=True)

    # Then solve L^T A^T = z for A^T
    A = jax.scipy.linalg.solve_triangular(L.T, z, lower=False).T

    return A


# Test cases
test_M = np.array([[4, 2], [2, 5]])  # symmetric positive definite matrix
test_y = np.array([[1], [2]])

# Compare results
chol_inv = matrix_inverse(test_M, test_y)
direct_inv = np.linalg.solve(test_M, test_y)

print("Matrix M:")
print(test_M)
print("\nVector y:")
print(test_y)
print("\nCholesky-based solution:")
print(chol_inv)
print("\nDirect solve solution:")
print(direct_inv)
print("\nDifference:")
print(np.abs(chol_inv.flatten() - direct_inv.flatten()).max())


Matrix M:
[[4 2]
 [2 5]]

Vector y:
[[1]
 [2]]

Cholesky-based solution:
[[0.0625 0.375 ]]

Direct solve solution:
[[0.0625]
 [0.375 ]]

Difference:
0.0
