PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!
Take a look at one of our Kaggle notebooks to get started:
PyTorch/XLA is now on PyPI!
To install PyTorch/XLA a new TPU VM:
pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 -f https://storage.googleapis.com/libtpu-releases/index.html
To update your existing training loop, make the following changes:
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
...
+ # Move the model paramters to your XLA device
+ model.to(xla.device())
for inputs, labels in train_loader:
+ with xla.step():
+ # Transfer data to the XLA device. This happens asynchronously.
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step(optimizer)
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ # xmp.spawn automatically selects the correct world size
+ xmp.spawn(_mp_fn, args=())
If you're using DistributedDataParallel
, make the following changes:
import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.xla_backend
def _mp_fn(rank):
...
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ # `gradient_as_bucket_view=True` required for XLA
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
for inputs, labels in train_loader:
+ with xla.step():
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ xmp.spawn(_mp_fn, args=())
Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).
Our comprehensive user guides are available at:
Documentation for the latest release
Documentation for master branch
PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You
can now install the main build with pip install torch_xla
. To also install the
Cloud TPU plugin, install the optional tpu
dependencies after installing the main build with
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
GPU and nightly builds are available in our public GCS bucket.
Version | Cloud TPU/GPU VMs Wheel |
---|---|
2.3 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.3 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
nightly (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
nightly (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl |
nightly (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
You can also add +yyyymmdd
after torch_xla-nightly
to get the nightly wheel of a specified date. To get the companion pytorch nightly wheel, replace the torch_xla
with torch
on above wheel links.
older versions
Version | Cloud TPU VMs Wheel |
---|---|
2.2 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.2 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (XRT + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl |
2.0 (Python 3.8) | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.13-cp38-cp38-linux_x86_64.whl |
1.12 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.12-cp38-cp38-linux_x86_64.whl |
1.11 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl |
1.10 | https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp38-cp38-linux_x86_64.whl |
Note: For TPU Pod customers using XRT (our legacy runtime), we have custom
wheels for torch
and torch_xla
at
https://storage.googleapis.com/tpu-pytorch/wheels/xrt
.
Package | Cloud TPU VMs Wheel (XRT on Pod, Legacy Only) |
---|---|
torch_xla | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl |
torch | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch-2.1.0%2Bxrt-cp310-cp310-linux_x86_64.whl |
Version | GPU Wheel + Python 3.8 |
---|---|
2.1+ CUDA 11.8 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.0 + CUDA 11.8 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
2.0 + CUDA 11.7 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-2.0-cp38-cp38-linux_x86_64.whl |
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp38-cp38-linux_x86_64.whl |
nightly + CUDA 12.0 >= 2023/06/27 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
nightly + CUDA 11.8 <= 2023/04/25 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
nightly + CUDA 11.8 >= 2023/04/25 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
Version | GPU Wheel + Python 3.7 |
---|---|
1.13 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl |
1.12 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl |
1.11 | https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl |
Version | Colab TPU Wheel |
---|---|
2.0 | https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl |
Version | Cloud TPU VMs Docker |
---|---|
2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm |
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm |
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm |
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm |
nightly python | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm |
To use the above dockers, please pass --privileged --net host --shm-size=16G
along. Here is an example:
docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash
Version | GPU CUDA 12.1 Docker |
---|---|
2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1 |
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1 |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1 |
nightly | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 |
nightly at date | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD |
Version | GPU CUDA 11.8 + Docker |
---|---|
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8 |
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8 |
nightly | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8 |
nightly at date | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD |
older versions
Version | GPU CUDA 11.7 + Docker |
---|---|
2.0 | gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7 |
Version | GPU CUDA 11.2 + Docker |
---|---|
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2 |
Version | GPU CUDA 11.2 + Docker |
---|---|
1.13 | gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2 |
1.12 | gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2 |
To run on compute instances with GPUs.
If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).
The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!
See the contribution guide.
This repository is jointly operated and maintained by Google, Meta and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Meta, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.
You can find additional useful reading materials in