Skip to content

Commit

Permalink
[distributed][Tensor Parallelism] Implement early transforms for colu…
Browse files Browse the repository at this point in the history
…mn-wise and row-wise linear and embedding (#410)
  • Loading branch information
crcrpar committed May 31, 2024
1 parent fe27109 commit 9107a3d
Show file tree
Hide file tree
Showing 18 changed files with 1,485 additions and 191 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/distributed/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ thunder.distributed
reset_skip_data_parallel_grad_sync
get_skip_data_parallel_grad_sync
skip_data_parallel_grad_sync
column_parallel
row_parallel
2 changes: 1 addition & 1 deletion notebooks/dev_tutorials/fsdp_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
" shard_param(param, global_rank, world_size, param_name)\n",
" # Mark the param to denote that it is sharded.\n",
" # This is required by the synchronization primitive we will use below.\n",
" param.ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED"
" param.distparallel_type = thunder.core.proxies.DistParallelType.FULLY_SHARDED"
]
},
{
Expand Down
15 changes: 13 additions & 2 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@
reset_tracectx,
)
from thunder.core.transform_common import dce, cse
from thunder.core.proxies import is_proxyable, proxy, Proxy, CollectionProxy, TensorProxy, DDPType, FutureTensorProxy
from thunder.core.proxies import (
is_proxyable,
proxy,
Proxy,
CollectionProxy,
TensorProxy,
DistParallelType,
FutureTensorProxy,
)
import thunder.core.prims as prims
import thunder.distributed as dist
import thunder.torch as ltorch
Expand Down Expand Up @@ -559,7 +567,10 @@ def _trace(
)

def ddp_sync(arg: Any | TensorProxy) -> Any | TensorProxy:
if isinstance(arg, TensorProxy) and arg.ddp_type in (DDPType.REPLICATED, DDPType.FULLY_SHARDED):
if isinstance(arg, TensorProxy) and arg.distparallel_type in (
DistParallelType.REPLICATED,
DistParallelType.FULLY_SHARDED,
):
return dist.prims.synchronize(arg, compile_data.process_group_for_ddp)
else:
return arg
Expand Down
7 changes: 5 additions & 2 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

