# Ray on TPUs - A Gentle Introduction

This notebook aims to provide a gentle introduction to Ray on TPU concepts.

Specifically, this notebook aims to cover:
- The Ray Cluster Launcher: Provisioning a VM-based Ray cluster 
- Basics of Ray Core + running on TPUs

## Ray Cluster Launcher

The [Ray Cluster Launcher](https://docs.ray.io/en/latest/cluster/vms/references/ray-cluster-cli.html) is a CLI tool that is used to deploy [Ray clusters](https://docs.ray.io/en/latest/cluster/getting-started.html) which are the foundation of Ray applications.

The Ray cluster launcher operates on a provided YAML file and can be deployed and teared down via commands `ray up cluster.yaml` and `ray down cluster.yaml`.

Take a look at the example below:

In [18]:
from ipywidgets import widgets

text = widgets.Text(
    value="my-gcp-project",
    description="GCP project ID: ",
    disabled=False,
)
display(text)

Text(value='my-gcp-project', description='GCP project ID: ')

In [23]:
project_name = text.value

print("Using project name: ", project_name)

Using project name:  mlperf-high-priority-project


In [49]:
import os

cluster_def = """
cluster_name: tpu-demo

max_workers: 3

available_node_types:
    ray_head_default:
        min_workers: 0
        max_workers: 0
        resources: {{"CPU": 0}}
        node_config:
            machineType: n1-standard-4
            disks:
              - boot: true
                autoDelete: true
                type: PERSISTENT
                initializeParams:
                  diskSizeGb: 50
                  sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2004-lts
    ray_tpu_v4_8:
        min_workers: 1
        max_workers: 2
        resources: {{"TPU": 4, "tpu-v4-8": 1}}
        node_config:
            acceleratorType: v4-8
            runtimeVersion: tpu-vm-v4-base
    ray_tpu_v4_16:
        min_workers: 1
        max_workers: 1
        resources: {{"TPU": 4, "tpu-v4-16": 1}}
        node_config:
            acceleratorConfig:
                type: V4
                topology: 2x2x2
            runtimeVersion: tpu-vm-v4-base

provider:
    type: gcp
    region: us-central2
    availability_zone: us-central2-b
    project_id: {project}

initialization_commands:
  - sudo apt-get update
  - sudo apt-get install -y python3-pip python-is-python3

setup_commands:
  - pip install "ray[default]"

head_setup_commands:
  - pip install google-api-python-client
 
worker_setup_commands:
  - pip install 'jax[tpu]==0.4.11' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

head_node_type: ray_head_default
""".format(project=project_name)

root_dir = os.path.join(os.getcwd(), "tmp")
os.makedirs(root_dir, exist_ok=True)
yaml_fpath = os.path.join(root_dir, "cluster.yaml")
with open(yaml_fpath, "w") as file:
    file.write(cluster_def)

print(f"Run: ray up -y {yaml_fpath}")

Run: ray up -y /home/allencwang/ray-tpu-hello/tmp/cluster.yaml


2023-10-09 16:38:40,069	ERROR dataclient.py:330 -- Unrecoverable error in data channel.


In [12]:
! ray up -y /home/allencwang/ray-tpu-hello/tmp/cluster.yaml

[37mCluster[39m: [1mtpu-demo[22m

Checking GCP environment settings
2023-10-09 15:23:37,153	INFO config.py:556 -- _configure_key_pair: Private key not specified in config, using/home/allencwang/.ssh/ray-autoscaler_gcp_us-central2_mlperf-high-priority-project_ubuntu_0.pem
Updating cluster configuration and running full setup.
[1mCluster Ray runtime will be restarted.[22m [4mConfirm [y/N]:[24m y [2m[automatic, due to --yes][22m

Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.

[2m<1/1>[22m [36mSetting up head node[39m
  Prepared bootstrap config
2023-10-09 15:23:42,034	INFO node.py:321 -- wait_for_compute_zone_operation: Waiting for operation operation-1696865021785-6074a2cdaf5f8-28aaed21-4b379b3f to finish...
2023-10-09 15:23:47,225	INFO node.

As you can see - we have defined a YAML spec of our Ray cluster and used that spec to launch our Ray cluster.

Notice the following:
- We create:
  1)  a `ray_head_default` node - e.g. a CPU VM `n1-standard-4`
  2)  a `ray_tpu_v4_8` node - a TPU VM v4-8. We start with a minimum number of replicas, 1, and allow it to scale up to a maximum of 2 replicas.
  3)  a `ray_tpu_v4_16` node - a TPU VM v4-16. We start with a minimum number of replicas 1, and do not allow it to scale up or down.
-  on top of a GCP based provider.
- Each node installs python3, and each worker node installs JAX.
- A maximum of 3 workers can be active at a given time at the Ray cluster level.

Now that the cluster is provisioned, let's move on to some of the basics Ray concepts.

## Connecting to the Cluster

Once your cluster is provisioned, the first step for developing a Ray application will be to connect to the cluster.

From a Jupyter notebook like this, we'll need to connect in "client mode," meaning we will need to provide the init string address. But note that if we were running a workload as a Ray job or Serve deployment (or from within the Ray Head node), `ray.init()` would be sufficient by itself.

In [55]:
import ray

# We connect to the internal IP at port 10001 for "client mode"
ray.init("ray://10.130.0.74:10001")

0,1
Python version:,3.8.10
Ray version:,2.7.0
Dashboard:,http://10.130.0.74:8265


## Ray Core Basics
Ray allows developers to easily scale out workloads from an interactive notebook setting like this, to large clusters e.g. in a Cloud environment. To help developers achieve this, Ray Core provides a small, limited number of core primitives for fleshing out applications. 

Let's walk through this together.

### Ray Tasks
Ray lets you run functions as remote tasks on a cluster. To do this is simple:
1) you decorate your function with `@ray.remote`
2) you invoke that function with `.remote()` which returns a reference,
3) you fetch the value of the reference with `ray.get`. 

Let's see this in action:

In [46]:
# Define the square task.
@ray.remote
def square(x):
    return x * x

# Launch four parallel square tasks.
futures = [square.remote(i) for i in range(4)]

# Retrieve results.
print(ray.get(futures))

[0, 1, 4, 9]


### Ray Actors
Ray provides actors to allow you to parallelize computation across multiple actor instances.

When you instantiate a class that is a Ray actor, Ray will start a remote instance of that class in the cluster.

This actor can then execute remote method calls and maintain its own internal state:



In [47]:
# Define the Counter actor.
@ray.remote
class Counter:
    def __init__(self):
        self.i = 0

    def get(self):
        return self.i

    def incr(self, value):
        self.i += value

# Create a Counter actor.
c = Counter.remote()

# Submit calls to the actor. These calls run asynchronously but in
# submission order on the remote actor process.
for _ in range(10):
    c.incr.remote(1)

# Retrieve final actor state.
print(ray.get(c.get.remote()))

10
[2m[1m[36m(autoscaler +8m59s)[0m Resized to 720 CPUs, 12 TPUs.


### Ray Resources

Ray resources abstract away physical machines and let you express your computation in terms of logical resources. The system manages the complexities of scheduling and autoscaling based on resource requests.

Resources in Ray are key-value pairs where the key is a resource name, and the value is a float quantity.

Ray provides native support for CPU, GPU, TPU, AWS-neuron, memory resource types - but also supports "custom resources" as well.

If we were to take a look at our available resources:

In [45]:
ray.available_resources()

{'node:10.130.0.92': 1.0,
 'TPU-V4': 8.0,
 'node:__internal_head__': 1.0,
 'TPU': 8.0,
 'tpu-v4-16': 2.0,
 'CPU': 480.0,
 'memory': 608839377307.0,
 'object_store_memory': 261567048498.0,
 'node:10.130.0.74': 1.0,
 'node:10.130.0.90': 1.0}

We can see that:
1) Despite us not having explicitly defined a resource for "TPU-V4" in the cluster YAML, it shows up in our available resources. This is a side effect of TPUs being a native resource. Similarly, we did not explicitly define the amount of memory or CPU available in our node types.
2) However, we do see resource types for tpu-v4-8 and tpu-v4-16. This is an example of a custom resource.

Ray resources are logical and don't strictly require a 1-to-1 mapping with physical resources - these resources are defined per ray node in the `ray start` command once the cluster is set up. An implication of this is that it's the responsibility of the user to be consistent with this mapping and not violate these assumptions in their application.

### Requesting Resources
Resources are the foundation of Ray's autoscaling capabilities - the Ray autoscaler scales based on the demand of the running Ray application.

The Ray application makes its demands by specifying the logical resource requirements of Ray tasks and actors.

To see this in action, check out the following example:

In [56]:
@ray.remote(resources={"TPU": 4, "tpu-v4-8": 1})
def my_function():
    import socket
    return socket.gethostname()

ray.get(my_function.remote())

't1v-n-c7cc0a56-w-0'

[2m[1m[36m(autoscaler +28s)[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
[2m[1m[36m(autoscaler +28s)[0m Resized to 720 CPUs, 12 TPUs.


Based on the returned hostname, we can tell that this indeed executed on a TPU VM. Notice that:
1) There was no constraint preventing us from claiming the TPU logical resources despite not accessing the physical resources
2) We can use this custom resource "tpu-v4-8" to target a particular type of TPU.

But we could also actually use the physical resources:

In [57]:
@ray.remote(resources={"TPU": 4, "tpu-v4-8": 1})
def device_count():
    import jax
    return jax.device_count()

ray.get(device_count.remote())

4

and run on a TPU pod as well...

In [58]:
@ray.remote(resources={"TPU": 4, "tpu-v4-16": 1})
def device_count():
    import jax
    return jax.device_count()

handles = [device_count.remote() for _ in range(2)]
ray.get(handles)

[8, 8]

We could also trigger autoscaling from the application, as long as our Ray cluster supports it. In our cluster config, we indicated that it's ok to create multiple v4-8s:

In [59]:
@ray.remote(resources={"TPU": 4, "tpu-v4-8": 1})
def my_long_running_fn():
    import time
    # Sleep for 5 min just to claim the resources and trigger autoscaling...
    time.sleep(5 * 60)


h1 = my_long_running_fn.remote()
h2 = my_long_running_fn.remote()

[2m[1m[36m(autoscaler +11m29s)[0m Adding 1 node(s) of type ray_tpu_v4_8.
[2m[1m[36m(autoscaler +13m33s)[0m Resized to 960 CPUs, 16 TPUs.
[2m[1m[36m(autoscaler +19m54s)[0m Removing 1 nodes of type ray_tpu_v4_8 (idle).
[2m[1m[36m(autoscaler +20m4s)[0m Resized to 720 CPUs, 12 TPUs.


...which, the Ray system will interpret as:

```
$ ray exec tmp/cluster.yaml "ray status"
...
Resources
---------------------------------------------------------------
Usage:
 1.0/720.0 CPU
 4.0/12.0 TPU
 0.0/12.0 TPU-V4
 0B/846.37GiB memory
 2.14KiB/363.33GiB object_store_memory
 0.0/2.0 tpu-v4-16
 1.0/1.0 tpu-v4-8

Demands:
 {'CPU': 1.0, 'TPU': 4.0, 'tpu-v4-8': 1.0}: 1+ pending tasks/actors
```

and use the demands to finally trigger autoscaling:
```
2023-10-09 16:48:15,909 INFO autoscaler.py:1379 -- StandardAutoscaler: Queue 1 new nodes for launch
2023-10-09 16:48:15,910 INFO autoscaler.py:464 -- The autoscaler took 0.162 seconds to complete the update iteration.
2023-10-09 16:48:15,910 INFO node_launcher.py:177 -- NodeLauncher0: Got 1 nodes to launch.
2023-10-09 16:48:18,331 INFO node.py:578 -- wait_for_tpu_operation: Waiting for operation projects/mlperf-high-priority-project/locations/us-central2-b/operations/operation-1696870095949-6074b5b4c9063-63d36f8f-7982f037 to finish...
```

Finally, to teardown the application - we could run something like:

```
$ ray down -y tmp/cluster.yaml
```

The world of Ray is vast and we've only scratched the surface in this notebook. However, after having gone through this notebook you should understand some of the basic concepts about Ray and the fundamentals for building large, distributed and accelerated applications!

For further code references for Ray with TPUs, check out: 
- [Distributed Hyperparameter Tuning with Ray Tune and TPUs](https://github.com/tensorflow/tpu/tree/master/tools/ray_tpu/src/tune)
- [Distributed Serving with Ray Serve and TPUs](https://github.com/tensorflow/tpu/tree/master/tools/ray_tpu/src/serve)
- [Training and serving LLaMa2 with PyTorch/XLA on TPUs](https://github.com/pytorch-tpu/ray-llama)