You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Improve parallel Numba examples with race condition explanation (#436)
* Improve parallel Numba examples with race condition explanation
Added demonstration of incorrect parallel implementation to teach thread
safety concepts, with detailed explanation of race conditions and how to
avoid them.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
* Fix code errors and grammar in numpy_vs_numba_vs_jax lecture
- Fix incorrect function name: compute_max_numba_parallel_nested → compute_max_numba_parallel
- Fix incorrect variable name: z_vmap → z_max
- Fix grammar: "similar to as" → "similar to"
- Fix technical description: lax.scan calls update function, not qm_jax
All fixes verified by converting to Python and running successfully with ipython.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
---------
Co-authored-by: Claude <noreply@anthropic.com>
The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`.
236
+
237
+
The reason is that the variable $m$ is shared across threads and not properly controlled.
238
+
239
+
When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition.
240
+
241
+
This results in lost updates—threads read stale values of `m` or overwrite each other's updates—and the variable often never gets updated from its initial value of `-inf`.
242
+
243
+
Here's a more carefully written version.
232
244
233
245
```{code-cell} ipython3
234
246
@numba.jit(parallel=True)
235
-
def compute_max_numba_parallel_nested(grid):
247
+
def compute_max_numba_parallel(grid):
236
248
n = len(grid)
237
-
m = -np.inf
249
+
row_maxes = np.empty(n)
238
250
for i in numba.prange(n):
239
-
for j in numba.prange(n):
251
+
row_max = -np.inf
252
+
for j in range(n):
240
253
x = grid[i]
241
254
y = grid[j]
242
255
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
243
-
if z > m:
244
-
m = z
245
-
return m
256
+
if z > row_max:
257
+
row_max = z
258
+
row_maxes[i] = row_max
259
+
return np.max(row_maxes)
260
+
```
246
261
262
+
Now the code block that `for i in numba.prange(n)` acts over is independent
263
+
across `i`.
264
+
265
+
Each thread writes to a separate element of the array `row_maxes`.
266
+
267
+
Hence the parallelization is safe.
268
+
269
+
Here's the timings.
270
+
271
+
```{code-cell} ipython3
247
272
with qe.Timer(precision=8):
248
-
compute_max_numba_parallel_nested(grid)
273
+
compute_max_numba_parallel(grid)
249
274
```
250
275
251
276
```{code-cell} ipython3
252
277
with qe.Timer(precision=8):
253
-
compute_max_numba_parallel_nested(grid)
278
+
compute_max_numba_parallel(grid)
254
279
```
255
280
281
+
If you have multiple cores, you should see at least some benefits from parallelization here.
256
282
257
-
Depending on your machine, you might or might not see large benefits from parallelization here.
258
-
259
-
If you have a small number of cores, the overhead of thread management and synchronization can
260
-
overwhelm the benefits of parallel execution.
261
-
262
-
For more powerful machines and larger grid sizes, parallelization can generate
263
-
large speed gains.
264
-
283
+
For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU.
0 commit comments