In [1]:
import jax.numpy as jnp
import jax
from jax import lax
from jax import Array

In [2]:
# Matrix times vector
def multiply(a: Array, b: Array) -> Array:
    def body_fun(i, c):
        return lax.fori_loop(0, a.shape[1], lambda j, c: c.at[i].add(a[i, j] * b[j]), c)

    c = jnp.zeros(a.shape[0], dtype=a.dtype)
    c = lax.fori_loop(0, a.shape[0], body_fun, c)
    return c

def matrix_dot_vector(a: Array, b: Array) -> Array:
    error_value = jnp.full(a.shape[0], -1, dtype=a.dtype)

    c = lax.cond(
            jnp.equal(a.shape[1], b.shape[0]),
            lambda: multiply(a, b),
            lambda: error_value
    )
    return c

In [5]:
# Test matrix_dot_vector
a = jnp.array([[1,2],[2,4]])
b = jnp.array([1,2])
print(matrix_dot_vector(a,b))

[ 5 10]


In [6]:
# Transpose a matrix  
def transpose_matrix(a: Array) -> Array:
    def body_fun(i, b):
        return lax.fori_loop(0, a.shape[1], lambda j, b : b.at[j,i].set(a[i,j]), b)
    b = jnp.zeros([a.shape[1], a.shape[0]], dtype=a.dtype)
    b = lax.fori_loop(0, a.shape[0], body_fun, b)
    return b

In [7]:
# Test transpose_matrix
a = jnp.array([[1,2,3],[4,5,6]])
print(transpose_matrix(a))

[[1 4]
 [2 5]
 [3 6]]


In [51]:
# Reshape Matrix
def reshape_matrix(a:Array, shape: tuple) -> Array:
    """Reshapes a matrix into a new shape"""
    if jnp.prod(jnp.array(a.shape)) != jnp.prod(jnp.array(shape)):
        return jnp.full(shape, -1, dtype=a.dtype)
    temp_matrix = jnp.zeros(jnp.prod(jnp.array(a.shape)), dtype=a.dtype)
    index = 0
    for i in range(a.shape[0]):
        for j in range(a.shape[1]):
            temp_matrix = temp_matrix.at[index].set(a[i][j])
            index += 1
    reshaped_matrix = jnp.zeros(shape, dtype=a.dtype) 
    for i in range(shape[0]):
        for j in range(shape[1]):
            reshaped_matrix = reshaped_matrix.at[i,j].set(temp_matrix[i*shape[1] + j])
    return reshaped_matrix


In [52]:
# Test reshape_matrix
a = jnp.array([[1,2,3,4],[5,6,7,8]])
print(reshape_matrix(a, (4,2)))

[[1 2]
 [3 4]
 [5 6]
 [7 8]]


In [57]:
# Calculate Mean by Row or Column 
def calculate_matrix_mean(matrix: Array, mode: str) -> Array:
    """Calculates the mean of a matrix by row or column"""
    if mode == "row":
        mean = jnp.zeros(matrix.shape[0], dtype=jnp.float32)
        for i in range(matrix.shape[0]):
            mean = mean.at[i].set(jnp.mean(matrix[i]))
        return mean
    elif mode == "column":
        mean = jnp.zeros(matrix.shape[1], dtype=jnp.float32)
        for i in range(matrix.shape[1]):
            mean = mean.at[i].set(jnp.mean(matrix[:,i]))
        return mean
    else:
        return jnp.array([-1])
            
        
        

In [58]:
# Test calculate_matrix_mean
a = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(calculate_matrix_mean(a, "row"))
print(calculate_matrix_mean(a, "column"))

[2. 5. 8.]
[4. 5. 6.]


In [61]:
# Scalar Multiplication of a Matrix
def scalar_multiply(matrix: Array, scalar: int | float) -> Array:
    """Multiplies a matrix by a scalar"""
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            matrix = matrix.at[i,j].set(matrix[i,j] * scalar)
    return matrix

