Skip to content

Commit

Permalink
Fix monkeypatching of _FabricModule methods (#19705)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Mar 27, 2024
1 parent 0fb267b commit ca6c94c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))

-
- Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped ([#19705](https://github.com/Lightning-AI/pytorch-lightning/pull/19705))

-

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def __setattr__(self, name: str, value: Any) -> None:
original_has_attr = hasattr(original_module, name)
# Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
# Can't use self.__getattr__ because it would pass through to the original module
fabric_has_attr = name in self.__dict__
fabric_has_attr = name in dir(self)

if not (original_has_attr or fabric_has_attr):
setattr(original_module, name, value)
Expand Down
20 changes: 20 additions & 0 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def __init__(self):

# Modify existing attribute on original_module
fabric_module.attribute = 101
# "attribute" is only in the original_module, so it shouldn't get set in the fabric_module
assert "attribute" not in fabric_module.__dict__
assert fabric_module.attribute == 101 # returns it from original_module
assert original_module.attribute == 101

# Check setattr of original_module
Expand All @@ -170,6 +173,23 @@ def __init__(self):
assert linear in fabric_module.modules()
assert linear in original_module.modules()

# Check monkeypatching of methods
fabric_module = _FabricModule(Mock(), Mock())
original = id(fabric_module.forward)
fabric_module.forward = lambda *_: None
assert id(fabric_module.forward) != original
# Check special methods
assert "__repr__" in dir(fabric_module)
assert "__repr__" not in fabric_module.__dict__
assert "__repr__" not in _FabricModule.__dict__
fabric_module.__repr__ = lambda *_: "test"
assert fabric_module.__repr__() == "test"
# needs to be monkeypatched on the class for `repr()` to change
assert repr(fabric_module) == "_FabricModule()"
with mock.patch.object(_FabricModule, "__repr__", return_value="test"):
assert fabric_module.__repr__() == "test"
assert repr(fabric_module) == "test"


def test_fabric_module_state_dict_access():
"""Test that state_dict access passes through to the original module."""
Expand Down

0 comments on commit ca6c94c

Please sign in to comment.