Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DCP sees 1/2 of the expected size of each tensor in 3D parallel #126595

Closed
wconstab opened this issue May 18, 2024 · 3 comments
Closed

DCP sees 1/2 of the expected size of each tensor in 3D parallel #126595

wconstab opened this issue May 18, 2024 · 3 comments
Assignees
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@wconstab
Copy link
Contributor

wconstab commented May 18, 2024

Quite possibly this is a bug in the 3d parallel implementation itself, but i'm trying to debug why i see this warning, and subsequently fail with ValueError: Failed to validate global plan:

torch/distributed/checkpoint/default_planner.py:495] key:model.layers.0.attention.wq.weight invalid fill tensor-volume: 65536 chunks-volume: 32768

The repro is on the 8gpu CI for pytorch/torchtitan#344. (log link)

The same warning issues for every weight in the model. For the remainder i'll focus on just one, model.layers.0.attention.wq.weight

DCP sees the shape of wq.weight as [256, 256], which is the correct full shape of the wq.weight per the model code.

some debugging..

p key
[rank0]:(Pdb) [rank0]:'model.layers.0.attention.wq.weight'

p value
[rank0]:(Pdb) [rank0]:TensorStorageMetadata(properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False), size=torch.Size([256, 256]), chunks=[ChunkStorageMetadata(offsets=torch.Size([0, 0]), sizes=torch.Size([64, 256])), ChunkStorageMetadata(offsets=torch.Size([64, 0]), sizes=torch.Size([64, 256]))])

It looks like DCP only sees 2 chunks of size 64. I'm wondering if sharding for both fsdp and TP are happening on the same dim and one of those shardings is being ignored here?

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k @LucasLLC

@mikaylagawarecki mikaylagawarecki added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 20, 2024
@fegin
Copy link
Contributor

fegin commented May 21, 2024

What are the model.layers.0.attention.wq.weight._spec on all the ranks that have this weight when saving the model? I would like to understand whether the issue happens during saving or loading.

@wconstab
Copy link
Contributor Author

model.layers.0.attention.wq.weight._spec prints out on rank 0, 1, 2, 3 as
Spec((Shard(dim=0), Shard(dim=0)) on (256, 256)). (ranks 4-7 don't have this param).

@fegin
Copy link
Contributor

fegin commented May 28, 2024

#127071 will fix the issue.

@fegin fegin self-assigned this May 28, 2024
BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this issue May 31, 2024
…lattening when loading (pytorch#127071)

Fixes pytorch#126595

**What does this PR do?**
This PR unflattens the optimizer state_dict, similar to what TorchRec does. The current `get_optimizer_state_dict()` converts the parameter IDs to FQNs in order to avoid any conflict with different optimizers on different ranks. The current returned optimizer state_dict looks like the following one:
```
{
    "state": {
          "layer1.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor},
          "layer2.weight": {"step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor},
    },
    "param_group": [
         {"lr": 0.0, "betas": (0.9, 0.95), ..., "params": ["layer1.weight", "layer2.weight"]}
    ]
}
```
While this can avoid the conflict and can support merging multiple optimizers use case (e.g., optimizer in backward), the current optimizer state_dict still cannot support MPMD (e.g., pipeline parallelism). The root cause is `param_group`. `param_group` cannot generate unique keys during saving -- DCP will flatten the dict but for `param_group`, DCP will get the keys like, `param_group.lr` or `param_group.params`. These keys will conflict when using pipeline parallelism.

This PR flatten the optimizer state_dict to the one as the following one:
```
{
    "state.layer1.weight.step": 10,
    "state.layer2.weight.step": 10,
    "state.layer1.weight.exp_avg": SomeTensor,
    "state.layer2.weight.exp_avg": SomeTensor,
    "state.layer1.weight.exp_avg_sq": SomeTensor,
    "state.layer2.weight.exp_avg_sq": SomeTensor,
    "param_group.layer1.weight.lr" : 0.1,
    "param_group.layer2.weight.lr" : 0.1,
    "param_group.layer1.weight.betas" : (0.9, 0.95),
    "param_group.layer2.weight.betas" : (0.9, 0.95),
}
```
This allows distributed state_dict (DSD) to support MPMD (e.g., pipeline parallelism).

**Pros and Cons**
*Pros*
1. Can support optimizer resharding (e.g., changing the parallelisms from 3D to 2D or changing the number of workers).
2. User don't need to manually add prefix to different optimizer.
3. Allow users to merge the optimizer states easily. One use case is loop-based pipeline parallelism.

*Cons*
1. The implementation has a strong assumption of the structure of `param_groups` and its value. If the assumption changes or some customized optimizers do not meet the assumption, the implementations will be broken.
2. There will be extra values saved in the checkpoints. The assumption here is `param_group` generally contains scalars which are cheap to save.

Pull Request resolved: pytorch#127071
Approved by: https://github.com/wconstab, https://github.com/wz337
ghstack dependencies: pytorch#127070
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants