# Basic Matrix Operations with JAX

In [1]:
import jax as J
import jax.numpy as jnp

## Addition and Subtraction

In [2]:
A = jnp.array(
    [
        [4, 5, 6, 8, 9],
        [3, 5, 7, 8, 3],
        [1, 4, 6, 0, 2]
    ]
)

A

DeviceArray([[4, 5, 6, 8, 9],
             [3, 5, 7, 8, 3],
             [1, 4, 6, 0, 2]], dtype=int32)

In [3]:
B = jnp.array(
    [
        [10, 1, 6, 8, 9],
        [3, 12, 7, 8, 3],
        [1, 4, 6, -1, 2]
    ]
)

B

DeviceArray([[10,  1,  6,  8,  9],
             [ 3, 12,  7,  8,  3],
             [ 1,  4,  6, -1,  2]], dtype=int32)

In [4]:
# addition

A_plus_B = jnp.add(A, B)
A_plus_B

DeviceArray([[14,  6, 12, 16, 18],
             [ 6, 17, 14, 16,  6],
             [ 2,  8, 12, -1,  4]], dtype=int32)

In [5]:
# subtraction

A_minus_B = jnp.subtract(A, B)
A_minus_B

DeviceArray([[-6,  4,  0,  0,  0],
             [ 0, -7,  0,  0,  0],
             [ 0,  0,  0,  1,  0]], dtype=int32)

## Multiplication

In [6]:
# check shapes first!
print(f"A : {A.shape}  ||| B : {B.shape}")

A : (3, 5)  ||| B : (3, 5)


The dimensions of A and B don't satisfy the rules of Matrix multiplication. 3X5 does not multiply with 3X5.
So let's create a new matrix here  with 5X3 dimension.

In [7]:
C = jnp.array(
    [
        [1, 2, 3],
        [2, 3, 5],
        [3, 4, 6],
        [4, 5, 7],
        [5, 6, 8],
    ]
)

C

DeviceArray([[1, 2, 3],
             [2, 3, 5],
             [3, 4, 6],
             [4, 5, 7],
             [5, 6, 8]], dtype=int32)

In [8]:
A_mult_C = jnp.dot(A, C)
A_mult_C

DeviceArray([[109, 141, 201],
             [ 81, 107, 156],
             [ 37,  50,  75]], dtype=int32)

## Inverting

Erm, we need a invertible matrix here. You can't just invert every matrix out there! An invertible matrix has two properties:

1. It's a square matrix (equal number of rows and columns)
2. The determinant of the matrix is NOT 0

[What is an invertible matrix](https://www.studypug.com/algebra-help/2-x-2-invertible-matrix#:~:text=An%20invertible%20matrix%20is%20a,the%20matrix%20is%20not%200.)

Luckily `A_mult_C` is a square matrix. Let's see if this can be inverted. 

In [9]:
D = A_mult_C

determinant = jnp.linalg.det(D)

if determinant != 0:
    print("Invertible")
else:
    print("Not Invertible")

Invertible


In [10]:
# time to invert
inverse = jnp.linalg.inv(D)
inverse

DeviceArray([[ 2.4192986 , -5.6450286 ,  5.2579417 ],
             [-3.257984  ,  7.9352956 , -7.7740183 ],
             [ 0.97846866, -2.5053158 ,  2.6020942 ]], dtype=float32)