### Using `numba.jit` to speedup the computation of the Euclidean distance matrix 


<a href="https://colab.research.google.com/github/Ziaeemehr/workshop_hpcpy/blob/main/notebooks/numba/euclidean-distance-matrix-numba.jit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


### Requirements for a Function to be Jittable by Numba

For a Python function to be successfully compiled and optimized by Numba's `@numba.jit` decorator, it must adhere to certain constraints. Numba translates Python code to efficient machine code, but it only supports a subset of Python features. Here are the key requirements:

#### 1. **Supported Data Types and Operations**
   - Use NumPy arrays, scalars (int, float, complex), and tuples.
   - Avoid Python lists, dictionaries, sets, or other dynamic containers (unless they are compile-time constants).
   - Stick to basic arithmetic, comparisons, and NumPy functions that Numba supports.


In [2]:
import numba 
@numba.jit
def add_arrays(a, b):
    return a + b  # Works with NumPy arrays

#### 2. **Static Typing**
   - Numba infers types at compile time. Avoid code that relies on dynamic typing or type changes at runtime.
   - Use explicit type annotations if needed (e.g., via `numba.types`).

   **How to use `numba.types`:**


In [None]:
import numba.types as types

@numba.jit(types.float64(types.float64, types.int32))
def explicit_types(x, n):
    """Function with explicit type signature."""
    return x ** n

# Or using string signatures (more common):
@numba.jit('f8(i4)')
def string_signature(x):
    return x * 2.0

#### 3. **Control Flow**
   - Loops (`for`, `while`), conditionals (`if`, `else`), and recursion are supported, but keep them simple.
   - Avoid exceptions or try-except blocks unless necessary.

In [4]:
@numba.jit
def sum_if_positive(arr):
    total = 0.0
    for val in arr:
        if val > 0:
            total += val
    return total

#### 4. **Function Calls**
   - Called functions must also be jittable or use `@numba.jit`.
   - Avoid calling arbitrary Python functions or libraries not supported by Numba.

In [5]:
@numba.jit
def helper(x):
    return x ** 2

@numba.jit
def caller(arr):
    return helper(arr[0])  # Calling another jitted function


#### 5. **No Side Effects**
   - Functions should be pure: no printing, file I/O, or global state modifications (except for arrays passed as arguments).
   - Avoid `print` statements in compiled code (use `nopython=False` if needed, but this reduces performance).


In [6]:
@numba.jit
def pure_add(a, b):
    return a + b  # No side effects, just computation

#### 6. **Compilation Modes**
   - Use `nopython=True` for best performance (forces "no Python" mode).
   - If `nopython=True` fails, Numba falls back to object mode, which is slower.



In [7]:
@numba.jit(nopython=True)
def fast_func(x):
    return x * 2  # Compiled without Python fallback

#### 7. **Parallelism**
   - For parallel execution, use `parallel=True` and `numba.prange` for loops.
   - Ensure no data races or dependencies between loop iterations.


In [None]:
import numba
numba.set_num_threads(4)

@numba.jit(parallel=True)
def parallel_sum(arr):
    total = 0.0
    for i in numba.prange(len(arr)): # Parallel execution
        total += arr[i]
    return total

@numba.jit(parallel=True)  
def sequential_sum(arr):
    total = 0.0
    for i in range(len(arr)):  # Sequential execution despite parallel=True
        total += arr[i]
    return total

- **Conditional logic**

In [None]:
@numba.jit(nopython=True)
def example_func(x, y):
    result = np.zeros_like(x)
    for i in range(len(x)):
        if x[i] > 0:
            result[i] = x[i] + y[i]
    return result


### Common Pitfalls
- Using unsupported NumPy functions (check Numba documentation).
- Creating lists with `append` (use pre-allocated arrays).
- Complex data structures or object-oriented code.

For more details, refer to the [Numba documentation](https://numba.readthedocs.io/en/stable/user/jit.html).

### Examples

In this notebook we implement a function to compute the Euclidean distance matrix using Numba's *just-in-time* compilation decorator. We compare it with the NumPy function we wrote before.

We will use two Numba functions here: The decorator ` @numba.jit` and `numba.prange`.


$$
d(i, j) = \sqrt{\sum_{k=1}^{m} \left( x_{ik} - y_{jk} \right)^2}
$$

In [None]:
import numpy as np
import numba

numba.set_num_threads(4)

In [None]:
@numba.jit(nopython=True, parallel=True)
def euclidean_numba_prange(x, y):
    """Implementation with numba using prange in inner loop."""

    nrows, ncols = x.shape
    dist_matrix = np.zeros((nrows, nrows))
    for i in range(nrows):
        for j in range(nrows):
            r = 0.0
            for k in numba.prange(ncols):
                r += (x[i][k] - y[j][k])**2
            dist_matrix[i][j] = r

    return dist_matrix


@numba.jit(nopython=True, parallel=True)
def euclidean_numba_vectorized(x, y):
    """Implementation with numba using vectorized numpy operations."""

    nrows, ncols = x.shape
    dist_matrix = np.zeros((nrows, nrows))
    for i in range(nrows):
        for j in numba.prange(nrows):
            dist_matrix[i][j] = ((x[i] - y[j])**2).sum()

    return dist_matrix

Let's include here our numpy implementation for comparison.

In [None]:
def euclidean_numpy(x, y):
    """Euclidean square distance matrix.
    
    Inputs:
    x: (N, m) numpy array
    y: (N, m) numpy array
    
    Ouput:
    (N, N) Euclidean square distance matrix:
    r_ij = (x_ij - y_ij)^2
    """

    x2 = np.einsum('ij,ij->i', x, x)[:, np.newaxis]
    y2 = np.einsum('ij,ij->i', y, y)[:, np.newaxis].T

    xy = np.dot(x, y.T)

    return np.abs(x2 + y2 - 2. * xy)

### Note
Observe that the inner loop, which is a reduction, is done with `numba.prange`. `numba.prange` automatically takes care of data privatization and reductions.

### Exercise 1
Before running the different functions, could you say which of the two numba implementations would be faster?

In [None]:
# Let's check that they all give the same result
rng = np.random.default_rng()
x = 10. * rng.random((100, 10))

print("Max diff numpy vs prange:", np.abs(euclidean_numpy(x, x) - euclidean_numba_prange(x, x)).max())
print("Max diff numpy vs vectorized:", np.abs(euclidean_numpy(x, x) - euclidean_numba_vectorized(x, x)).max())

Max diff numpy vs prange: 3.126388037344441e-13
Max diff numpy vs vectorized: 3.126388037344441e-13
Max diff numpy vs vectorized: 3.126388037344441e-13


Our Numba implementations can be faster than the NumPy one for a list of small vectors. However, with larger vectors, the NumPy implementation is faster:

In [None]:
import timeit

def timing(nrow, ncols):
    x = 10. * rng.random((nrow, ncols))
    
    times = []
    for name, func in [("numpy", lambda: euclidean_numpy(x, x)), 
                       ("prange", lambda: euclidean_numba_prange(x, x)), 
                       ("vectorized", lambda: euclidean_numba_vectorized(x, x))]:
        t = timeit.timeit(func, number=10)
        times.append((name, t / 10 * 1e6))  # Convert to microseconds
    
    print(f"{'Method':<12} {'Time (µs)':<10}")
    print("-" * 25)
    for name, t in times:
        print(f"{name:<12} {t:<10.2f}")

In [None]:
timing(100, 3)

Method       Time (µs) 
-------------------------
numpy        141.70    
prange       22675.43  
vectorized   234.54    


In a more realistic case, our NumPy implementation is much faster:

In [None]:
timing(100, 10)

Method       Time (µs) 
-------------------------
numpy        135.59    
prange       24748.91  
vectorized   257.67    


In [None]:
nrows = 1000
ncols = 50

x = 10. * rng.random((nrows, ncols))

times = []
for name, func in [("numpy", lambda: euclidean_numpy(x, x)), 
                   ("prange", lambda: euclidean_numba_prange(x, x)),
                   ("vectorized", lambda: euclidean_numba_vectorized(x, x))]:
    t = timeit.timeit(func, number=5)  # Fewer runs for larger arrays
    times.append((name, t / 5 * 1e3))  # Convert to milliseconds

print(f"{'Method':<12} {'Time (ms)':<10}")
print("-" * 25)
for name, t in times:
    print(f"{name:<12} {t:<10.2f}")

Method       Time (ms) 
-------------------------
numpy        13.06     
prange       2246.61   
vectorized   14.13     
