You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
How to save the model weights when using lightning.Fabric() to accelerate the model training code? How to load the model weights trained by lightning.Fabric() when resume or distillation? What is the difference between the weight file saved by lightning.Fabric() and the weight file saved by pytorch code?
In addition, if I use 4 GPUs for training and 1 GPU for reading, I find that the training acceleration slows down and the training gets stuck.
Finally, when using Fully Sharded Data Parallelism (FSDP) along with multiple GPUs, the model is sharded. How can these sharded parts of the model be merged back together so that it can be loaded using torch.load()? Looking forward to your reply, thank you !
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): lightning.Fabric()
PyTorch Lightning Version :2.1.0
Lightning App Version :2.1.0
PyTorch Version :2.1.0+cu118
Python version :3.10
OS :ubuntu22.04
CUDA/cuDNN version: cuda11.8.0-cudnn8-devel
GPU models and configuration: multi 3090
How you installed Lightning(`conda`, `pip`, source): pip
#- Running environment of LightningApp (e.g. local, cloud):
I submitted a proposal #18884 to handle the situation where a previous checkpoint directory should be overwritten. See also my reply in #18873, where I raise a point against introducing a new argument (override_checkpoint=True does not exist, that's why you get the error).
Regarding the second error, could you provide a runnable example to reproduce this. I don't get in which circumstances this happens.
How can these sharded parts of the model be merged back together so that it can be loaded using torch.load()? Looking forward to your reply, thank you !
Something like this doesn't exist as a function directly at the moment. You can achieve this in two ways:
Either decide to use full state checkpoints in the first place (FSDPStrategy(state_dict_type="full")) or if you already have a sharded checkpoint, load it back into the model, then save it again (with full state dict type):
fabric=Fabric(strategy=FSDPStrategy(state_dict_type="full"))
model=fabric.setup(model)
state= {...}
fabric.load(sharded_path, state) # loads the sharded checkpointfabric.save(full_path, state) # now saves a consolidated checkpoint
There is currently no way to transform the sharded checkpoint directly (this will probably come in the future).
Interesting. It was not clear to me from the doc that you could load a sharded checkpoint if you had set state_dict_type="full" in the Fabric constructor. Will try it.
Bug description
How to save the model weights when using lightning.Fabric() to accelerate the model training code? How to load the model weights trained by lightning.Fabric() when resume or distillation? What is the difference between the weight file saved by lightning.Fabric() and the weight file saved by pytorch code?
When I save the file using code like below:
Since the training code will be executed multiple times depending on the number of GPUs, and an error is reported when adding override_checkpoint.
When I load the file using code like below:
The error is follow:
In addition, if I use 4 GPUs for training and 1 GPU for reading, I find that the training acceleration slows down and the training gets stuck.
Finally, when using Fully Sharded Data Parallelism (FSDP) along with multiple GPUs, the model is sharded. How can these sharded parts of the model be merged back together so that it can be loaded using torch.load()? Looking forward to your reply, thank you !
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
More info
No response
cc @awaelchli @carmocca
The text was updated successfully, but these errors were encountered: