Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 83 additions & 61 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ grid = np.linspace(-3, 3, 3_000)
x, y = np.meshgrid(grid, grid)

with qe.Timer(precision=8):
np.max(f(x, y))
z_max_numpy = np.max(f(x, y))

print(f"NumPy result: {z_max_numpy}")
```

In the vectorized version, all the looping takes place in compiled code.
Expand Down Expand Up @@ -190,9 +192,11 @@ with qe.Timer(precision=8):
```


Depending on your machine, the Numba version can be a bit slower or a bit faster than NumPy.
Depending on your machine, the Numba version can be a bit slower or a bit faster
than NumPy.

On one hand, NumPy combines efficient arithmetic (like Numba) with some multithreading (unlike this Numba code), which provides an advantage.
On one hand, NumPy combines efficient arithmetic (like Numba) with some
multithreading (unlike this Numba code), which provides an advantage.

On the other hand, the Numba routine uses much less memory, since we are only
working with a single one-dimensional grid.
Expand All @@ -202,7 +206,7 @@ working with a single one-dimensional grid.

Now let's try parallelization with Numba using `prange`:

First we parallelize just the outer loop.
Here's a naive and *incorrect* attempt.

```{code-cell} ipython3
@numba.jit(parallel=True)
Expand All @@ -218,50 +222,65 @@ def compute_max_numba_parallel(grid):
m = z
return m

with qe.Timer(precision=8):
compute_max_numba_parallel(grid)
```

Usually this returns an incorrect result:

```{code-cell} ipython3
with qe.Timer(precision=8):
compute_max_numba_parallel(grid)
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
print(f"Incorrect parallel Numba result: {z_max_parallel_incorrect}")
print(f"NumPy result: {z_max_numpy}")
```

Next we parallelize both loops.
The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`.

The reason is that the variable $m$ is shared across threads and not properly controlled.

When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition.

This results in lost updates—threads read stale values of `m` or overwrite each other's updates—and the variable often never gets updated from its initial value of `-inf`.

Here's a more carefully written version.

```{code-cell} ipython3
@numba.jit(parallel=True)
def compute_max_numba_parallel_nested(grid):
def compute_max_numba_parallel(grid):
n = len(grid)
m = -np.inf
row_maxes = np.empty(n)
for i in numba.prange(n):
for j in numba.prange(n):
row_max = -np.inf
for j in range(n):
x = grid[i]
y = grid[j]
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
if z > m:
m = z
return m
if z > row_max:
row_max = z
row_maxes[i] = row_max
return np.max(row_maxes)
```

Now the code block that `for i in numba.prange(n)` acts over is independent
across `i`.

Each thread writes to a separate element of the array `row_maxes`.

Hence the parallelization is safe.

Here's the timings.

```{code-cell} ipython3
with qe.Timer(precision=8):
compute_max_numba_parallel_nested(grid)
compute_max_numba_parallel(grid)
```

```{code-cell} ipython3
with qe.Timer(precision=8):
compute_max_numba_parallel_nested(grid)
compute_max_numba_parallel(grid)
```

If you have multiple cores, you should see at least some benefits from parallelization here.

Depending on your machine, you might or might not see large benefits from parallelization here.

If you have a small number of cores, the overhead of thread management and synchronization can
overwhelm the benefits of parallel execution.

For more powerful machines and larger grid sizes, parallelization can generate
large speed gains.

For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU.


### Vectorized code with JAX
Expand All @@ -288,14 +307,14 @@ grid = jnp.linspace(-3, 3, 3_000)
x_mesh, y_mesh = np.meshgrid(grid, grid)

with qe.Timer(precision=8):
z_mesh = f(x_mesh, y_mesh).block_until_ready()
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
```

Let's run again to eliminate compile time.

