Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP: _optimizer_has_flat_params only checks first parameter group #17817

Closed
schmidt-ai opened this issue Jun 12, 2023 · 1 comment · Fixed by #17914
Closed

FSDP: _optimizer_has_flat_params only checks first parameter group #17817

schmidt-ai opened this issue Jun 12, 2023 · 1 comment · Fixed by #17914
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.0.x
Milestone

Comments

@schmidt-ai
Copy link
Contributor

schmidt-ai commented Jun 12, 2023

Bug description

The function _optimizer_has_flat_params only checks the first parameter group for fsdp_flattened parameters. There is an edge case where the first parameter group has no fsdp_flattened parameters but subesquent groups do. It would be a small change to this function to simply check all groups in optimizer.param_groups:

def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
    _FSDP_FLATTENED = "_fsdp_flattened"
    if _TORCH_GREATER_EQUAL_1_13:
-        return any(getattr(param, _FSDP_FLATTENED, False) for param in optimizer.param_groups[0]["params"])
+        return any(getattr(param, _FSDP_FLATTENED, False) for group in optimizer.param_group for param in group["params"])

    from torch.distributed.fsdp import FlatParameter
-    return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
+    return any(isinstance(param, FlatParameter) for group in optimizer.param_groups for param in group["params"])

What version are you seeing the problem on?

v2.0

How to reproduce the bug

No response

Error messages and logs

No response

Environment

Current environment
Python version: 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.178-162.673.amzn2.x86_64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 525.85.12
[pip3] numpy==1.24.3
[pip3] pytorch-lightning==2.0.3
[pip3] sagemaker-pytorch-training==2.8.0
[pip3] torch==2.0.1
[pip3] torchaudio==2.0.2
[pip3] torchdata==0.6.1
[pip3] torchelastic==0.2.2
[pip3] torchmetrics==0.11.4
[pip3] torchtext==0.15.2
[pip3] torchvision==0.15.2

More info

No response

cc @awaelchli @carmocca

@schmidt-ai schmidt-ai added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 12, 2023
@awaelchli
Copy link
Member

awaelchli commented Jun 12, 2023

Hi @schmidt-ai Thanks for reporting.
A PR for this would be very welcome. Are you open to contribute this fix?

@awaelchli awaelchli added strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Jun 12, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Jun 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants