In [5]:
import time as time

import numpy as np
import tensorflow as tf
import torch as pt
import jax as j
from jax.lib import xla_bridge
import jax.tools.colab_tpu
#jax.tools.colab_tpu.setup_tpu()

#Declaration of constants
MATRIX_WIDTH = 2**10  #Width of the (square) matrices we're multiplying
NUM_LOOPS = 10         #Number of times to repeat the test
MULTS_PER_LOOP = 10  #Number of times to multiply per test
NUM_TESTS = 1

#Matrix generation; makes two uniformly distributed square matrices made
#up of floating point numbers that are >= 0 and < 10. 
matrix1 = np.random.uniform(low = 0, high = 10,
                            size = (MATRIX_WIDTH, MATRIX_WIDTH))
matrix2 = np.random.uniform(low = 0, high = 10,
                            size = (MATRIX_WIDTH, MATRIX_WIDTH))


#Tensorflow zone
def tensorflowBenchmark(matrix1, matrix2, numLoops = 1, multsPerIter = 1):
 return 0


#Pytorch zone; returns the avg time per loop
def pytorchBenchmark(matrix1, matrix2, numLoops = 1, multsPerIter = 1):

  ptMatrix1 = pt.from_numpy(matrix1)
  ptMatrix2 = pt.from_numpy(matrix2)
  ptTotalTime = 0

  for i in range(numLoops):
      ptStartTime = time.time()
      for n in range(multsPerIter):
        pt.matmul(ptMatrix1, ptMatrix2)
      ptEndtime = time.time()
      ptTotalTime += ptEndtime - ptStartTime

  ptAvgTime = ptTotalTime / numLoops
  return ptAvgTime


#Jax zone
def jaxBenchmark(matrix1, matrix2, numLoops = 1, multsPerIter = 1):
  print(xla_bridge.get_backend().platform)
  jTotalTime = 0
  for i in range(numLoops):
      jStartTime = time.time()
      for n in range(multsPerIter):
          j.numpy.matmul(matrix1, matrix2)
      jEndTime = time.time()
      jTotalTime += jEndTime - jStartTime
  return (jTotalTime/numLoops)


#Numpy zone
def numpyBenchmark(matrix1, matrix2, numLoops = 1, multsPerIter = 1):
  npMatrix1 = matrix1
  npMatrix2 = matrix2
  npTotalTime = 0
  
  for i in range(numLoops):
    npStartTime = time.time()
    for n in range(multsPerIter):
      np.matmul(npMatrix1, npMatrix2)
    npEndtime = time.time()
    npTotalTime += npEndtime - npStartTime
  
  npAvgTime = npTotalTime / numLoops
  return npAvgTime


#Gather results
tfAvg = tensorflowBenchmark(matrix1, matrix2, numLoops = NUM_TESTS, multsPerIter = MULTS_PER_LOOP)
ptAvg = pytorchBenchmark(matrix1, matrix2, numLoops = NUM_TESTS, multsPerIter = MULTS_PER_LOOP)
jAvg = jaxBenchmark(matrix1, matrix2, numLoops = NUM_TESTS, multsPerIter = MULTS_PER_LOOP)
nAvg = numpyBenchmark(matrix1, matrix2, numLoops = NUM_TESTS, multsPerIter = MULTS_PER_LOOP)

#Print results
print("Tensorflow average time:\t" + str(tfAvg))
print("Pytorch average time:\t\t" + str(ptAvg))
print("Jax average time:\t\t" + str(jAvg))
print("Numpy average time:\t\t" + str(nAvg))

ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.