Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid false-positive warnings about method calls on the Fabric-wrapped module #18819

Merged
merged 13 commits into from Oct 23, 2023

Conversation

awaelchli
Copy link
Member

@awaelchli awaelchli commented Oct 18, 2023

What does this PR do?

This PR implements a better way to warn users in case they call a method on the wrapper.

A warning is currently printed on any call to a method that is not forward(), given that the original module is wrapped by a distributed wrapper like DDP, FSDP or similar. See an example in Lightning-AI/litgpt#641
This PR restricts this warning to only the case when an actual module forward call occurs. A call to a method that does not run inputs through the submodules is not considered problematic anymore.

I ran a quick benchmark to measure the impact of adding the model hooks:

import lightning as L
import torch
from lightning.pytorch.demos import Transformer


class NewTransformer(Transformer):
    def new_forward(self, input, target):
        return self(input, target)


def run():
    fabric = L.Fabric(accelerator="cuda", devices=1, strategy="ddp")
    fabric.seed_everything(1)
    fabric.launch()
    model = NewTransformer(ninp=2, nhead=1, nhid=2, nlayers=1, vocab_size=4)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    model, optimizer = fabric.setup(model, optimizer)
    model.eval()
    input = torch.randint(0, 4, (1, 35), device=fabric.device)
    target = torch.randint(0, 4, (1, 35), device=fabric.device)
    with torch.inference_mode():
        for _ in range(100):
            model.new_forward(input, target)


if __name__ == "__main__":
    run()

Times measured with:

python -m timeit -r 10 -n 10  -vv  'from speed_test import run; run()'

GPU benchmark
master:
10 loops, best of 10: 184.9 msec per loop
branch:
10 loops, best of 10: 208.9 msec per loop

CPU benchmark
master:
10 loops, best of 10: 151 msec per loop
branch:
10 loops, best of 10: 173.7 msec per loop

The benchmarks show a ~20 msec impact when running with the hooks in this extreme example of a super small model. On regular-sized models, the impact is not measurable.


📚 Documentation preview 📚: https://pytorch-lightning--18819.org.readthedocs.build/en/18819/

cc @Borda @carmocca @justusschock @awaelchli

@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Oct 18, 2023
@awaelchli awaelchli changed the title Better non-forward() warnings for Fabric Avoid false-positive warnings about method calls on the Fabric-wrapped module Oct 19, 2023
@awaelchli awaelchli added the feature Is an improvement or enhancement label Oct 21, 2023
@awaelchli awaelchli added this to the 2.1.x milestone Oct 21, 2023
@awaelchli awaelchli added the fun Staff contributions outside working hours - to differentiate from the "community" label label Oct 21, 2023
@awaelchli awaelchli marked this pull request as ready for review October 21, 2023 23:47
@github-actions
Copy link
Contributor

github-actions bot commented Oct 21, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/wrappers.py, tests/tests_pytorch/checkpointing/test_model_checkpoint.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to tests/tests_pytorch/checkpointing/test_model_checkpoint.py, src/lightning/fabric/wrappers.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/fabric/wrappers.py.

🟢 fabric: Docs
Check ID Status
docs-make (fabric, doctest) success
docs-make (fabric, html) success

These checks are required after the changes to src/lightning/fabric/wrappers.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.11, 2.1) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1) success
fabric-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.11, 2.1) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success
fabric-cpu (macOS-12, fabric, 3.11, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.1) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1) success
fabric-cpu (windows-2022, fabric, 3.11, 2.0) success
fabric-cpu (windows-2022, fabric, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/wrappers.py, tests/tests_fabric/test_wrappers.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) (testing Fabric | latest) success
lightning-fabric (GPUs) (testing Lightning | latest) success

These checks are required after the changes to src/lightning/fabric/wrappers.py, tests/tests_fabric/test_wrappers.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/wrappers.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/fabric/wrappers.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@awaelchli awaelchli requested a review from Borda as a code owner October 22, 2023 00:12
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Oct 22, 2023
src/lightning/fabric/wrappers.py Outdated Show resolved Hide resolved
src/lightning/fabric/wrappers.py Outdated Show resolved Hide resolved
@mergify mergify bot added the ready PRs ready to be merged label Oct 22, 2023
@awaelchli awaelchli merged commit 97303b0 into master Oct 23, 2023
119 checks passed
@awaelchli awaelchli deleted the fabric/forward-warnings branch October 23, 2023 02:26
tsenst pushed a commit to tsenst/lightning that referenced this pull request Oct 26, 2023
Borda pushed a commit that referenced this pull request Nov 2, 2023
lantiga pushed a commit that referenced this pull request Nov 6, 2023
@ardywibowo
Copy link

This seems to be breaking model.generate(...) for HuggingFace models. Any ideas?

RuntimeError: You are calling the method `BloomForCausalLM.generate()` from outside the model. This will bypass the wrapper from the strategy and result in incorrect behavior in `.backward()`. You should pass your inputs through `forward()`.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement fun Staff contributions outside working hours - to differentiate from the "community" label pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants