# 4. Accelerating Python Code with Numba

## The Economist's Need for Speed

Economic models, especially those involving simulation, optimization, or repeated estimation, are often computationally intensive. A common bottleneck is the execution speed of pure Python code. While Python is lauded for its readability and ease of use, it is an interpreted language, and its loops can be orders of magnitude slower than compiled languages like C, C++, or Fortran.

Traditionally, overcoming this involved complex workflows: writing performance-critical code in a low-level language, compiling it, and then writing Python "wrappers" to call it. This process is time-consuming and requires multi-language expertise.

**Numba** changes this paradigm. Numba is a **Just-In-Time (JIT) compiler** that translates a subset of Python and NumPy code into fast, native machine code. It allows you to achieve performance comparable to C or Fortran without ever leaving the Python ecosystem.

In this notebook, you will learn:
- What JIT compilation is and how it works.
- How to use Numba's decorators to accelerate your functions with a single line of code.
- How to benchmark and quantify the performance gains.
- Best practices for using Numba effectively.

### What is a Decorator?

Before diving into Numba, it's useful to understand the concept of a **decorator**. In Python, a decorator is a special syntax that allows you to modify or enhance a function without permanently changing its code. It's a function that takes another function as input and returns a modified function as output. You'll recognize them by the `@` symbol placed directly above a function definition.

For example, `@njit` is a decorator provided by the Numba library. When you write:
```python
@njit
def my_function(x):
    return x * 2
```
You are essentially telling Python to do this:
```python
def my_function(x):
    return x * 2
my_function = njit(my_function)
```
The `@njit` decorator takes your original `my_function` and replaces it with a new, compiled, and highly optimized version.

In [None]:
import random
import numpy as np
import timeit
from numba import njit, prange

### A Classic Example: The Monte Carlo Pi Simulation

A simple way to demonstrate Numba's power is with a Monte Carlo simulation to estimate \( \pi \). The logic is as follows:
1. Imagine a square with side length 2, centered at the origin. Its area is 4.
2. Inscribe a circle with radius 1 within this square. Its area is \( \pi r^2 = \pi \).
3. Generate a large number of random points \( (x, y) \) within the square.
4. The proportion of points that fall inside the circle should be equal to the ratio of the circle's area to the square's area: \( \frac{\text{Points in Circle}}{\text{Total Points}} \approx \frac{\pi}{4} \).
5. Therefore, \( \pi \approx 4 \times \frac{\text{Points in Circle}}{\text{Total Points}} \).

A point \( (x, y) \) is inside the circle if \( x^2 + y^2 < 1 \). This requires a loop, which is notoriously slow in pure Python.

#### Pure Python Implementation

In [None]:
def monte_carlo_pi_python(num_samples):
    acc = 0
    for _ in range(num_samples):
        x = random.random()
        y = random.random()
        if (x**2 + y**2) < 1.0:
            acc += 1
    return 4.0 * acc / num_samples

Now, let's benchmark this function. For reproducibility, it's better to use the `timeit` module directly rather than the IPython magic `%timeit`.

In [None]:
num_samples = 10_000_000
py_time = timeit.timeit(lambda: monte_carlo_pi_python(num_samples), number=1)
print(f"Pure Python execution time: {py_time:.4f} seconds")

#### Numba Implementation

To accelerate this function, we apply the `@njit` decorator. `njit` stands for **"no-python JIT,"** which is Numba's highest-performance mode. When the function is first called, Numba infers the data types of the variables (e.g., `num_samples` is an integer, `acc` is an integer, `x` and `y` are floats) and uses this information to compile a version of the function in machine code, tailored to these types.

A key practice for Numba is to use NumPy functions where possible, as Numba is optimized for them. We will replace `random.random()` with `np.random.rand()`, which is much faster inside a Numba-compiled function.

In [None]:
@njit(cache=True)
def monte_carlo_pi_numba(num_samples):
    acc = 0
    for _ in range(num_samples):
        # Use NumPy's random number generator for better performance with Numba
        x = np.random.rand()
        y = np.random.rand()
        if (x**2 + y**2) < 1.0:
            acc += 1
    return 4.0 * acc / num_samples

