# PyTorch/XLA DTensor Integration

This notebook focuses on the application of DTensor with PyTorch/XLA.

See internal implementation details at
[[RFC] XLA Lazy Backend Support In DistributedTensor API #92909][dtensor-rfc]

[dtensor-rfc]: https://github.com/pytorch/pytorch/issues/92909

This can be run in command line using:

```
# Install jupyter
$ apt install jupyter

# Create Jupyter ipy kernel for current development venv:
$ pip install ipykernel
$ python -m ipykernel install --user --name=ptxla.venv

# 
$ jupyter execute --kernel_name=ptxla.venv /usr/local/google/home/gleasonk/Coding/pytorch/pytorch/xla/docs/source/perf/dtensor.ipynb
```

In [1]:
from platform import python_version

print(python_version())

3.11.9


## Setup parallel environment

We'll fake an 8 CPU setup. Note this must be done before the XLA PJRT plugin is
initialized by PyTorch/XLA.

In [87]:
import os
os.environ["WORLD_SIZE"] = '4'
os.environ["RANK"] = '1'
os.environ["CPU_NUM_DEVICES"] = os.environ["WORLD_SIZE"]
os.environ["PJRT_DEVICE"] = 'CPU'

import torch_xla
torch_xla.runtime.global_runtime_device_attributes()

[{'name': 'CPU:0'}, {'name': 'CPU:1'}, {'name': 'CPU:2'}, {'name': 'CPU:3'}]

# Intro to DTensor

The following sections are intended to mirror the PyTorch natice DTensor
tutorial:

https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md

In [24]:
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md#introduction

import os
import torch
from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor

# Create a mesh topology with the available devices:
# 1. We can directly create the mesh using elastic launcher, (recommended)
# 2. If using mp.spawn, one need to initialize the world process_group first and set device
#   i.e. torch.distributed.init_process_group(backend="nccl", world_size=world_size)
torch_xla.runtime.use_spmd()
mesh = init_device_mesh("xla", (int(os.environ["WORLD_SIZE"]),))
big_tensor = torch.randn(100000, 88)
# Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`.
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(dim=0)])

print(my_dtensor)

XLAShardedTensor(tensor([[-2.3854e-01,  5.5333e-01, -6.4550e-01,  ..., -8.1820e-01,
          1.3780e+00,  5.1097e-01],
        [ 7.3101e-01, -2.2367e+00, -8.9715e-01,  ..., -7.7787e-01,
          1.0784e+00,  7.0975e-02],
        [-1.3072e+00,  2.0384e+00, -1.9566e-01,  ..., -2.8503e-01,
         -2.2741e-01, -4.6136e-01],
        ...,
        [ 8.4810e-01,  1.3419e-01,  7.7133e-01,  ..., -1.5206e-01,
         -6.3849e-01, -2.4455e-03],
        [ 4.5250e-01, -3.7841e-01, -2.7445e+00,  ..., -4.5199e-01,
          1.4458e+00,  2.7002e-01],
        [-4.9910e-01, -1.1933e-01, -1.2691e+00,  ...,  5.2942e-01,
         -1.0121e+00, -1.4557e+00]], device='xla:0'))


## Basic DTensor Examples

https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md#basic-dtensor-api-examples

In [25]:
import torch
from torch.distributed.tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh

# construct a device mesh with available devices (multi-host or single host)
device_mesh = init_device_mesh("xla", (4,))
# if we want to do row-wise sharding
rowwise_placement=[Shard(0)]
# if we want to do col-wise sharding
colwise_placement=[Shard(1)]

big_tensor = torch.randn(888, 12)
# distributed tensor returned will be sharded across the dimension specified in placements
rowwise_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=rowwise_placement)

# if we want to do replication across a certain device list
replica_placement = [Replicate()]
# distributed tensor will be replicated to all four GPUs.
replica_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=replica_placement)

# if we want to distributed a tensor with both replication and sharding
device_mesh = init_device_mesh("xla", (2, 2))
# replicate across the first dimension of device mesh, then sharding on the second dimension of device mesh
spec=[Replicate(), Shard(0)]
partial_replica = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=spec)


print(rowwise_tensor)
print(replica_tensor)
print(partial_replica)

XLAShardedTensor(tensor([[ 0.3499,  0.7769,  0.4670,  ...,  0.2456, -0.1884,  1.0635],
        [-1.2960, -0.7200,  0.2598,  ..., -0.1129, -0.3152,  0.5862],
        [ 2.3719, -0.1890, -0.2029,  ...,  1.2554,  1.9394,  0.1867],
        ...,
        [ 1.1783, -2.0097, -0.2131,  ..., -0.2364,  0.0429, -0.1690],
        [ 0.2073,  0.0102,  1.1132,  ...,  1.1448,  0.2130,  1.2484],
        [ 0.6789,  1.9195, -0.5259,  ..., -1.0065, -1.3724,  0.6549]],
       device='xla:0'))
XLAShardedTensor(tensor([[ 0.3499,  0.7769,  0.4670,  ...,  0.2456, -0.1884,  1.0635],
        [-1.2960, -0.7200,  0.2598,  ..., -0.1129, -0.3152,  0.5862],
        [ 2.3719, -0.1890, -0.2029,  ...,  1.2554,  1.9394,  0.1867],
        ...,
        [ 1.1783, -2.0097, -0.2131,  ..., -0.2364,  0.0429, -0.1690],
        [ 0.2073,  0.0102,  1.1132,  ...,  1.1448,  0.2130,  1.2484],
        [ 0.6789,  1.9195, -0.5259,  ..., -1.0065, -1.3724,  0.6549]],
       device='xla:0'))
