Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading large models with fabric, FSDP and empty_init=True does not work #19833

Open
RuABraun opened this issue May 1, 2024 · 1 comment
Open
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers

Comments

@RuABraun
Copy link

RuABraun commented May 1, 2024

Bug description

I'm following litgpt for how to load a large model with FSDP and I'm getting an error related to shapes. See example script and error message

What version are you seeing the problem on?

v2.2

How to reproduce the bug

#!/usr/bin/env python
import os.path as osp
import torch
from lightning.fabric.strategies import DDPStrategy, FSDPStrategy
from litgpt.lora import Block as LoraBlock
from litgpt.model import Block
import lightning as L
from litgpt.model import GPT
from litgpt.lora import Config

from torch.distributed.fsdp import MixedPrecision

def main():

    strategy = FSDPStrategy(auto_wrap_policy={Block, LoraBlock},
                                sharding_strategy='HYBRID_SHARD',
                                mixed_precision=torch.distributed.fsdp.MixedPrecision(param_dtype=torch.bfloat16),
                                use_orig_params=True)

    fabric = L.Fabric(
        accelerator="gpu",
        num_nodes=1,
        devices=8,
        strategy=strategy,
        precision='bf16-mixed')
    
    fabric.launch()

    llm_config_path = '.../checkpoints/mistralai/Mistral-7B-v0.1/model_config.yaml'
    config = Config.from_file(llm_config_path)
    with fabric.init_module(empty_init=True):
        model = GPT(config)
    print('setting up')
    fabric.setup_module(model)

    llm_path = osp.join(osp.dirname(llm_config_path), "lit_model.pth")
    fabric.load_raw(llm_path, model)
    print('done')


main()

Error messages and logs

  File "test_init.py", line 42, in <module>
    main()
  File "test_init.py", line 38, in main
    fabric.load_raw(llm_path, model)
  File "lightning/fabric/fabric.py", line 816, in load_raw
    self._strategy.load_checkpoint(path=path, state=obj, strict=strict)
  File "lightning/fabric/strategies/fsdp.py", line 548, in load_checkpoint
    _load_raw_module_state_from_path(path, module=state, world_size=self.world_size, strict=strict)
  File "/lightning/fabric/strategies/fsdp.py", line 888, in _load_raw_module_state_from_path
    _load_raw_module_state(state_dict=_lazy_load(path), module=module, world_size=world_size, strict=strict)
  File "lightning/fabric/strategies/fsdp.py", line 896, in _load_raw_module_state
    module.load_state_dict(state_dict, strict=strict)
  File "torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GPT:
        size mismatch for lm_head.weight: copying a param with shape torch.Size([32000, 4096]) from checkpoint, the shape in current model is torch.Size([32768512]).
        size mismatch for transformer.wte.weight: copying a param with shape torch.Size([32000, 4096]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.ln_f.weight: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([0]).

Environment

	- lightning:         2.2.3
	- lightning-cloud:   0.5.64
	- lightning-utilities: 0.9.0
	- lion-pytorch:      0.1.4
	- pytorch-lightning: 2.2.1
	- pytorch-wpe:       0.0.1
	- torch:             2.2.2+cu121

More info

Looking into the implementation of fabric.load_raw, it seems to first unwrap compiled objects and then call ‎FSDPStrategy.load_checkpoint, looking at the latter there is the following sentence in the doc string:

The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file.

which seems to contradict the doc of load_raw():

Use this for loading a raw PyTorch model checkpoint created without Fabric.
        This is conceptually equivalent to ``obj.load_state_dict(torch.load(path))``, but is agnostic to the strategy
        being used.

Very confused what the right approach is supposed to be.

@RuABraun RuABraun added bug Something isn't working needs triage Waiting to be triaged by maintainers labels May 1, 2024
@RuABraun RuABraun changed the title Loading large models with FSDP and empty_init=True does not work Loading large models with fabric, FSDP and empty_init=True does not work May 1, 2024
@Nilabhra
Copy link

Hi! I am facing a similar issue which might be related. Is there anything I can do at the moment to make this work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants