<a href="https://colab.research.google.com/github/JackCaoG/torch-xla-examples/blob/main/spmd_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torch_xla

import numpy as np

import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs

print(torch.__version__)
print(torch_xla.__version__)

2.3.0+cpu
2.3.0+libtpu


In [12]:
# Enable XLA SPMD execution mode.
xr.use_spmd()



In [13]:
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = xs.Mesh(device_ids, mesh_shape, ('data', 'model'))

In [14]:
print(mesh.axis_names)
print(mesh.mesh_shape)

('data', 'model')
(8, 1)


In [15]:
t_cpu = torch.randn(8, 4)
t_xla = t_cpu.to(torch_xla.device())

In [17]:
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xt = xs.mark_sharding(t_xla, mesh, partition_spec)
print(len(xt.local_shards))
print(xt.local_shards[0])
print(xt.local_shards[1])
print(t_cpu)

8
XLAShard(data=tensor([[ 1.1481,  1.0661, -0.4202, -0.7410]]), indices=[slice(0, 1, 1), slice(0, 4, 1)], shard_device='TPU:0', replica_id=0)
XLAShard(data=tensor([[ 0.2942,  0.3965, -0.2677, -0.8216]]), indices=[slice(1, 2, 1), slice(0, 4, 1)], shard_device='TPU:1', replica_id=0)
tensor([[ 1.1481,  1.0661, -0.4202, -0.7410],
        [ 0.2942,  0.3965, -0.2677, -0.8216],
        [ 0.9702, -1.0988, -0.8096,  0.3651],
        [ 0.1583, -1.0259,  0.5243, -1.5866],
        [ 1.7533,  1.9341, -0.1616,  0.2106],
        [ 0.6075, -0.3239, -0.7455,  0.7426],
        [-0.7109, -1.8223,  1.3984,  1.6798],
        [-0.7604, -2.1333,  0.7389, -0.4245]])


In [19]:
print(type(t_xla))
print(t_xla.device)
print(type(xt))

<class 'torch.Tensor'>
xla:0
<class 'torch_xla.distributed.spmd.xla_sharded_tensor.XLAShardedTensor'>


In [20]:
torch.allclose(torch.cos(t_xla).cpu(), torch.cos(t_cpu))

True

In [None]:
import torchvision

device = torch_xla.device()

model = torchvision.models.resnet18().to(device)
# [Batch, Channel, dim, dim]
input = torch.randn(512, 3, 224, 224).to(device)

# Shard at batch dimension, this is data parallel.
xs.mark_sharding(input, mesh, ('data', None, None, None))

loss = model(input)
xm.mark_step()

print(loss)

tensor([[-0.3702, -0.3322,  0.7581,  ...,  0.1561, -0.4892,  0.2370],
        [-0.4208, -0.2955,  0.7115,  ...,  0.1580, -0.7177,  0.0763],
        [-0.4375, -0.3585,  0.7511,  ..., -0.0321, -0.7185,  0.2437],
        ...,
        [-0.5015, -0.4109,  0.7333,  ...,  0.2299, -0.7468,  0.2354],
        [-0.4877, -0.2055,  0.7506,  ...,  0.0622, -0.8617,  0.1749],
        [-0.3511, -0.3073,  0.7137,  ...,  0.0111, -0.6867,  0.2517]],
       device='xla:0', grad_fn=<AddmmBackward0>)