Let's benchmark the Numba-compiled version. The first run has a slight overhead for compilation. We add `cache=True` to the decorator to save the compiled code to a file, avoiding this overhead in future sessions.

In [None]:
# The first run compiles the function
print(f"First run result: {monte_carlo_pi_numba(num_samples)}")

# Now let's time it
numba_time = timeit.timeit(lambda: monte_carlo_pi_numba(num_samples), number=1)
print(f"Numba execution time: {numba_time:.4f} seconds")
print(f"Speedup: {py_time / numba_time:.1f}x")

You should observe a very large speedup. This is the power of JIT compilation. You've achieved C-like speed with a single line of Python code.

### Numba with NumPy

Numba is specifically designed to work well with NumPy arrays and functions. When Numba compiles code that uses NumPy arrays, it generates specialized, fast code that can operate directly on the underlying data buffers, avoiding the overhead of Python's object model.

In [None]:
def sum_of_squares_python(arr):
    total = 0.0
    for i in range(arr.shape[0]):
        total += arr[i] ** 2
    return total

@njit(cache=True)
def sum_of_squares_numba(arr):
    total = 0.0
    for i in range(arr.shape[0]):
        total += arr[i] ** 2
    return total

my_array = np.random.randn(10_000_000)

In [None]:
py_time_sos = timeit.timeit(lambda: sum_of_squares_python(my_array), number=1)
print(f"Python sum of squares time: {py_time_sos:.4f}s")

In [None]:
# Warm-up run for compilation
sum_of_squares_numba(my_array)
numba_time_sos = timeit.timeit(lambda: sum_of_squares_numba(my_array), number=10)
numba_time_sos /= 10 # Average time
print(f"Numba sum of squares time: {numba_time_sos:.4f}s")
print(f"Speedup: {py_time_sos / numba_time_sos:.1f}x")

### Automatic Parallelization

Numba can also automatically parallelize some loops, allowing you to take advantage of multi-core CPUs with minimal effort. By adding the `parallel=True` argument to the decorator, you can instruct Numba to attempt to parallelize the function. You can then use `numba.prange` to mark loops that are safe to run in parallel.

In [None]:
@njit(parallel=True, cache=True)
def sum_of_squares_parallel(arr):
    total = 0.0
    # prange indicates this loop can be parallelized
    for i in prange(arr.shape[0]):
        total += arr[i] ** 2
    return total

In [None]:
# Warm-up run for compilation
sum_of_squares_parallel(my_array)
parallel_time_sos = timeit.timeit(lambda: sum_of_squares_parallel(my_array), number=10)
parallel_time_sos /= 10 # Average time
print(f"Parallel Numba time: {parallel_time_sos:.4f}s")
print(f"Speedup vs. Python: {py_time_sos / parallel_time_sos:.1f}x")

On a multi-core machine, you should see another significant speedup over the serial Numba version.

## Common Pitfalls and Best Practices

Numba is powerful, but it's not a magic bullet. To use it effectively, it helps to know its limitations.

- **Isolate Numerical Code:** Numba works best on functions that are purely numerical and contain loops. It does not support many standard Python data structures or libraries. For instance, **you cannot use pandas DataFrames or dictionaries inside a function decorated with `@njit`**. The best practice is to extract your data from DataFrames into NumPy arrays first, then pass those arrays to a specialized Numba function.

- **Type Stability:** Numba is fastest when the types of variables inside a function do not change. Code like `x = 1` followed later by `x = 1.0` can slow down Numba's optimizations.

- **Check Compilation Mode:** If Numba cannot compile a function in `nopython` mode, it might fall back to a slower "object mode," which shows little to no speedup. You can ensure Numba raises an error if it can't use `nopython` mode by using `@njit(nopython=True)`.

## Summary

Numba is an essential tool for any computational economist. It provides a remarkably simple and powerful way to break through the performance barriers of pure Python. By using the `@njit` decorator on functions that contain computationally-heavy loops over numerical data, you can often achieve speedups of 100x or more, turning a long computation into a near-instantaneous one. This allows for more complex models, more extensive simulations, and faster iteration in your research.