import torch
from thunder.core.proxies import (
DDPType,
DistParallelType,
proxy,
Proxy,
NumberProxy,
Expand Down Expand Up @@ -575,7 +575,10 @@ def proxify(self, value: WrappedValue) -> Any:
# TensorProxy attributes should be considered derived quantities, so we flag TensorProxies here
value.provenance.ext_flag |= EXT_FLAG_IS_TENSOR_PROXY

if isinstance(p, TensorProxy) and p.ddp_type in (DDPType.REPLICATED, DDPType.FULLY_SHARDED):
if isinstance(p, TensorProxy) and p.distparallel_type in (
DistParallelType.REPLICATED,
DistParallelType.FULLY_SHARDED,
):
p_new = thunder.distributed.prims.synchronize(
p,
self._process_group_for_ddp,
Expand Down
51 changes: 33 additions & 18 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,10 +994,13 @@ def __repr__(self):
return f"[FloatProxy name={self.name}, value={self.value}]"


class DDPType(Enum):
class DistParallelType(Enum):
NONE = auto()
REPLICATED = auto()
FULLY_SHARDED = auto()
# Following two are for tensor parallelism
COLUMN_WISE = auto()
ROW_WISE = auto()


def _infer_tensor_properties(
Expand All @@ -1006,14 +1009,14 @@ def _infer_tensor_properties(
device: devices.Device | None = None,
dtype: dtypes.dtype | None = None,
requires_grad: bool | None = None,
ddp_type: DDPType | None = None,
distparallel_type: DistParallelType | None = None,
thunder_fsdp_padding_size: int | None = None,
):
_shape = None
_device = None
_dtype = None
_requires_grad: None | bool = None
_ddp_type = DDPType.NONE
_dist_parallel_type = DistParallelType.NONE
_thunder_fsdp_padding_size = None

if like is not None:
Expand All @@ -1022,7 +1025,7 @@ def _infer_tensor_properties(
_device = like.device
_dtype = like.true_dtype
_requires_grad = like.requires_grad
_ddp_type = getattr(like, "ddp_type", DDPType.NONE)
_dist_parallel_type = getattr(like, "distparallel_type", DistParallelType.NONE)

if shape is not None:
baseutils.check_valid_shape(shape)
Expand All @@ -1033,7 +1036,7 @@ def _infer_tensor_properties(
_dtype = dtypes.numbertype_to_dtype(_dtype) if dtypes.is_numbertype(_dtype) else _dtype
_requires_grad = requires_grad if requires_grad is not None else _requires_grad
_requires_grad = False if not dtypes.is_inexact_dtype(_dtype) else _requires_grad
_ddp_type = ddp_type if ddp_type is not None else _ddp_type
_dist_parallel_type = distparallel_type if distparallel_type is not None else _dist_parallel_type
_thunder_fsdp_padding_size = (
thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size
)
Expand All @@ -1057,11 +1060,11 @@ def _infer_tensor_properties(
baseutils.check_type(_device, devices.Device)
baseutils.check_type(_dtype, dtypes.dtype)
baseutils.check_type(_requires_grad, bool)
baseutils.check_type(_ddp_type, DDPType)
baseutils.check_type(_dist_parallel_type, DistParallelType)
if isinstance(_thunder_fsdp_padding_size, int):
baseutils.check(
_ddp_type == DDPType.FULLY_SHARDED,
lambda: f"{_ddp_type = } and {_thunder_fsdp_padding_size = } do not work",
_dist_parallel_type == DistParallelType.FULLY_SHARDED,
lambda: f"{_dist_parallel_type = } and {_thunder_fsdp_padding_size = } do not work",
)
baseutils.check(
_thunder_fsdp_padding_size > 0,
Expand All @@ -1073,7 +1076,17 @@ def _infer_tensor_properties(
_true_dtype = _dtype
_dtype = dtypes.to_strong_dtype(_dtype)

return _shape, _device, _dtype, _true_dtype, _numel, _ndim, _requires_grad, _ddp_type, _thunder_fsdp_padding_size
return (
_shape,
_device,
_dtype,
_true_dtype,
_numel,
_ndim,
_requires_grad,
_dist_parallel_type,
_thunder_fsdp_padding_size,
)


# NOTE A FutureTensorProxy is intentionally NOT a subclass of TensorProxy
Expand All @@ -1100,7 +1113,7 @@ def __init__(
self._numel,
self._ndim,
self._requires_grad,
_, # ddp_type
_, # distparallel_type
_, # thunder_fsdp_padding_size
) = _infer_tensor_properties(
like,
Expand Down Expand Up @@ -1167,7 +1180,7 @@ def __init__(
dtype: dtypes.dtype | None = None,
requires_grad: bool | None = None,
prefix: None | str = None,
ddp_type: DDPType | None = None,
distparallel_type: DistParallelType | None = None,
history: None | tuple = None,
thunder_fsdp_padding_size: int | None = None,
):
Expand All @@ -1181,9 +1194,11 @@ def __init__(
self._numel,
self._ndim,
self._requires_grad,
self._ddp_type,
self._distparallel_type,
self._thunder_fsdp_padding_size,
) = _infer_tensor_properties(like, shape, device, dtype, requires_grad, ddp_type, thunder_fsdp_padding_size)
) = _infer_tensor_properties(
like, shape, device, dtype, requires_grad, distparallel_type, thunder_fsdp_padding_size
)

# NOTE The following properties DO NOT depend on the language context or record
# themselves into the trace, so they can be used when working with tensor proxies
Expand Down Expand Up @@ -1213,8 +1228,8 @@ def requires_grad(self):
return self._requires_grad

@property
def ddp_type(self):
return self._ddp_type
def distparallel_type(self):
return self._distparallel_type

@property
def thunder_fsdp_padding_size(self):
Expand Down Expand Up @@ -1519,8 +1534,8 @@ def real(self):
def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy:
device = devices.device_from_string(str(t.device))
dtype = dtypes.to_dtype(t.dtype)
# See Note [DistributedDataParallel and ddp_type]
ddp_type = getattr(t, "ddp_type", None)
# See Note [DistributedDataParallel and distparallel_type]
distparallel_type = getattr(t, "distparallel_type", None)
_thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None)
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
return TensorProxy(
Expand All @@ -1529,7 +1544,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
device=device,
dtype=dtype,
requires_grad=t.requires_grad,
ddp_type=ddp_type,
distparallel_type=distparallel_type,
history=history,
thunder_fsdp_padding_size=_thunder_fsdp_padding_size,
)
Expand Down
69 changes: 40 additions & 29 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from torch.utils.weak import WeakTensorKeyDictionary

import thunder.core.utils as utils
from thunder.core.proxies import DDPType
from thunder.core.proxies import DistParallelType
from thunder.distributed.tensor_parallel import column_parallel
from thunder.distributed.tensor_parallel import row_parallel

if TYPE_CHECKING:
from torch.distributed import ProcessGroup
Expand All @@ -28,6 +30,8 @@
"fsdp",
"FSDPBucketingStrategy",
"FSDPType",
"column_parallel",
"row_parallel",
]


Expand Down Expand Up @@ -273,14 +277,14 @@ def main():
),
)

# Note [DistributedDataParallel and ddp_type]
# Note [DistributedDataParallel and distparallel_type]
# If model was wrapped with thunder.distributed.ddp it would have a
# .use_ddp attribute set to True and all parameters would be already
# broadcasted to all other processes. So that our tracing is aware of
# this we need to mark the ddp_type of model's parameters as
# thunder.proxies.DDPType.REPLICATED
# this we need to mark the distparallel_type of model's parameters as
# thunder.proxies.DistParallelType.REPLICATED
for p in model.parameters():
p.ddp_type = DDPType.REPLICATED
p.distparallel_type = DistParallelType.REPLICATED

if broadcast_from is None:
return model
Expand Down Expand Up @@ -555,14 +559,14 @@ def fsdp(
# Shard the parameters
_shard_params(model, process_group, device, broadcast_from, allow_padding_for_fsdp=True)

# See Note [DistributedDataParallel and ddp_type]
# See Note [DistributedDataParallel and distparallel_type]
# If model was wrapped with thunder.distributed.fsdp it would have a
# .use_fsdp attribute set to True and all parameters would be already
# sharded across all other processes. So that our tracing is aware of
# this we need to mark the ddp_type of model's parameters as
# thunder.proxies.DDPType.FULLY_SHARDED
# this we need to mark the distparallel_type of model's parameters as
# thunder.proxies.DistParallelType.FULLY_SHARDED
for p in model.parameters():
p.ddp_type = DDPType.FULLY_SHARDED
p.distparallel_type = DistParallelType.FULLY_SHARDED

return model

Expand Down Expand Up @@ -619,33 +623,40 @@ def _shard_param(
rank: int,
world_size: int,
name: str,
*,
allow_padding_for_fsdp: bool = False,
dim: int | None = None,
) -> None:

if not allow_padding_for_fsdp or (param.size(0) % world_size == 0):
if not allow_padding_for_fsdp:
utils.check(
param.shape[0] % world_size == 0,
lambda: (
f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})"
f" to be divisible by the world size ({world_size})"
),
)
chunk_size = param.shape[0] // world_size
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param.data = shard
else:
dim_to_shard = 0 if dim is None else dim
if allow_padding_for_fsdp:
utils.check(dim_to_shard == 0, lambda: f"Invalid {dim=} with {allow_padding_for_fsdp=}, Only 0 is supported")
padded_param_shape = list(param.shape)
orig_0dim_size = param.size(0)
orig_0dim_size = param.size(dim_to_shard)
chunk_size = (padded_param_shape[0] + world_size - 1) // world_size
padded_param_shape[0] = chunk_size * world_size
_thunder_fsdp_padding_size = padded_param_shape[0] - param.size(0)
padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype)
padded_param[:orig_0dim_size].copy_(param)
param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size
if _thunder_fsdp_padding_size > 0:
padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype)
padded_param[:orig_0dim_size].copy_(param)
param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size
else:
param.data = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = None
else:
utils.check(
param.shape[dim_to_shard] % world_size == 0,
lambda: (
f"Current sharding requires the sharded dimension of the parameter {name!r} ({param.shape[dim_to_shard]})"
f" to be divisible by the world size ({world_size})"
),
)
chunk_size = param.shape[dim_to_shard] // world_size
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(dim_to_shard, chunk_size * rank, chunk_size).clone()
param.data = shard


@torch.no_grad()
Expand Down
Loading

0 comments on commit 9107a3d

Please sign in to comment.