Skip to content

Commit ce3de00

Browse files
jstacclaude
andauthored
Improve parallel Numba examples with race condition explanation (#436)
* Improve parallel Numba examples with race condition explanation Added demonstration of incorrect parallel implementation to teach thread safety concepts, with detailed explanation of race conditions and how to avoid them. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix code errors and grammar in numpy_vs_numba_vs_jax lecture - Fix incorrect function name: compute_max_numba_parallel_nested → compute_max_numba_parallel - Fix incorrect variable name: z_vmap → z_max - Fix grammar: "similar to as" → "similar to" - Fix technical description: lax.scan calls update function, not qm_jax All fixes verified by converting to Python and running successfully with ipython. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent f0c480e commit ce3de00

File tree

1 file changed

+83
-61
lines changed

1 file changed

+83
-61
lines changed

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 83 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ grid = np.linspace(-3, 3, 3_000)
141141
x, y = np.meshgrid(grid, grid)
142142
143143
with qe.Timer(precision=8):
144-
np.max(f(x, y))
144+
z_max_numpy = np.max(f(x, y))
145+
146+
print(f"NumPy result: {z_max_numpy}")
145147
```
146148

147149
In the vectorized version, all the looping takes place in compiled code.
@@ -190,9 +192,11 @@ with qe.Timer(precision=8):
190192
```
191193

192194

193-
Depending on your machine, the Numba version can be a bit slower or a bit faster than NumPy.
195+
Depending on your machine, the Numba version can be a bit slower or a bit faster
196+
than NumPy.
194197

195-
On one hand, NumPy combines efficient arithmetic (like Numba) with some multithreading (unlike this Numba code), which provides an advantage.
198+
On one hand, NumPy combines efficient arithmetic (like Numba) with some
199+
multithreading (unlike this Numba code), which provides an advantage.
196200

197201
On the other hand, the Numba routine uses much less memory, since we are only
198202
working with a single one-dimensional grid.
@@ -202,7 +206,7 @@ working with a single one-dimensional grid.
202206

203207
Now let's try parallelization with Numba using `prange`:
204208

205-
First we parallelize just the outer loop.
209+
Here's a naive and *incorrect* attempt.
206210

207211
```{code-cell} ipython3
208212
@numba.jit(parallel=True)
@@ -218,50 +222,65 @@ def compute_max_numba_parallel(grid):
218222
m = z
219223
return m
220224
221-
with qe.Timer(precision=8):
222-
compute_max_numba_parallel(grid)
223225
```
224226

227+
Usually this returns an incorrect result:
225228

226229
```{code-cell} ipython3
227-
with qe.Timer(precision=8):
228-
compute_max_numba_parallel(grid)
230+
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
231+
print(f"Incorrect parallel Numba result: {z_max_parallel_incorrect}")
232+
print(f"NumPy result: {z_max_numpy}")
229233
```
230234

231-
Next we parallelize both loops.
235+
The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`.
236+
237+
The reason is that the variable $m$ is shared across threads and not properly controlled.
238+
239+
When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition.
240+
241+
This results in lost updates—threads read stale values of `m` or overwrite each other's updates—and the variable often never gets updated from its initial value of `-inf`.
242+
243+
Here's a more carefully written version.
232244

233245
```{code-cell} ipython3
234246
@numba.jit(parallel=True)
235-
def compute_max_numba_parallel_nested(grid):
247+
def compute_max_numba_parallel(grid):
236248
n = len(grid)
237-
m = -np.inf
249+
row_maxes = np.empty(n)
238250
for i in numba.prange(n):
239-
for j in numba.prange(n):
251+
row_max = -np.inf
252+
for j in range(n):
240253
x = grid[i]
241254
y = grid[j]
242255
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
243-
if z > m:
244-
m = z
245-
return m
256+
if z > row_max:
257+
row_max = z
258+
row_maxes[i] = row_max
259+
return np.max(row_maxes)
260+
```
246261

262+
Now the code block that `for i in numba.prange(n)` acts over is independent
263+
across `i`.
264+
265+
Each thread writes to a separate element of the array `row_maxes`.
266+
267+
Hence the parallelization is safe.
268+
269+
Here's the timings.
270+
271+
```{code-cell} ipython3
247272
with qe.Timer(precision=8):
248-
compute_max_numba_parallel_nested(grid)
273+
compute_max_numba_parallel(grid)
249274
```
250275

251276
```{code-cell} ipython3
252277
with qe.Timer(precision=8):
253-
compute_max_numba_parallel_nested(grid)
278+
compute_max_numba_parallel(grid)
254279
```
255280

281+
If you have multiple cores, you should see at least some benefits from parallelization here.
256282

257-
Depending on your machine, you might or might not see large benefits from parallelization here.
258-
259-
If you have a small number of cores, the overhead of thread management and synchronization can
260-
overwhelm the benefits of parallel execution.
261-
262-
For more powerful machines and larger grid sizes, parallelization can generate
263-
large speed gains.
264-
283+
For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU.
265284

266285

267286
### Vectorized code with JAX
@@ -288,14 +307,14 @@ grid = jnp.linspace(-3, 3, 3_000)
288307
x_mesh, y_mesh = np.meshgrid(grid, grid)
289308
290309
with qe.Timer(precision=8):
291-
z_mesh = f(x_mesh, y_mesh).block_until_ready()
310+
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
292311
```
293312

294313
Let's run again to eliminate compile time.
295314

296315
```{code-cell} ipython3
297316
with qe.Timer(precision=8):
298-
z_mesh = f(x_mesh, y_mesh).block_until_ready()
317+
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
299318
```
300319

301320
Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU.
@@ -331,75 +350,71 @@ Here's one way we can apply `vmap`.
331350
```{code-cell} ipython3
332351
# Set up f to compute f(x, y) at every x for any given y
333352
f_vec_x = lambda y: f(grid, y)
334-
# Vectorize this operation over all y
353+
# Create a second function that vectorizes this operation over all y
335354
f_vec = jax.vmap(f_vec_x)
336-
# Compute result at all y
337-
z_vmap = f_vec(grid)
338355
```
339356

357+
Now `f_vec` will compute `f(x,y)` at every `x,y` when called with the flat array `grid`.
358+
340359
Let's see the timing:
341360

342361
```{code-cell} ipython3
343362
with qe.Timer(precision=8):
344-
z_vmap = f_vec(grid)
345-
z_vmap.block_until_ready()
363+
z_max = jnp.max(f_vec(grid))
364+
z_max.block_until_ready()
346365
```
347366

348-
Let's check we got the right result:
349-
350-
351367
```{code-cell} ipython3
352-
jnp.allclose(z_mesh, z_vmap)
368+
with qe.Timer(precision=8):
369+
z_max = jnp.max(f_vec(grid))
370+
z_max.block_until_ready()
353371
```
354372

355-
The execution time is similar to as the mesh operation but we are using much
356-
less memory.
373+
The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`,
374+
we are using far less memory.
357375

358376
In addition, `vmap` allows us to break vectorization up into stages, which is
359377
often easier to comprehend than the traditional approach.
360378

361379
This will become more obvious when we tackle larger problems.
362380

363381

364-
#### Version 2
382+
### vmap version 2
365383

366-
Here's a more generic approach to using `vmap` that we often use in the lectures.
384+
We can be still more memory efficient using vmap.
367385

368-
First we vectorize in `y`.
386+
While we avoided large input arrays in the preceding version,
387+
we still create the large output array `f(x,y)` before we compute the max.
369388

370-
```{code-cell} ipython3
371-
f_vec_y = jax.vmap(f, in_axes=(None, 0))
372-
```
373-
374-
In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`.
375-
376-
Next, we vectorize in the first argument, which is `x`.
389+
Let's use a slightly different approach that takes the max to the inside.
377390

378391
```{code-cell} ipython3
379-
f_vec = jax.vmap(f_vec_y, in_axes=(0, None))
392+
@jax.jit
393+
def compute_max_vmap_v2(grid):
394+
# Construct a function that takes the max along each row
395+
f_vec_x_max = lambda y: jnp.max(f(grid, y))
396+
# Vectorize the function so we can call on all rows simultaneously
397+
f_vec_max = jax.vmap(f_vec_x_max)
398+
# Call the vectorized function and take the max
399+
return jnp.max(f_vec_max(grid))
380400
```
381401

382-
With this construction, we can now call $f$ directly on flat (low memory) arrays.
402+
Let's try it
383403

384404
```{code-cell} ipython3
385-
x, y = grid, grid
386405
with qe.Timer(precision=8):
387-
z_vmap = f_vec(x, y).block_until_ready()
406+
z_max = compute_max_vmap_v2(grid).block_until_ready()
388407
```
389408

409+
390410
Let's run it again to eliminate compilation time:
391411

392412
```{code-cell} ipython3
393413
with qe.Timer(precision=8):
394-
z_vmap = f_vec(x, y).block_until_ready()
414+
z_max = compute_max_vmap_v2(grid).block_until_ready()
395415
```
396416

397-
Let's check we got the right result:
398-
399-
400-
```{code-cell} ipython3
401-
jnp.allclose(z_mesh, z_vmap)
402-
```
417+
We don't get much speed gain but we do save some memory.
403418

404419

405420

@@ -414,7 +429,14 @@ Moreover, the `vmap` approach can sometimes lead to significantly clearer code.
414429
While Numba is impressive, the beauty of JAX is that, with fully vectorized
415430
operations, we can run exactly the
416431
same code on machines with hardware accelerators and reap all the benefits
417-
without paying extra cost.
432+
without extra effort.
433+
434+
Moreover, JAX already knows how to effectively parallelize many common array
435+
operations, which is key to fast execution.
436+
437+
For almost all cases encountered in economics, econometrics, and finance, it is
438+
far better to hand over to the JAX compiler for efficient parallelization than to
439+
try to hand code these routines ourselves.
418440

419441

420442
## Sequential operations
@@ -485,7 +507,7 @@ def qm_jax(x0, n, α=4.0):
485507
return jnp.concatenate([jnp.array([x0]), x])
486508
```
487509

488-
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `qm_jax` and accumulates the returns `x_new` into an array.
510+
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
489511

490512
Let's time it with the same parameters:
491513

0 commit comments

Comments
 (0)