In [1]:
import math
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit as jaxjit
from numba import njit as numbajit

import time
from functools import wraps

import plotly.graph_objects as go

# Test JAX

In [2]:
# Retrieve the list of available devices
devices = jax.devices()

# Display information about each device
cpu_device = jax.devices("cpu")[0]
print(cpu_device)
for device in devices:
    print(f"Device ID: {device.id}, Platform: {device.platform}, Device Kind: {device.device_kind}")



TFRT_CPU_0
Device ID: 0, Platform: cpu, Device Kind: cpu


# Decorator to mesure execution time

In [3]:
def measure_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Execution time for '{func.__name__}': {elapsed_time:.6f} seconds")
        return result, elapsed_time
    return wrapper

# Sinus

## Functions

### Python

In [4]:
def test_dot_python(A, B):
    row_lenght_A = len(A[0])
    if row_lenght_A != len(B):
        raise ValueError("Matrix A and B cannot be multiplied")

@measure_time
def dot_python(A, B):
    test_dot_python(A, B)
    
    result_rows = len(A)
    result_cols = len(B[0])
    result = [[0 for _ in range(result_cols)] for _ in range(result_rows)]
    
    for i in range(result_rows):
        for j in range(result_cols):
            for k in range(len(B)):
                result[i][j] += A[i][k] * B[k][j]
                
    return result

### Numpy

In [5]:
@measure_time
def dot_numpy(A, B):
    return np.dot(A, B)

### Numba

In [6]:
@numbajit
def test_dot_numba(A, B):
    row_lenght_A = len(A[0])
    if row_lenght_A != len(B):
        raise ValueError("Matrix A and B cannot be multiplied")

@measure_time
def dot_numba(A, B):
    test_dot_numba(A, B)
    
    result_rows = len(A)
    result_cols = len(B[0])
    result = [[0 for _ in range(result_cols)] for _ in range(result_rows)]
    
    for i in range(result_rows):
        for j in range(result_cols):
            for k in range(len(B)):
                result[i][j] += A[i][k] * B[k][j]
                
    return result

### JAX CPU

In [7]:
@measure_time
def dot_jax_cpu(A, B):
    return jnp.dot(A, B)

### JAX CPU JIT

In [8]:
@measure_time
@jaxjit
def dot_jax_jit_cpu(A, B):
    return jnp.dot(A, B)

## Comparison

In [9]:
n_rows, n_col = 1000, 1000

A_numpy = np.random.normal(0, 5, size=(n_rows, n_col))
B_numpy = np.random.normal(0, 5, size=(n_rows, n_col))

A_jax = jax.numpy.array(A_numpy)
B_jax = jax.numpy.array(B_numpy)

A_python = list(A_numpy)
B_python = list(B_numpy)

In [10]:
# Python
_, time_python = dot_python(A_python, B_python)

# Numpy
_, time_numpy = dot_numpy(A_numpy, B_numpy)

# Numba
print("Numba run 1")
_, time_numba_1 = dot_numba(A_python, B_python)

print("Numba run 2")
_, time_numba_2 = dot_numba(A_python, B_python)

# JAX CPU
_, time_jax_cpu = dot_jax_cpu(A_jax, B_jax)

# JAX CPU with JIT
print("JAX CPU JIT run 1")
_, time_jax_cpu_jit_1 = dot_jax_jit_cpu(A_jax, B_jax)
print("JAX CPU JIT run 2")
_, time_jax_cpu_jit_2 = dot_jax_jit_cpu(A_jax, B_jax)

Execution time for 'dot_python': 262.590900 seconds
Execution time for 'dot_numpy': 0.048383 seconds
Numba run 1
Execution time for 'dot_numba': 291.645644 seconds
Numba run 2
Execution time for 'dot_numba': 298.468739 seconds
Execution time for 'dot_jax_cpu': 0.022761 seconds
JAX CPU JIT run 1
Execution time for 'dot_jax_jit_cpu': 0.018260 seconds
JAX CPU JIT run 2
Execution time for 'dot_jax_jit_cpu': 0.000099 seconds


## Plot comparison

In [None]:
labels = [
    "Python",
    "Numpy",
    "Numba run 1",
    "Numba run 2",
    "JAX CPU",
    "JAX CPU JIT run 1",
    "JAX CPU JIT run 2",
]

times = [
    time_python,
    time_numpy,
    time_numba_1,
    time_numba_2,
    time_jax_cpu,
    time_jax_cpu_jit_1,
    time_jax_cpu_jit_2,
]

fig = go.Figure(data=[go.Bar(
    x=labels,
    y=times,
    text=[f'{t:.6f}' for t in times],
    textposition='auto',
    marker_color='skyblue'
)])

fig.update_layout(
    title=f"Execution Time Comparison of Matrices Product Implementations - A {n_rows, n_col} @ B{n_rows, n_col}",
    xaxis_title="Implementation",
    yaxis_title="Execution Time (seconds)",
    yaxis=dict(type="log"),
    template="plotly_white"
)

fig.write_html("matrix_product_benchmark_CPU_only.html")
fig.write_image("matrix_product_benchmark_CPU_only.png")
fig.show()
