Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distributed][Tensor Parallelism] Implement early transforms for column-wise and row-wise linear and embedding #410

Merged
merged 95 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
638c09c
comms before/after tensor parallel op (column wise)
crcrpar May 11, 2024
74b528c
add `dim` arg to `all_gather` and `reduce_scatter`
crcrpar May 12, 2024
07871cd
dim arg to shard/unshard params
crcrpar May 12, 2024
2f1ba9f
init col wise transform
crcrpar May 12, 2024
20d6b2e
fix circular import & add test
crcrpar May 13, 2024
6bcf899
`split_forward_backward` discarding the sync for colwise :/
crcrpar May 13, 2024
6712f08
fix sync of colwise output impl
crcrpar May 13, 2024
fb71513
use visitor_transform
crcrpar May 13, 2024
071ee7c
fix typo: input -> output
crcrpar May 13, 2024
38100f8
tentatively commenting out for green run
crcrpar May 13, 2024
89236e0
del unused `synchronize_input_for_column_wise_tensor_parallel`
crcrpar May 13, 2024
bfef29f
docstring
crcrpar May 13, 2024
ff01803
clean up test
crcrpar May 13, 2024
c6c19f0
remove `dim` from `_shard_params`
crcrpar May 13, 2024
f25b4af
fix gather in last dim
crcrpar May 13, 2024
e6c6aa4
fwd numerical check
crcrpar May 13, 2024
47249df
bwd check
crcrpar May 13, 2024
0c8c18f
docs
crcrpar May 13, 2024
1f56e24
fix rebase conflicts
crcrpar May 17, 2024
799fe82
docstring update
crcrpar May 17, 2024
04de709
make `TransformForColumnWiseParallel` frozen
crcrpar May 17, 2024
49cafef
simplification
crcrpar May 17, 2024
965ac7e
refactor colwise linear, initialized embedding
crcrpar May 20, 2024
5e66bb9
colwise embedding
crcrpar May 20, 2024
934710a
as a class variable
crcrpar May 20, 2024
eba0cd0
fix return type
crcrpar May 20, 2024
36b7dd4
del outdated todo comment
crcrpar May 20, 2024
4222660
organize
crcrpar May 22, 2024
2e21ef8
common interface for structure
crcrpar May 22, 2024
9f0becb
initial row-wise
crcrpar May 22, 2024
9b1b879
merge col/row output sync into one
crcrpar May 23, 2024
af7c139
row parallel embedding
crcrpar May 23, 2024
c93b73b
row wise, part 2
crcrpar May 23, 2024
45d976b
test
crcrpar May 23, 2024
0a62b0e
test of both column/row-wise
crcrpar May 23, 2024
d9d545d
dataclass inheritance must obey parent being frozen or not
crcrpar May 23, 2024
f612771
fix
crcrpar May 23, 2024
6c4d161
fix
crcrpar May 23, 2024
b346fc3
move model instantiation to inside of the loop
crcrpar May 23, 2024
4e4c2ff
modify how to chunk params
crcrpar May 23, 2024
9edf065
cosmetic
crcrpar May 24, 2024
365f19e
properly modify prologue&computation traces
crcrpar May 24, 2024
855ae9f
remove print
crcrpar May 24, 2024
9850a78
set `allow_padding_for_fsdp` to `False`
crcrpar May 24, 2024
960f703
fix preprocess
crcrpar May 24, 2024
ddb1a4c
fix row-wise linear bias accumulation
crcrpar May 24, 2024
4806ead
parametrize column or row, as the caching looks polluted
crcrpar May 24, 2024
90960e7
brief note on row-wise parallel linear
crcrpar May 24, 2024
29f1fed
`convert_module_to_rownwise_parallel` to docs
crcrpar May 24, 2024
82f6268
docs
crcrpar May 24, 2024
42b8bb7
temporary off
crcrpar May 24, 2024
7cace21
fix padding value
crcrpar May 24, 2024
3adad7c
fix embedding test
crcrpar May 24, 2024
c64790f
Apply suggestions from code review
crcrpar May 29, 2024
1b0e7a4
Apply suggestions from code review
crcrpar May 29, 2024
4b06c0e
`convert_module_to_(column|row)wise_parallel` -> `(column|row)_parallel`
crcrpar May 29, 2024
389ea66
no need to specifically mark a new id as experimental
crcrpar May 29, 2024
813e738
`bsym2prepostprocess` -> `bsym_to_prepostprocess`
crcrpar May 29, 2024
106ca70
`chunked_param_name2layer_type` -> `chunked_param_name_to_layer_type`
crcrpar May 29, 2024
27494c0
`get_visitor_of_computation_trc_and_provenance` -> `get_visitor_of_co…
crcrpar May 29, 2024
3f6be16
`converter` -> `transformer`
crcrpar May 29, 2024
10ee435
rename `DDPType` to `DistParallelType` to be friendly to tensor paral…
crcrpar May 26, 2024
a0aa61c
set dist parallel type
crcrpar May 26, 2024
65924da
check dist parallel type
crcrpar May 27, 2024
e41019b
`abstractproperty` is deprecated, import `DistParallelType`
crcrpar May 27, 2024
f68dddb
remove unused `synchronize_input_for_column_wise_tensor_parallel_meta`
crcrpar May 29, 2024
679cb64
split test
crcrpar May 29, 2024
1c13f01
import new_gelu
crcrpar May 29, 2024
8223203
`_overrides` -> `_overrides_parameters`
crcrpar May 29, 2024
d7f99af
update signature
crcrpar May 29, 2024
36aac4e
`ddp_type` -> `distparallel_type`
crcrpar May 29, 2024
e2a8a08
simplify bsym_to_prepostprocess construction
crcrpar May 29, 2024
51a365f
avoid using strings
crcrpar May 29, 2024
ab511a4
fix column-wise linear preprocessing
crcrpar May 29, 2024
a9e18c8
fix preprocess of row-wise parallel linear
crcrpar May 29, 2024
f6d2680
parametrize tp linear bias
crcrpar May 30, 2024
a9078ab
cosmetic
crcrpar May 30, 2024
a36cdb8
split `swap_map` into `input_swap_map` and `swap_map`
crcrpar May 30, 2024
b5909ae
no initial param sync, input grad check
crcrpar May 30, 2024
632ca68
input grad check
crcrpar May 30, 2024
6a5a6e5
grads check
crcrpar May 30, 2024
ea70ae3
create input_swap_map per bsym
crcrpar May 30, 2024
ae30fc6
destroy process group, always
crcrpar May 30, 2024
0500919
multiple row-parallel linears
crcrpar May 30, 2024
0f57a0c
local split in last dim
crcrpar May 31, 2024
605d701
expose hardcoded params as class var
crcrpar May 31, 2024
88bbbe2
simplify test
crcrpar May 31, 2024
9ff8848
bit more informative err msg
crcrpar May 31, 2024
264b549
fix bwd of row-embed postprocess
crcrpar May 31, 2024
7f13ea0
organized
crcrpar May 31, 2024
8231177
organized
crcrpar May 31, 2024
46bd367
Merge branch 'main' into crpa/tensor-parallel
crcrpar May 31, 2024
5bf797b
use gelu for better numeric
crcrpar May 31, 2024
46183d8
longer name
crcrpar May 31, 2024
f146205
no nvfuser linear/matmul
crcrpar May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/distributed/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ thunder.distributed
reset_skip_data_parallel_grad_sync
get_skip_data_parallel_grad_sync
skip_data_parallel_grad_sync
column_parallel
row_parallel
2 changes: 1 addition & 1 deletion notebooks/dev_tutorials/fsdp_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
" shard_param(param, global_rank, world_size, param_name)\n",
" # Mark the param to denote that it is sharded.\n",
" # This is required by the synchronization primitive we will use below.\n",
" param.ddp_type = thunder.core.proxies.DDPType.FULLY_SHARDED"
" param.distparallel_type = thunder.core.proxies.DistParallelType.FULLY_SHARDED"
]
},
{
Expand Down
15 changes: 13 additions & 2 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@
reset_tracectx,
)
from thunder.core.transform_common import dce, cse
from thunder.core.proxies import is_proxyable, proxy, Proxy, CollectionProxy, TensorProxy, DDPType, FutureTensorProxy
from thunder.core.proxies import (
is_proxyable,
proxy,
Proxy,
CollectionProxy,
TensorProxy,
DistParallelType,
FutureTensorProxy,
)
import thunder.core.prims as prims
import thunder.distributed as dist
import thunder.torch as ltorch
Expand Down Expand Up @@ -559,7 +567,10 @@ def _trace(
)

def ddp_sync(arg: Any | TensorProxy) -> Any | TensorProxy:
if isinstance(arg, TensorProxy) and arg.ddp_type in (DDPType.REPLICATED, DDPType.FULLY_SHARDED):
if isinstance(arg, TensorProxy) and arg.distparallel_type in (
DistParallelType.REPLICATED,
DistParallelType.FULLY_SHARDED,
):
return dist.prims.synchronize(arg, compile_data.process_group_for_ddp)
else:
return arg
Expand Down
7 changes: 5 additions & 2 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

import torch
from thunder.core.proxies import (
DDPType,
DistParallelType,
proxy,
Proxy,
NumberProxy,
Expand Down Expand Up @@ -575,7 +575,10 @@ def proxify(self, value: WrappedValue) -> Any:
# TensorProxy attributes should be considered derived quantities, so we flag TensorProxies here
value.provenance.ext_flag |= EXT_FLAG_IS_TENSOR_PROXY

if isinstance(p, TensorProxy) and p.ddp_type in (DDPType.REPLICATED, DDPType.FULLY_SHARDED):
if isinstance(p, TensorProxy) and p.distparallel_type in (
DistParallelType.REPLICATED,
DistParallelType.FULLY_SHARDED,
):
p_new = thunder.distributed.prims.synchronize(
p,
self._process_group_for_ddp,
Expand Down
51 changes: 33 additions & 18 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,10 +994,13 @@ def __repr__(self):
return f"[FloatProxy name={self.name}, value={self.value}]"


class DDPType(Enum):
class DistParallelType(Enum):
lantiga marked this conversation as resolved.
Show resolved Hide resolved
NONE = auto()
REPLICATED = auto()
FULLY_SHARDED = auto()
# Following two are for tensor parallelism
COLUMN_WISE = auto()
ROW_WISE = auto()


def _infer_tensor_properties(
Expand All @@ -1006,14 +1009,14 @@ def _infer_tensor_properties(
device: devices.Device | None = None,
dtype: dtypes.dtype | None = None,
requires_grad: bool | None = None,
ddp_type: DDPType | None = None,
distparallel_type: DistParallelType | None = None,
thunder_fsdp_padding_size: int | None = None,
):
_shape = None
_device = None
_dtype = None
_requires_grad: None | bool = None
_ddp_type = DDPType.NONE
_dist_parallel_type = DistParallelType.NONE
_thunder_fsdp_padding_size = None

if like is not None:
Expand All @@ -1022,7 +1025,7 @@ def _infer_tensor_properties(
_device = like.device
_dtype = like.true_dtype
_requires_grad = like.requires_grad
_ddp_type = getattr(like, "ddp_type", DDPType.NONE)
_dist_parallel_type = getattr(like, "distparallel_type", DistParallelType.NONE)

if shape is not None:
baseutils.check_valid_shape(shape)
Expand All @@ -1033,7 +1036,7 @@ def _infer_tensor_properties(
_dtype = dtypes.numbertype_to_dtype(_dtype) if dtypes.is_numbertype(_dtype) else _dtype
_requires_grad = requires_grad if requires_grad is not None else _requires_grad
_requires_grad = False if not dtypes.is_inexact_dtype(_dtype) else _requires_grad
_ddp_type = ddp_type if ddp_type is not None else _ddp_type
_dist_parallel_type = distparallel_type if distparallel_type is not None else _dist_parallel_type
_thunder_fsdp_padding_size = (
thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size
)
Expand All @@ -1057,11 +1060,11 @@ def _infer_tensor_properties(
baseutils.check_type(_device, devices.Device)
baseutils.check_type(_dtype, dtypes.dtype)
baseutils.check_type(_requires_grad, bool)
baseutils.check_type(_ddp_type, DDPType)
baseutils.check_type(_dist_parallel_type, DistParallelType)
if isinstance(_thunder_fsdp_padding_size, int):
baseutils.check(
_ddp_type == DDPType.FULLY_SHARDED,
lambda: f"{_ddp_type = } and {_thunder_fsdp_padding_size = } do not work",
_dist_parallel_type == DistParallelType.FULLY_SHARDED,
lambda: f"{_dist_parallel_type = } and {_thunder_fsdp_padding_size = } do not work",
)
baseutils.check(
_thunder_fsdp_padding_size > 0,
Expand All @@ -1073,7 +1076,17 @@ def _infer_tensor_properties(
_true_dtype = _dtype
_dtype = dtypes.to_strong_dtype(_dtype)

return _shape, _device, _dtype, _true_dtype, _numel, _ndim, _requires_grad, _ddp_type, _thunder_fsdp_padding_size
return (
_shape,
_device,
_dtype,
_true_dtype,
_numel,
_ndim,
_requires_grad,
_dist_parallel_type,
_thunder_fsdp_padding_size,
)


# NOTE A FutureTensorProxy is intentionally NOT a subclass of TensorProxy
Expand All @@ -1100,7 +1113,7 @@ def __init__(
self._numel,
self._ndim,
self._requires_grad,
_, # ddp_type
_, # distparallel_type
_, # thunder_fsdp_padding_size
) = _infer_tensor_properties(
like,
Expand Down Expand Up @@ -1167,7 +1180,7 @@ def __init__(
dtype: dtypes.dtype | None = None,
requires_grad: bool | None = None,
prefix: None | str = None,
ddp_type: DDPType | None = None,
distparallel_type: DistParallelType | None = None,
history: None | tuple = None,
thunder_fsdp_padding_size: int | None = None,
):
Expand All @@ -1181,9 +1194,11 @@ def __init__(
self._numel,
self._ndim,
self._requires_grad,
self._ddp_type,
self._distparallel_type,
self._thunder_fsdp_padding_size,
) = _infer_tensor_properties(like, shape, device, dtype, requires_grad, ddp_type, thunder_fsdp_padding_size)
) = _infer_tensor_properties(
like, shape, device, dtype, requires_grad, distparallel_type, thunder_fsdp_padding_size
)

# NOTE The following properties DO NOT depend on the language context or record
# themselves into the trace, so they can be used when working with tensor proxies
Expand Down Expand Up @@ -1213,8 +1228,8 @@ def requires_grad(self):
return self._requires_grad

@property
def ddp_type(self):
return self._ddp_type
def distparallel_type(self):
return self._distparallel_type

@property
def thunder_fsdp_padding_size(self):
Expand Down Expand Up @@ -1519,8 +1534,8 @@ def real(self):
def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy:
device = devices.device_from_string(str(t.device))
dtype = dtypes.to_dtype(t.dtype)
# See Note [DistributedDataParallel and ddp_type]
ddp_type = getattr(t, "ddp_type", None)
# See Note [DistributedDataParallel and distparallel_type]
distparallel_type = getattr(t, "distparallel_type", None)
_thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None)
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
return TensorProxy(
Expand All @@ -1529,7 +1544,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
device=device,
dtype=dtype,
requires_grad=t.requires_grad,
ddp_type=ddp_type,
distparallel_type=distparallel_type,
history=history,
thunder_fsdp_padding_size=_thunder_fsdp_padding_size,
)
Expand Down
69 changes: 40 additions & 29 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from torch.utils.weak import WeakTensorKeyDictionary

import thunder.core.utils as utils
from thunder.core.proxies import DDPType
from thunder.core.proxies import DistParallelType
from thunder.distributed.tensor_parallel import column_parallel
from thunder.distributed.tensor_parallel import row_parallel

if TYPE_CHECKING:
from torch.distributed import ProcessGroup
Expand All @@ -28,6 +30,8 @@
"fsdp",
"FSDPBucketingStrategy",
"FSDPType",
"column_parallel",
"row_parallel",
]


Expand Down Expand Up @@ -273,14 +277,14 @@ def main():
),
)

# Note [DistributedDataParallel and ddp_type]
# Note [DistributedDataParallel and distparallel_type]
# If model was wrapped with thunder.distributed.ddp it would have a
# .use_ddp attribute set to True and all parameters would be already
# broadcasted to all other processes. So that our tracing is aware of
# this we need to mark the ddp_type of model's parameters as
# thunder.proxies.DDPType.REPLICATED
# this we need to mark the distparallel_type of model's parameters as
# thunder.proxies.DistParallelType.REPLICATED
for p in model.parameters():
p.ddp_type = DDPType.REPLICATED
p.distparallel_type = DistParallelType.REPLICATED

if broadcast_from is None:
return model
Expand Down Expand Up @@ -555,14 +559,14 @@ def fsdp(
# Shard the parameters
_shard_params(model, process_group, device, broadcast_from, allow_padding_for_fsdp=True)

# See Note [DistributedDataParallel and ddp_type]
# See Note [DistributedDataParallel and distparallel_type]
# If model was wrapped with thunder.distributed.fsdp it would have a
# .use_fsdp attribute set to True and all parameters would be already
# sharded across all other processes. So that our tracing is aware of
# this we need to mark the ddp_type of model's parameters as
# thunder.proxies.DDPType.FULLY_SHARDED
# this we need to mark the distparallel_type of model's parameters as
# thunder.proxies.DistParallelType.FULLY_SHARDED
for p in model.parameters():
p.ddp_type = DDPType.FULLY_SHARDED
p.distparallel_type = DistParallelType.FULLY_SHARDED

return model

Expand Down Expand Up @@ -619,33 +623,40 @@ def _shard_param(
rank: int,
world_size: int,
name: str,
*,
allow_padding_for_fsdp: bool = False,
dim: int | None = None,
) -> None:

if not allow_padding_for_fsdp or (param.size(0) % world_size == 0):
if not allow_padding_for_fsdp:
utils.check(
param.shape[0] % world_size == 0,
lambda: (
f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})"
f" to be divisible by the world size ({world_size})"
),
)
chunk_size = param.shape[0] // world_size
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param.data = shard
else:
dim_to_shard = 0 if dim is None else dim
if allow_padding_for_fsdp:
lantiga marked this conversation as resolved.
Show resolved Hide resolved
utils.check(dim_to_shard == 0, lambda: f"Invalid {dim=} with {allow_padding_for_fsdp=}, Only 0 is supported")
padded_param_shape = list(param.shape)
orig_0dim_size = param.size(0)
orig_0dim_size = param.size(dim_to_shard)
chunk_size = (padded_param_shape[0] + world_size - 1) // world_size
padded_param_shape[0] = chunk_size * world_size
_thunder_fsdp_padding_size = padded_param_shape[0] - param.size(0)
padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype)
padded_param[:orig_0dim_size].copy_(param)
param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size
if _thunder_fsdp_padding_size > 0:
padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype)
padded_param[:orig_0dim_size].copy_(param)
param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size
else:
param.data = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = None
else:
utils.check(
param.shape[dim_to_shard] % world_size == 0,
lambda: (
f"Current sharding requires the sharded dimension of the parameter {name!r} ({param.shape[dim_to_shard]})"
f" to be divisible by the world size ({world_size})"
),
)
chunk_size = param.shape[dim_to_shard] // world_size
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(dim_to_shard, chunk_size * rank, chunk_size).clone()
param.data = shard


@torch.no_grad()
Expand Down
Loading
Loading