You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lectures/short_path.md
+45-9Lines changed: 45 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,7 +17,7 @@ kernelspec:
17
17
18
18
## Overview
19
19
20
-
This lecture is the extended version of the [shortest path lecture](https://python.quantecon.org/short_path.html) using JAX.
20
+
This lecture is the extended version of the [shortest path lecture](https://python.quantecon.org/short_path.html) using JAX. Please see that lecture for all background and notation.
Notice that the cost of staying still (on the principle diagonal) is set to
77
77
78
-
* jnp.inf for non-destination nodes --- moving on is required.
79
-
*0 for the destination node --- here is where we stop.
78
+
*`jnp.inf` for non-destination nodes --- moving on is required.
79
+
*`0` for the destination node --- here is where we stop.
80
80
81
-
Let's try with this example and see how we go:
81
+
Let's try with this example using python `while` loop and some `jax` vectorized code:
82
+
83
+
```{code-cell} ipython3
84
+
%%time
85
+
86
+
num_nodes = Q.shape[0]
87
+
J = jnp.zeros(num_nodes)
88
+
89
+
max_iter = 500
90
+
i = 0
91
+
92
+
while i < max_iter:
93
+
next_J = jnp.min(Q + J, axis=1)
94
+
if jnp.allclose(next_J, J):
95
+
break
96
+
else:
97
+
J = next_J.copy()
98
+
i += 1
99
+
100
+
print("The cost-to-go function is", J)
101
+
```
102
+
103
+
104
+
We can further optimize the above code by using [jax.lax.while_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html). The extra acceleration is due to the fact that the entire operation can be optimized by the JAX compiler and launched as a single kernel on the GPU.
82
105
83
106
```{code-cell} ipython3
84
107
max_iter = 500
@@ -105,13 +128,30 @@ def cond_fun(values):
105
128
return ~break_condition & (i < max_iter)
106
129
```
107
130
131
+
132
+
Let's see the timing for JIT compilation of the functions and runtime results.
0 commit comments