From 9107a3d8e9bb32345ddc5bf1df5883bc8e4ac458 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 1 Jun 2024 01:47:45 +0900 Subject: [PATCH] [distributed][Tensor Parallelism] Implement early transforms for column-wise and row-wise linear and embedding (#410) --- docs/source/reference/distributed/index.rst | 2 + notebooks/dev_tutorials/fsdp_tutorial.ipynb | 2 +- thunder/common.py | 15 +- thunder/core/jit_ext.py | 7 +- thunder/core/proxies.py | 51 ++-- thunder/distributed/__init__.py | 69 +++-- thunder/distributed/prims.py | 263 +++++++++++++++-- .../distributed/tensor_parallel/__init__.py | 10 + .../tensor_parallel/column_wise.py | 269 +++++++++++++++++ thunder/distributed/tensor_parallel/common.py | 221 ++++++++++++++ .../distributed/tensor_parallel/row_wise.py | 274 ++++++++++++++++++ thunder/distributed/transforms/fsdp_v2.py | 6 +- thunder/executors/torchex.py | 20 +- thunder/tests/distributed/helper.py | 115 ++++++++ thunder/tests/distributed/test_checkpoint.py | 2 +- thunder/tests/distributed/test_ddp.py | 136 ++------- .../tests/distributed/test_tensor_parallel.py | 208 +++++++++++++ thunder/torch/__init__.py | 6 +- 18 files changed, 1485 insertions(+), 191 deletions(-) create mode 100644 thunder/distributed/tensor_parallel/__init__.py create mode 100644 thunder/distributed/tensor_parallel/column_wise.py create mode 100644 thunder/distributed/tensor_parallel/common.py create mode 100644 thunder/distributed/tensor_parallel/row_wise.py create mode 100644 thunder/tests/distributed/helper.py create mode 100644 thunder/tests/distributed/test_tensor_parallel.py diff --git a/docs/source/reference/distributed/index.rst b/docs/source/reference/distributed/index.rst index 7dfd1d915..ea9c189a6 100644 --- a/docs/source/reference/distributed/index.rst +++ b/docs/source/reference/distributed/index.rst @@ -14,3 +14,5 @@ thunder.distributed reset_skip_data_parallel_grad_sync get_skip_data_parallel_grad_sync skip_data_parallel_grad_sync + column_parallel + row_parallel diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb index 41c4dee40..5f920c051 100644 --- a/notebooks/dev_tutorials/fsdp_tutorial.ipynb +++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb @@ -165,7 +165,7 @@ " shard_param(param, global_rank, world_size, param_name)\n", " # Mark the param to denote that it is sharded.\n", " # This is required by the synchronization primitive we will use below.\n", - " param.ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED" + " param.distparallel_type = thunder.core.proxies.DistParallelType.FULLY_SHARDED" ] }, { diff --git a/thunder/common.py b/thunder/common.py index 596741192..6d6aa02b2 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -33,7 +33,15 @@ reset_tracectx, ) from thunder.core.transform_common import dce, cse -from thunder.core.proxies import is_proxyable, proxy, Proxy, CollectionProxy, TensorProxy, DDPType, FutureTensorProxy +from thunder.core.proxies import ( + is_proxyable, + proxy, + Proxy, + CollectionProxy, + TensorProxy, + DistParallelType, + FutureTensorProxy, +) import thunder.core.prims as prims import thunder.distributed as dist import thunder.torch as ltorch @@ -559,7 +567,10 @@ def _trace( ) def ddp_sync(arg: Any | TensorProxy) -> Any | TensorProxy: - if isinstance(arg, TensorProxy) and arg.ddp_type in (DDPType.REPLICATED, DDPType.FULLY_SHARDED): + if isinstance(arg, TensorProxy) and arg.distparallel_type in ( + DistParallelType.REPLICATED, + DistParallelType.FULLY_SHARDED, + ): return dist.prims.synchronize(arg, compile_data.process_group_for_ddp) else: return arg diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index f805f74b2..3841c9dce 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -54,7 +54,7 @@ import torch from thunder.core.proxies import ( - DDPType, + DistParallelType, proxy, Proxy, NumberProxy, @@ -575,7 +575,10 @@ def proxify(self, value: WrappedValue) -> Any: # TensorProxy attributes should be considered derived quantities, so we flag TensorProxies here value.provenance.ext_flag |= EXT_FLAG_IS_TENSOR_PROXY - if isinstance(p, TensorProxy) and p.ddp_type in (DDPType.REPLICATED, DDPType.FULLY_SHARDED): + if isinstance(p, TensorProxy) and p.distparallel_type in ( + DistParallelType.REPLICATED, + DistParallelType.FULLY_SHARDED, + ): p_new = thunder.distributed.prims.synchronize( p, self._process_group_for_ddp, diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 64180a7c2..cf0da4f84 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -994,10 +994,13 @@ def __repr__(self): return f"[FloatProxy name={self.name}, value={self.value}]" -class DDPType(Enum): +class DistParallelType(Enum): NONE = auto() REPLICATED = auto() FULLY_SHARDED = auto() + # Following two are for tensor parallelism + COLUMN_WISE = auto() + ROW_WISE = auto() def _infer_tensor_properties( @@ -1006,14 +1009,14 @@ def _infer_tensor_properties( device: devices.Device | None = None, dtype: dtypes.dtype | None = None, requires_grad: bool | None = None, - ddp_type: DDPType | None = None, + distparallel_type: DistParallelType | None = None, thunder_fsdp_padding_size: int | None = None, ): _shape = None _device = None _dtype = None _requires_grad: None | bool = None - _ddp_type = DDPType.NONE + _dist_parallel_type = DistParallelType.NONE _thunder_fsdp_padding_size = None if like is not None: @@ -1022,7 +1025,7 @@ def _infer_tensor_properties( _device = like.device _dtype = like.true_dtype _requires_grad = like.requires_grad - _ddp_type = getattr(like, "ddp_type", DDPType.NONE) + _dist_parallel_type = getattr(like, "distparallel_type", DistParallelType.NONE) if shape is not None: baseutils.check_valid_shape(shape) @@ -1033,7 +1036,7 @@ def _infer_tensor_properties( _dtype = dtypes.numbertype_to_dtype(_dtype) if dtypes.is_numbertype(_dtype) else _dtype _requires_grad = requires_grad if requires_grad is not None else _requires_grad _requires_grad = False if not dtypes.is_inexact_dtype(_dtype) else _requires_grad - _ddp_type = ddp_type if ddp_type is not None else _ddp_type + _dist_parallel_type = distparallel_type if distparallel_type is not None else _dist_parallel_type _thunder_fsdp_padding_size = ( thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size ) @@ -1057,11 +1060,11 @@ def _infer_tensor_properties( baseutils.check_type(_device, devices.Device) baseutils.check_type(_dtype, dtypes.dtype) baseutils.check_type(_requires_grad, bool) - baseutils.check_type(_ddp_type, DDPType) + baseutils.check_type(_dist_parallel_type, DistParallelType) if isinstance(_thunder_fsdp_padding_size, int): baseutils.check( - _ddp_type == DDPType.FULLY_SHARDED, - lambda: f"{_ddp_type = } and {_thunder_fsdp_padding_size = } do not work", + _dist_parallel_type == DistParallelType.FULLY_SHARDED, + lambda: f"{_dist_parallel_type = } and {_thunder_fsdp_padding_size = } do not work", ) baseutils.check( _thunder_fsdp_padding_size > 0, @@ -1073,7 +1076,17 @@ def _infer_tensor_properties( _true_dtype = _dtype _dtype = dtypes.to_strong_dtype(_dtype) - return _shape, _device, _dtype, _true_dtype, _numel, _ndim, _requires_grad, _ddp_type, _thunder_fsdp_padding_size + return ( + _shape, + _device, + _dtype, + _true_dtype, + _numel, + _ndim, + _requires_grad, + _dist_parallel_type, + _thunder_fsdp_padding_size, + ) # NOTE A FutureTensorProxy is intentionally NOT a subclass of TensorProxy @@ -1100,7 +1113,7 @@ def __init__( self._numel, self._ndim, self._requires_grad, - _, # ddp_type + _, # distparallel_type _, # thunder_fsdp_padding_size ) = _infer_tensor_properties( like, @@ -1167,7 +1180,7 @@ def __init__( dtype: dtypes.dtype | None = None, requires_grad: bool | None = None, prefix: None | str = None, - ddp_type: DDPType | None = None, + distparallel_type: DistParallelType | None = None, history: None | tuple = None, thunder_fsdp_padding_size: int | None = None, ): @@ -1181,9 +1194,11 @@ def __init__( self._numel, self._ndim, self._requires_grad, - self._ddp_type, + self._distparallel_type, self._thunder_fsdp_padding_size, - ) = _infer_tensor_properties(like, shape, device, dtype, requires_grad, ddp_type, thunder_fsdp_padding_size) + ) = _infer_tensor_properties( + like, shape, device, dtype, requires_grad, distparallel_type, thunder_fsdp_padding_size + ) # NOTE The following properties DO NOT depend on the language context or record # themselves into the trace, so they can be used when working with tensor proxies @@ -1213,8 +1228,8 @@ def requires_grad(self): return self._requires_grad @property - def ddp_type(self): - return self._ddp_type + def distparallel_type(self): + return self._distparallel_type @property def thunder_fsdp_padding_size(self): @@ -1519,8 +1534,8 @@ def real(self): def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy: device = devices.device_from_string(str(t.device)) dtype = dtypes.to_dtype(t.dtype) - # See Note [DistributedDataParallel and ddp_type] - ddp_type = getattr(t, "ddp_type", None) + # See Note [DistributedDataParallel and distparallel_type] + distparallel_type = getattr(t, "distparallel_type", None) _thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None) # NOTE Without tuple(t.shape) then the shape would be a torch.Size object return TensorProxy( @@ -1529,7 +1544,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = device=device, dtype=dtype, requires_grad=t.requires_grad, - ddp_type=ddp_type, + distparallel_type=distparallel_type, history=history, thunder_fsdp_padding_size=_thunder_fsdp_padding_size, ) diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index 91fa27dc6..1ca79ef62 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -16,7 +16,9 @@ from torch.utils.weak import WeakTensorKeyDictionary import thunder.core.utils as utils -from thunder.core.proxies import DDPType +from thunder.core.proxies import DistParallelType +from thunder.distributed.tensor_parallel import column_parallel +from thunder.distributed.tensor_parallel import row_parallel if TYPE_CHECKING: from torch.distributed import ProcessGroup @@ -28,6 +30,8 @@ "fsdp", "FSDPBucketingStrategy", "FSDPType", + "column_parallel", + "row_parallel", ] @@ -273,14 +277,14 @@ def main(): ), ) - # Note [DistributedDataParallel and ddp_type] + # Note [DistributedDataParallel and distparallel_type] # If model was wrapped with thunder.distributed.ddp it would have a # .use_ddp attribute set to True and all parameters would be already # broadcasted to all other processes. So that our tracing is aware of - # this we need to mark the ddp_type of model's parameters as - # thunder.proxies.DDPType.REPLICATED + # this we need to mark the distparallel_type of model's parameters as + # thunder.proxies.DistParallelType.REPLICATED for p in model.parameters(): - p.ddp_type = DDPType.REPLICATED + p.distparallel_type = DistParallelType.REPLICATED if broadcast_from is None: return model @@ -555,14 +559,14 @@ def fsdp( # Shard the parameters _shard_params(model, process_group, device, broadcast_from, allow_padding_for_fsdp=True) - # See Note [DistributedDataParallel and ddp_type] + # See Note [DistributedDataParallel and distparallel_type] # If model was wrapped with thunder.distributed.fsdp it would have a # .use_fsdp attribute set to True and all parameters would be already # sharded across all other processes. So that our tracing is aware of - # this we need to mark the ddp_type of model's parameters as - # thunder.proxies.DDPType.FULLY_SHARDED + # this we need to mark the distparallel_type of model's parameters as + # thunder.proxies.DistParallelType.FULLY_SHARDED for p in model.parameters(): - p.ddp_type = DDPType.FULLY_SHARDED + p.distparallel_type = DistParallelType.FULLY_SHARDED return model @@ -619,33 +623,40 @@ def _shard_param( rank: int, world_size: int, name: str, + *, allow_padding_for_fsdp: bool = False, + dim: int | None = None, ) -> None: - if not allow_padding_for_fsdp or (param.size(0) % world_size == 0): - if not allow_padding_for_fsdp: - utils.check( - param.shape[0] % world_size == 0, - lambda: ( - f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})" - f" to be divisible by the world size ({world_size})" - ), - ) - chunk_size = param.shape[0] // world_size - # NOTE This could be a ShardTensor to indicate other parts of the code - # that it's sharded and should be treated differently - shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone() - param.data = shard - else: + dim_to_shard = 0 if dim is None else dim + if allow_padding_for_fsdp: + utils.check(dim_to_shard == 0, lambda: f"Invalid {dim=} with {allow_padding_for_fsdp=}, Only 0 is supported") padded_param_shape = list(param.shape) - orig_0dim_size = param.size(0) + orig_0dim_size = param.size(dim_to_shard) chunk_size = (padded_param_shape[0] + world_size - 1) // world_size padded_param_shape[0] = chunk_size * world_size _thunder_fsdp_padding_size = padded_param_shape[0] - param.size(0) - padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype) - padded_param[:orig_0dim_size].copy_(param) - param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone() - param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size + if _thunder_fsdp_padding_size > 0: + padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype) + padded_param[:orig_0dim_size].copy_(param) + param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone() + param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size + else: + param.data = param.data.narrow(0, chunk_size * rank, chunk_size).clone() + param._thunder_fsdp_padding_size = None + else: + utils.check( + param.shape[dim_to_shard] % world_size == 0, + lambda: ( + f"Current sharding requires the sharded dimension of the parameter {name!r} ({param.shape[dim_to_shard]})" + f" to be divisible by the world size ({world_size})" + ), + ) + chunk_size = param.shape[dim_to_shard] // world_size + # NOTE This could be a ShardTensor to indicate other parts of the code + # that it's sharded and should be treated differently + shard = param.data.narrow(dim_to_shard, chunk_size * rank, chunk_size).clone() + param.data = shard @torch.no_grad() diff --git a/thunder/distributed/prims.py b/thunder/distributed/prims.py index 7d6cad511..cb93c36f9 100644 --- a/thunder/distributed/prims.py +++ b/thunder/distributed/prims.py @@ -2,18 +2,20 @@ from enum import auto, Enum from numbers import Number from typing import TYPE_CHECKING +from abc import ABC, abstractmethod import torch.distributed import thunder.core.utils as utils from thunder.core.prims import make_prim -from thunder.core.proxies import DDPType, FutureTensorProxy, pytype, TensorProxy +from thunder.core.proxies import DistParallelType, FutureTensorProxy, pytype, TensorProxy from thunder.core.transforms import register_augmented_forward, register_backward from thunder.distributed import get_skip_data_parallel_grad_sync if TYPE_CHECKING: from thunder.common import CompileData + from thunder.distributed.tensor_parallel.common import TensorParallelLayerType class PrimIDs(Enum): @@ -31,6 +33,9 @@ class PrimIDs(Enum): PACK_FOR_FSDP = auto() STASH_GRAD_FOR_FSDP = auto() + SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT = auto() + SYNCHRONIZE_TENSOR_PARALLEL_INPUT = auto() + # This enum describes what all_reduce (below) will actually do # These operations are performed elementwise on all the "versions" of @@ -59,15 +64,20 @@ def all_gather_meta( /, group: torch.distributed.ProcessGroup, do_async: Number, + dim: int | None = None, ) -> TensorProxy: check_if_distributed_available() utils.check_type(a, TensorProxy) utils.check_type(group, torch.distributed.ProcessGroup) utils.check(pytype(do_async) is bool, lambda: f"Expected {do_async=} to be a boolean value") - # PyTorch's all_gather_into_tensor supports also other modes of gathering - # but we only do concatenation on the first dimension for now - result_shape = a.shape[0] * group.size(), *a.shape[1:] + if dim is not None: + utils.check_type(dim, int) + utils.check(dim >= 0 and dim < a.ndim, lambda: f"dim must satisfy 0 <= {dim=} < {a.ndim=}") + result_shape = list(a.shape) + result_shape[dim] *= group.size() + else: + result_shape = a.shape[0] * group.size(), *a.shape[1:] if do_async: return FutureTensorProxy(shape=result_shape, like=a) @@ -93,6 +103,7 @@ def all_reduce_meta( utils.check_type(op, DistributedReduceOps) utils.check_type(group, torch.distributed.ProcessGroup) utils.check(pytype(do_async) is bool, lambda: f"Expected {do_async=} to be a boolean value") + utils.check_type(skip_clone, bool) if do_async: return FutureTensorProxy(like=a) @@ -121,6 +132,7 @@ def reduce_scatter( op: DistributedReduceOps, group: torch.distributed.ProcessGroup, do_async: Number, + dim: int | None = None, ) -> TensorProxy: check_if_distributed_available() utils.check_type(a, TensorProxy) @@ -128,11 +140,19 @@ def reduce_scatter( utils.check_type(group, torch.distributed.ProcessGroup) utils.check(pytype(do_async) is bool, lambda: f"Expected {do_async=} to be a boolean value") - utils.check(a.shape[0] % group.size() == 0, lambda: f"Expected {a.shape[0]=} to be divisible by {group.size()=}") - - # PyTorch's reduce_scatter_tensor supports also other modes of scattering - # but we only do splitting on the first dimension for now - result_shape = a.shape[0] // group.size(), *a.shape[1:] + result_shape = list(a.shape) + if dim is not None: + utils.check_type(dim, int) + utils.check(dim >= 0 and dim < a.ndim, lambda: f"dim must satisfy 0 <= {dim=} < {a.ndim=}") + utils.check( + a.shape[dim] % group.size() == 0, lambda: f"Expected {a.shape[dim]=} to be divisible by {group.size()=}" + ) + result_shape[dim] //= group.size() + else: + result_shape[0] //= group.size() + utils.check( + a.shape[0] % group.size() == 0, lambda: f"Expected {a.shape[0]=} to be divisible by {group.size()=}" + ) if do_async: return FutureTensorProxy(shape=result_shape, like=a) @@ -157,16 +177,16 @@ def synchronize_meta( utils.check_type(a, TensorProxy) utils.check_type(group, torch.distributed.ProcessGroup) - match a.ddp_type: - case DDPType.REPLICATED: + match a.distparallel_type: + case DistParallelType.REPLICATED: return TensorProxy(like=a) - case DDPType.FULLY_SHARDED: + case DistParallelType.FULLY_SHARDED: # Assuming that the sharding is done on the first dimension # See [FSDP Sharding] in distributed/__init__.py unsharded_shape = a.shape[0] * group.size(), *a.shape[1:] - return TensorProxy(shape=unsharded_shape, like=a, ddp_type=DDPType.REPLICATED) + return TensorProxy(shape=unsharded_shape, like=a, distparallel_type=DistParallelType.REPLICATED) case _: - utils.check(False, lambda: f"Proxy {a} has unexpected {a.ddp_type=}") + utils.check(False, lambda: f"Proxy {a} has unexpected {a.distparallel_type=}") def pack_meta( @@ -271,6 +291,53 @@ def stash_grad_for_fsdp_meta( return TensorProxy(like=grad) +# see [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)'s Code 1. +def synchronize_tensor_parallel_output_meta( + t: TensorProxy, + group: torch.distributed.ProcessGroup, + layer_type: TensorParallelLayerType, +) -> TensorProxy: + from thunder.distributed.tensor_parallel.common import TensorParallelLayerType + + utils.check_type(t, TensorProxy) + utils.check_type(group, torch.distributed.ProcessGroup) + utils.check_type(layer_type, TensorParallelLayerType) + + supported_ops = ( + TensorParallelLayerType.COLUMN_PARALLEL_EMBED, + TensorParallelLayerType.COLUMN_PARALLEL_LINEAR, + TensorParallelLayerType.ROW_PARALLEL_LINEAR, + TensorParallelLayerType.ROW_PARALLEL_EMBED, + ) + utils.check( + layer_type in supported_ops, + lambda: f"Unsupported {layer_type=}, supported ones are {supported_ops=}", + ) + return TensorProxy(like=t) + + +def synchronize_tensor_parallel_input_meta( + t: TensorProxy, + group: torch.distributed.ProcessGroup, + layer_type: TensorParallelLayerType, +) -> TensorProxy: + from thunder.distributed.tensor_parallel.common import TensorParallelLayerType + + utils.check_type(t, TensorProxy) + utils.check_type(group, torch.distributed.ProcessGroup) + utils.check_type(layer_type, TensorParallelLayerType) + + supported_ops = ( + TensorParallelLayerType.COLUMN_PARALLEL_LINEAR, + TensorParallelLayerType.ROW_PARALLEL_LINEAR, + ) + utils.check( + layer_type in supported_ops, + lambda: f"Unsupported {layer_type=}, supported ones are {supported_ops=}", + ) + return TensorProxy(like=t) + + all_gather = make_prim(PrimIDs.ALL_GATHER, "all_gather", meta=all_gather_meta) all_reduce = make_prim(PrimIDs.ALL_REDUCE, "all_reduce", meta=all_reduce_meta) broadcast = make_prim(PrimIDs.BROADCAST, "broadcast", meta=broadcast_meta) @@ -287,6 +354,16 @@ def stash_grad_for_fsdp_meta( "stash_grad_for_fsdp", meta=stash_grad_for_fsdp_meta, ) +synchronize_tensor_parallel_output = make_prim( + PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT, + "synchronize_tensor_parallel_output", + meta=synchronize_tensor_parallel_output_meta, +) +synchronize_tensor_parallel_input = make_prim( + PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_INPUT, + "synchronize_tensor_parallel_input", + meta=synchronize_tensor_parallel_input_meta, +) @register_augmented_forward(PrimIDs.SYNCHRONIZE) @@ -294,42 +371,178 @@ def synchronize_augmented_forward_rule( a: TensorProxy, group: torch.distributed.ProcessGroup, ) -> tuple[TensorProxy, tuple]: - match a.ddp_type: - case DDPType.REPLICATED: + match a.distparallel_type: + case DistParallelType.REPLICATED: # Assuming that the input is a replicated tensor, so no need to do anything # in the forward pass return a, ( - a.ddp_type, + a.distparallel_type, group, ) - case DDPType.FULLY_SHARDED: + case DistParallelType.FULLY_SHARDED: # Assuming that the sharding is done on the first dimension. # We do the communication on the side CUDA stream and wait is # immediately called on the result with the hope that the execution # passes would reorder the wait operation to be closer to the actual # usage of the tensor. return all_gather(a, group, True).wait(), ( - a.ddp_type, + a.distparallel_type, group, ) case _: - utils.check(False, lambda: f"Proxy {a} has unexpected {a.ddp_type=}") + utils.check(False, lambda: f"Proxy {a} has unexpected {a.distparallel_type=}") @register_backward(PrimIDs.SYNCHRONIZE) def synchronize_backward_rule( - ddp_type: DDPType, + distparallel_type: DistParallelType, group: torch.distributed.ProcessGroup, grad: TensorProxy, ) -> tuple[TensorProxy, None]: if get_skip_data_parallel_grad_sync() and ddp_type == DDPType.REPLICATED: return grad, None preaverage_grad = grad / group.size() - match ddp_type: - case DDPType.REPLICATED: + match distparallel_type: + case DistParallelType.REPLICATED: synced_grad = all_reduce(preaverage_grad, DistributedReduceOps.SUM, group, do_async=True).wait() - case DDPType.FULLY_SHARDED: + case DistParallelType.FULLY_SHARDED: synced_grad = reduce_scatter(preaverage_grad, DistributedReduceOps.SUM, group, do_async=True).wait() case _: - utils.check(False, lambda: f"synchronize with unexpected {ddp_type=}") + utils.check(False, lambda: f"synchronize with unexpected {distparallel_type=}") return synced_grad, None + + +class _TensorParallelOutputPostProcessFwdBwdInterface(ABC): + + @staticmethod + @abstractmethod + def forward(t: TensorProxy, group: torch.distributed.ProcessGroup) -> TensorProxy: ... + + @staticmethod + @abstractmethod + def backward(grad: TensorProxy, group: torch.distributed.ProcessGroup) -> TensorProxy: ... + + +class FwdGatherBwdSplitAlongLastDim(_TensorParallelOutputPostProcessFwdBwdInterface): + + @staticmethod + def forward(t: TensorProxy, group: torch.distributed.ProcessGroup) -> TensorProxy: + """Gather along last dim""" + import thunder.torch as ltorch + + all_gathered = all_gather(t, group, True, 0).wait() + chunked = ltorch.chunk(all_gathered, group.size(), 0) + gathered = ltorch.cat(chunked, dim=-1) + return gathered + + @staticmethod + def backward(grad: TensorProxy, group: torch.distributed.ProcessGroup) -> TensorProxy: + """Split along last dim""" + from torch.distributed import distributed_c10d as c10d + import thunder.torch as ltorch + + local_grad = ltorch.chunk(grad, c10d.get_world_size(group), dim=grad.ndim - 1)[c10d.get_rank(group)] + return local_grad + + +class FwdAllReduceBwdIdentity(_TensorParallelOutputPostProcessFwdBwdInterface): + + @staticmethod + def forward(t: TensorProxy, group: torch.distributed.ProcessGroup) -> TensorProxy: + return all_reduce(t, DistributedReduceOps.SUM, group, do_async=True, skip_clone=True).wait() + + @staticmethod + def backward(grad: TensorProxy, _: torch.distributed.ProcessGroup) -> TensorProxy: + return grad + + +@register_augmented_forward(PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT) +def synchronize_tensor_parallel_output_forward_rule( + t: TensorProxy, + group: torch.distributed.ProcessGroup, + layer_type: TensorParallelLayerType, +) -> tuple[TensorProxy, tuple[torch.distributed.ProcessGroup, TensorParallelLayerType]]: + from thunder.distributed.tensor_parallel.common import TensorParallelLayerType + import thunder.torch as ltorch + + match layer_type: + case TensorParallelLayerType.COLUMN_PARALLEL_LINEAR: + return FwdGatherBwdSplitAlongLastDim.forward(t, group), (group, layer_type) + case TensorParallelLayerType.ROW_PARALLEL_LINEAR: + return FwdAllReduceBwdIdentity.forward(t, group), (group, layer_type) + case TensorParallelLayerType.COLUMN_PARALLEL_EMBED: + return FwdAllReduceBwdIdentity.forward(t, group), (group, layer_type) + case TensorParallelLayerType.ROW_PARALLEL_EMBED: + return FwdGatherBwdSplitAlongLastDim.forward(t, group), (group, layer_type) + case _: + utils.check(False, lambda: f"Invalid {layer_type=}") + + +@register_backward(PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT) +def synchronize_tensor_parallel_output_backward_rule( + group: torch.distributed.ProcessGroup, + layer_type: TensorParallelLayerType, + grad: TensorProxy, +) -> tuple[TensorProxy, None, None]: + from thunder.distributed.tensor_parallel.common import TensorParallelLayerType + + match layer_type: + case TensorParallelLayerType.COLUMN_PARALLEL_LINEAR: + return FwdGatherBwdSplitAlongLastDim.backward(grad, group), None, None + case TensorParallelLayerType.ROW_PARALLEL_LINEAR: + return FwdAllReduceBwdIdentity.backward(grad, group), None, None + case TensorParallelLayerType.COLUMN_PARALLEL_EMBED: + return FwdAllReduceBwdIdentity.backward(grad, group), None, None + case TensorParallelLayerType.ROW_PARALLEL_EMBED: + return FwdGatherBwdSplitAlongLastDim.backward(grad, group), None, None + case _: + utils.check(False, lambda: f"Invalid {layer_type=}") + + +@register_augmented_forward(PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_INPUT) +def synchronize_tensor_parallel_input_forward_rule( + t: TensorProxy, + group: torch.distributed.ProcessGroup, + layer_type: TensorParallelLayerType, +) -> tuple[TensorProxy, tuple[torch.distributed.ProcessGroup, TensorParallelLayerType]]: + from thunder.distributed.tensor_parallel.common import TensorParallelLayerType + import thunder.torch as ltorch + + match layer_type: + case TensorParallelLayerType.COLUMN_PARALLEL_LINEAR: + return t, (group, layer_type) + case TensorParallelLayerType.ROW_PARALLEL_LINEAR: + from torch.distributed import distributed_c10d as c10d + from thunder import clang + + chunk_size = t.shape[t.ndim - 1] // group.size() + start_idx = chunk_size * c10d.get_rank(group) + return clang.slice_in_dim(t, start_idx, start_idx + chunk_size, dim=t.ndim - 1), (group, layer_type) + case _: + utils.check(False, lambda: f"Invalid {layer_type=}") + + +@register_backward(PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_INPUT) +def synchronize_tensor_parallel_input_backward_rule( + group: torch.distributed.ProcessGroup, + layer_type: TensorParallelLayerType, + grad: TensorProxy, +) -> tuple[TensorProxy, None, None]: + from thunder.distributed.tensor_parallel.common import TensorParallelLayerType + + match layer_type: + case TensorParallelLayerType.COLUMN_PARALLEL_LINEAR: + return ( + all_reduce(grad, DistributedReduceOps.SUM, group, do_async=True, skip_clone=True).wait(), + None, + None, + ) + case TensorParallelLayerType.ROW_PARALLEL_LINEAR: + import thunder.torch as ltorch + + all_gathered = all_gather(grad, group, True, 0).wait() + chunked = ltorch.chunk(all_gathered, group.size(), 0) + gathered_grad = ltorch.cat(chunked, dim=-1) + return gathered_grad, None, None + case _: + utils.check(False, lambda: f"Invalid {layer_type=}") diff --git a/thunder/distributed/tensor_parallel/__init__.py b/thunder/distributed/tensor_parallel/__init__.py new file mode 100644 index 000000000..abaeff587 --- /dev/null +++ b/thunder/distributed/tensor_parallel/__init__.py @@ -0,0 +1,10 @@ +from thunder.distributed.tensor_parallel.common import TensorParallelLayerType +from thunder.distributed.tensor_parallel.column_wise import column_parallel +from thunder.distributed.tensor_parallel.row_wise import row_parallel + + +__all__ = [ + "TensorParallelLayerType", + "column_parallel", + "row_parallel", +] diff --git a/thunder/distributed/tensor_parallel/column_wise.py b/thunder/distributed/tensor_parallel/column_wise.py new file mode 100644 index 000000000..0f01e4a93 --- /dev/null +++ b/thunder/distributed/tensor_parallel/column_wise.py @@ -0,0 +1,269 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import TYPE_CHECKING +from typing import ClassVar + +import torch.nn as nn +from torch.distributed import distributed_c10d + +from thunder.core import utils +from thunder.core.proxies import TensorProxy +from thunder.core.proxies import DistParallelType +from thunder.distributed.tensor_parallel.common import PrePostProcessInterface +from thunder.distributed.tensor_parallel.common import ComputationTraceTransformVisitorForTensorParallel +from thunder.distributed.tensor_parallel.common import TransformForTensorParallel +from thunder.distributed.tensor_parallel.common import TensorParallelLayerType + +if TYPE_CHECKING: + from typing import Any + from collections.abc import Callable + from collections.abc import Sequence + from torch.distributed import ProcessGroup + from thunder.core.trace import TraceCtx + from thunder.core.trace import TraceProvenance + from thunder.core.transforms import VISIT_TYPE + from thunder.core.symbol import BoundSymbol + from thunder.core.module import ThunderModule + + +__all__ = [ + "column_parallel", +] + + +@dataclass(frozen=True) +class ColumnParallelLinearPrePostProcess(PrePostProcessInterface): + process_group: ProcessGroup + + layer_type: ClassVar[TensorParallelLayerType] = TensorParallelLayerType.COLUMN_PARALLEL_LINEAR + + def preprocess(self, x: TensorProxy) -> tuple[TensorProxy, tuple[Any, ...]]: + from thunder.distributed import prims as dist_prims + + return ( + dist_prims.synchronize_tensor_parallel_input( + x, self.process_group, ColumnParallelLinearPrePostProcess.layer_type + ), + None, + ) + + def postprocess(self, y: TensorProxy, _: Any) -> TensorProxy: + from thunder.distributed import prims as dist_prims + + return dist_prims.synchronize_tensor_parallel_output( + y, + self.process_group, + ColumnParallelLinearPrePostProcess.layer_type, + ) + + +@dataclass +class ColumnParallelEmbeddingPrePostProcess(PrePostProcessInterface): + num_local_embeddings: int + process_group: ProcessGroup + + layer_type: ClassVar[TensorParallelLayerType] = TensorParallelLayerType.COLUMN_PARALLEL_EMBED + + def __post_init__(self) -> None: + from torch.distributed import distributed_c10d + + rank = distributed_c10d.get_rank(self.process_group) + + self.vocab_start_index: int = rank * self.num_local_embeddings + self.vocab_end_index: int = (rank + 1) * self.num_local_embeddings - 1 + + def preprocess(self, x: TensorProxy) -> tuple[TensorProxy, tuple[TensorProxy, TensorProxy]]: + import thunder.torch as ltorch + + x = ltorch.sub(x, self.vocab_start_index) + mask1 = ltorch.ge(x, self.num_local_embeddings) + masked1 = ltorch.masked_fill(x, mask1, 0) + mask2 = ltorch.le(x, -1) + masked2 = ltorch.masked_fill(masked1, mask2, 0) + return masked2, (mask1, mask2) + + def postprocess(self, y: TensorProxy, masks: Any) -> TensorProxy: + from thunder.distributed import prims as dist_prims + import thunder.torch as ltorch + + utils.check(len(masks) == 2, lambda: f"Expected 2 masks but {len(masks)=}") + for mask in masks: + utils.check( + mask.shape == y.shape[: mask.ndim], + lambda: f"{mask.shape = }, {y.shape = }", + ) + mask = ltorch.unsqueeze(mask, mask.ndim) + unflattened_mask = ltorch.repeat(mask, (1, 1, y.shape[-1])) + utils.check( + unflattened_mask.shape == y.shape, + lambda: f"{unflattened_mask.shape = }, {y.shape = }", + ) + y = ltorch.masked_fill(y, unflattened_mask, 0.0) + return dist_prims.synchronize_tensor_parallel_output( + y, + self.process_group, + ColumnParallelEmbeddingPrePostProcess.layer_type, + ) + + +@dataclass +class TransformForColumnWiseParallel(TransformForTensorParallel): + + @property + def distparallel_type(self) -> DistParallelType: + return DistParallelType.COLUMN_WISE + + def _calc_new_shape(self, orig_shape: list[int]) -> tuple[int, ...]: + new_shape = orig_shape[:] + new_shape[0] //= self.process_group.size() + return tuple(new_shape) + + def get_visitor_of_computation_trace_and_provenance( + self, + computation_trace: TraceCtx, + ) -> tuple[Callable[[BoundSymbol], VISIT_TYPE], TraceProvenance | str]: + from thunder.core.pytree import tree_flatten + + consumers = utils.consumers(computation_trace) + flat_args, _ = tree_flatten((computation_trace.args, computation_trace.kwargs)) + bsym_to_prepostprocess: dict[BoundSymbol, PrePostProcessInterface] = {} + for proxy in filter(lambda p: isinstance(p, TensorProxy), flat_args): + if (layer_type := self.chunked_param_name_to_layer_type.get(proxy.name, None)) is not None: + consumer_bsym = consumers[proxy][0] + if consumer_bsym not in bsym_to_prepostprocess: + match layer_type: + case nn.Linear: + bsym_to_prepostprocess[consumer_bsym] = ColumnParallelLinearPrePostProcess( + process_group=self.process_group + ) + case nn.Embedding: + bsym_to_prepostprocess[consumer_bsym] = ColumnParallelEmbeddingPrePostProcess( + num_local_embeddings=proxy.shape[0], process_group=self.process_group + ) + case _: + utils.check( + False, + lambda: f"{self.chunked_param_name_to_layer_type[proxy.name]=} is not supported", + ) + utils.check(bsym_to_prepostprocess, lambda: f"{bsym_to_prepostprocess} must not be empty") + + visit = ComputationTraceTransformVisitorForTensorParallel(bsym_to_prepostprocess) + return visit, "transform into column-wise tensor parallel" + + +# TODO(crcrpar): Add an option to turn off output all-gather. +def column_parallel( + thunder_module: ThunderModule, + target_modules: Sequence[str], + process_group: ProcessGroup | None = None, +) -> ThunderModule: + """Convert specified modules into column-wise parallel ones. + + This method has two effects: + 1. Chunks target modules' parameters in 0-th dimension. + 2. Insert preprocess and postprocess around modified module ops. + + Args: + thunder_module: + target_modules: Names of modules to convert into column-wise. + process_group: + + + Example: + .. code-block:: python + + import os + + import torch + import torch.nn + import torch.nn.functional as F + from torch.distributed import distributed_c10d + + import thunder + from thunder.distributed import column_parallel + + + class Model(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + n_hidden: int, + n_out: int, + ) -> None: + super().__init__() + self.embed = nn.Embedding(num_embeddings, embedding_dim) + self.l1 = nn.Linear(embedding_dim, n_hidden) + self.l2 = nn.Linear(n_hidden, n_out) + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + feature = self.embed(tokens) + h = F.gelu(self.l1(feature), approximate='tanh') + return self.l2(h) + + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + distributed_c10d.init_process_group() + model = Model().to(device) + jitted_model = thunder.jit(model) + # `l2`'s output size (= n_out) needs to be divisible by `world_size` + tp_model = column_parallel( + jitted_model, + target_modules=("embed", "l2",), + ) + + x = torch.randn(4, n_in, device=device) + out = tp_model(x) # shape: [4, n_out] + """ + from thunder import compile_data as get_compile_data + from thunder.distributed import _shard_param + from thunder.core.transforms import add_transform + from thunder.core.module import ThunderModule + + utils.check_type(thunder_module, ThunderModule) + + if process_group is None: + process_group = distributed_c10d._get_default_group() + rank = distributed_c10d.get_rank(process_group) + world_size = distributed_c10d.get_world_size(process_group) + + chunked_param_name_to_layer_type: dict[str, Any] = {} + for target_mod_name in target_modules: + mod = thunder_module.get_submodule(target_mod_name) + utils.check_type( + mod, + ( + nn.Linear, + nn.Embedding, + ), + ) + for name, p in mod.named_parameters(recurse=False): + chunked_param_name_to_layer_type["t_" + f"{target_mod_name}.{name}".replace(".", "_")] = type(mod) + + import copy + + # Modify module + for module_name, _ in thunder_module._model.named_modules(): + if module_name not in target_modules: + continue + submodule = thunder_module.get_submodule(module_name) + + for pn, p in submodule.named_parameters(recurse=False, prefix=module_name): + # if we don't have an override or it is just the original, do create a copy + if thunder_module._overrides_parameters.get(pn, p) is p: + thunder_module._overrides_parameters[pn] = copy.copy(p) + _shard_param(thunder_module._overrides_parameters[pn], rank, world_size, pn, allow_padding_for_fsdp=False) + + colwise_thunder_module = add_transform( + thunder_module, + early_transform=TransformForColumnWiseParallel( + rank=rank, + world_size=world_size, + compile_data=get_compile_data(thunder_module), + chunked_param_name_to_layer_type=chunked_param_name_to_layer_type, + process_group=process_group, + ), + ) + + return colwise_thunder_module diff --git a/thunder/distributed/tensor_parallel/common.py b/thunder/distributed/tensor_parallel/common.py new file mode 100644 index 000000000..590f72219 --- /dev/null +++ b/thunder/distributed/tensor_parallel/common.py @@ -0,0 +1,221 @@ +from __future__ import annotations +from abc import ABC +from abc import abstractmethod +from enum import Enum +from enum import auto +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING + +from thunder.core.proxies import DistParallelType + +if TYPE_CHECKING: + from typing import Any + from collections.abc import Callable + from torch.distributed import ProcessGroup + from thunder.common import CompileData + from thunder.core.proxies import ProxyInterface + from thunder.core.proxies import TensorProxy + from thunder.core.symbol import BoundSymbol + from thunder.core.trace import TraceCtx + from thunder.core.trace import TraceProvenance + from thunder.core.trace import VariableInterface + from thunder.core.transforms import VISIT_TYPE + + +__all__ = [ + "ComputationTraceTransformVisitorForTensorParallel", + "TensorParallelLayerType", + "NoOp", + "PrePostProcessInterface", + "TransformForTensorParallel", +] + + +class TensorParallelLayerType(Enum): + COLUMN_PARALLEL_LINEAR = auto() + ROW_PARALLEL_LINEAR = auto() + + COLUMN_PARALLEL_EMBED = auto() + ROW_PARALLEL_EMBED = auto() + + +class PrePostProcessInterface(ABC): + """Defining interfaces of pre/post-process of tensor parallelized ops.""" + + @abstractmethod + def preprocess(self, x: TensorProxy) -> tuple[TensorProxy, tuple[Any, ...]]: + """Apply preprocessing to tensor parallel op's inputs. + + The second return value could be consumed by :func:`PrePostProcessInterface.postprocess`. + """ + return x, (None,) + + @abstractmethod + def postprocess(self, y: TensorProxy, _: Any) -> TensorProxy: + """Apply postprocessing to tensor parallel op's outputs.""" + return y + + def maybe_modify_args_and_kwargs(self, bsym: BoundSymbol) -> BoundSymbol: + """No-op. Mainly for row-wise parallel linear.""" + return bsym + + +@dataclass(frozen=True) +class NoOp(PrePostProcessInterface): + def preprocess(self, x: TensorProxy) -> tuple[TensorProxy, tuple[Any, ...]]: + return super().preprocess(x) + + def postprocess(self, y: TensorProxy, _: Any) -> TensorProxy: + return super().postprocess(y) + + +@dataclass +class ComputationTraceTransformVisitorForTensorParallel: + """Wrap tensor parallel ops with necessary preprocessing and postprocessing. + + With the reference of ``bsyms_before_allgather``, this takes care of inputs and outputs of + tensor parallel ops by applying defined processings. Each pair of them is supposed to be defined + based on :clss:`PrePostProcessInterface`. + + Args: + bsym_to_prepostprocess: + + Attributes: + swap_map: A map from the original output of a tensor-parallel opt to the post-processed output. + """ + + bsym_to_prepostprocess: dict[BoundSymbol, PrePostProcessInterface] + swap_map: dict[VariableInterface, ProxyInterface] = field(init=False, default_factory=dict) + + def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE: + from thunder.core.transforms import VISIT_TYPE + from thunder.core.trace import get_tracectx + from thunder.core.proxies import variableify + + input_swap_map: dict[VariableInterface, ProxyInterface] = {} + pre_post_process: PrePostProcessInterface | None = None + if bsym in self.bsym_to_prepostprocess: + pre_post_process = self.bsym_to_prepostprocess[bsym] + orig_arg = bsym.flat_proxy_args[0] + new_arg, preprocess_artifacts = pre_post_process.preprocess(orig_arg) + if new_arg.name != orig_arg.name: + input_swap_map[variableify(orig_arg)] = new_arg + + new_bsym = bsym.from_bsym_swap_proxies(self.swap_map, skip_output=True) + if pre_post_process is not None: + new_bsym = new_bsym.from_bsym_swap_proxies(input_swap_map) + new_bsym = pre_post_process.maybe_modify_args_and_kwargs(new_bsym) + trace = get_tracectx() + trace.scopes[-1].append(new_bsym) + + if pre_post_process is not None: + y = bsym.flat_proxy_outs[0] + processed_y = pre_post_process.postprocess(y, preprocess_artifacts) + self.swap_map[variableify(y)] = processed_y + + return VISIT_TYPE.REPLACE + + +@dataclass +class TransformForTensorParallel: + rank: int + world_size: int + compile_data: CompileData + chunked_param_name_to_layer_type: dict[str, Any] + process_group: ProcessGroup + + def __post_init__(self): + from thunder.common import CompileData + from thunder.core import utils + + utils.check_type(self.compile_data, CompileData) + if getattr(self.compile_data, "use_fsdp", False) or getattr(self.compile_data.fn, "use_fsdp", False): + raise NotImplementedError("Currently thunder does not support the combination of fsdp and tensor parallel") + + @abstractmethod + def get_visitor_of_computation_trace_and_provenance( + self, + computation_trace: TraceCtx, + ) -> tuple[Callable[[BoundSymbol], VISIT_TYPE], TraceProvenance | str]: ... + + @abstractmethod + def _calc_new_shape(self, orig_shape) -> tuple[int, ...]: ... + + @property + def distparallel_type(self) -> DistParallelType: ... + + def __call__( + self, + prologue_trace: TraceCtx, + computation_trace: TraceCtx, + epilogue_trace: TraceCtx, + **kwargs, + ) -> tuple[TraceCtx, TraceCtx, TraceCtx]: + from thunder.core import prims + from thunder.core import utils + from thunder.core.transforms import visitor_transform + + modules_and_thunder_modules = [ + (bsym.args[0], bsym.output) + for bsym in prologue_trace.bound_symbols + if bsym.sym is prims.unpack_thunder_module + ] + ((_, thunder_module_proxy),) = modules_and_thunder_modules + + prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace) + pro_out_p: TensorProxy + comp_inp_p: TensorProxy + for pro_out_p, comp_inp_p in zip(prologue_trace.output, computation_trace.args): + if pro_out_p.name not in self.chunked_param_name_to_layer_type: + continue + bsym = prologue_producers[pro_out_p] + if bsym.sym.id == prims.PrimIDs.UNPACK_PARAMETER: + param_thunder_module, param_name = bsym.args + assert param_thunder_module is thunder_module_proxy + + if ( + proxy_like_param_name := f"""t_{param_name.replace(".", "_")}""" + ) in self.chunked_param_name_to_layer_type: + + orig_shape = list(pro_out_p._shape) + new_shape = self._calc_new_shape(orig_shape) + pro_out_p._shape = new_shape + utils.check( + comp_inp_p.distparallel_type in (self.distparallel_type, DistParallelType.NONE), + lambda: f"{comp_inp_p.distparallel_type = } is not compatible with {self.distparallel_type=}", + ) + pro_out_p._distparallel_type = self.distparallel_type + if comp_inp_p is not pro_out_p: + comp_inp_p._shape = new_shape + comp_inp_p._distparallel_type = self.distparallel_type + + for c in prologue_consumers[pro_out_p]: + if c.sym is prims.check_tensor_shape_and_metadata: + # TODO have a more principled way to update this? + a0, _, _, *a2pp = c.args + c.args = (a0, tuple(new_shape), str(a0.device), *a2pp) + + for bsym in prologue_trace.bound_symbols: + if bsym.sym is prims.check_tensor_shape_and_metadata and prologue_producers[bsym.args[0]].sym in ( + prims.unpack_parameter, + prims.unpack_buffer, + ): + param_thunder_module, name = prologue_producers[bsym.args[0]].args + assert param_thunder_module is thunder_module_proxy + if name not in self.chunked_param_name_to_layer_type: + a0, shape, _, *a2pp = bsym.args + bsym.args = (a0, shape, str(a0.device), *a2pp) + + if len(modules_and_thunder_modules) != 1: + raise NotImplementedError("cannot deal with modules other than the compiled module") + + visit, provenance = self.get_visitor_of_computation_trace_and_provenance( + computation_trace=computation_trace, + ) + new_computation_trace = visitor_transform( + computation_trace, + visit=visit, + provenance=provenance, + ) + return prologue_trace, new_computation_trace, epilogue_trace diff --git a/thunder/distributed/tensor_parallel/row_wise.py b/thunder/distributed/tensor_parallel/row_wise.py new file mode 100644 index 000000000..0cf6cc44c --- /dev/null +++ b/thunder/distributed/tensor_parallel/row_wise.py @@ -0,0 +1,274 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import TYPE_CHECKING +from typing import ClassVar + +import torch.nn as nn +from torch.distributed import distributed_c10d + +from thunder.core import utils +from thunder.core.proxies import DistParallelType +from thunder.core.proxies import TensorProxy +from thunder.core.proxies import variableify +from thunder.distributed.tensor_parallel.common import PrePostProcessInterface +from thunder.distributed.tensor_parallel.common import ComputationTraceTransformVisitorForTensorParallel +from thunder.distributed.tensor_parallel.common import TransformForTensorParallel +from thunder.distributed.tensor_parallel.common import TensorParallelLayerType + +if TYPE_CHECKING: + from typing import Any + from collections.abc import Sequence + from collections.abc import Callable + from torch.distributed import ProcessGroup + from thunder.core.module import ThunderModule + from thunder.core.symbol import BoundSymbol + from thunder.core.trace import TraceCtx + from thunder.core.trace import TraceProvenance + from thunder.core.transforms import VISIT_TYPE + + +__all__ = [ + "row_parallel", +] + + +@dataclass(frozen=True) +class RowParallelLinearPrePostProcess(PrePostProcessInterface): + process_group: ProcessGroup + bias_or_none: TensorProxy | None + + layer_type: ClassVar[TensorParallelLayerType] = TensorParallelLayerType.ROW_PARALLEL_LINEAR + + def preprocess(self, x: TensorProxy) -> tuple[TensorProxy, tuple[Any, ...]]: + from thunder.distributed import prims as dist_prims + + # split `x` in the last dim. + return ( + dist_prims.synchronize_tensor_parallel_input( + x, self.process_group, RowParallelLinearPrePostProcess.layer_type + ), + None, + ) + + def postprocess(self, y: TensorProxy, _: Any) -> TensorProxy: + # gather `y` along the last dimension + import thunder.torch as ltorch + from thunder.distributed import prims as dist_prims + + all_reduced = dist_prims.synchronize_tensor_parallel_output( + y, + self.process_group, + RowParallelLinearPrePostProcess.layer_type, + ) + if (bias := self.bias_or_none) is not None: + return ltorch.add(all_reduced, bias) + else: + return all_reduced + + def maybe_modify_args_and_kwargs(self, bsym: BoundSymbol) -> BoundSymbol: + """Replace `bias` of `bsym` with `None` if it's Tensor to avoid redundant accumulation. + + The removed `bias` is added by `postprocess` after all-reduce. + The local row-wise parallel linear operation with bias could be + y_ = linear(x_, weight_, bias) + which leads to a wrong result after all_reduce as `bias` is added times. + y = all_reduce(y_) + """ + if self.bias_or_none is not None: + return bsym.from_bsym_swap_proxies({variableify(self.bias_or_none): None}, skip_output=True) + else: + return super().maybe_modify_args_and_kwargs(bsym) + + +@dataclass +class RowParallelEmbeddingPreProcess(PrePostProcessInterface): + process_group: ProcessGroup + + layer_type: ClassVar[TensorParallelLayerType] = TensorParallelLayerType.ROW_PARALLEL_EMBED + + def preprocess(self, x: TensorProxy) -> tuple[TensorProxy, tuple[Any, ...]]: + return super().preprocess(x) + + def postprocess(self, y: TensorProxy, _: Any) -> TensorProxy: + from thunder.distributed import prims as dist_prims + + return dist_prims.synchronize_tensor_parallel_output( + y, + self.process_group, + RowParallelEmbeddingPreProcess.layer_type, + ) + + +@dataclass +class TransformForRowWiseParallel(TransformForTensorParallel): + + @property + def distparallel_type(self) -> DistParallelType: + return DistParallelType.ROW_WISE + + def _calc_new_shape(self, orig_shape: list[int]) -> tuple[int, ...]: + new_shape = orig_shape[:] + new_shape[1] //= self.process_group.size() + return tuple(new_shape) + + def get_visitor_of_computation_trace_and_provenance( + self, + computation_trace: TraceCtx, + ) -> tuple[Callable[[BoundSymbol], VISIT_TYPE], TraceProvenance | str]: + from thunder.core.pytree import tree_flatten + + consumers = utils.consumers(computation_trace) + flat_args, _ = tree_flatten((computation_trace.args, computation_trace.kwargs)) + bsym_to_prepostprocess: dict[BoundSymbol, PrePostProcessInterface] = {} + for proxy in filter(lambda p: isinstance(p, TensorProxy), flat_args): + if (layer_type := self.chunked_param_name_to_layer_type.get(proxy.name, None)) is not None: + consumer_bsym = consumers[proxy][0] + if consumer_bsym not in bsym_to_prepostprocess: + match layer_type: + case nn.Linear: + orig_args = consumer_bsym.args + utils.check( + len(orig_args) == 3, + lambda: f"{consumer_bsym.sym.id} expected to have 3 args but {orig_args}", + ) + bias_or_none = orig_args[2] + utils.check( + isinstance(bias_or_none, TensorProxy) or bias_or_none is None, + lambda: f"{orig_args[-1]} expected to be either `None` or `TensorProxy`", + ) + bsym_to_prepostprocess[consumer_bsym] = RowParallelLinearPrePostProcess( + process_group=self.process_group, + bias_or_none=bias_or_none, + ) + case nn.Embedding: + bsym_to_prepostprocess[consumer_bsym] = RowParallelEmbeddingPreProcess( + process_group=self.process_group + ) + case _: + utils.check( + False, + lambda: f"{self.chunked_param_name_to_layer_type[proxy.name]=} is not supported", + ) + utils.check(bsym_to_prepostprocess, lambda: f"{bsym_to_prepostprocess} must not be empty") + + visit = ComputationTraceTransformVisitorForTensorParallel(bsym_to_prepostprocess) + return visit, "transform into row-wise tensor parallel" + + +def row_parallel( + thunder_module: ThunderModule, + target_modules: Sequence[str], + process_group: ProcessGroup | None = None, +) -> ThunderModule: + """Convert specified modules into row-wise parallel ones. + + This method has two effects: + 1. Chunks target modules' parameters in 1st dimension. + 2. Insert preprocess and postprocess around modified module ops. + + Args: + thunder_module: + target_modules: Names of modules to convert into row-wise. + process_group: + + Example: + .. code-block:: python + + import os + + import torch + import torch.nn + import torch.nn.functional as F + from torch.distributed import distributed_c10d + + import thunder + from thunder.distributed import row_parallel + + + class Model(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + n_hidden: int, + n_out: int, + ) -> None: + super().__init__() + self.embed = nn.Embedding(num_embeddings, embedding_dim) + self.l1 = nn.Linear(embedding_dim, n_hidden) + self.l2 = nn.Linear(n_hidden, n_out) + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + feature = self.embed(tokens) + h = F.gelu(self.l1(feature), approximate='tanh') + return self.l2(h) + + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + distributed_c10d.init_process_group() + model = Model().to(device) + jitted_model = thunder.jit(model) + # ``embedding_dim`` and `l2`'s input size (= n_hidden) need to be divisible by `world_size` + tp_model = column_parallel( + jitted_model, + target_modules=("embed", "l2",), + ) + + x = torch.randn(4, n_in, device=device) + out = tp_model(x) # shape: [4, n_out] + """ + from thunder import compile_data as get_compile_data + from thunder.distributed import _shard_param + from thunder.core.transforms import add_transform + from thunder.core.module import ThunderModule + + utils.check_type(thunder_module, ThunderModule) + + if process_group is None: + process_group = distributed_c10d._get_default_group() + rank = distributed_c10d.get_rank(process_group) + world_size = distributed_c10d.get_world_size(process_group) + + chunked_param_name_to_layer_type: dict[str, Any] = {} + for target_mod_name in target_modules: + mod = thunder_module.get_submodule(target_mod_name) + utils.check_type( + mod, + (nn.Linear, nn.Embedding), + ) + for name, p in mod.named_parameters(recurse=False): + if p.ndim < 2: + continue + chunked_param_name_to_layer_type["t_" + f"{target_mod_name}.{name}".replace(".", "_")] = type(mod) + + import copy + + # Modify module + for module_name, _ in thunder_module._model.named_modules(): + if module_name not in target_modules: + continue + submodule = thunder_module.get_submodule(module_name) + + for pn, p in submodule.named_parameters(recurse=False, prefix=module_name): + # if we don't have an override or it is just the original, do create a copy + if thunder_module._overrides_parameters.get(pn, p) is p: + thunder_module._overrides_parameters[pn] = copy.copy(p) + if p.ndim < 2: + continue + _shard_param( + thunder_module._overrides_parameters[pn], rank, world_size, pn, dim=1, allow_padding_for_fsdp=False + ) + + rowwise_thunder_module = add_transform( + thunder_module, + early_transform=TransformForRowWiseParallel( + rank=rank, + world_size=world_size, + compile_data=get_compile_data(thunder_module), + chunked_param_name_to_layer_type=chunked_param_name_to_layer_type, + process_group=process_group, + ), + ) + + return rowwise_thunder_module diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index 40d7b5280..3c2a45f6a 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -7,7 +7,7 @@ from thunder.core import devices from thunder.core import prims from thunder.core import utils -from thunder.core.proxies import DDPType +from thunder.core.proxies import DistParallelType from thunder.core.trace import from_trace from thunder.core.trace import tracectx from thunder.core.trace import TraceProvenance @@ -59,11 +59,11 @@ def __call__(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): thunder_device = devices.to_device(new_torch_device) thunder_device_str = str(thunder_device) - pro_out_p._ddp_type = DDPType.FULLY_SHARDED + pro_out_p._distparallel_type = DistParallelType.FULLY_SHARDED pro_out_p._shape = tuple(new_shape) pro_out_p._device = thunder_device if comp_inp_p is not pro_out_p: - comp_inp_p._ddp_type = DDPType.FULLY_SHARDED + comp_inp_p._distparallel_type = DistParallelType.FULLY_SHARDED comp_inp_p._shape = tuple(new_shape) comp_inp_p._device = thunder_device with tracectx(computation_trace): diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index a6e8d21bd..e7f3f4c2b 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1806,8 +1806,16 @@ def _all_gather_prim_impl( /, group: torch.distributed.ProcessGroup, do_async: Number, + dim: int | None = None, ) -> torch.Tensor | tuple[torch.distributed.distributed_c10d.Work, torch.Tensor]: - out: torch.Tensor = torch.empty((group.size() * a.shape[0],) + a.shape[1:], dtype=a.dtype, device=a.device) + result_shape = list(a.shape) + if dim is not None: + utils.check_type(dim, int) + utils.check(dim >= 0 and dim < a.dim(), lambda: f"dim must satisfy 0 <= {dim=} < {a.dim()=}") + result_shape[dim] *= group.size() + else: + result_shape[0] *= group.size() + out: torch.Tensor = torch.empty(result_shape, dtype=a.dtype, device=a.device) do_async: bool = bool(do_async) handle: None | torch.distributed.distributed_c10d.Work = torch.distributed.all_gather_into_tensor( @@ -1860,8 +1868,16 @@ def _reduce_scatter_prim_impl( op: DistributedReduceOps, group: torch.distributed.ProcessGroup, do_async: Number, + dim: int | None, ) -> torch.Tensor | tuple[torch.distributed.distributed_c10d.Work, torch.Tensor]: - out = torch.empty((a.shape[0] // group.size(),) + a.shape[1:], dtype=a.dtype, device=a.device) + result_shape = list(a.shape) + if dim is not None: + utils.check_type(dim, int) + utils.check(dim >= 0 and dim < a.dim(), lambda: f"dim must satisfry 0 <= {dim=} < {a.dim()=}") + result_shape[dim] //= group.size() + else: + result_shape[0] //= group.size() + out = torch.empty(result_shape, dtype=a.dtype, device=a.device) op: torch.distributed.ReduceOp = ltorch.to_torch_distributed_reduce_op(op) do_async: bool = bool(do_async) diff --git a/thunder/tests/distributed/helper.py b/thunder/tests/distributed/helper.py new file mode 100644 index 000000000..3eaec5da9 --- /dev/null +++ b/thunder/tests/distributed/helper.py @@ -0,0 +1,115 @@ +import math +import os +import sys +from typing import ClassVar + +import torch +import torch.nn as nn + +try: + import expecttest + import hypothesis +except ImportError: + raise ImportError( + "Required packages of `expecttest` and/or `hypothesis` are missing. " + "Install them with `pip install expecttest hypothesis`" + ) +from torch.testing._internal import common_distributed, common_utils + + +__all__ = [ + "new_gelu", + "ToyModel", + "DataParallelTestCase", +] + + +def new_gelu(x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class ToyModel(nn.Module): + """Linear(12, 12) -> gelu -> Linear(12, 8).""" + + N_IN: ClassVar[int] = 12 + N_HIDDEN: ClassVar[int] = 16 + N_OUT: ClassVar[int] = 8 + LAYER_NAMES: ClassVar[tuple[str, ...]] = ("net2", "net1") + + def __init__(self, bias: bool = True): + super().__init__() + self.net1 = nn.Linear(ToyModel.N_IN, ToyModel.N_HIDDEN, bias=bias) + self.net2 = nn.Linear(ToyModel.N_HIDDEN, ToyModel.N_OUT, bias=bias) + + def forward(self, x): + return self.net2(new_gelu(self.net1(x))) + + +# note(crcrpar): How to write a test with `DDP` +# Just add a method to :class:`CompileDDPTest`. The class is responsible for +# - calling `torch.distributed.init_process_group` with NCCL backend +# - setting rank to each process group / device +# so what you'd need to do is to prepare a model and tensors, wrap the model with DDP, and +# `thunder.jit` the original model or the DDP'd model, and do some computation and/or +# examine the traces of the `thunder.jit`d. +# If you force a test to be run with >2 GPUs for a test, you might want to inherit `CompileDDPTest` +# and modify `world_size` to e.g. `max(torch.cuda.device_count(), 2)`. +# note(crcrpar): Why inheriting `common_distributed.MultiProcessTestCase`? +# When we're quite sure that we would only use `pytest` instead of `unittest`, +# IIUC it's possible to run a test that is dependent on `DistributedDataParallel` and/or +# `torch.distributed` by running the test file with [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html), +# but I don't think (a) it's quite intuitive to require `torchrun` explicitly to run a test and +# (b) it's quite friendly to our CI as it's currently simply runs `pytest thunder/tests`. +# I would say it's feasible to write a test with `torch.distributed` by using `torch.multiprocessing`, +# but it would require us to make the function which defines the test logic picklable and would +# lead to boilerplate test functions. +# Ref: https://github.com/NVIDIA/apex/blob/7b2e71b0d4013f8e2f9f1c8dd21980ff1d76f1b6/apex/transformer/testing/distributed_test_base.py#L22 +class DataParallelTestCase(common_distributed.MultiProcessTestCase): + DISTRIBUTED_BACKEND = "nccl" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def tearDown(self) -> None: + torch.cuda.empty_cache() + super().tearDown() + + # note(crcrpar): This means the world_size is up to two. + @property + def world_size(self) -> int: + return min(torch.cuda.device_count(), 2) + + @property + def init_method(self): + return f"{common_utils.FILE_SCHEMA}{self.file_name}" + + @classmethod + def _run(cls, rank, test_name, file_name, pipe): + self = cls(test_name) + self.rank = rank + self.file_name = file_name + + torch.distributed.init_process_group( + init_method=self.init_method, + backend=self.DISTRIBUTED_BACKEND, + world_size=self.world_size, + rank=self.rank, + ) + + local_rank = self.rank % torch.cuda.device_count() + torch.cuda.set_device(local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + + torch.distributed.barrier() + try: + self.run_test(test_name, pipe) + except Exception: + raise + finally: + torch.distributed.barrier() + torch.distributed.destroy_process_group() + sys.exit(0) diff --git a/thunder/tests/distributed/test_checkpoint.py b/thunder/tests/distributed/test_checkpoint.py index fbc1b27e4..c3eaf0159 100644 --- a/thunder/tests/distributed/test_checkpoint.py +++ b/thunder/tests/distributed/test_checkpoint.py @@ -20,7 +20,7 @@ get_model_state_dict, _TORCH_GREATER_EQUAL_2_3, ) -from thunder.tests.distributed.test_ddp import DataParallelTestCase +from thunder.tests.distributed.helper import DataParallelTestCase class Submodule(torch.nn.Module): diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 835b6af9a..f90ef4984 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -1,4 +1,3 @@ -import math import multiprocessing as mp import os import sys @@ -22,6 +21,7 @@ from torch.testing import assert_close, make_tensor import thunder +import thunder.executors import thunder.torch as ltorch from thunder.core import devices from thunder.distributed import FSDPBucketingStrategy, FSDPType @@ -49,15 +49,8 @@ is_fp8_supported, fp8_support_reason = check_fp8_support() -try: - import expecttest # noqa: F401 - import hypothesis # noqa: F401 -except ImportError: - raise ImportError( - "Required packages of `expecttest` and/or `hypothesis` are missing. " - "Install them with `pip install expecttest hypothesis`" - ) -from torch.testing._internal import common_distributed, common_utils +from thunder.tests.distributed.helper import ToyModel, DataParallelTestCase, new_gelu +from torch.testing._internal import common_utils executors_map = { TorchExecutor.name: TorchExecutor, @@ -66,90 +59,6 @@ executors_map[nvFuserExecutor.name] = nvFuserExecutor -# Compile - DDP tests -def new_gelu(x): - return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) - - -class ToyModel(nn.Module): - def __init__(self): - super().__init__() - self.net1 = nn.Linear(12, 12) - self.net2 = nn.Linear(12, 8) - - def forward(self, x): - return self.net2(new_gelu(self.net1(x))) - - -# note(crcrpar): How to write a test with `DDP` -# Just add a method to :class:`CompileDDPTest`. The class is responsible for -# - calling `torch.distributed.init_process_group` with NCCL backend -# - setting rank to each process group / device -# so what you'd need to do is to prepare a model and tensors, wrap the model with DDP, and -# `thunder.jit` the original model or the DDP'd model, and do some computation and/or -# examine the traces of the `thunder.jit`d. -# If you force a test to be run with >2 GPUs for a test, you might want to inherit `CompileDDPTest` -# and modify `world_size` to e.g. `max(torch.cuda.device_count(), 2)`. - - -# note(crcrpar): Why inheriting `common_distributed.MultiProcessTestCase`? -# When we're quite sure that we would only use `pytest` instead of `unittest`, -# IIUC it's possible to run a test that is dependent on `DistributedDataParallel` and/or -# `torch.distributed` by running the test file with [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html), -# but I don't think (a) it's quite intuitive to require `torchrun` explicitly to run a test and -# (b) it's quite friendly to our CI as it's currently simply runs `pytest thunder/tests`. -# I would say it's feasible to write a test with `torch.distributed` by using `torch.multiprocessing`, -# but it would require us to make the function which defines the test logic picklable and would -# lead to boilerplate test functions. -# Ref: https://github.com/NVIDIA/apex/blob/7b2e71b0d4013f8e2f9f1c8dd21980ff1d76f1b6/apex/transformer/testing/distributed_test_base.py#L22 -class DataParallelTestCase(common_distributed.MultiProcessTestCase): - DISTRIBUTED_BACKEND = "nccl" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - def tearDown(self) -> None: - torch.cuda.empty_cache() - super().tearDown() - - # note(crcrpar): This means the world_size is up to two. - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 2) - - @property - def init_method(self): - return f"{common_utils.FILE_SCHEMA}{self.file_name}" - - @classmethod - def _run(cls, rank, test_name, file_name, pipe): - self = cls(test_name) - self.rank = rank - self.file_name = file_name - - torch.distributed.init_process_group( - init_method=self.init_method, - backend=self.DISTRIBUTED_BACKEND, - world_size=self.world_size, - rank=self.rank, - ) - - local_rank = self.rank % torch.cuda.device_count() - torch.cuda.set_device(local_rank) - os.environ["LOCAL_RANK"] = str(local_rank) - - torch.distributed.barrier() - self.run_test(test_name, pipe) - torch.distributed.barrier() - - torch.distributed.destroy_process_group() - sys.exit(0) - - @unittest.skipUnless( torch.cuda.is_available() and torch.distributed.is_available() and torch.distributed.is_nccl_available(), "DDP test requires CUDA and NCCL `torch.distributed` backend", @@ -287,8 +196,8 @@ def lc_foo( self.assertEqual(actual, expected) - @common_utils.parametrize("executor", tuple(executors_map.keys())) - def test_all_gather(self, executor): + @common_utils.parametrize("executor,dim", product(tuple(executors_map.keys()), (None, 0, 1))) + def test_all_gather(self, executor, dim: int | None): _executor = executors_map[executor] # NOTE torch.distributed.all_gather is an inplace operation @@ -297,9 +206,16 @@ def foo( b, process_group: torch.distributed.ProcessGroup, async_op: bool, + dim: int | None, ): c = a + b - d = torch.empty((c.shape[0] * process_group.size(), *c.shape[1:]), device=c.device, dtype=c.dtype) + + result_shape = list(c.shape) + if dim is not None: + result_shape[dim] *= process_group.size() + else: + result_shape[0] *= process_group.size() + d = torch.empty(result_shape, device=c.device, dtype=c.dtype) handle = torch.distributed.all_gather_into_tensor(d, c, group=process_group, async_op=async_op) if async_op: @@ -314,10 +230,11 @@ def lc_foo( b, process_group: torch.distributed.ProcessGroup, async_op: bool, + dim: int | None, ): c = a + b - d = ltorch.all_gather(c, group=process_group, async_op=async_op) + d = ltorch.all_gather(c, group=process_group, async_op=async_op, dim=dim) if async_op: d = prims.wait(d) @@ -334,8 +251,8 @@ def lc_foo( cfoo = thunder.jit(lc_foo, executors=_executor.executors_list()) for async_op in (True, False): - expected = foo(a, b, process_group, async_op) - actual = cfoo(a, b, process_group, async_op) + expected = foo(a, b, process_group, async_op, dim) + actual = cfoo(a, b, process_group, async_op, dim) self.assertEqual(actual, expected) @@ -397,8 +314,8 @@ def lc_foo( self.assertEqual(actual, expected) - @common_utils.parametrize("executor", tuple(executors_map.keys())) - def test_reduce_scatter(self, executor): + @common_utils.parametrize("executor,dim", product(tuple(executors_map.keys()), (None, 0, 1))) + def test_reduce_scatter(self, executor, dim): _executor = executors_map[executor] # NOTE torch.distributed.all_gather is an inplace operation @@ -408,9 +325,15 @@ def foo( op, process_group: torch.distributed.ProcessGroup, async_op: bool, + dim: int | None, ): c = a + b - d = torch.empty((c.shape[0] // process_group.size(), *c.shape[1:]), device=c.device, dtype=c.dtype) + result_shape = list(a.shape) + if dim is None: + result_shape[0] //= process_group.size() + else: + result_shape[dim] //= process_group.size() + d = torch.empty(result_shape, device=c.device, dtype=c.dtype) if op is not None: handle = torch.distributed.reduce_scatter_tensor(d, c, op, group=process_group, async_op=async_op) else: @@ -429,10 +352,11 @@ def lc_foo( op, process_group: torch.distributed.ProcessGroup, async_op: bool, + dim: int | None, ): c = a + b - d = ltorch.reduce_scatter(c, op, group=process_group, async_op=async_op) + d = ltorch.reduce_scatter(c, op, group=process_group, async_op=async_op, dim=dim) if async_op: d = prims.wait(d) @@ -449,8 +373,8 @@ def lc_foo( cfoo = thunder.jit(lc_foo, executors=_executor.executors_list()) for op, async_op in product((None, torch.distributed.ReduceOp.SUM), (False, True)): - expected = foo(a, b, op, process_group, async_op) - actual = cfoo(a, b, op, process_group, async_op) + expected = foo(a, b, op, process_group, async_op, dim=dim) + actual = cfoo(a, b, op, process_group, async_op, dim=dim) self.assertEqual(actual, expected) diff --git a/thunder/tests/distributed/test_tensor_parallel.py b/thunder/tests/distributed/test_tensor_parallel.py new file mode 100644 index 000000000..71b737867 --- /dev/null +++ b/thunder/tests/distributed/test_tensor_parallel.py @@ -0,0 +1,208 @@ +from itertools import product + +import pytest +import torch +import torch.nn as nn + +import thunder +from thunder.distributed import column_parallel, row_parallel +import thunder.executors +from thunder.tests.distributed.helper import ToyModel, DataParallelTestCase + +from torch.testing._internal import common_utils +from torch.distributed import distributed_c10d as c10d + +_COL = "column" +_ROW = "row" +_name_to_transform = { + _COL: column_parallel, + _ROW: row_parallel, +} + + +class TensorParallelTest(DataParallelTestCase): + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="") + @common_utils.parametrize("name,bias", product(tuple(_name_to_transform.keys()), (True, False))) + def test_tensor_parallel_linear(self, name, bias): + device = torch.device("cuda", self.rank) + x = torch.randn(2, 12).to(device).requires_grad_() + x_ref = x.clone().detach().requires_grad_() + + process_group = None + ref_model = ToyModel(bias).to(device) + + ref_state_dict = ref_model.state_dict() + expected = ref_model(x_ref) + + transform = _name_to_transform[name] + model = ToyModel(bias=bias).to(device) + model.load_state_dict(ref_state_dict) + jitted_model = thunder.jit(model) + tp_jitted_model = transform( + jitted_model, + target_modules=("net1", "net2"), + process_group=process_group, + ) + y = tp_jitted_model(x) + torch.testing.assert_close(expected=expected, actual=y) + + expected.mean().backward() + y.mean().backward() + + if self.rank == 0: + fwd_extrace = thunder.last_traces(tp_jitted_model)[-1] + bwd_extrace = thunder.last_backward_traces(tp_jitted_model)[-1] + for bsym in fwd_extrace.bound_symbols + bwd_extrace.bound_symbols: + bsym.subsymbols = [] + + with open("./fwd_extrace_1.py", "w") as f: + f.write(str(fwd_extrace)) + with open("./bwd_extrace_1.py", "w") as f: + f.write(str(bwd_extrace)) + + dim = 1 if name == _ROW else 0 + for layer_name in ("net1", "net2"): + param_name = f"{layer_name}.weight" + expected_full_grad: torch.Tensor = ref_model.get_parameter(param_name).grad + expected = torch.chunk(expected_full_grad, self.world_size, dim)[self.rank] + torch.testing.assert_close( + expected=expected, + actual=tp_jitted_model.get_parameter(param_name).grad, + ) + if bias: + param_name = f"{layer_name}.bias" + expected_bias_grad: torch.Tensor = ref_model.get_parameter(param_name).grad + if name == _COL: + expected = torch.chunk(expected_bias_grad, self.world_size, 0)[self.rank] + else: + expected = expected_bias_grad + torch.testing.assert_close( + expected=expected, + actual=tp_jitted_model.get_parameter(param_name).grad, + ) + torch.testing.assert_close(expected=x_ref.grad, actual=x.grad) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="") + @common_utils.parametrize("name", tuple(_name_to_transform.keys())) + def test_tensor_parallel_embedding(self, name): + num_embeddings = 128 + embedding_dim = 32 + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.embed = nn.Embedding(num_embeddings, embedding_dim) + + def forward(self, x): + return self.embed(x) + + device = torch.device(f"cuda:{self.rank}") + x = torch.randint(0, num_embeddings - 1, (16, 16), device=device) + x_ref = x.clone().detach() + + process_group = None + ref_model = Model().to(device) + + ref_state_dict = ref_model.state_dict() + expected = ref_model(x_ref) + + transform = _name_to_transform[name] + model = Model().to(device) + model.load_state_dict(ref_state_dict) + jitted_model = thunder.jit(model) + tp_jitted_model = transform( + jitted_model, + target_modules=("embed",), + process_group=process_group, + ) + y = tp_jitted_model(x) + + dim: int + orig_size: int + if name == _COL: + dim = 0 + orig_size = num_embeddings + else: + dim = 1 + orig_size = embedding_dim + torch.testing.assert_close( + tp_jitted_model.get_parameter("embed.weight").size(dim), orig_size // self.world_size + ) + torch.testing.assert_close(expected=expected, actual=y) + + expected.mean().backward() + y.mean().backward() + + torch.testing.assert_close( + expected=ref_model.embed.weight.grad.chunk(self.world_size, dim)[self.rank], + actual=tp_jitted_model.get_parameter("embed.weight").grad, + ) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="") + @common_utils.parametrize("bias", (True, False)) + def test_tensor_parallel_both_column_and_row(self, bias): + num_embeddings = 128 + embedding_dim = 32 + n_hidden = 96 + + class Model(nn.Module): + def __init__(self, bias: bool = True): + super().__init__() + self.embed_1 = nn.Embedding(num_embeddings, embedding_dim) + self.embed_2 = nn.Embedding(num_embeddings, embedding_dim) + self.linear1_0 = nn.Linear(embedding_dim, n_hidden, bias=bias) + self.linear1_1 = nn.Linear(n_hidden, n_hidden, bias=bias) + + def forward(self, x): + feat_1 = self.embed_1(x) + feat_2 = self.embed_2(x) + sum_of_feat = feat_1 + feat_2 + h = self.linear1_1(torch.relu(self.linear1_0(sum_of_feat))) + return h + + device = torch.device("cuda", self.rank) + x = torch.randint(0, num_embeddings - 1, (16, 16), device=device) + x_ref = x.clone().detach() + + process_group = None + ref_model = Model(bias=bias).to(device) + ref_state_dict = ref_model.state_dict() + expected = ref_model(x_ref) + + model = Model(bias=bias).to(device) + model.load_state_dict(ref_state_dict) + tp_model = thunder.jit(model) + + column_parallel_layers = ["embed_1", "linear1_0"] + tp_model = column_parallel(tp_model, column_parallel_layers, process_group) + row_parallel_layers = ["embed_2", "linear1_1"] + tp_model = row_parallel(tp_model, row_parallel_layers, process_group) + actual = tp_model(x) + torch.testing.assert_close(actual=actual, expected=expected) + + with torch.no_grad(): + g_ref = torch.rand_like(expected) + g = g_ref.clone().detach() + expected.backward(g_ref) + actual.backward(g) + + for l_name, layer in reversed(list(ref_model.named_modules())): + dim = int(l_name in row_parallel_layers) + is_tensor_parallel = l_name in row_parallel_layers or l_name in column_parallel_layers + prefix = "row-parallel" if dim else "column-parallel" + for p_name, p_ref in layer.named_parameters(recurse=False): + param_fqn = f"{l_name}.{p_name}" + ref_grad = p_ref.grad + msg = lambda err_msg: f"[{prefix} {param_fqn}] {err_msg}" + if is_tensor_parallel and (ref_grad.ndim > 1 or dim == 0): + ref_grad = ref_grad.chunk(self.world_size, dim)[self.rank] + grad = tp_model.get_parameter(param_fqn).grad + torch.testing.assert_close(actual=grad, expected=ref_grad, msg=msg) + + +common_utils.instantiate_parametrized_tests(TensorParallelTest) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 8048ee7da..67a6cb6d5 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -4497,10 +4497,11 @@ def all_gather( a: TensorLike, group: torch.distributed.ProcessGroup | None = None, async_op: bool = False, + dim: int | None = None, ) -> TensorLike | FutureTensorLike: group = group if group is not None else torch.distributed.new_group() - return dist_prims.all_gather(a, group, async_op) + return dist_prims.all_gather(a, group, async_op, dim=dim) # NOTE torch.distributed.all_reduce is an inplace operation (although the underlying NCCL # call does not need to be inplace). This, however, is modeled as an out-of-place functional @@ -4548,11 +4549,12 @@ def reduce_scatter( op: DistributedReduceOpLike | None = None, group: torch.distributed.ProcessGroup | None = None, async_op: bool = False, + dim: int | None = None, ) -> TensorLike | FutureTensorLike: op = to_thunder_distributed_reduce_op(op) group = group if group is not None else torch.distributed.new_group() - return dist_prims.reduce_scatter(a, op, group, async_op) + return dist_prims.reduce_scatter(a, op, group, async_op, dim=dim) else: