In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit as jaxjit
from numba import njit as numbajit

import time
from functools import wraps

import plotly.graph_objects as go

# Test JAX

In [2]:
# Retrieve the list of available devices
devices = jax.devices()

# Display information about each device
cpu_device = jax.devices("cpu")[0]
print(cpu_device)
for device in devices:
    print(f"Device ID: {device.id}, Platform: {device.platform}, Device Kind: {device.device_kind}")



TFRT_CPU_0
Device ID: 0, Platform: cpu, Device Kind: cpu


# Decorator to mesure execution time

In [3]:
def measure_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Execution time for '{func.__name__}': {elapsed_time:.6f} seconds")
        return result, elapsed_time
    return wrapper

# Sum of squares

## Functions

### Python

In [4]:
@measure_time
def sum_of_squares_python(x):
    total = 0
    for i in range(len(x)):
        total += x[i] ** 2
    return total

### Numpy

In [5]:
@measure_time
def sum_of_squares_numpy(x):
    return np.sum(x ** 2)

### Numba

In [6]:
@measure_time
@numbajit
def sum_of_squares_jit(x):
    total = 0
    for i in range(len(x)):
        total += x[i] ** 2
    return total

### JAX CPU

In [7]:
@measure_time
def sum_of_squares_jax_cpu(x):
    return jnp.sum(x ** 2)

### JAX CPU JIT

In [8]:
@measure_time
@jaxjit
def sum_of_squares_jax_jit_cpu(x):
    return jnp.sum(x ** 2)

## Comparison

In [9]:
x_numpy = np.random.normal(0, 1, 10000000)
x_python = list(x_numpy)
x_jax = jnp.array(x_numpy)

In [10]:
# Python
_, time_python = sum_of_squares_python(x_python)

# Numpy
_, time_numpy = sum_of_squares_numpy(x_numpy)

# Numba
print("Numba run 1")
_, time_numba_1 = sum_of_squares_jit(x_python)

print("Numba run 2")
_, time_numba_2 = sum_of_squares_jit(x_python)

# JAX CPU
_, time_jax_cpu = sum_of_squares_jax_cpu(x_jax)

# JAX CPU with JIT
print("JAX CPU JIT run 1")
_, time_jax_cpu_jit_1 = sum_of_squares_jax_jit_cpu(x_jax)
print("JAX CPU JIT run 2")
_, time_jax_cpu_jit_2 = sum_of_squares_jax_jit_cpu(x_jax)

Execution time for 'sum_of_squares_python': 1.270238 seconds
Execution time for 'sum_of_squares_numpy': 0.029028 seconds
Numba run 1
Execution time for 'sum_of_squares_jit': 11.787377 seconds
Numba run 2
Execution time for 'sum_of_squares_jit': 11.664075 seconds
Execution time for 'sum_of_squares_jax_cpu': 0.066148 seconds
JAX CPU JIT run 1
Execution time for 'sum_of_squares_jax_jit_cpu': 0.035758 seconds
JAX CPU JIT run 2
Execution time for 'sum_of_squares_jax_jit_cpu': 0.000058 seconds


## Plot comparison

In [None]:
labels = [
    "Python",
    "Numpy",
    "Numba run 1",
    "Numba run 2",
    "JAX CPU",
    "JAX CPU JIT run 1",
    "JAX CPU JIT run 2",
]

times = [
    time_python,
    time_numpy,
    time_numba_1,
    time_numba_2,
    time_jax_cpu,
    time_jax_cpu_jit_1,
    time_jax_cpu_jit_2,
]

fig = go.Figure(data=[go.Bar(
    x=labels,
    y=times,
    text=[f'{t:.6f}' for t in times],
    textposition='auto',
    marker_color='skyblue'
)])

fig.update_layout(
    title=f"Execution Time Comparison of Sum of Squares Implementations - {x_numpy.size} iterations",
    xaxis_title="Implementation",
    yaxis_title="Execution Time (seconds)",
    yaxis=dict(type="log"),
    template="plotly_white"
)

fig.write_html("sum_of_squares_benchmark_cpu_only.html")
fig.write_image("sum_of_squares_benchmark_cpu_only.png")
fig.show()