```{code-cell} ipython3
with qe.Timer(precision=8):
z_mesh = f(x_mesh, y_mesh).block_until_ready()
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
```

Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU.
Expand Down Expand Up @@ -331,75 +350,71 @@ Here's one way we can apply `vmap`.
```{code-cell} ipython3
# Set up f to compute f(x, y) at every x for any given y
f_vec_x = lambda y: f(grid, y)
# Vectorize this operation over all y
# Create a second function that vectorizes this operation over all y
f_vec = jax.vmap(f_vec_x)
# Compute result at all y
z_vmap = f_vec(grid)
```

Now `f_vec` will compute `f(x,y)` at every `x,y` when called with the flat array `grid`.

Let's see the timing:

```{code-cell} ipython3
with qe.Timer(precision=8):
z_vmap = f_vec(grid)
z_vmap.block_until_ready()
z_max = jnp.max(f_vec(grid))
z_max.block_until_ready()
```

Let's check we got the right result:


```{code-cell} ipython3
jnp.allclose(z_mesh, z_vmap)
with qe.Timer(precision=8):
z_max = jnp.max(f_vec(grid))
z_max.block_until_ready()
```

The execution time is similar to as the mesh operation but we are using much
less memory.
The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`,
we are using far less memory.

In addition, `vmap` allows us to break vectorization up into stages, which is
often easier to comprehend than the traditional approach.

This will become more obvious when we tackle larger problems.


#### Version 2
### vmap version 2

Here's a more generic approach to using `vmap` that we often use in the lectures.
We can be still more memory efficient using vmap.

First we vectorize in `y`.
While we avoided large input arrays in the preceding version,
we still create the large output array `f(x,y)` before we compute the max.

```{code-cell} ipython3
f_vec_y = jax.vmap(f, in_axes=(None, 0))
```

In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`.

Next, we vectorize in the first argument, which is `x`.
Let's use a slightly different approach that takes the max to the inside.

```{code-cell} ipython3
f_vec = jax.vmap(f_vec_y, in_axes=(0, None))
@jax.jit
def compute_max_vmap_v2(grid):
# Construct a function that takes the max along each row
f_vec_x_max = lambda y: jnp.max(f(grid, y))
# Vectorize the function so we can call on all rows simultaneously
f_vec_max = jax.vmap(f_vec_x_max)
# Call the vectorized function and take the max
return jnp.max(f_vec_max(grid))
```

With this construction, we can now call $f$ directly on flat (low memory) arrays.
Let's try it

```{code-cell} ipython3
x, y = grid, grid
with qe.Timer(precision=8):
z_vmap = f_vec(x, y).block_until_ready()
z_max = compute_max_vmap_v2(grid).block_until_ready()
```


Let's run it again to eliminate compilation time:

```{code-cell} ipython3
with qe.Timer(precision=8):
z_vmap = f_vec(x, y).block_until_ready()
z_max = compute_max_vmap_v2(grid).block_until_ready()
```

Let's check we got the right result:


```{code-cell} ipython3
jnp.allclose(z_mesh, z_vmap)
```
We don't get much speed gain but we do save some memory.



Expand All @@ -414,7 +429,14 @@ Moreover, the `vmap` approach can sometimes lead to significantly clearer code.
While Numba is impressive, the beauty of JAX is that, with fully vectorized
operations, we can run exactly the
same code on machines with hardware accelerators and reap all the benefits
without paying extra cost.
without extra effort.

Moreover, JAX already knows how to effectively parallelize many common array
operations, which is key to fast execution.

For almost all cases encountered in economics, econometrics, and finance, it is
far better to hand over to the JAX compiler for efficient parallelization than to
try to hand code these routines ourselves.


## Sequential operations
Expand Down Expand Up @@ -485,7 +507,7 @@ def qm_jax(x0, n, α=4.0):
return jnp.concatenate([jnp.array([x0]), x])
```

This code is not easy to read but, in essence, `lax.scan` repeatedly calls `qm_jax` and accumulates the returns `x_new` into an array.
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.

Let's time it with the same parameters:

Expand Down
Loading