In [4]:
import numpy as np
from numba import njit, prange
import time


# NumPy implementation
def van_leer_numpy(a: np.ndarray, b: np.ndarray):
    result = np.zeros_like(a)
    # same_sign_mask = (a * b) > 0
    # a_masked = a[same_sign_mask]
    # b_masked = b[same_sign_mask]
    # result[same_sign_mask] = (2 * a_masked * b_masked) / (a_masked + b_masked)
    result = (2 * a * b) / (a + b)
    return result


# Numba implementation
@njit(parallel=True)
def van_leer_numba(a: np.ndarray, b: np.ndarray):
    result = np.zeros_like(a)
    for i in prange(a.shape[0]):
        for j in range(a.shape[1]):
            for k in range(a.shape[2]):
                result[i, j, k] = (2 * a[i, j, k] * b[i, j, k]) / (
                    a[i, j, k] + b[i, j, k]
                )
    # result[result < 0] = 0
    return result


# Test data
a = np.random.rand(10**3, 10**3, 5) - 0.5  # Random data with positive/negative values
b = np.random.rand(10**3, 10**3, 5) - 0.5

# Benchmark NumPy
start = time.time()
result_numpy = van_leer_numpy(a, b)
print(f"NumPy time: {time.time() - start:.4f} seconds")

# Benchmark Numba (first run includes compilation time)
start = time.time()
result_numba = van_leer_numba(a, b)
print(f"Numba time (first run): {time.time() - start:.4f} seconds")

# Benchmark Numba (subsequent runs)
start = time.time()
result_numba = van_leer_numba(a, b)
print(f"Numba time (second run): {time.time() - start:.4f} seconds")

# Verify correctness
assert np.allclose(result_numpy, result_numba), "Results differ!"

NumPy time: 0.0326 seconds
Numba time (first run): 0.6878 seconds
Numba time (second run): 0.0264 seconds