In [62]:
# Test scalar_multiply
a = jnp.array([[1,2,3],[4,5,6]])
print(scalar_multiply(a, 2))

[[ 2  4  6]
 [ 8 10 12]]


In [65]:
# Transformation Matrix from Basis B to C
def transform_basis(matrix: Array, basis: Array) -> Array:
    """Transforms a matrix from basis B to basis C"""
    matrix_inv = jnp.linalg.inv(matrix)
  
    return  jnp.dot(matrix_inv, basis)

In [66]:
# Test transform_basis
b = jnp.array([[1, 0, 0],
               [0, 1, 0],
               [0, 0, 1]])
a = jnp.array([[1, 2.3, 3],
               [4.4, 25, 6],
               [7.4, 8, 9]])
print(transform_basis(a, b))

[[-0.6772268  -0.01262626  0.23415977]
 [-0.01836547  0.05050505 -0.02754821]
 [ 0.5731558  -0.03451178 -0.05693296]]


In [82]:
# Calculate Determinant of a Matrix
def make_diagonal(matrix: Array) -> Array:
    """Calculates the determinant of a matrix"""
    identity_matrix = jnp.identity(jnp.size(matrix), dtype=matrix.dtype)
    return identity_matrix * matrix

In [83]:
# Test make_diagonal
a = jnp.array([1,2,3])
print(make_diagonal(a))

[[1 0 0]
 [0 2 0]
 [0 0 3]]


In [84]:
# Implement Compressed Row Sparse Matrix (CSR) Format Conversion
 
def compressed_row_sparse_matrix (dense_matrix: Array) -> [Array, Array, Array]:
    """Converts a dense matrix to a compressed row sparse matrix"""
    crs_row = []
    crs_col = []
    crs_data = []
    for i in range(dense_matrix.shape[0]):
        for j in range(dense_matrix.shape[1]):
            if dense_matrix[i,j] != 0:
                crs_data.append(dense_matrix[i,j])
                crs_col.append(j)
        crs_row.append(len(crs_data))
    return jnp.array(crs_data), jnp.array(crs_col), jnp.array(crs_row)

In [85]:
# Test compressed_row_sparse_matrix
a = jnp.array([
        [1, 0, 0, 0],
        [0, 2, 0, 0],
        [3, 0, 4, 0],
        [1, 0, 0, 5]
]
)
print(compressed_row_sparse_matrix(a))

(Array([1, 2, 3, 4, 1, 5], dtype=int32), Array([0, 1, 0, 2, 0, 3], dtype=int32), Array([1, 2, 4, 6], dtype=int32))


In [86]:
# Implement Orthogonal Projection of a Vector onto a Line
def orthogonal_projection(v: Array, L: Array) -> Array:
    """Projects a vector onto a line"""
    return jnp.dot(jnp.dot(v, L) / jnp.dot(L, L), L)

In [87]:
# Test orthogonal_projection
v = jnp.array([3, 4])
L = jnp.array([1, 0])
print(orthogonal_projection(v, L))

[3. 0.]


In [107]:
# Implement Compressed Column Sparse Matrix Format (CSC)
def compressed_col_sparse_matrix (dense_matrix: Array) -> [Array, Array, Array]:

    vals = []
    row_idx = []
    col_ptr = [0]

    rows, cols = len(dense_matrix), len(dense_matrix[0])

    for i in range(cols):
        for j in range(rows):
            val = dense_matrix[j][i]
            if val != 0:
                vals.append(val)
                row_idx.append(j)
        col_ptr.append(len(vals))

    return jnp.array(vals), jnp.array(row_idx), jnp.array(col_ptr)

In [108]:
# Test compressed_col_sparse_matrix
a = jnp.array([
        [0, 0, 3, 0],
        [1, 0, 0, 4],
        [0, 2, 0, 0]
]
)
print(compressed_col_sparse_matrix(a))

(Array([1, 2, 3, 4], dtype=int32), Array([1, 2, 0, 1], dtype=int32), Array([0, 1, 2, 3, 4], dtype=int32))
