In [1]:
import jax.numpy as jnp
import jax
from jax import lax
from jax import Array

In [2]:
# Calculate Eigenvalues of a Matrix

def calculate_eigenvalues(matrix: Array) -> Array:
    det = matrix[0, 0] * matrix[1, 1] - matrix[0, 1] * matrix[1, 0]
    trace = matrix[0, 0] + matrix[1, 1]
    discriminant = jnp.sqrt(trace ** 2 - 4 * det)
    lambda1 = (trace + discriminant) / 2
    lambda2 = (trace - discriminant) / 2
    return jnp.array([lambda1, lambda2])

In [3]:
# Test the function
matrix = jnp.array([[2, 1], [1, 2]])
calculate_eigenvalues(matrix)

Array([3., 1.], dtype=float32)

In [6]:
# Matrix Transformation
def transform_matrix(A: Array, T: Array, S: Array) -> Array:
    if jnp.linalg.det(T) == 0 or jnp.linalg.det(S) == 0:
        return jnp.array([-1])
    T_inv = jnp.linalg.inv(T)
    return jnp.dot(jnp.dot(T_inv, A), S)

In [7]:
# Test the function
A = jnp.array([[1, 2], [3, 4]])
T = jnp.array([[2, 0], [0, 2]])
S = jnp.array([[1, 1], [0, 1]])
transform_matrix(A, T, S)

Array([[0.5, 1.5],
       [1.5, 3.5]], dtype=float32)

In [8]:
# Calculate 2x2 Matrix Inverse 
def inverse_2x2(matrix: Array) -> Array:
    det = matrix[0, 0] * matrix[1, 1] - matrix[0, 1] * matrix[1, 0]
    if det == 0:
        return jnp.array([-1])
    A_inv = (jnp.array([[matrix[1, 1], -matrix[0, 1]], [-matrix[1, 0], matrix[0, 0]]])) / det
    return A_inv

In [9]:
# Test the function
matrix = jnp.array([[4, 7], [2, 6]])
inverse_2x2(matrix)

Array([[ 0.6, -0.7],
       [-0.2,  0.4]], dtype=float32)

In [12]:
# Matrix times Matrix
def matrix_times_matrix(A: Array, B: Array) -> Array:
    if A.shape[1] != B.shape[0]:
        return jnp.array([-1])
    C = jnp.zeros([A.shape[0], B.shape[1]], dtype=A.dtype)
    for i in range(A.shape[0]):
        for j in range(B.shape[1]):
            C = C.at[i, j].set(jnp.dot(A[i, :], B[:, j]))
    return C

In [13]:
# Test the function matrix_times_matrix 
A = jnp.array([[1,2],
               [2,4]])
B = jnp.array([[2,1],
               [3,4]])
matrix_times_matrix(A, B)

Array([[ 8,  9],
       [16, 18]], dtype=int32)

In [14]:
# Test the function matrix_times_matrix 
A = jnp.array([[1,2],
               [2,4]])
B = jnp.array([[2,1],
               [3,4],
               [4,5]])
matrix_times_matrix(A, B)

Array([-1], dtype=int32)

In [31]:
# Calculate Covariance Matrix
def calculate_covariance_matrix(X: Array) -> Array:
    n_samples, n_features = X.shape
    mean = jnp.mean(X, axis=0)
    covariance_matrix = jnp.zeros([n_features, n_features])
    for i in range(n_features):
        for j in range(i, n_features):
            covariance = sum((X[i][k] - mean[i]) * (X[j][k] - mean[j]) for k in range(n_features)) / (n_features - 1)
            covariance_matrix = covariance_matrix.at[i, j].set(covariance)
            covariance_matrix = covariance_matrix.at[j, i].set(covariance)
    return covariance_matrix
            
                              

In [32]:
# Test the function calculate_covariance_matrix
X = jnp.array([[1, 2, 3], [4, 5, 6]])
calculate_covariance_matrix(X)

[2.5 3.5 4.5]
1.375
-0.125
0.625
4.375
2.125
1.375


Array([[ 1.375, -0.125,  0.625],
       [-0.125,  4.375,  2.125],
       [ 0.625,  2.125,  1.375]], dtype=float32)

In [None]:
# Calculate Correlation Matrix
def calculate_correlation_matrix(X: Array) -> Array:
    n = X.shape[0]
    X_mean = jnp.mean(X, axis=0)
    X_centered = X - X_mean
    covariance_matrix = jnp.dot(X_centered.T, X_centered) / n
    correlation_matrix = covariance_matrix / jnp.sqrt(jnp.dot(jnp.diag(covariance_matrix), jnp.diag(covariance_matrix).T))
    return correlation_matrix