# TPUs in Colab
**Authors**

* Gerardo Durán-Martín
* Mahmoud Soliman

Before start this tutorial, make sure to configure your session correctly.

### 1. First we authenticate GCP to our current session

In [16]:
from google.colab import auth
auth.authenticate_user()

### 2. Next, we install GCloud SDK

In [None]:
!curl -S https://sdk.cloud.google.com | bash

### 3. Finally, we initialise all the variables we will be using throughout this tutorial.

We will create a `.sh` file that must be called at every cell that begins with `%%bash` as follows:

```bash
%%bash
source /content/commands.sh
# ... rest of the commands
```

In [82]:
%%writefile commands.sh
gcloud="/root/google-cloud-sdk/bin/gcloud"
gtpu="gcloud alpha compute tpus tpu-vm"
instance_name="probml-01-gerdm" # Modify for your instance name 
tpu_zone="us-east1-d"
jax_install="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Overwriting commands.sh


# gcloud

This first section introduces the gloud command line. We can work in the cloud in one of two ways:

1. Using the command line (this tutorial)
2. Using the google cloud console ([console.cloud.google.com](https://console.cloud.google.com/))

## Setup

Our first step is to install `gcloud alpha`.

- Installing `gcloud alpha`

    We begin by installing the `gcloud alpha` command line. This will allow us to work with TPUs at Google cloud. Run the following command

In [65]:
%%bash
source /content/commands.sh

$gcloud components install alpha


All components are up to date.


Next, we set the project to `probml` 

In [None]:
%%bash
source /content/commands.sh

$gcloud config set project probml

- Verify installation

Finally, we verify that you've successfully installed `gcloud alpha` by running the following command. Make sure to have version `alpha 2021.06.25` or later.

In [83]:
%%bash
source /content/commands.sh

$gcloud -v 

Google Cloud SDK 351.0.0
alpha 2021.07.30
bq 2.0.70
core 2021.07.30
gsutil 4.66


# TPUS

## The basics

### Creating an instance

Each GSoC member obtains 8 v3-32 cores (or a Slice) when following the instructions outlined below.

To create our first TPU instance, we run the following command. Note that `instance_name` should be unique (it was defined at the top of this tutorial)

In [67]:
%%bash
source /content/commands.sh
$gtpu create $instance_name \
    --accelerator-type v3-32 \
    --version v2-alpha \
    --zone $tpu_zone

Create request issued for: [probml-01-gerdm]
Waiting for operation [projects/probml/locations/us-east1-d/operations/operation-1628065808121-5c8b79c2a006b-a528f872-851a3d0d] to complete...
.......................................................................................................................................................................................................................................................................................................................................................................................done.
Created tpu [probml-01-gerdm].


You can verify whether your instance has been created by running the following cell

In [68]:
%%bash
source /content/commands.sh
$gcloud alpha compute tpus list --zone $tpu_zone

NAME              ZONE        ACCELERATOR_TYPE  NETWORK  RANGE          STATUS  API_VERSION
probml-01-gerdm   us-east1-d  v3-32             default  10.142.0.0/20  READY   V2_ALPHA1
murphyk-tpu       us-east1-d  v3-32             default  10.142.0.0/20  READY   V2_ALPHA1
probml-05-srikar  us-east1-d  v3-32             default  10.142.0.0/20  READY   V2_ALPHA1
probml-00-mjsml   us-east1-d  v3-32             default  10.142.0.0/20  READY   V2_ALPHA1


### Deleting an instance

To avoid extra costs, it is important to delete the instance after use (training, testing experimenting, etc.).

To delete an instance, we create and run a cell with the following content

```bash
%%bash
source /content/commands.sh

$gtpu delete --quiet $instance_name --zone=$tpu_zone
```

**Make sure to delete your instance once you finish!!**

# Jax

### Installing Jax

When connecting to an instance directly via ssh, it is important to note that running any Jax command will wait for the other hosts to be active. To void this, we have to run the desired code simultaneously on all the hosts.

> To run JAX code on a TPU Pod slice, you must run the code **on each host in the TPU Pod slice.**

In the next cell, we install Jax on each host of our slice.

In [None]:
%%bash
source /content/commands.sh
$gtpu ssh $instance_name \
    --zone $tpu_zone \
    --command "$jax_install" \
    --worker all # or machine instance 1..3

### Example 1: Hello, TPUs!

In this example, we create a `hello_tpu.sh` that asserts whether we can connect to all of the hosts. First, we create the `.sh` file that will be run **in each of the workers**.

In [73]:
%%writefile hello_tpu.sh
#!/bin/bash
# file: hello_tpu.sh

export gist_url="https://gist.github.com/1e8d226e7a744d22d010ca4980456c3a.git"
git clone $gist_url hello_gsoc
python3 hello_gsoc/hello_tpu.py

Writing hello_tpu.sh


The content of `$gist_url` is the following

You do not need to store the following file. Our script `hello_tpu.sh` will download the file to each of the hosts and run it.

```python
# Taken from https://cloud.google.com/tpu/docs/jax-pods
# To be used by the Pyprobml GSoC 2021 team
# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)%
```

Next, we run the code across all workers

In [78]:
%%bash
source /content/commands.sh
$gtpu ssh $instance_name \
    --zone $tpu_zone \
    --command "$(<./hello_tpu.sh)" \
    --worker all

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]


SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
Cloning into 'hello_gsoc'...
Cloning into 'hello_gsoc'...
Cloning into 'hello_gsoc'...
Cloning into 'hello_gsoc'...


### Example 2: 🚧K-nearest neighbours🚧

In this example we train the MNIST dataset using the KNN algorithm `pmap`. Our program clones a Github gist into each of the hosts. We use the multi-device availability of our slice to delegate a part of the training to each of the workers.

First, we create the script that will be run on each of the workers

In [79]:
%%writefile knn_tpu.sh
#!/bin/bash
# file: knn_tpu.sh

export gist_url="https://gist.github.com/716a7bfd4c5c0c0e1949072f7b2e03a6.git"
pip3 install -q tensorflow_datasets
git clone $gist_url demo
python3 demo/knn_tpu.py

Writing knn_tpu.sh


Next, we run the script

In [81]:
%%bash
source /content/commands.sh

$gtpu ssh $instance_name \
    --zone $tpu_zone \
    --command "$(<./knn_tpu.sh)" \
    --worker all

(8, 10, 20)
class_rate=0.9125


SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.
fatal: destination path 'demo' already exists and is not an empty directory.
fatal: destination path 'demo' already exists and is not an empty directory.
fatal: destination path 'demo' already exists and is not an empty directory.
fatal: destination path 'demo' already exists and is not an empty directory.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
In

# 🔪TPUs - The Sharp Bits🔪


## Service accounts

Before creating a new TPU instance, make sure that the Admin of the project grants the correct IAM user/group roles for your service account

- `TPU Admin`
- `Service Account User`

This prevents you from running into the following error

![error](https://imgur.com/sMAV2A5.png)

## Running Jax on a Pod

When creating an instance, we obtain different *slices*. Running a parallel operation on a single slice will not perform any computation until all of the slices have been run in sync. In Jax, this is done using `jax.pmap` function

## `pmap`ing a function

> *The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by jax.local_device_count() (unless devices is specified, [...])*

## Misc

- [Padding can tank your performance](https://github.com/google/jax/tree/main/cloud_tpu_colabs#padding)

# References

- gcloud
    - [gcloud CLI cheatsheet](https://cloud.google.com/sdk/docs/cheatsheet)
    - [gcloud update components](https://cloud.google.com/sdk/gcloud/reference/components/update)
- TPUs
    - [Jax cloud TPU](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm)
    - [TPU VM User's guide](https://cloud.google.com/tpu/docs/users-guide-tpu-vm)
    - [Jax TPUs on Slices](https://cloud.google.com/tpu/docs/jax-pods)
- Jax
    - [MNIST example with Flax](https://github.com/google/flax/tree/master/examples/mnist)
    - [Parallelism in Jax](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html)
    - [Jax multi-hosts](https://jax.readthedocs.io/en/latest/multi_process.html)
    - [ColCollective communication operations](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb#scrollTo=f-FBsWeo1AXE&uniqifier=1)