diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 55de222b..883f2d14 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -141,7 +141,9 @@ grid = np.linspace(-3, 3, 3_000) x, y = np.meshgrid(grid, grid) with qe.Timer(precision=8): - np.max(f(x, y)) + z_max_numpy = np.max(f(x, y)) + +print(f"NumPy result: {z_max_numpy}") ``` In the vectorized version, all the looping takes place in compiled code. @@ -190,9 +192,11 @@ with qe.Timer(precision=8): ``` -Depending on your machine, the Numba version can be a bit slower or a bit faster than NumPy. +Depending on your machine, the Numba version can be a bit slower or a bit faster +than NumPy. -On one hand, NumPy combines efficient arithmetic (like Numba) with some multithreading (unlike this Numba code), which provides an advantage. +On one hand, NumPy combines efficient arithmetic (like Numba) with some +multithreading (unlike this Numba code), which provides an advantage. On the other hand, the Numba routine uses much less memory, since we are only working with a single one-dimensional grid. @@ -202,7 +206,7 @@ working with a single one-dimensional grid. Now let's try parallelization with Numba using `prange`: -First we parallelize just the outer loop. +Here's a naive and *incorrect* attempt. ```{code-cell} ipython3 @numba.jit(parallel=True) @@ -218,50 +222,65 @@ def compute_max_numba_parallel(grid): m = z return m -with qe.Timer(precision=8): - compute_max_numba_parallel(grid) ``` +Usually this returns an incorrect result: ```{code-cell} ipython3 -with qe.Timer(precision=8): - compute_max_numba_parallel(grid) +z_max_parallel_incorrect = compute_max_numba_parallel(grid) +print(f"Incorrect parallel Numba result: {z_max_parallel_incorrect}") +print(f"NumPy result: {z_max_numpy}") ``` -Next we parallelize both loops. +The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`. + +The reason is that the variable $m$ is shared across threads and not properly controlled. + +When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition. + +This results in lost updates—threads read stale values of `m` or overwrite each other's updates—and the variable often never gets updated from its initial value of `-inf`. + +Here's a more carefully written version. ```{code-cell} ipython3 @numba.jit(parallel=True) -def compute_max_numba_parallel_nested(grid): +def compute_max_numba_parallel(grid): n = len(grid) - m = -np.inf + row_maxes = np.empty(n) for i in numba.prange(n): - for j in numba.prange(n): + row_max = -np.inf + for j in range(n): x = grid[i] y = grid[j] z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z - return m + if z > row_max: + row_max = z + row_maxes[i] = row_max + return np.max(row_maxes) +``` +Now the code block that `for i in numba.prange(n)` acts over is independent +across `i`. + +Each thread writes to a separate element of the array `row_maxes`. + +Hence the parallelization is safe. + +Here's the timings. + +```{code-cell} ipython3 with qe.Timer(precision=8): - compute_max_numba_parallel_nested(grid) + compute_max_numba_parallel(grid) ``` ```{code-cell} ipython3 with qe.Timer(precision=8): - compute_max_numba_parallel_nested(grid) + compute_max_numba_parallel(grid) ``` +If you have multiple cores, you should see at least some benefits from parallelization here. -Depending on your machine, you might or might not see large benefits from parallelization here. - -If you have a small number of cores, the overhead of thread management and synchronization can -overwhelm the benefits of parallel execution. - -For more powerful machines and larger grid sizes, parallelization can generate -large speed gains. - +For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU. ### Vectorized code with JAX @@ -288,14 +307,14 @@ grid = jnp.linspace(-3, 3, 3_000) x_mesh, y_mesh = np.meshgrid(grid, grid) with qe.Timer(precision=8): - z_mesh = f(x_mesh, y_mesh).block_until_ready() + z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready() ``` Let's run again to eliminate compile time. ```{code-cell} ipython3 with qe.Timer(precision=8): - z_mesh = f(x_mesh, y_mesh).block_until_ready() + z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready() ``` Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU. @@ -331,29 +350,28 @@ Here's one way we can apply `vmap`. ```{code-cell} ipython3 # Set up f to compute f(x, y) at every x for any given y f_vec_x = lambda y: f(grid, y) -# Vectorize this operation over all y +# Create a second function that vectorizes this operation over all y f_vec = jax.vmap(f_vec_x) -# Compute result at all y -z_vmap = f_vec(grid) ``` +Now `f_vec` will compute `f(x,y)` at every `x,y` when called with the flat array `grid`. + Let's see the timing: ```{code-cell} ipython3 with qe.Timer(precision=8): - z_vmap = f_vec(grid) - z_vmap.block_until_ready() + z_max = jnp.max(f_vec(grid)) + z_max.block_until_ready() ``` -Let's check we got the right result: - - ```{code-cell} ipython3 -jnp.allclose(z_mesh, z_vmap) +with qe.Timer(precision=8): + z_max = jnp.max(f_vec(grid)) + z_max.block_until_ready() ``` -The execution time is similar to as the mesh operation but we are using much -less memory. +The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`, +we are using far less memory. In addition, `vmap` allows us to break vectorization up into stages, which is often easier to comprehend than the traditional approach. @@ -361,45 +379,42 @@ often easier to comprehend than the traditional approach. This will become more obvious when we tackle larger problems. -#### Version 2 +### vmap version 2 -Here's a more generic approach to using `vmap` that we often use in the lectures. +We can be still more memory efficient using vmap. -First we vectorize in `y`. +While we avoided large input arrays in the preceding version, +we still create the large output array `f(x,y)` before we compute the max. -```{code-cell} ipython3 -f_vec_y = jax.vmap(f, in_axes=(None, 0)) -``` - -In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`. - -Next, we vectorize in the first argument, which is `x`. +Let's use a slightly different approach that takes the max to the inside. ```{code-cell} ipython3 -f_vec = jax.vmap(f_vec_y, in_axes=(0, None)) +@jax.jit +def compute_max_vmap_v2(grid): + # Construct a function that takes the max along each row + f_vec_x_max = lambda y: jnp.max(f(grid, y)) + # Vectorize the function so we can call on all rows simultaneously + f_vec_max = jax.vmap(f_vec_x_max) + # Call the vectorized function and take the max + return jnp.max(f_vec_max(grid)) ``` -With this construction, we can now call $f$ directly on flat (low memory) arrays. +Let's try it ```{code-cell} ipython3 -x, y = grid, grid with qe.Timer(precision=8): - z_vmap = f_vec(x, y).block_until_ready() + z_max = compute_max_vmap_v2(grid).block_until_ready() ``` + Let's run it again to eliminate compilation time: ```{code-cell} ipython3 with qe.Timer(precision=8): - z_vmap = f_vec(x, y).block_until_ready() + z_max = compute_max_vmap_v2(grid).block_until_ready() ``` -Let's check we got the right result: - - -```{code-cell} ipython3 -jnp.allclose(z_mesh, z_vmap) -``` +We don't get much speed gain but we do save some memory. @@ -414,7 +429,14 @@ Moreover, the `vmap` approach can sometimes lead to significantly clearer code. While Numba is impressive, the beauty of JAX is that, with fully vectorized operations, we can run exactly the same code on machines with hardware accelerators and reap all the benefits -without paying extra cost. +without extra effort. + +Moreover, JAX already knows how to effectively parallelize many common array +operations, which is key to fast execution. + +For almost all cases encountered in economics, econometrics, and finance, it is +far better to hand over to the JAX compiler for efficient parallelization than to +try to hand code these routines ourselves. ## Sequential operations @@ -485,7 +507,7 @@ def qm_jax(x0, n, α=4.0): return jnp.concatenate([jnp.array([x0]), x]) ``` -This code is not easy to read but, in essence, `lax.scan` repeatedly calls `qm_jax` and accumulates the returns `x_new` into an array. +This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array. Let's time it with the same parameters: