Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Deepspeed sharding and load from checkpoint with custom lightning module - setup() not called during checkpoint loading #290

Open
maxzvyagin opened this issue Oct 4, 2022 · 2 comments
Labels
question Further information is requested

Comments

@maxzvyagin
Copy link

maxzvyagin commented Oct 4, 2022

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

Hi, I'm doing training from scratch using deepspeed, pytorch lightning, and transformers in a multi node setting, and wanted to know how to setup the code to handle loading from a pytorch checkpoint.

Going off of the docs here, I see that the model is intended to be defined in setup(). However, this doesn't work when loading from a state dict since setup is not called. What's the right way to structure the code here? Does enable_transformers_pretrained_deepspeed_sharding need to be called in setup or can it be called in the constructor?

This has been my potential workaround in the constructor, because it does seem to fail on certain ranks

def __init__(self, config):
        # irrelevant constructor things here
        try:
            enable_transformers_pretrained_deepspeed_sharding(self)
        except AttributeError:
            pl.utilities.rank_zero.rank_zero_warn(
                "Transformers sharding initialization not enabled..."
            )
        # needed to load from checkpoint
         self.model = AutoModelForCausalLM.from_config(self.base_config)

As opposed to:

    def setup(self, stage):
        if not hasattr(self, "model"):
            try:
                enable_transformers_pretrained_deepspeed_sharding(self)
             ### sometimes using ddp for inference so this will fail
            except AttributeError:
                pl.utilities.rank_zero.rank_zero_warn(
                    "Transformers sharding initialization not enabled -  likely not using DeepSpeed..."
                )
            self.model = AutoModelForCausalLM.from_config(self.base_config)

Code

What have you tried?

What's your environment?

Linux, conda/pip,
deepspeed==0.7.3
pytorch-lightning==1.6.5
lighting-transformers==0.2.1

  • OS: [e.g. iOS, Linux, Win]
  • Packaging [e.g. pip, conda]
  • Version [e.g. 0.5.2.1]

Thanks in advance for the help!

@maxzvyagin maxzvyagin added the question Further information is requested label Oct 4, 2022
@uakarsh
Copy link

uakarsh commented Oct 15, 2022

Hi @maxzvyagin, what I can understand from your question is, performing the enable_, on the custom Lightning Module. For that, I think one simple and straight strategy would be inheriting the TaskTransformer class and modifying the initialize_model method, as mentioned here.

Do correct, me if I did not understand your question correctly.

Regards,
Akarsh

@maxzvyagin
Copy link
Author

Hi Akarsh, thanks for checking this out! I guess my question is partially what the reason is for calling enable_transformers_pretrained_deepspeed_sharding(self) in the setup() function vs in the constructor, and if there's a detrimental effect by calling it in the constructor if we're loading from a pre-trained checkpoint. Because otherwise, we're not able to load from our own PyTorch checkpoint file and continue training with a sharded DeepSpeed approach.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants