# PyTorch/XLA DTensor Integration

This notebook focuses on the application of DTensor with PyTorch/XLA.

See internal implementation details at
[[RFC] XLA Lazy Backend Support In DistributedTensor API #92909][dtensor-rfc]

[dtensor-rfc]: https://github.com/pytorch/pytorch/issues/92909

This can be run in command line using:

```
# Install jupyter
$ apt install jupyter

# Create Jupyter ipy kernel for current development venv:
$ pip install ipykernel
$ python -m ipykernel install --user --name=ptxla.venv

# 
$ jupyter execute --kernel_name=ptxla.venv /usr/local/google/home/gleasonk/Coding/pytorch/pytorch/xla/docs/source/perf/dtensor.ipynb
```

In [1]:
from platform import python_version

print(python_version())

3.11.9


## Setup parallel environment

We'll fake an 8 CPU setup. Note this must be done before the XLA PJRT plugin is
initialized by PyTorch/XLA.

In [2]:
import os
os.environ["WORLD_SIZE"] = '4'
os.environ["CPU_NUM_DEVICES"] = os.environ["WORLD_SIZE"]
os.environ["PJRT_DEVICE"] = 'CPU'

import torch_xla
torch_xla.runtime.global_runtime_device_attributes()

[{'name': 'CPU:0'}, {'name': 'CPU:1'}, {'name': 'CPU:2'}, {'name': 'CPU:3'}]

# Intro to DTensor

The following sections are intended to mirror the PyTorch natice DTensor
tutorial:

https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md

In [None]:
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md#introduction

import os
import torch
from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor

# Create a mesh topology with the available devices:
# 1. We can directly create the mesh using elastic launcher, (recommended)
# 2. If using mp.spawn, one need to initialize the world process_group first and set device
#   i.e. torch.distributed.init_process_group(backend="nccl", world_size=world_size)

mesh = init_device_mesh("xla", (int(os.environ["WORLD_SIZE"]),))
big_tensor = torch.randn(100000, 88)
# Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`.
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(dim=0)])

print(my_dtensor)



XLAShardedTensor(tensor([[-1.7080,  0.4637, -1.4058,  ..., -1.6900, -1.1051,  1.7210],
        [-0.3269, -1.5310,  0.4159,  ...,  0.1561,  1.0350,  0.2642],
        [ 0.8611,  0.5285,  1.2459,  ...,  0.2676, -1.2947,  1.0346],
        ...,
        [ 0.1595, -0.1214,  0.4675,  ..., -0.6620,  0.4335, -0.5811],
        [-0.1768,  0.3318,  1.4751,  ..., -0.3326,  0.0671,  1.1855],
        [-0.3218, -0.6718, -0.2988,  ...,  0.2206, -0.4058,  1.0825]],
       device='xla:0'))




## Basic DTensor Examples

https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md#basic-dtensor-api-examples

In [34]:
import torch
from torch.distributed.tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh

# construct a device mesh with available devices (multi-host or single host)
device_mesh = init_device_mesh("xla", (4,))
# if we want to do row-wise sharding
rowwise_placement=[Shard(0)]
# if we want to do col-wise sharding
colwise_placement=[Shard(1)]

big_tensor = torch.randn(888, 12)
# distributed tensor returned will be sharded across the dimension specified in placements
rowwise_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=rowwise_placement)

# if we want to do replication across a certain device list
replica_placement = [Replicate()]
# distributed tensor will be replicated to all four GPUs.
replica_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=replica_placement)

# if we want to distributed a tensor with both replication and sharding
device_mesh = init_device_mesh("xla", (2, 2))
# replicate across the first dimension of device mesh, then sharding on the second dimension of device mesh
spec=[Replicate(), Shard(0)]
partial_replica = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=spec)


print(rowwise_tensor)
print(replica_tensor)
print(partial_replica)

XLAShardedTensor(tensor([[ 2.7933, -1.2275, -0.2549,  ...,  0.9429, -0.2904, -1.2301],
        [ 0.8740, -1.5658,  2.2909,  ..., -0.3352,  0.3340,  1.7716],
        [ 1.9126, -2.0719, -2.1520,  ...,  1.7489,  0.9192, -0.9709],
        ...,
        [ 1.4185, -0.0687,  0.3937,  ...,  0.2070, -0.9086,  1.4251],
        [ 1.6532, -0.4115,  0.9860,  ..., -0.6602,  1.1456,  0.7818],
        [-0.3824,  0.0094, -0.2849,  ..., -2.6232,  0.1736, -0.6748]],
       device='xla:0'))
XLAShardedTensor(tensor([[ 2.7933, -1.2275, -0.2549,  ...,  0.9429, -0.2904, -1.2301],
        [ 0.8740, -1.5658,  2.2909,  ..., -0.3352,  0.3340,  1.7716],
        [ 1.9126, -2.0719, -2.1520,  ...,  1.7489,  0.9192, -0.9709],
        ...,
        [ 1.4185, -0.0687,  0.3937,  ...,  0.2070, -0.9086,  1.4251],
        [ 1.6532, -0.4115,  0.9860,  ..., -0.6602,  1.1456,  0.7818],
        [-0.3824,  0.0094, -0.2849,  ..., -2.6232,  0.1736, -0.6748]],
       device='xla:0'))
XLAShardedTensor(tensor([[ 2.7933, -1.2275, -0.254

### UX ISSUE: `DTensor.from_local` fails

In [35]:
# create a DistributedTensor that shards on dim 0, from a local torch.Tensor
local_tensor = torch.randn((8, 8), requires_grad=True)
rowwise_tensor = DTensor.from_local(local_tensor, device_mesh, rowwise_placement)

# reshard the current row-wise tensor to a colwise tensor or replicate tensor
colwise_tensor = rowwise_tensor.redistribute(device_mesh, colwise_placement)
replica_tensor = colwise_tensor.redistribute(device_mesh, replica_placement)

AttributeError: 'DeviceMesh' object has no attribute '_coordinate_on_dim'

### Attempted workaround

Sort of... New error.

In [37]:
# See: https://github.com/pytorch/xla/issues/8528
# Need to stub the method in the meantime.
# Not sure what to stub it to though.
_rank = 0
device_mesh._coordinate_on_dim = [_rank]  

# create a DistributedTensor that shards on dim 0, from a local torch.Tensor
local_tensor = torch.randn((8, 8), requires_grad=True)
rowwise_tensor = DTensor.from_local(local_tensor, device_mesh, rowwise_placement)

# reshard the current row-wise tensor to a colwise tensor or replicate tensor
colwise_tensor = rowwise_tensor.redistribute(device_mesh, colwise_placement)
replica_tensor = colwise_tensor.redistribute(device_mesh, replica_placement)

AttributeError: 'DeviceMesh' object has no attribute '_dim_group_infos'

# DummyMLP Example from Torch Native

In [24]:
import torch.nn.functional as F

class DummyMLP(torch.nn.Module):
  def __init__(self, device):
    super().__init__()
    self.net1 = torch.nn.Linear(5, 1024, device=device)
    self.relu = torch.nn.ReLU()
    self.net2 = torch.nn.Linear(1024, 4, device=device)

  def forward(self, x):
    return self.net2(F.relu(self.net1(x)))

  def reset_parameters(self, *args, **kwargs):
    with torch.no_grad():
      self.net1.weight.fill_(0.5)
      self.net2.weight.fill_(1)
      self.net1.bias.fill_(1.5)
      self.net2.bias.fill_(1.2)

DummyMLP("xla")

DummyMLP(
  (net1): Linear(in_features=5, out_features=1024, bias=True)
  (relu): ReLU()
  (net2): Linear(in_features=1024, out_features=4, bias=True)
)

### UX ISSUE: Can't distribute using `parallelize_module`

In [None]:
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    parallelize_module,
    RowwiseParallel,
)

device_type = "xla" # was "meta" in test

model = DummyMLP("xla")
device_mesh = init_device_mesh(device_type, (int(os.environ["WORLD_SIZE"]),))

# UX ISSUE: XLA and PyTorch handle "RANK" differently.
# See: https://github.com/pytorch/xla/issues/8528
# Need to stub the method in the meantime.
# Not sure what to stub it to though.
_rank = 0
device_mesh._coordinate_on_dim = [_rank]  

# UX ISSUE: We can't use upstream parallelization plans
parallelize_plan = {
    "net1": ColwiseParallel(),
    "net2": RowwiseParallel(),
}
model_tp = parallelize_module(model, device_mesh, parallelize_plan)
model_tp.to_empty(device=device_type)
model_tp.reset_parameters()
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)

torch.manual_seed(0)
inp = torch.randn(20, 5, device=device_type)
output = model_tp(inp)

ValueError: Default process group has not been initialized, please make sure to call init_process_group.

In [None]:
import torch.nn as nn
from torch.distributed.tensor import Shard, distribute_tensor, distribute_module, init_device_mesh

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(8, 8)
        self.fc2 = nn.Linear(8, 8)
        self.relu = nn.ReLU()

    def forward(self, input):
        return self.relu(self.fc1(input) + self.fc2(input))

mesh = init_device_mesh("xla", (4,))

def shard_params(mod_name, mod, mesh):
    col_linear_placement = [Shard(0)]
    # shard fc1 and fc2
    if isinstance(mod, nn.Linear):
        for name, param in mod.named_parameters():
            dist_param = nn.Parameter(
                distribute_tensor(param, mesh, col_linear_placement)
            )
            mod.register_parameter(name, dist_param)

sharded_module = distribute_module(MyModule(), mesh, partition_fn=shard_params)
print(sharded_module)

x = torch.rand((8, 8))
sharded_module(x)

MyModule(
  (fc1): Linear(in_features=8, out_features=8, bias=True)
  (fc2): Linear(in_features=8, out_features=8, bias=True)
  (relu): ReLU()
)




RuntimeError: !at::functionalization::impl::isFunctionalTensor(t) INTERNAL ASSERT FAILED at "/usr/local/google/home/gleasonk/Coding/pytorch/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":838, please report a bug to PyTorch. The composite op functionalization fallback expects its inputs all not to be functional tensors