In [14]:
import numpy as np
import timeit
import torch

try:
    import cupy as cp
except ImportError:
    cp = None

try:
    import jax.numpy as jnp
    from jax import jit
    import jax
except ImportError:
    jnp = None

import pandas as pd
from IPython.display import display


In [15]:
# Create a benchmarking function
def benchmark(func, setup="pass", number=10):
    time = timeit.timeit(func, setup=setup, number=number)
    return round(time / number * 1000, 4)  # ms per run


In [16]:
# Prepare the benchmarking data
results = []

size = 1000

In [17]:
# NumPy
numpy_setup = f"import numpy as np; x = np.random.rand({size}, {size}); y = np.random.rand({size}, {size})"
numpy_func = "np.dot(x, y)"
numpy_time = benchmark(numpy_func, numpy_setup)
results.append(('NumPy', 'dot', size, numpy_time))

In [20]:
# Torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_func = lambda: torch.sum(torch.rand(size, size) * torch.rand(size, size))
torch_time = benchmark("torch_func()", setup="from __main__ import torch_func")
results.append(('Torch', 'dot', size, torch_time))

In [21]:
# CuPy
if cp:
    cupy_func = lambda: cp.dot(cp.random.rand(size, size), cp.random.rand(size, size))
    cupy_time = benchmark("cupy_func()", setup="from __main__ import cupy_func")
    results.append(('CuPy', 'dot', size, cupy_time))
else:
    results.append(('CuPy (not installed)', 'dot', size, None))

In [23]:
# JAX
def benchmark(func, setup="pass", number=10):
    return round(timeit.timeit(func, setup=setup, number=number, globals=globals()) / number * 1000, 4)

if jnp:
    @jax.jit
    def jax_dot(x, y):
        return jnp.dot(x, y)

    def jax_func():
        key = jax.random.PRNGKey(0)
        key_x, key_y = jax.random.split(key)
        x = jax.random.uniform(key_x, shape=(size,size))
        y = jax.random.uniform(key_y, shape=(size,size))
        return jax_dot(x, y).block_until_ready()

    jax_time = benchmark("jax_func()")
    results.append(('JAX', 'dot (JIT)', size, jax_time))
else:
    results.append(('JAX (not installed)', 'dot', size, None))

In [24]:
# Create a DataFrame for presentation
df = pd.DataFrame(results, columns=["Library", "Operation", "Array Size", "Time (ms)"])
import ace_tools_open as tools; tools.display_dataframe_to_user(name="Array Dot Product Benchmark", dataframe=df)


Array Dot Product Benchmark


Library,Operation,Array Size,Time (ms)
Loading ITables v2.3.0 from the internet... (need help?),,,
