Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues regarding model saving and reading when I use lightning.Fabric multi-GPU and FSDP strategies #18881

Closed
Williamwsk opened this issue Oct 27, 2023 · 2 comments
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x

Comments

@Williamwsk
Copy link

Williamwsk commented Oct 27, 2023

Bug description

How to save the model weights when using lightning.Fabric() to accelerate the model training code? How to load the model weights trained by lightning.Fabric() when resume or distillation? What is the difference between the weight file saved by lightning.Fabric() and the weight file saved by pytorch code?

When I save the file using code like below:

fabric = Fabric(
    accelerator="cuda", precision="bf16-mixed",
    devices=4, strategy="fsdp"
)
fabric.launch()
# ...
state = {
    "model": model,
    "optimizer": optimizer
}
fabric.save("checkpoint.ckpt", state, override_checkpoint=True)

Since the training code will be executed multiple times depending on the number of GPUs, and an error is reported when adding override_checkpoint.

Traceback (most recent call last):
  File "/project/Code/Skin/train_distill_accelerate.py", line 421, in <module>
    main()
  File "/project/Code/Skin/train_distill_accelerate.py", line 337, in main
    fabric.save(os.path.join(model_dir, "model_iter_best.ckpt"), state, override_checkpoint=True)
TypeError: Fabric.save() got an unexpected keyword argument 'override_checkpoint'

When I load the file using code like below:

fabric = Fabric(
    accelerator="cuda", precision="bf16-mixed",
    devices=4, strategy="fsdp"
)
fabric.launch()
model, optimizer = fabric.setup(model, optimizer)

state = {
    "model": model,
    "optimizer": optimizer
}

fabric.load("checkpoint.ckpt", state)

The error is follow:

Traceback (most recent call last):
  File "/project/Code/Skin/train_distill_accelerate_load.py", line 427, in <module>
    main()
  File "/project/Code/Skin/train_distill_accelerate_load.py", line 248, in main
    fabric.load(weight_file, state)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 764, in load
    remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/lightning/fabric/strategies/fsdp.py", line 569, in load_checkpoint
    optim_state = load_sharded_optimizer_state_dict(
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/distributed/checkpoint/optimizer.py", line 294, in load_sharded_optimizer_state_dict
    state_dict[key] = _shard_tensor(
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/distributed/_shard/api.py", line 68, in _shard_tensor
    st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 181, in shard
    dist.scatter(
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3139, in scatter
    _check_tensor_list(scatter_list, "scatter_list")
  File "/root/anaconda3/envs/wsk-py310/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 807, in _check_tensor_list
    raise RuntimeError(
RuntimeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].

In addition, if I use 4 GPUs for training and 1 GPU for reading, I find that the training acceleration slows down and the training gets stuck.
Finally, when using Fully Sharded Data Parallelism (FSDP) along with multiple GPUs, the model is sharded. How can these sharded parts of the model be merged back together so that it can be loaded using torch.load()? Looking forward to your reply, thank you !

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): lightning.Fabric()
PyTorch Lightning Version :2.1.0
Lightning App Version :2.1.0
PyTorch Version :2.1.0+cu118
Python version :3.10
OS :ubuntu22.04
CUDA/cuDNN version: cuda11.8.0-cudnn8-devel
GPU models and configuration: multi 3090
How you installed Lightning(`conda`, `pip`, source): pip
#- Running environment of LightningApp (e.g. local, cloud): 

More info

No response

cc @awaelchli @carmocca

@Williamwsk Williamwsk added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 27, 2023
@awaelchli
Copy link
Member

I submitted a proposal #18884 to handle the situation where a previous checkpoint directory should be overwritten. See also my reply in #18873, where I raise a point against introducing a new argument (override_checkpoint=True does not exist, that's why you get the error).

Regarding the second error, could you provide a runnable example to reproduce this. I don't get in which circumstances this happens.

How can these sharded parts of the model be merged back together so that it can be loaded using torch.load()? Looking forward to your reply, thank you !

Something like this doesn't exist as a function directly at the moment. You can achieve this in two ways:

Either decide to use full state checkpoints in the first place (FSDPStrategy(state_dict_type="full")) or if you already have a sharded checkpoint, load it back into the model, then save it again (with full state dict type):

fabric = Fabric(strategy=FSDPStrategy(state_dict_type="full"))
model = fabric.setup(model)
state = {...}
fabric.load(sharded_path, state) # loads the sharded checkpoint
fabric.save(full_path, state) # now saves a consolidated checkpoint

There is currently no way to transform the sharded checkpoint directly (this will probably come in the future).

@awaelchli awaelchli added strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Oct 28, 2023
@fabienGenhealth
Copy link

Interesting. It was not clear to me from the doc that you could load a sharded checkpoint if you had set state_dict_type="full" in the Fabric constructor. Will try it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

3 participants