# PyTorch DTensor API: distribute_tensor() and distribute_module()

PyTorch's **DTensor (Distributed Tensor)** API provides JAX-like explicit tensor sharding capabilities. The key functions `distribute_tensor()` and `distribute_module()` give you fine-grained control over how tensors and modules are distributed across devices.

## Key Concepts

**DTensor** represents a tensor that is distributed across multiple devices with explicit sharding specifications, similar to JAX's sharded arrays.


In [None]:
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.tensor import (
    DeviceMesh,
    DTensor,
    Replicate,
    Shard,
    distribute_tensor,
    distribute_module,
)
from torch.distributed.tensor.placement_types import Placement
import os

# Setup for demonstration
print("PyTorch DTensor API Demo")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# For this demo, we'll use CPU to simulate multiple devices
WORLD_SIZE = 4  # Simulate 4 devices


## 1. torch.distribute_tensor()

`distribute_tensor()` converts a regular PyTorch tensor into a **DTensor** (Distributed Tensor) with explicit sharding across devices.

**Function Signature:**
```python
distribute_tensor(
    tensor: torch.Tensor,
    device_mesh: DeviceMesh,
    placements: List[Placement],
    src_data_rank: int = 0
) -> DTensor
```

**Parameters:**
- `tensor`: The tensor to distribute
- `device_mesh`: Defines the device topology (similar to JAX's Mesh)
- `placements`: How to shard/replicate each dimension (similar to JAX's PartitionSpec)
- `src_data_rank`: Which rank holds the original data
