Skip to content

Commit

Permalink
Uncomment the TPU_DRIVER_MODE cell
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Jun 23, 2020
1 parent 769df92 commit d5c6160
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions docs/notebooks/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@
"colab_type": "text"
},
"source": [
"### Optional: enable Cloud TPUs if you're running this in Google Colab"
"### Enable Cloud TPUs if you're running this in Google Colab"
]
},
{
Expand All @@ -943,7 +943,7 @@
"colab_type": "text"
},
"source": [
"You can use Google's TPUs in this notebook to take advantage of JAX NumPy and `pmap`. First, change your Colab runtime by going to **Edit** > **Notebook settings** and selecting **Hardware acceleration: TPU**. Then, run the cell below to further configure the TPU support. Your output should have the Colab TPU IP address (gRPC) and port number."
"You can use Google's TPUs in this notebook to take advantage of JAX NumPy and `pmap`. Click on **Open in Colab** at the top of this page is you're viewing this from the Read the Docs site. Then, change your Google Colab runtime by going to **Edit** > **Notebook settings** and selecting **Hardware acceleration: TPU**. Finally, run the cell below to further configure the TPU support. Your output should have the Colab TPU IP address (gRPC) and port number."
]
},
{
Expand All @@ -954,20 +954,20 @@
"colab": {}
},
"source": [
"# Run this inside Google Colab with TPU enabled\n",
"import requests\n",
"import os\n",
"\n",
"if 'TPU_DRIVER_MODE' not in globals():\n",
" url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n",
" resp = requests.post(url)\n",
" TPU_DRIVER_MODE = 1\n",
"\n",
"# This is required to use TPU Driver as JAX's backend\n",
"from jax.config import config\n",
"config.FLAGS.jax_xla_backend = \"tpu_driver\"\n",
"config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n",
"print(config.FLAGS.jax_backend_target)"
"## Uncomment and run this inside Google Colab with TPU enabled\n",
"# import requests\n",
"# import os\n",
"\n",
"# if 'TPU_DRIVER_MODE' not in globals():\n",
"# url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n",
"# resp = requests.post(url)\n",
"# TPU_DRIVER_MODE = 1\n",
"\n",
"## This is required to use TPU Driver as JAX's backend\n",
"# from jax.config import config\n",
"# config.FLAGS.jax_xla_backend = \"tpu_driver\"\n",
"# config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n",
"# print(config.FLAGS.jax_backend_target)"
],
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -1187,6 +1187,16 @@
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xd5nGzbUPEi2",
"colab_type": "text"
},
"source": [
"The code here is simple. For a neural network example where you can do some data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down

0 comments on commit d5c6160

Please sign in to comment.