Skip to content

Commit c01a6ab

Browse files
committed
Add simple version
1 parent b5bc07e commit c01a6ab

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

lectures/short_path.md

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ kernelspec:
1717

1818
## Overview
1919

20-
This lecture is the extended version of the [shortest path lecture](https://python.quantecon.org/short_path.html) using JAX.
20+
This lecture is the extended version of the [shortest path lecture](https://python.quantecon.org/short_path.html) using JAX. Please see that lecture for all background and notation.
2121

2222
Let's start by importing the libraries.
2323

@@ -75,10 +75,33 @@ Q = jnp.array([[inf, 1, 5, 3, inf, inf, inf],
7575

7676
Notice that the cost of staying still (on the principle diagonal) is set to
7777

78-
* jnp.inf for non-destination nodes --- moving on is required.
79-
* 0 for the destination node --- here is where we stop.
78+
* `jnp.inf` for non-destination nodes --- moving on is required.
79+
* `0` for the destination node --- here is where we stop.
8080

81-
Let's try with this example and see how we go:
81+
Let's try with this example using python `while` loop and some `jax` vectorized code:
82+
83+
```{code-cell} ipython3
84+
%%time
85+
86+
num_nodes = Q.shape[0]
87+
J = jnp.zeros(num_nodes)
88+
89+
max_iter = 500
90+
i = 0
91+
92+
while i < max_iter:
93+
next_J = jnp.min(Q + J, axis=1)
94+
if jnp.allclose(next_J, J):
95+
break
96+
else:
97+
J = next_J.copy()
98+
i += 1
99+
100+
print("The cost-to-go function is", J)
101+
```
102+
103+
104+
We can further optimize the above code by using [jax.lax.while_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html). The extra acceleration is due to the fact that the entire operation can be optimized by the JAX compiler and launched as a single kernel on the GPU.
82105

83106
```{code-cell} ipython3
84107
max_iter = 500
@@ -105,13 +128,30 @@ def cond_fun(values):
105128
return ~break_condition & (i < max_iter)
106129
```
107130

131+
132+
Let's see the timing for JIT compilation of the functions and runtime results.
133+
134+
```{code-cell} ipython3
135+
%%time
136+
137+
jax.lax.while_loop(cond_fun, body_fun, init_val=(0, J, False))[1]
138+
```
139+
140+
141+
Now, this runs faster once we have the JIT compiled JAX version of the functions.
142+
108143
```{code-cell} ipython3
109144
%%time
110145
111146
jax.lax.while_loop(cond_fun, body_fun, init_val=(0, J, False))[1]
112147
```
113148

114149

150+
```{note}
151+
Large speed gains while using `jax.lax.while_loop` won't be realized unless the shortest path problem is relatively large.
152+
```
153+
154+
+++
115155

116156
## Exercises
117157

@@ -238,7 +278,6 @@ node98, node99 0.33
238278
node99,
239279
```
240280

241-
242281
```{exercise-end}
243282
```
244283

@@ -302,7 +341,7 @@ def compute_cost_to_go(Q):
302341
```
303342

304343

305-
Finally, here's a function that uses the cost-to-go function to obtain the
344+
Finally, here's a function that uses the `cost-to-go` function to obtain the
306345
optimal path (and its cost).
307346

308347
```{code-cell} ipython3
@@ -350,6 +389,3 @@ The total cost of the path should agree with $J[0]$ so let's check this.
350389
```{code-cell} ipython3
351390
J[0].item()
352391
```
353-
354-
```{solution-end}
355-
```

0 commit comments

Comments
 (0)