diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index de142932..9b0f6e4d 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -4,6 +4,7 @@ import numpy as np import ray +from ray.runtime_env import RuntimeEnv from ray.util.accelerators import tpu from jetstream.engine import engine_api, tokenizer_pb2 @@ -241,7 +242,8 @@ def create_pytorch_ray_engine( ), f"num_hosts (current value {num_hosts}) should be a positive number" # pylint: disable-next=all engine_worker_with_tpu_resource = PyTorchRayWorker.options( - resources={"TPU": 4} + resources={"TPU": 4}, + runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "tpu,cpu"}), ) engine_workers = [] for _ in range(num_hosts):