<a href="https://colab.research.google.com/github/trefftzc/cis677/blob/main/Matrix_multiplication_in_python.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Matrix Multiplication in Python

Matrix multiplication is a fundamental operation in linear algebra.

The algorithm has been carefully optimized for many different environments.

The code below is a starightforward implementation of matrix implementation that follows the definition of the operation.

In [1]:
#
# A very simple initial version
#
import time
A = [[1,2,3,4,5],
  [6,7,8,9,10],
  [11,12,13,14,15],
  [16,17,18,19,20],
  [21,22,23,24,25]]

B = [[1,6,11,16,21],
     [2,7,12,17,22],
     [3,8,13,18,23],
     [4,9,14,19,24],
     [5,10,15,20,25]]

C = [[0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0]]


start_time = time.time()
N = 5
for i in range(N):
  for j in range(N):
    for k in range(N):
      C[i][j] += A[i][k]*B[k][j]

end_time = time.time()
elapsed_time = end_time - start_time
print("Time required to carry out the computation: ",elapsed_time)

for i in range(N):
  for j in range(N):
    print(C[i][j]," ",end="")
  print("\n")



Time required to carry out the computation:  0.0
55  130  205  280  355  

130  330  530  730  930  

205  530  855  1180  1505  

280  730  1180  1630  2080  

355  930  1505  2080  2655  



# Numpy
Numpy is a python library that has been written to carry out numerical linear algebra operations, including matrix multiplications.

The code below calls the function matmul in the numpy library to perform a matrix multiplication between two matrices.

The original two matrices are convert into matrices in numpy format using the array() function in numpy.

In [3]:
#
# Now with numpy
#
import time
import numpy as np

A = [[1,2,3,4,5],
  [6,7,8,9,10],
  [11,12,13,14,15],
  [16,17,18,19,20],
  [21,22,23,24,25]]

B = [[1,6,11,16,21],
     [2,7,12,17,22],
     [3,8,13,18,23],
     [4,9,14,19,24],
     [5,10,15,20,25]]

C = [[0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0]]

N = 5

A_numpy = np.array(A)
B_numpy = np.array(B)
C_numpy = np.array(C)

start_time = time.time()

np.matmul(A_numpy,B_numpy,out=C_numpy)

end_time = time.time()
elapsed_time = end_time - start_time
print("Time required to carry out the computation: ",elapsed_time)

for i in range(N):
  for j in range(N):
    print(C_numpy[i][j]," ",end="")
  print("\n")


Time required to carry out the computation:  9.942054748535156e-05
55  130  205  280  355  

130  330  530  730  930  

205  530  855  1180  1505  

280  730  1180  1630  2080  

355  930  1505  2080  2655  



# JAX

JAX is a python library that allows a python programmer to program Google's TPUs (Tensor Processing Units) or NVIDIA's GPUs (Graphical Processing Units).

JAX offers, among other features, a set of compatible calls with numpy, including creating a JAX array out of a python array and matrix multiplication.


Now a version with JAX

In [4]:
#
# Now with JAX
#
import jax
import jax.numpy as jnp
import time

A = [[1,2,3,4,5],
  [6,7,8,9,10],
  [11,12,13,14,15],
  [16,17,18,19,20],
  [21,22,23,24,25]]

B = [[1,6,11,16,21],
     [2,7,12,17,22],
     [3,8,13,18,23],
     [4,9,14,19,24],
     [5,10,15,20,25]]

C = [[0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0],
     [0,0,0,0,0]]

N = 5

A_jaxnumpy = jnp.array(A)
B_jaxnumpy = jnp.array(B)
C_jaxnumpy = jnp.array(C)


start_time = time.time()
C_jaxnumpy = jnp.matmul(A_jaxnumpy,B_jaxnumpy)
end_time = time.time()
elapsed_time = end_time - start_time
print("Time required to carry out the computation: ",elapsed_time)


for i in range(N):
  for j in range(N):
    print(C_jaxnumpy[i][j]," ",end="")
  print("\n")

Time required to carry out the computation:  0.041764259338378906
55  130  205  280  355  

130  330  530  730  930  

205  530  855  1180  1505  

280  730  1180  1630  2080  

355  930  1505  2080  2655  

