Skip to content

Conversation

@jstac
Copy link
Contributor

@jstac jstac commented Nov 26, 2025

Summary

Enhanced the NumPy vs Numba vs JAX lecture by improving the parallel Numba section to better teach thread safety concepts through practical examples.

Changes

This PR improves the pedagogical value of the parallel programming examples by:

  1. Adding incorrect parallel implementation example - Shows a naive attempt at parallelization that fails due to race conditions
  2. Demonstrating the bug - Includes code that runs the incorrect version and shows it typically returns -inf instead of the correct maximum
  3. Explaining race conditions - Added detailed explanation of why the shared variable m causes thread interference and lost updates
  4. Improved correct implementation - Refactored to use row-wise maxima array for thread-safe parallelization
  5. Better code organization - Both incorrect and correct versions now use the same function name for clarity
  6. Minor improvements - Added result storage/printing for NumPy benchmark, improved text formatting

Test plan

  • Build documentation and verify all code cells execute correctly
  • Verify the incorrect parallel example demonstrates the race condition bug
  • Confirm the correct parallel example produces accurate results

🤖 Generated with Claude Code

Added demonstration of incorrect parallel implementation to teach thread
safety concepts, with detailed explanation of race conditions and how to
avoid them.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@jstac
Copy link
Contributor Author

jstac commented Nov 26, 2025

Detailed Changes to lectures/numpy_vs_numba_vs_jax.md

1. NumPy Benchmark Enhancement (lines 143-146)

Before:

with qe.Timer(precision=8):
    np.max(f(x, y))

After:

with qe.Timer(precision=8):
    z_max_numpy = np.max(f(x, y))

print(f"NumPy result: {z_max_numpy}")
  • Store the result for later comparison
  • Print the value for verification

2. Incorrect Parallel Implementation (lines 206-233)

Changed: "First we parallelize just the outer loop." → "Here's a naive and incorrect attempt."

Added: Demonstration that shows the bug in action:

z_max_parallel_incorrect = compute_max_numba_parallel(grid)
print(f"Incorrect parallel Numba result: {z_max_parallel_incorrect}")
print(f"NumPy result: {z_max_numpy}")

Added: Detailed race condition explanation:

  • Why the incorrect version typically returns -inf
  • Explanation of shared variable m causing thread interference
  • Description of how race conditions lead to lost updates

3. Correct Parallel Implementation (lines 235-268)

Refactored from compute_max_numba_parallel_nested to compute_max_numba_parallel:

Key changes:

  • Uses row_maxes = np.empty(n) array instead of shared m variable
  • Each thread computes row_max independently for its row
  • Only parallelizes outer loop (prange on i, regular range on j)
  • Each thread writes to separate array element (thread-safe)

Added explanation:

  • Why the code block is independent across i
  • How separate array elements prevent race conditions
  • Clear statement that "the parallelization is safe"

4. Minor Text Improvements

  • Line wrapping for readability (lines 193-198)
  • Simplified performance expectations (line 281-283)
  • Changed "large speed gains" to "major speed gains" for clarity

5. JAX Section Updates (lines 309-357)

  • Changed variable name from z_mesh to z_max for consistency
  • Improved clarity in vmap example comments

All changes enhance the educational value by explicitly showing what not to do and why, making the correct approach more understandable.

- Fix incorrect function name: compute_max_numba_parallel_nested → compute_max_numba_parallel
- Fix incorrect variable name: z_vmap → z_max
- Fix grammar: "similar to as" → "similar to"
- Fix technical description: lax.scan calls update function, not qm_jax

All fixes verified by converting to Python and running successfully with ipython.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@jstac
Copy link
Contributor Author

jstac commented Nov 26, 2025

Fixed several code errors and a grammar issue found during testing:

  1. Fixed undefined function calls: Changed compute_max_numba_parallel_nested to compute_max_numba_parallel (the _nested version was never defined)
  2. Fixed undefined variable: Changed z_vmap to z_max in JAX timing blocks
  3. Fixed grammar: Changed "similar to as the mesh operation" to "similar to the mesh operation"
  4. Fixed technical inaccuracy: Corrected description to say lax.scan calls the update function, not qm_jax

All fixes have been verified by converting the notebook to Python using jupytext and running it successfully with ipython.

@github-actions
Copy link

@github-actions github-actions bot temporarily deployed to pull request November 26, 2025 11:26 Inactive
@jstac jstac merged commit ce3de00 into main Nov 26, 2025
5 checks passed
@jstac jstac deleted the nvnvj branch November 26, 2025 11:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants