Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement no_sync for thunder.distributed.fsdp (PR2457) #45

Merged
merged 17 commits into from
May 3, 2024

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Mar 22, 2024

tldr

Enables no_sync for thunder.jit(thunder.distributed.fsdp(model)). The accompanied changes are:

  • new argument of return_none_instead_of_grads of ThunderFunction.forward
    • This could be eliminated once a TraceCtx's bound symbols are not deleted even if it just returns one or more Nones
  • removal of no_sync check before applying dist_prims.synchronize to args and kwargs
    • FSDP's forward needs this prim for its param AllGather
    • [ddp] visitor_transform removes dist_prims.all_reduce, dist_prims.wait, and preaveraging when no_sync
    • [fsdp] visitor_transform removes comms and puts dist_prims.stash_grad_for_fsdp and optional param AllGather when no_sync
      • The generated trace and its executable python code return unsynchronized unsharded gradients.
      • The prim's implementation accumulates the grads as param._thunder_fsdp_unsharded_grad.
      • ThunderFunction's backward returns Nones instead of such grads to avoid shape mismatch between params and unsharded grads.

as of fa61c49

  • llama-2-7b-hf
  • world size 8 H100s
  • micro batch size 1
  • global batch size 32
  • gradient accumulation 4
  • no bucketing (of AllGather and ReduceScatter)

zero2

command: torchrun --nproc-per-node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor --distributed_mode=fsdp --nsys_enabled=False --micro_batch_size=1 --global_batch_size=32 --skip_data_sync <false|true> --model_name=Llama-2-7b-hf --shard_mode=zero2 --bucketing_mode=none --json_path "<filename>.json" --return_metrics_as_json=true

w/ no_sync w/o no_sync
tokens/sec 82713.0 80341.0
memory consumption [GB] 65.6 40.3

zero3

command: torchrun --nproc-per-node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor --distributed_mode=fsdp --nsys_enabled=False --micro_batch_size=1 --global_batch_size=32 --skip_data_sync <false|true> --model_name=Llama-2-7b-hf --shard_mode=zero3 --bucketing_mode=none --json_path "<filename>.json" --return_metrics_as_json=true

w/ no_sync w/o no_sync
tokens/sec 77839.0 75511.9
memory consumption [GB] 52.5 27.1

def forward(ctx, compiled_backward, saved_tensors, saved_other, flat_output, *flat_args):
def forward(
ctx,
return_none_instead_of_grads,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[RFC] new argument to ThunderFunction.apply

Comment on lines +297 to +362
if self.skip_data_sync:
data_sync_ctx = self.model.no_sync
else:
data_sync_ctx = nullcontext
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thunder/distributed/__init__.py Outdated Show resolved Hide resolved
thunder/distributed/__init__.py Show resolved Hide resolved
@crcrpar
Copy link
Collaborator Author

crcrpar commented Apr 23, 2024

@t-vi this is ready for merge

@crcrpar
Copy link
Collaborator Author

crcrpar commented Apr 23, 2024

How is HF's Accelerate or LIghtning's Fabric wrapping PyTorch's no_sync?

@crcrpar crcrpar force-pushed the crpa/fsdp-no-sync branch 3 times, most recently from e9bde84 to 078e7b1 Compare May 1, 2024 06:44
crcrpar and others added 17 commits May 3, 2024 17:47
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for `ThunderFunction`

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Co-authored-by: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @crcrpar @IvanYashchuk

@t-vi t-vi merged commit 85b2cd8 into main May 3, 2024
37 of 39 checks passed
@t-vi t-vi deleted the crpa/fsdp-no-sync branch May 3, 2024 12:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants