Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does jit understand monkeypatched methods? #85

Closed
carmocca opened this issue Mar 27, 2024 · 1 comment
Closed

Does jit understand monkeypatched methods? #85

carmocca opened this issue Mar 27, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@carmocca
Copy link
Contributor

carmocca commented Mar 27, 2024

🐛 Bug

Tensor.register_hook is currently not supported by Thunder.

In Lightning Fabric, we use this once for error checking that the user properly called backward. https://github.com/Lightning-AI/pytorch-lightning/blob/096b063d6eeb41567409f4a6b9bac6f5af28ed93/src/lightning/fabric/wrappers.py#L232-L233

Since this hook is not critical, as it's only meant to avoid user errors, I would like to be able to monkeypatch it externally.

However, it doesn't seem like it has an effect with Thunder:

To Reproduce

import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
import thunder

model = torch.nn.Linear(1, 1, bias=False, device="cuda")
x = torch.randn(1, 1, device="cuda", requires_grad=True)

fabric = Fabric(accelerator="cuda", devices=1)
model = fabric.setup(model)

# monkeypatch what's causing trouble
assert isinstance(model, _FabricModule)
assert model._register_backward_hook is not None
model._register_backward_hook = lambda *_: None

model = thunder.jit(model)

y = model(x)
y.backward()
print(y)
print(x.grad)

Which fails as Thunder doesn't support register_hook

AttributeError: The torch language context has no method register_hook

Interestingly, a non-fabric snippet doesn't fail so there is something funny going on:

import thunder
import torch

class Wrapper(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(1, 1, bias=False)

    def forward(self, x):
        y = self.model(x)
        self.register_hook(y)
        return y

    def register_hook(self, tensor):
        tensor.register_hook(self.hook)

    def hook(self, _):
        print("hi")

model = Wrapper()
x = torch.randn(1, 1)

model.register_hook = lambda *_: None

model = thunder.jit(model)

y = model(x)
y.backward()
@carmocca carmocca added the bug Something isn't working label Mar 27, 2024
@carmocca
Copy link
Contributor Author

Ignore this. The answer is yes. Looks like this is a Fabric bug

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant