diff --git a/lectures/_admonition/gpu.md b/lectures/_admonition/gpu.md new file mode 100644 index 00000000..dd218be3 --- /dev/null +++ b/lectures/_admonition/gpu.md @@ -0,0 +1,5 @@ +```{admonition} GPU +:class: warning + +This lecture is designed to run on a GPU. To use Google Colab's free GPUs, click the play icon top right, select Colab, and set the runtime to include a GPU. For local GPU setup, see the [JAX installation guide](https://github.com/google/jax). +``` diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index c9e1870f..adf3b80c 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -33,16 +33,7 @@ In addition to what's in Anaconda, this lecture will need the following librarie !pip install jax quantecon ``` -```{admonition} GPU -:class: warning - -This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and targets JAX for GPU programming. - -Free GPUs are available on Google Colab. -To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU. - -Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. -If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]` +```{include} _admonition/gpu.md ``` ## JAX as a NumPy Replacement diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 1fa83116..aa9c85a9 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -48,16 +48,7 @@ tags: [hide-output] !pip install quantecon jax ``` -```{admonition} GPU -:class: warning - -This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming. - -Free GPUs are available on Google Colab. -To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU. - -Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. -If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]` +```{include} _admonition/gpu.md ``` We will use the following imports.