XLAShardedTensor(tensor([[ 0.3499,  0.7769,  0.467

### UX ISSUE: `DTensor.from_local` fails

In [26]:
# create a DistributedTensor that shards on dim 0, from a local torch.Tensor
local_tensor = torch.randn((8, 8), requires_grad=True)
rowwise_tensor = DTensor.from_local(local_tensor, device_mesh, rowwise_placement)

# reshard the current row-wise tensor to a colwise tensor or replicate tensor
colwise_tensor = rowwise_tensor.redistribute(device_mesh, colwise_placement)
replica_tensor = colwise_tensor.redistribute(device_mesh, replica_placement)

AttributeError: 'DeviceMesh' object has no attribute '_coordinate_on_dim'

### Attempted workaround

Sort of... New error.

In [37]:
# See: https://github.com/pytorch/xla/issues/8528
# Need to stub the method in the meantime.
# Not sure what to stub it to though.
_rank = 0
device_mesh._coordinate_on_dim = [_rank]  

# create a DistributedTensor that shards on dim 0, from a local torch.Tensor
local_tensor = torch.randn((8, 8), requires_grad=True)
rowwise_tensor = DTensor.from_local(local_tensor, device_mesh, rowwise_placement)

# reshard the current row-wise tensor to a colwise tensor or replicate tensor
colwise_tensor = rowwise_tensor.redistribute(device_mesh, colwise_placement)
replica_tensor = colwise_tensor.redistribute(device_mesh, replica_placement)

AttributeError: 'DeviceMesh' object has no attribute '_dim_group_infos'

# DummyMLP Example from Torch Native

In [24]:
import torch.nn.functional as F

class DummyMLP(torch.nn.Module):
  def __init__(self, device):
    super().__init__()
    self.net1 = torch.nn.Linear(5, 1024, device=device)
    self.relu = torch.nn.ReLU()
    self.net2 = torch.nn.Linear(1024, 4, device=device)

  def forward(self, x):
    return self.net2(F.relu(self.net1(x)))

  def reset_parameters(self, *args, **kwargs):
    with torch.no_grad():
      self.net1.weight.fill_(0.5)
      self.net2.weight.fill_(1)
      self.net1.bias.fill_(1.5)
      self.net2.bias.fill_(1.2)

DummyMLP("xla")

DummyMLP(
  (net1): Linear(in_features=5, out_features=1024, bias=True)
  (relu): ReLU()
  (net2): Linear(in_features=1024, out_features=4, bias=True)
)

### UX ISSUE: Can't distribute using `parallelize_module`

In [None]:
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    parallelize_module,
    RowwiseParallel,
)

device_type = "xla" # was "meta" in test

model = DummyMLP("xla")
device_mesh = init_device_mesh(device_type, (int(os.environ["WORLD_SIZE"]),))

# UX ISSUE: XLA and PyTorch handle "RANK" differently.
# See: https://github.com/pytorch/xla/issues/8528
# Need to stub the method in the meantime.
# Not sure what to stub it to though.
_rank = 0
device_mesh._coordinate_on_dim = [_rank]  

# UX ISSUE: We can't use upstream parallelization plans
parallelize_plan = {
    "net1": ColwiseParallel(),
    "net2": RowwiseParallel(),
}
model_tp = parallelize_module(model, device_mesh, parallelize_plan)
model_tp.to_empty(device=device_type)
model_tp.reset_parameters()
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)

torch.manual_seed(0)
inp = torch.randn(20, 5, device=device_type)
output = model_tp(inp)

ValueError: Default process group has not been initialized, please make sure to call init_process_group.

In [None]:
import torch.nn as nn
from torch.distributed.tensor import Shard, distribute_tensor, distribute_module, init_device_mesh

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(8, 8)
        self.fc2 = nn.Linear(8, 8)
        self.relu = nn.ReLU()

    def forward(self, input):
        return self.relu(self.fc1(input) + self.fc2(input))

mesh = init_device_mesh("xla", (4,))

def shard_params(mod_name, mod, mesh):
    col_linear_placement = [Shard(0)]
    # shard fc1 and fc2
    if isinstance(mod, nn.Linear):
        for name, param in mod.named_parameters():
            dist_param = nn.Parameter(
                distribute_tensor(param, mesh, col_linear_placement)
            )
            mod.register_parameter(name, dist_param)

sharded_module = distribute_module(MyModule(), mesh, partition_fn=shard_params)
print(sharded_module)

x = torch.rand((8, 8))
sharded_module(x)

MyModule(
  (fc1): Linear(in_features=8, out_features=8, bias=True)
  (fc2): Linear(in_features=8, out_features=8, bias=True)
  (relu): ReLU()
)




RuntimeError: !at::functionalization::impl::isFunctionalTensor(t) INTERNAL ASSERT FAILED at "/usr/local/google/home/gleasonk/Coding/pytorch/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":838, please report a bug to PyTorch. The composite op functionalization fallback expects its inputs all not to be functional tensors

In [31]:
# test_dtensor_toy_model_forward.py
import os
import torch
from torch import nn
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
from torch_xla.core import xla_model as xm


class ToyModel(nn.Module):
    """MLP based model"""

    def __init__(self):
        super(ToyModel, self).__init__()
        self.in_proj = nn.Linear(10, 32)
        #self.relu = nn.ReLU()
        self.out_proj = nn.Linear(32, 5)

    def forward(self, x):
        return self.out_proj(self.in_proj(x))
        #return self.out_proj(self.relu(self.in_proj(x)))

torch_xla.runtime.use_spmd()
_world_size = int(os.environ["WORLD_SIZE"])
_rank = int(os.environ.get("RANK", 1))
device_name = 'xla' # 'cuda'


device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(_world_size,))
device_mesh._coordinate_on_dim = [_rank]  # workaround for https://github.com/pytorch/xla/issues/8528
torch.manual_seed(1)
inp = torch.rand(20, 10).to(device=device_name)
tp_model = ToyModel().to(device=device_name)

DIST_MODE = 2
if DIST_MODE==0:
    distribute_tensor(tp_model.in_proj.weight, device_mesh, [Shard(0)])
    distribute_tensor(tp_model.in_proj.bias, device_mesh, [Shard(0)])
    distribute_tensor(tp_model.out_proj.weight, device_mesh, [Shard(1)])
    distribute_tensor(tp_model.out_proj.bias, device_mesh, [Replicate()])
elif DIST_MODE == 1:
    # This is a hack to just call mark sharding...
    # This doesn't replace the weights with XLAShardedTensors
    tp_model.in_proj.weight = distribute_tensor(tp_model.in_proj.weight, device_mesh, [Shard(0)])
    tp_model.in_proj.bias = distribute_tensor(tp_model.in_proj.bias, device_mesh, [Shard(0)])
    tp_model.out_proj.weight = distribute_tensor(tp_model.out_proj.weight, device_mesh, [Shard(1)])
    tp_model.out_proj.bias = distribute_tensor(tp_model.out_proj.bias, device_mesh, [Replicate()])
    print(type(tp_model.in_proj.weight))
elif DIST_MODE == 2:
    # This replaces all weights with XLAShardedTensors, however XLAShardedTensors
    # dont execute properly today.
    tp_model = parallelize_module(
            module=tp_model,
            device_mesh=device_mesh,
            parallelize_plan={
                "in_proj": ColwiseParallel(),
                "out_proj": RowwiseParallel(),
            },
        )
    print(type(tp_model.in_proj.weight))
    print(tp_model.in_proj.weight.sharding_spec)
    print(tp_model.out_proj.weight.sharding_spec)

out = tp_model(inp)
print("STABLEHLO\n", xm.get_stablehlo([out]))

print(out.cpu())
xm.mark_step()



NameError: name 'torch_xla' is not defined

# Native PyTorch Tensor Parallel

In [30]:
import torch
from torch.testing._internal.common_distributed import spawn_threads_and_init_comms
WORLD_SIZE=4

import torch.nn as nn
import torch.distributed as dist
from torch.distributed._tensor import (
    DeviceMesh,
)
from torch.distributed.tensor.parallel import (
    RowwiseParallel,
    ColwiseParallel,
    parallelize_module,
)

ITER_TIME = 20

class ToyModel(nn.Module):
    """MLP based model"""

    def __init__(self):
        super(ToyModel, self).__init__()
        self.in_proj = nn.Linear(10, 32)
        self.relu = nn.ReLU()
        self.out_proj = nn.Linear(32, 5)

    def forward(self, x):
        return self.out_proj(self.relu(self.in_proj(x)))

def print0(msg, rank):
    if rank == 0:
        print(msg)

def printR(msg, rank):
    print(f"Rank{rank}: {msg}")

@spawn_threads_and_init_comms
def demo_tp(world_size):
    """
    Main body of the demo of a basic version of tensor parallel by using
    PyTorch native APIs.
    """
    rank = dist.get_rank()
    print("Create a sharding plan based on the given world_size", rank)
    # create a sharding plan based on the given world_size.
    device_mesh = DeviceMesh(
        "cpu",
        torch.arange(world_size),
    )

    # create model and move it to GPU with id rank
    model = ToyModel()
    tp_model = parallelize_module(
            module=model,
            device_mesh=device_mesh,
            parallelize_plan={
                "in_proj": ColwiseParallel(),
                "out_proj": RowwiseParallel(),
            },
        )
    from torch.fx import symbolic_trace
    traced = symbolic_trace(tp_model)
    print0(traced.graph, rank)
    printR(tp_model.in_proj.weight, rank)

    # Create a optimizer for the parallelized module.
    LR = 0.25
    optimizer = torch.optim.SGD(tp_model.parameters(), lr=LR)
    print0("Parallelize the module based on the given Parallel Style", rank)
    # Parallelize the module based on the given Parallel Style.

    # Perform a num of iterations of forward/backward
    # and optimizations for the sharded module.
    for i in range(ITER_TIME):
        inp = torch.rand(20, 10)
        output = tp_model(inp)
        print0(f"FWD Step: iter {i}", rank)
        output.sum().backward()
        print0(f"BWD Step: iter {i}", rank)
        optimizer.step()
        print0(f"Optimization Step: iter {i}", rank)
    print0("Training finished", rank)

demo_tp(WORLD_SIZE)

Create a sharding plan based on the given world_sizeCreate a sharding plan based on the given world_size 2
Create a sharding plan based on the given world_size 3
Create a sharding plan based on the given world_size 0
 1
graph():
    %x : [num_users=1] = placeholder[target=x]
    %in_proj : [num_users=1] = call_module[target=in_proj](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%in_proj,), kwargs = {})
    %out_proj : [num_users=1] = call_module[target=out_proj](args = (%relu,), kwargs = {})
    return out_proj
Rank0: DTensor(local_tensor=tensor([[ 0.0714,  0.2321, -0.2109,  0.0709, -0.0369, -0.0149, -0.2217,  0.0326,
         -0.3099,  0.0043],
        [ 0.0765, -0.0198, -0.1454,  0.2559,  0.1785, -0.0122,  0.1996,  0.0504,
         -0.1649, -0.1638],
        [ 0.2190,  0.0206,  0.0890,  0.2985, -0.2653, -0.1647,  0.1480, -0.2851,
          0.2455, -0.1108],
        [-0.2619, -0.2695,  0.0860,  0.1910,  0.1654, -0.0758,  0.2099,  0.0857,
     