Skip to content

Commit

Permalink
simplification
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed May 17, 2024
1 parent c1e398d commit 25b2d14
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 24 deletions.
31 changes: 8 additions & 23 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,39 +582,29 @@ def _shard_params(
# Note [FSDP Sharding]
# All internal code will assume that the parameters are sharded on the first dimension
for param_name, param in submodule.named_parameters(recurse=False, prefix=module_name):
_shard_param(
param, global_rank, world_size, param_name, dim=0, allow_padding_for_fsdp=allow_padding_for_fsdp
)
_shard_param(param, global_rank, world_size, param_name, allow_padding_for_fsdp=allow_padding_for_fsdp)


def _shard_param(
param: torch.Tensor,
rank: int,
world_size: int,
name: str,
*,
dim: int | None,
allow_padding_for_fsdp: bool = False,
) -> None:

dim_to_shard: int = 0 if dim is None else dim

if allow_padding_for_fsdp:
utils.check(dim_to_shard == 0, lambda: f"{dim=} expected to be `None` for FSDP")

if not allow_padding_for_fsdp or (param.size(dim_to_shard) % world_size == 0):
if not allow_padding_for_fsdp or (param.size(0) % world_size == 0):
if not allow_padding_for_fsdp:
utils.check(
param.size(dim_to_shard) % world_size == 0,
param.shape[0] % world_size == 0,
lambda: (
f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[dim_to_shard]})"
f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})"
f" to be divisible by the world size ({world_size})"
),
)
chunk_size = param.size(dim_to_shard) // world_size
chunk_size = param.shape[0] // world_size
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(dim_to_shard, chunk_size * rank, chunk_size).clone()
shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param.data = shard
else:
padded_param_shape = list(param.shape)
Expand All @@ -629,12 +619,7 @@ def _shard_param(


@torch.no_grad()
def _unshard_params(
module: torch.nn.Module,
process_group: ProcessGroup,
cpu_offload: bool = False,
dim: int | None = None,
) -> None:
def _unshard_params(module: torch.nn.Module, process_group: ProcessGroup, cpu_offload: bool = False) -> None:
"""Unshard a module's parameters.
This supports CPU offloading of parameters.
Expand All @@ -643,7 +628,7 @@ def _unshard_params(

cpu = torch.device("cpu")
for param in module.parameters():
out = _all_gather_prim_impl(param.data, group=process_group, do_async=0, dim=dim)
out = _all_gather_prim_impl(param.data, group=process_group, do_async=0)
if cpu_offload:
out = out.to(device=cpu)
param.data = out
Expand Down
2 changes: 1 addition & 1 deletion thunder/distributed/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
for target_mod_name in target_modules:
mod = colwise_thunder_module.get_submodule(target_mod_name)
for name, p in mod.named_parameters(recurse=False):
_shard_param(p, rank, world_size, name, dim=0)
_shard_param(p, rank, world_size, name, allow_padding_for_fsdp=False)

return colwise_thunder_module

0 comments on commit 25b2d14

Please sign in to comment.