You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File "test_init.py", line 42, in <module>
main()
File "test_init.py", line 38, in main
fabric.load_raw(llm_path, model)
File "lightning/fabric/fabric.py", line 816, in load_raw
self._strategy.load_checkpoint(path=path, state=obj, strict=strict)
File "lightning/fabric/strategies/fsdp.py", line 548, in load_checkpoint
_load_raw_module_state_from_path(path, module=state, world_size=self.world_size, strict=strict)
File "/lightning/fabric/strategies/fsdp.py", line 888, in _load_raw_module_state_from_path
_load_raw_module_state(state_dict=_lazy_load(path), module=module, world_size=world_size, strict=strict)
File "lightning/fabric/strategies/fsdp.py", line 896, in _load_raw_module_state
module.load_state_dict(state_dict, strict=strict)
File "torch/nn/modules/module.py", line 2153, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GPT:
size mismatch for lm_head.weight: copying a param with shape torch.Size([32000, 4096]) from checkpoint, the shape in current model is torch.Size([32768512]).
size mismatch for transformer.wte.weight: copying a param with shape torch.Size([32000, 4096]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.ln_f.weight: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([0]).
Looking into the implementation of fabric.load_raw, it seems to first unwrap compiled objects and then call FSDPStrategy.load_checkpoint, looking at the latter there is the following sentence in the doc string:
The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file.
which seems to contradict the doc of load_raw():
Use this for loading a raw PyTorch model checkpoint created without Fabric.
This is conceptually equivalent to ``obj.load_state_dict(torch.load(path))``, but is agnostic to the strategy
being used.
Very confused what the right approach is supposed to be.
The text was updated successfully, but these errors were encountered:
RuABraun
changed the title
Loading large models with FSDP and empty_init=True does not work
Loading large models with fabric, FSDP and empty_init=True does not work
May 1, 2024
Bug description
I'm following litgpt for how to load a large model with FSDP and I'm getting an error related to shapes. See example script and error message
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs
Environment
More info
Looking into the implementation of
fabric.load_raw
, it seems to first unwrap compiled objects and then callFSDPStrategy.load_checkpoint
, looking at the latter there is the following sentence in the doc string:which seems to contradict the doc of
load_raw()
:Very confused what the right approach is supposed to be.
The text was updated successfully, but these errors were encountered: