Skip to content

Commit 287849f

Browse files
committed
Add a cell to check JAX backend
1 parent c01a6ab commit 287849f

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

lectures/short_path.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ kernelspec:
1212
---
1313

1414

15+
1516
# Shortest Paths
1617

1718

@@ -28,6 +29,17 @@ import jax
2829
```
2930

3031

32+
33+
Let’s check the backend used by JAX and the devices available.
34+
35+
```{code-cell} ipython3
36+
# Check if JAX is using GPU
37+
print(f"JAX backend: {jax.devices()[0].platform}")
38+
# Check the devices available for JAX
39+
print(jax.devices())
40+
```
41+
42+
3143
## Solving for Minimum Cost-to-Go
3244

3345
Let $J(v)$ denote the minimum cost-to-go from node $v$,
@@ -73,6 +85,7 @@ Q = jnp.array([[inf, 1, 5, 3, inf, inf, inf],
7385
```
7486

7587

88+
7689
Notice that the cost of staying still (on the principle diagonal) is set to
7790

7891
* `jnp.inf` for non-destination nodes --- moving on is required.
@@ -101,6 +114,7 @@ print("The cost-to-go function is", J)
101114
```
102115

103116

117+
104118
We can further optimize the above code by using [jax.lax.while_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html). The extra acceleration is due to the fact that the entire operation can be optimized by the JAX compiler and launched as a single kernel on the GPU.
105119

106120
```{code-cell} ipython3
@@ -129,6 +143,7 @@ def cond_fun(values):
129143
```
130144

131145

146+
132147
Let's see the timing for JIT compilation of the functions and runtime results.
133148

134149
```{code-cell} ipython3
@@ -138,6 +153,7 @@ jax.lax.while_loop(cond_fun, body_fun, init_val=(0, J, False))[1]
138153
```
139154

140155

156+
141157
Now, this runs faster once we have the JIT compiled JAX version of the functions.
142158

143159
```{code-cell} ipython3
@@ -147,6 +163,7 @@ jax.lax.while_loop(cond_fun, body_fun, init_val=(0, J, False))[1]
147163
```
148164

149165

166+
150167
```{note}
151168
Large speed gains while using `jax.lax.while_loop` won't be realized unless the shortest path problem is relatively large.
152169
```
@@ -311,6 +328,7 @@ def map_graph_to_distance_matrix(in_file):
311328
```
312329

313330

331+
314332
Let's write a function `compute_cost_to_go` that returns $J$ given any valid $Q$.
315333

316334
```{code-cell} ipython3
@@ -341,6 +359,7 @@ def compute_cost_to_go(Q):
341359
```
342360

343361

362+
344363
Finally, here's a function that uses the `cost-to-go` function to obtain the
345364
optimal path (and its cost).
346365

@@ -359,13 +378,15 @@ def print_best_path(J, Q):
359378
```
360379

361380

381+
362382
Okay, now we have the necessary functions, let's call them to do the job we were assigned.
363383

364384
```{code-cell} ipython3
365385
Q = map_graph_to_distance_matrix('graph.txt')
366386
```
367387

368388

389+
369390
Let's see the timings for jitting the function and runtime results.
370391

371392
```{code-cell} ipython3
@@ -384,6 +405,7 @@ print_best_path(J, Q)
384405
```
385406

386407

408+
387409
The total cost of the path should agree with $J[0]$ so let's check this.
388410

389411
```{code-cell} ipython3

0 commit comments

Comments
 (0)