-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement no_sync
for thunder.distributed.fsdp
(PR2457)
#45
Conversation
0c7f3d6
to
f8f3ed4
Compare
f8f3ed4
to
01e7b08
Compare
01e7b08
to
2e3d5e2
Compare
6fcc38a
to
22c5736
Compare
thunder/executors/torch_autograd.py
Outdated
def forward(ctx, compiled_backward, saved_tensors, saved_other, flat_output, *flat_args): | ||
def forward( | ||
ctx, | ||
return_none_instead_of_grads, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[RFC] new argument to ThunderFunction.apply
if self.skip_data_sync: | ||
data_sync_ctx = self.model.no_sync | ||
else: | ||
data_sync_ctx = nullcontext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @parthmannan
8a2ad62
to
1ff7530
Compare
4a729fe
to
dd348cd
Compare
dd348cd
to
fa61c49
Compare
6652486
to
80e77c2
Compare
80e77c2
to
d9e80c4
Compare
d9e80c4
to
d39e450
Compare
c1604ca
to
ba55439
Compare
@t-vi this is ready for merge |
|
e9bde84
to
078e7b1
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for `ThunderFunction` Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Co-authored-by: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @crcrpar @IvanYashchuk
tldr
Enables
no_sync
forthunder.jit(thunder.distributed.fsdp(model))
. The accompanied changes are:return_none_instead_of_grads
ofThunderFunction.forward
TraceCtx
's bound symbols are not deleted even if it just returns one or moreNone
sno_sync
check before applyingdist_prims.synchronize
to args and kwargsvisitor_transform
removesdist_prims.all_reduce
,dist_prims.wait
, and preaveraging whenno_sync
visitor_transform
removes comms and putsdist_prims.stash_grad_for_fsdp
and optional paramAllGather
whenno_sync
param._thunder_fsdp_unsharded_grad
.ThunderFunction
'sbackward
returnsNone
s instead of such grads to avoid shape mismatch between params and unsharded grads.as of fa61c49
zero2
command:
torchrun --nproc-per-node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor --distributed_mode=fsdp --nsys_enabled=False --micro_batch_size=1 --global_batch_size=32 --skip_data_sync <false|true> --model_name=Llama-2-7b-hf --shard_mode=zero2 --bucketing_mode=none --json_path "<filename>.json" --return_metrics_as_json=true
no_sync
no_sync
zero3
command:
torchrun --nproc-per-node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor --distributed_mode=fsdp --nsys_enabled=False --micro_batch_size=1 --global_batch_size=32 --skip_data_sync <false|true> --model_name=Llama-2-7b-hf --shard_mode=zero3 --bucketing_mode=none --json_path "<filename>.json" --return_metrics_as_json=true
no_sync
no_sync