-
Notifications
You must be signed in to change notification settings - Fork 29.6k
Avoid registering pytree when using FSDP #39325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@SunMarc @ArthurZucker @IlyasMoutawwakil pls help review, thx! |
@kaixuanliu do you mean this PR #36311 where it was added ? |
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@IlyasMoutawwakil ,Oh yes, it's 36311, not 35873. And for the extra memory, I made an experiment: I use 4 processes to do FSDP finetune with llama2-7b model, and compare the maximum memory consumption for two configurations. Result shows w/ |
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for discovering that ! Left a comment
@@ -680,7 +681,8 @@ def _flatten_dynamic_cache_for_fx(cache, spec): | |||
return torch.fx._pytree._dict_flatten_spec(dictionary, spec) | |||
|
|||
|
|||
if is_torch_greater_or_equal("2.3"): | |||
# Register pytree node for DynamicCache if torch version is >= 2.3 and FSDP is not imported, FSDP will need more extra memory when using pytree node | |||
if is_torch_greater_or_equal("2.3") and "torch.distributed.fsdp" not in sys.modules: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we register the pytree node for DynamicCache somewhere else so that we can perform a better check compared to just checking if "torch.distributed.fsdp" not in sys.modules ? cc @IlyasMoutawwakil @gante
We also have the is_fsdp_enabled
function in modeling utils that could be used to perform the check.
When using FSDP, this
register_pytree_node
operation will cost lots of extra memory. We found after this PR: #35873, we cannot finetune 70b model using FSDP due to OOM issue.