# Numba for Scientific Computing

Demonstrates jit compilation, vectorization, parallelization, benchmarking, and inline correctness checks.

In [None]:
import math
import time
from typing import Callable, Tuple

import numba as nb
import numpy as np

np.random.seed(42)

## Baseline NumPy operations
We start with pure NumPy helpers to compare against Numba.

In [None]:
def pairwise_l2_numpy(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    diff = a[:, None, :] - b[None, :, :]
    return np.sqrt((diff * diff).sum(axis=2))


def row_sums_numpy(x: np.ndarray) -> np.ndarray:
    return x.sum(axis=1)


def make_arrays(n: int = 300, d: int = 32) -> Tuple[np.ndarray, np.ndarray]:
    a = np.random.randn(n, d).astype(np.float64)
    b = np.random.randn(n, d).astype(np.float64)
    return a, b


a_ref, b_ref = make_arrays(64, 16)

## `njit` for faster loops
Use nopython mode to compile tight loops without Python overhead.

In [None]:
@nb.njit
def pairwise_l2_numba(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    n_a, d = a.shape
    n_b = b.shape[0]
    out = np.empty((n_a, n_b), dtype=np.float64)
    for i in range(n_a):
        for j in range(n_b):
            s = 0.0
            for k in range(d):
                diff = a[i, k] - b[j, k]
                s += diff * diff
            out[i, j] = math.sqrt(s)
    return out


@nb.njit
def row_sums_numba(x: np.ndarray) -> np.ndarray:
    n, d = x.shape
    out = np.empty(n, dtype=np.float64)
    for i in range(n):
        s = 0.0
        for j in range(d):
            s += x[i, j]
        out[i] = s
    return out

# Warm-up compilation
_ = pairwise_l2_numba(a_ref, b_ref)
_ = row_sums_numba(a_ref)

## `vectorize` / `guvectorize`
Vectorization creates ufuncs; generalized ufuncs operate over core dimensions.

In [None]:
@nb.vectorize(['float64(float64, float64)'], nopython=True)
def add_vec(x, y):
    return x + y


@nb.guvectorize(['void(float64[:], float64[:])'], '(d)->()', nopython=True)
def row_sum_gu(x, out):
    s = 0.0
    for i in range(x.shape[0]):
        s += x[i]
    out[0] = s


sample_small = np.random.randn(4, 3)
add_vec(sample_small, sample_small)  # warm-up
row_sum_gu(sample_small)  # warm-up

## Parallel `prange`
Parallelize independent iterations with `parallel=True` and `prange`.

In [None]:
@nb.njit(parallel=True)
def row_sums_parallel(x: np.ndarray) -> np.ndarray:
    n, d = x.shape
    out = np.empty(n, dtype=np.float64)
    for i in nb.prange(n):
        s = 0.0
        for j in range(d):
            s += x[i, j]
        out[i] = s
    return out

row_sums_parallel(a_ref)  # warm-up

## Benchmark helpers
Measure best-of runs to smooth jitter and skip first-call compile cost.

In [None]:
def time_best_of(fn: Callable, *args, repeat: int = 3, number: int = 1) -> float:
    best = float("inf")
    for _ in range(repeat):
        start = time.perf_counter()
        for _ in range(number):
            fn(*args)
        best = min(best, (time.perf_counter() - start) / number)
    return best


def benchmark_row_sums(x: np.ndarray):
    np_t = time_best_of(row_sums_numpy, x)
    njit_t = time_best_of(row_sums_numba, x)
    par_t = time_best_of(row_sums_parallel, x)
    return {
        "numpy_s": np_t,
        "numba_njit_s": njit_t,
        "numba_parallel_s": par_t,
        "speedup_parallel_vs_numpy": np_t / par_t,
    }


def benchmark_pairwise(a: np.ndarray, b: np.ndarray):
    np_t = time_best_of(pairwise_l2_numpy, a, b, repeat=2)
    njit_t = time_best_of(pairwise_l2_numba, a, b, repeat=2)
    return {
        "numpy_s": np_t,
        "numba_njit_s": njit_t,
        "speedup": np_t / njit_t,
    }

small_x = np.random.randn(256, 32)
bench_row = benchmark_row_sums(small_x)
bench_pair = benchmark_pairwise(a_ref, b_ref)
bench_row, bench_pair

## Correctness checks
Compare Numba outputs to NumPy baselines within a tight tolerance.

In [None]:
np.testing.assert_allclose(pairwise_l2_numpy(a_ref, b_ref), pairwise_l2_numba(a_ref, b_ref))
np.testing.assert_allclose(row_sums_numpy(a_ref), row_sums_numba(a_ref))
np.testing.assert_allclose(row_sums_numpy(a_ref), row_sums_parallel(a_ref))
np.testing.assert_allclose(row_sums_numpy(a_ref), row_sum_gu(a_ref))
print("Correctness checks passed.")

## Inspect types and layout
Check compiled signatures and ensure contiguous input for best speed.

In [None]:
print("pairwise signatures:", pairwise_l2_numba.signatures)
print("row_sums signatures:", row_sums_numba.signatures)
print("parallel signatures:", row_sums_parallel.signatures)
print("vectorize types:", add_vec.types)
print("guvectorize signatures:", row_sum_gu.signatures)

print("a_ref contiguous:", a_ref.flags["C_CONTIGUOUS"], "dtype", a_ref.dtype)

## Save benchmark results
Persist a small CSV for later inspection.

In [None]:
import csv

rows = [
    {"case": "row_sums", **bench_row},
    {"case": "pairwise", **bench_pair},
]

csv_path = "numba_benchmarks.csv"
with open(csv_path, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
    writer.writeheader()
    writer.writerows(rows)

csv_path