Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support accessing the module reference for the process group #96

Merged
merged 4 commits into from Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion thunder/distributed/__init__.py
Expand Up @@ -71,9 +71,23 @@ def skip_data_parallel_grad_sync() -> Generator[Any, Any, Any]:
def _sync_grads(module: torch.nn.Module) -> None:
import thunder

if hasattr(module, "process_group_for_ddp"):
# This branch is required when a function that takes the model as an input is jitted instead
# of the model itself. In that case, the user won't have a reference to a `ThunderModule` so this needs to use
# the reference set by ddp and fsdp on the module directly
process_group = module.process_group_for_ddp
elif (cd := thunder.compile_data(module)) is not None:
# The ordinary jitted module branch
process_group = cd.process_group_for_ddp
else:
raise RuntimeError(
f"Expected `{type(module).__name__}` to have been jitted or to contain a `process_group_for_ddp` attribute"
)

params_with_grad = [p for p in module.parameters() if p.grad is not None]
if not params_with_grad:
return
grads = [p.grad for p in params_with_grad]
process_group = thunder.compile_data(module).process_group_for_ddp
torch._foreach_div_(grads, process_group.size())
with tdist.distributed_c10d._coalescing_manager(group=process_group, async_ops=True) as cm:
for g in grads:
Expand Down
4 changes: 4 additions & 0 deletions thunder/tests/distributed/test_ddp.py
Expand Up @@ -880,6 +880,10 @@ def fwd_loss(m, x):
fwd_loss = thunder.jit(fwd_loss)
fwd_loss(model, x)

# notice how we cannot do `model.no_sync()` because it's not a ThunderModule
with thunder.ThunderModule.no_sync(model):
fwd_loss(model, x)


common_utils.instantiate_parametrized_tests(CompileDDPTest)

Expand Down