In [1]:
import math
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit as jaxjit
from numba import njit as numbajit

import time
from functools import wraps

import plotly.graph_objects as go

# Test JAX

In [2]:
# Retrieve the list of available devices
devices = jax.devices()

# Display information about each device
cpu_device = jax.devices("cpu")[0]
print(cpu_device)
for device in devices:
    print(f"Device ID: {device.id}, Platform: {device.platform}, Device Kind: {device.device_kind}")



TFRT_CPU_0
Device ID: 0, Platform: cpu, Device Kind: cpu


# Decorator to mesure execution time

In [3]:
def measure_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Execution time for '{func.__name__}': {elapsed_time:.6f} seconds")
        return result, elapsed_time
    return wrapper

# Sinus

## Functions

### Python

In [4]:
@measure_time
def sinus_python(x):
    x = x.copy()
    for i in range(len(x)):
        x[i] = math.sin(x[i])
    return x

### Numpy

In [5]:
@measure_time
def sinus_numpy(x):
    return np.sin(x)

### Numba

In [6]:
@measure_time
@numbajit
def sinus_numba(x):
    x = x.copy()
    for i in range(len(x)):
        x[i] = math.sin(x[i])
    return x

## Numba + Numpy

In [7]:
@measure_time
@numbajit
def sinus_numba_numpy(x):
    return np.sin(x)

### JAX CPU

In [8]:
@measure_time
def sinus_jax_cpu(x):
    return jnp.sin(x)

### JAX CPU JIT

In [9]:
@measure_time
@jaxjit
def sinus_jax_jit_cpu(x):
    return jnp.sin(x)

## Comparison

In [10]:
x_numpy = np.random.normal(0, 2*np.pi, 100000000)
x_python = np.copy(x_numpy)
x_jax = jnp.array(x_numpy)

In [11]:
# Python
_, time_python = sinus_python(x_python)

# Numpy
_, time_numpy = sinus_numpy(x_numpy)

# Numba
print("Numba run 1")
_, time_numba_1 = sinus_numba(x_python)

print("Numba run 2")
_, time_numba_2 = sinus_numba(x_python)

# Numba with numpy
print("Numba with numpy run 1")
_, time_numba_numpy_1 = sinus_numba_numpy(x_numpy)
print("Numba with numpy run 2")
_, time_numba_numpy_2 = sinus_numba_numpy(x_numpy)

# JAX CPU
_, time_jax_cpu = sinus_jax_cpu(x_jax)

# JAX CPU with JIT
print("JAX CPU JIT run 1")
_, time_jax_cpu_jit_1 = sinus_jax_jit_cpu(x_jax)
print("JAX CPU JIT run 2")
_, time_jax_cpu_jit_2 = sinus_jax_jit_cpu(x_jax)

Execution time for 'sinus_python': 19.641624 seconds
Execution time for 'sinus_numpy': 2.413217 seconds
Numba run 1
Execution time for 'sinus_numba': 3.076962 seconds
Numba run 2
Execution time for 'sinus_numba': 3.059247 seconds
Numba with numpy run 1
Execution time for 'sinus_numba_numpy': 3.074860 seconds
Numba with numpy run 2
Execution time for 'sinus_numba_numpy': 3.025301 seconds
Execution time for 'sinus_jax_cpu': 0.334587 seconds
JAX CPU JIT run 1
Execution time for 'sinus_jax_jit_cpu': 0.332503 seconds
JAX CPU JIT run 2
Execution time for 'sinus_jax_jit_cpu': 0.317630 seconds


## Plot comparison

In [12]:
labels = [
    "Python",
    "Numpy",
    "Numba run 1",
    "Numba run 2",
    "Numba with numpy run 1",
    "Numba with numpy run 2",
    "JAX CPU",
    "JAX CPU JIT run 1",
    "JAX CPU JIT run 2",
]

times = [
    time_python,
    time_numpy,
    time_numba_1,
    time_numba_2,
    time_numba_numpy_1,
    time_numba_numpy_2,
    time_jax_cpu,
    time_jax_cpu_jit_1,
    time_jax_cpu_jit_2,
]

fig = go.Figure(data=[go.Bar(
    x=labels,
    y=times,
    text=[f'{t:.6f}' for t in times],
    textposition='auto',
    marker_color='skyblue'
)])

fig.update_layout(
    title=f"Execution Time Comparison of Sinus Implementations - {x_numpy.size} size",
    xaxis_title="Implementation",
    yaxis_title="Execution Time (seconds)",
    yaxis=dict(type="log"),
    template="plotly_white"
)

fig.write_html("sinus_benchmark_CPU_only.html")
fig.write_image("sinus_benchmark_CPU_only.png")
fig.show()
