Skip to content

Avoid registering pytree when using FSDP #39325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

kaixuanliu
Copy link
Contributor

When using FSDP, this register_pytree_node operation will cost lots of extra memory. We found after this PR: #35873, we cannot finetune 70b model using FSDP due to OOM issue.

@kaixuanliu
Copy link
Contributor Author

@SunMarc @ArthurZucker @IlyasMoutawwakil pls help review, thx!

@IlyasMoutawwakil
Copy link
Member

IlyasMoutawwakil commented Jul 10, 2025

@kaixuanliu do you mean this PR #36311 where it was added ?
Do you have exact measures of how much it costs when FSDP is not used and when it's used, because the operation itself has nothing to do with FSDP.

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu
Copy link
Contributor Author

@IlyasMoutawwakil ,Oh yes, it's 36311, not 35873. And for the extra memory, I made an experiment: I use 4 processes to do FSDP finetune with llama2-7b model, and compare the maximum memory consumption for two configurations. Result shows w/ register_pytree_node operation, it will cost extra ~12GB memory for each card.

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for discovering that ! Left a comment

@@ -680,7 +681,8 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
return torch.fx._pytree._dict_flatten_spec(dictionary, spec)


if is_torch_greater_or_equal("2.3"):
# Register pytree node for DynamicCache if torch version is >= 2.3 and FSDP is not imported, FSDP will need more extra memory when using pytree node
if is_torch_greater_or_equal("2.3") and "torch.distributed.fsdp" not in sys.modules:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we register the pytree node for DynamicCache somewhere else so that we can perform a better check compared to just checking if "torch.distributed.fsdp" not in sys.modules ? cc @IlyasMoutawwakil @gante
We also have the is_fsdp_enabled function in modeling utils that could be used to perform the check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants