You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Notice that the cost of staying still (on the principle diagonal) is set to
77
90
78
91
*`jnp.inf` for non-destination nodes --- moving on is required.
@@ -101,6 +114,7 @@ print("The cost-to-go function is", J)
101
114
```
102
115
103
116
117
+
104
118
We can further optimize the above code by using [jax.lax.while_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html). The extra acceleration is due to the fact that the entire operation can be optimized by the JAX compiler and launched as a single kernel on the GPU.
105
119
106
120
```{code-cell} ipython3
@@ -129,6 +143,7 @@ def cond_fun(values):
129
143
```
130
144
131
145
146
+
132
147
Let's see the timing for JIT compilation of the functions and runtime results.
0 commit comments