Skip to content

Commit

Permalink
[FSDP2] Supported set_all_reduce_gradients=False for HSDP (pytorch#…
Browse files Browse the repository at this point in the history
…126166)

**Context**
For FSDP, gradient accumulation across microbatches has two flavors: (1) reduce-scatter or (2) no reduce-scatter. (1) incurs the collective per microbatch backward but saves gradient memory (storing the sharded gradients), while (2) avoids the communication but uses more gradient memory (storing the unsharded gradients).
- FSDP2 offers (1) without any intervention. The user should simply make sure to run the optimizer step after `K` microbatches for `K > 1`.
- FSDP2 offers (2) via `module.set_requires_gradient_sync()` (e.g. `module.set_requires_gradient_sync(is_last_microbatch)`.

For HSDP, since we reduce-scatter and then all-reduce, we have additional flexibility and get three flavors: (1) reduce-scatter and all-reduce, (2) reduce-scatter but no all-reduce, and (3) no reduce-scatter and no all-reduce. This PR adds support for (2).
- FSDP2 offers (1) without any intervention like mentioned above.
- FSDP2 offers (3) via `module.set_requires_gradient_sync()` like mentioned above.
- FSDP2 offers (2) via `module.set_requires_all_reduce()` similar to `set_requires_gradient_sync()`.

**Overview**
For HSDP, to reduce-scatter but not all-reduce during gradient accumulation, the user can do something like:
```
for microbatch_idx, microbatch in enumerate(microbatches):
    is_last_microbatch = microbatch_idx == len(microbatches) - 1
    model.set_requires_all_reduce(is_last_microbatch)
    # Run forward/backward
```

This PR also makes the minor change of making the `recurse: bool` argument in these setter methods to be kwarg only.

**Developer Notes**
We choose to implement this by saving the partial reduce output to the `FSDPParamGroup` for simplicity, where we assume that the set of parameters that receive gradients does not change across microbatches. An alternative would be to view into the partial reduce output per parameter and save the view to each parameter. We prefer to avoid this alternative for now because it introduces more complexity to do extra viewing when saving the partial reduce output to each parameter, accumulating into them, and accumulating back to the last microbatch's reduce output.

Pull Request resolved: pytorch#126166
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: pytorch#126067, pytorch#126070, pytorch#126161
  • Loading branch information
awgu authored and ZelboK committed May 19, 2024
1 parent 8288174 commit 5dd875a
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 64 deletions.
6 changes: 4 additions & 2 deletions test/distributed/_composable/fsdp/test_fully_shard_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _test_reduce_scatter(
group = fsdp_param_group.mesh_info.shard_process_group
self.assertEqual(group.size(), self.world_size)
all_reduce_stream = torch.cuda.Stream()
view_out_event = foreach_reduce(
post_reduce_event, _ = foreach_reduce(
fsdp_params,
unsharded_grads,
group,
Expand All @@ -254,8 +254,10 @@ def _test_reduce_scatter(
device=self.device,
all_reduce_group=None,
all_reduce_stream=all_reduce_stream,
all_reduce_grads=True,
partial_reduce_output=None,
)
torch.cuda.current_stream().wait_event(view_out_event)
torch.cuda.current_stream().wait_event(post_reduce_event)

# Check reduce-scatter correctness
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
Expand Down
72 changes: 47 additions & 25 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,12 @@ def test_gradient_accumulation(self):
"mode": ["all", "root_only", "some_mlps"],
"reshard_after_backward": [False, True],
"offload_policy": [OffloadPolicy(), CPUOffloadPolicy()],
# For HSDP only:
# `True`: reduce-scatter only (no all-reduce) each microbatch
# until the last microbatch
# `False`: neither reduce-scatter nor all-reduce each
# microbatch until the last microbatch
"reduce_scatter_only": [False, True],
},
self._test_gradient_accumulation,
)
Expand All @@ -683,15 +689,20 @@ def _test_gradient_accumulation(
mode: str,
reshard_after_backward: bool,
offload_policy: OffloadPolicy,
reduce_scatter_only: bool, # for HSDP
):
if (
not reshard_after_backward
and (reshard_after_forward is not False or mode == "some_mlps")
) or (
isinstance(offload_policy, CPUOffloadPolicy)
and reshard_after_forward is not True
(
not reshard_after_backward
and (reshard_after_forward is not False or mode == "some_mlps")
)
or (
isinstance(offload_policy, CPUOffloadPolicy)
and reshard_after_forward is not True
)
or (mesh.ndim != 2 and reduce_scatter_only)
):
return # skip since not common
return # skip since not common or applicable

torch.manual_seed(42)
batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
Expand All @@ -713,29 +724,35 @@ def _test_gradient_accumulation(
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

def set_grad_sync_flag(
module: nn.Module, is_last_microbatch: bool, recurse: bool = True
):
if reduce_scatter_only:
module.set_requires_all_reduce(is_last_microbatch, recurse=recurse)
else:
module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse)

def set_backward_flags(_model: nn.Module, is_last_microbatch: bool):
if mode == "all":
set_grad_sync_flag(_model, is_last_microbatch)
if not reshard_after_backward:
_model.set_reshard_after_backward(is_last_microbatch)
elif mode == "some_mlps":
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
set_grad_sync_flag(mlp, is_last_microbatch)
if not reshard_after_backward:
mlp.set_reshard_after_backward(is_last_microbatch)
elif mode == "root_only":
set_grad_sync_flag(model, is_last_microbatch, recurse=False)
if not reshard_after_backward:
model.set_reshard_after_backward(is_last_microbatch, recurse=False)

torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
with CommDebugMode() as comm_mode:
for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1
if mode == "all":
model.set_requires_gradient_sync(is_last_microbatch)
if not reshard_after_backward:
model.set_reshard_after_backward(is_last_microbatch)
elif mode == "some_mlps":
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
mlp.set_requires_gradient_sync(is_last_microbatch)
if not reshard_after_backward:
mlp.set_reshard_after_backward(is_last_microbatch)
elif mode == "root_only":
model.set_requires_gradient_sync(
is_last_microbatch, recurse=False
)
if not reshard_after_backward:
model.set_reshard_after_backward(
is_last_microbatch, recurse=False
)

set_backward_flags(model, is_last_microbatch)
inp = torch.randn(batch_size, lin_dim, device="cuda")
losses: List[torch.Tensor] = []
for _model in (ref_model, model):
Expand All @@ -760,10 +777,15 @@ def _test_gradient_accumulation(
elif mode == "root_only":
# Expect additional reduce-scatters for all MLPs
expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1)
self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
expected_all_reduce_count = (
expected_reduce_scatter_count if mesh.ndim == 2 else 0
)
if reduce_scatter_only:
# Specially for HSDP if only reduce-scattering but not
# all-reducing until the last microbatch, expect one
# reduce-scatter per MLP plus for the root per microbatch
expected_reduce_scatter_count = (num_mlps + 1) * num_microbatches
self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
self.assertEqual(all_reduce_count, expected_all_reduce_count)

# Expect one all-gather per MLP plus one for the root's linear in
Expand Down
55 changes: 32 additions & 23 deletions torch/distributed/_composable/fsdp/_fsdp_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,11 @@ def foreach_reduce(
orig_dtype: torch.dtype,
reduce_dtype: Optional[torch.dtype],
device: torch.device,
all_reduce_group: Optional[dist.ProcessGroup],
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
all_reduce_stream: torch.cuda.Stream,
) -> torch.cuda.Event:
all_reduce_grads: bool,
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
) -> Tuple[torch.cuda.Event, Optional[torch.Tensor]]:
"""
``unsharded_grads`` owns the references to the gradients computed by
autograd, so clearing the list frees the gradients.
Expand Down Expand Up @@ -163,36 +165,43 @@ def foreach_reduce(
# computed in the default stream
current_stream.wait_stream(reduce_scatter_stream)
unsharded_grads.clear()
post_reduce_output = reduce_scatter_input.new_empty(
(reduce_scatter_output_numel,)
)
reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))
_div_if_needed(reduce_scatter_input, predivide_factor)
dist.reduce_scatter_tensor(
output=post_reduce_output,
output=reduce_output,
input=reduce_scatter_input,
group=reduce_scatter_group,
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
)
view_out_stream = reduce_scatter_stream
if all_reduce_group is not None:
view_out_stream = all_reduce_stream
all_reduce_stream.wait_stream(reduce_scatter_stream)
with torch.cuda.stream(all_reduce_stream):
dist.all_reduce(
post_reduce_output,
group=all_reduce_group,
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
)
with torch.cuda.stream(view_out_stream):
_div_if_needed(post_reduce_output, postdivide_factor)
post_reduce_output = _to_dtype_if_needed(post_reduce_output, orig_dtype)
# - View out and accumulate
post_reduce_stream = reduce_scatter_stream
if all_reduce_group is not None: # HSDP
# Accumulations must run in the reduce-scatter stream
if not all_reduce_grads:
if partial_reduce_output is not None:
partial_reduce_output += reduce_output
else:
partial_reduce_output = reduce_output
return post_reduce_stream.record_event(), partial_reduce_output
if partial_reduce_output is not None:
reduce_output += partial_reduce_output
post_reduce_stream = all_reduce_stream
all_reduce_stream.wait_stream(reduce_scatter_stream)
with torch.cuda.stream(all_reduce_stream):
dist.all_reduce(
reduce_output,
group=all_reduce_group,
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
)
with torch.cuda.stream(post_reduce_stream):
_div_if_needed(reduce_output, postdivide_factor)
reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype)
# View out and accumulate sharded gradients
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
for padded_unsharded_size, fsdp_param in zip(
padded_unsharded_sizes, fsdp_params
):
new_sharded_grad = torch.as_strided(
post_reduce_output,
reduce_output,
size=fsdp_param.sharded_size,
stride=fsdp_param.contiguous_sharded_stride,
storage_offset=flat_grad_offset,
Expand Down Expand Up @@ -220,12 +229,12 @@ def foreach_reduce(
fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
padded_sharded_numel = padded_unsharded_size.numel() // world_size
flat_grad_offset += padded_sharded_numel
post_reduce_view_out_event = view_out_stream.record_event()
post_reduce_event = post_reduce_stream.record_event()
# The RS output is allocated in the RS stream and used in the default
# stream (for optimizer). To ensure its memory is not reused for later
# RSs, we do not need extra synchronization since the sharded parameters
# hold refs through the end of backward.
return post_reduce_view_out_event
return post_reduce_event, None


def foreach_reduce_scatter_copy_in(
Expand Down
20 changes: 12 additions & 8 deletions torch/distributed/_composable/fsdp/_fsdp_param_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,15 @@ def __init__(
# Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of
# the group's post-backward (e.g. reduce-scatter, all-reduce and div), which
# should be waited on at the end of backward
self._post_reduce_view_out_event: Optional[torch.cuda.Event] = None
self._post_reduce_event: Optional[torch.cuda.Event] = None
# Holds the reshard-after-forward CUDA event when resharding to a
# different world size, which should be waited on in the next unshard
self._reshard_after_forward_event: Optional[torch.cuda.Event] = None

# Only for HSDP, if accumulating gradients without all-reduce, save the
# partial reduce output (only reduce-scattered but not all-reduced)
self._partial_reduce_output: Optional[torch.Tensor] = None

# Initialization #
def _init_mp_dtypes(self) -> None:
for fsdp_param in self.fsdp_params:
Expand Down Expand Up @@ -311,24 +315,24 @@ def post_backward(self, *unused: Any):
if len(fsdp_params_with_grad) == 0:
return
with torch.profiler.record_function("FSDP::post_backward_reduce"):
self._post_reduce_view_out_event = foreach_reduce(
self._post_reduce_event, self._partial_reduce_output = foreach_reduce(
fsdp_params_with_grad,
unsharded_grads,
self._reduce_scatter_process_group,
self.comm_ctx.reduce_scatter_stream,
self._orig_dtype,
self._reduce_dtype,
self.device,
self._all_reduce_process_group
if self._is_hsdp and self.all_reduce_grads
else None,
self._all_reduce_process_group if self._is_hsdp else None,
self.comm_ctx.all_reduce_stream,
self.all_reduce_grads,
self._partial_reduce_output,
)

def finalize_backward(self):
if self._post_reduce_view_out_event is not None:
torch.cuda.current_stream().wait_event(self._post_reduce_view_out_event)
self._post_reduce_view_out_event = None
if self._post_reduce_event is not None:
torch.cuda.current_stream().wait_event(self._post_reduce_event)
self._post_reduce_event = None
for fsdp_param in self.fsdp_params:
if fsdp_param.grad_offload_event is not None:
fsdp_param.grad_offload_event.synchronize()
Expand Down
9 changes: 3 additions & 6 deletions torch/distributed/_composable/fsdp/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def set_is_last_backward(self, is_last_backward: bool) -> None:
state._state_ctx.is_last_backward = is_last_backward

def set_requires_gradient_sync(
self, requires_gradient_sync: bool, recurse: bool = True
self, requires_gradient_sync: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should sync gradients. This can be used to implement
Expand All @@ -231,16 +231,13 @@ def set_requires_gradient_sync(
fsdp_param_group.all_reduce_grads = requires_gradient_sync

def set_requires_all_reduce(
self, requires_all_reduce: bool, recurse: bool = True
self, requires_all_reduce: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should all-reduce gradients. This can be used to
implement gradient accumulation with only reduce-scatter but not
all-reduce for HSDP.
"""
# TODO: post_reduce_output += fsdp_param.sharded_param.grad
# after reduce-scatter and before all-reduce
raise NotImplementedError("requires_all_reduce is not yet supported in HSDP")
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
Expand All @@ -250,7 +247,7 @@ def set_requires_all_reduce(
fsdp_param_group.all_reduce_grads = requires_all_reduce

def set_reshard_after_backward(
self, reshard_after_backward: bool, recurse: bool = True
self, reshard_after_backward: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should reshard parameters after backward. This can
Expand Down

0 comments on commit 5dd875a

Please sign in to comment.