Skip to content

Commit

Permalink
Fix torch.compile patching when applied as decorator (#19627)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 15, 2024
1 parent 7555384 commit 14e98ec
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))

-

Expand Down
10 changes: 7 additions & 3 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,14 @@ def _capture_compile_kwargs(compile_fn: Callable) -> Callable:
# PyTorch will resolve this in the future: https://github.com/pytorch/pytorch/issues/116575

@wraps(compile_fn)
def _capture(model: Any, **kwargs: Any) -> Any:
def _capture(*args: Any, **kwargs: Any) -> Any:
if not args or not isinstance(args[0], nn.Module):
# either torch.compile is being applied as a decorator or we're compiling something else
return compile_fn(*args, **kwargs)

model = args[0]
compiled_model = compile_fn(model, **kwargs)
if isinstance(model, nn.Module):
compiled_model._compile_kwargs = deepcopy(kwargs)
compiled_model._compile_kwargs = deepcopy(kwargs)
return compiled_model

return _capture
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a KeyError when saving a FSDP sharded checkpoint and setting `save_weights_only=True` ([#19524](https://github.com/Lightning-AI/pytorch-lightning/pull/19524))


-
- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))

-

Expand Down
12 changes: 12 additions & 0 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,9 @@ def normal_method(self):
def test_unwrap_compiled():
model = torch.nn.Linear(1, 1)

# We wrap `torch.compile` on import of lightning in `wrappers.py`
assert torch.compile.__wrapped__

with mock.patch("lightning.fabric.wrappers", "_TORCH_GREATER_EQUAL_2_0", False):
unwrapped, compile_kwargs = _unwrap_compiled(model)
assert unwrapped is model
Expand All @@ -615,3 +618,12 @@ def test_unwrap_compiled():
del compiled._compile_kwargs
with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"):
_unwrap_compiled(compiled)

# can still be applied as decorator
@torch.compile()
def cos(x):
return torch.cos(x)

@torch.compile
def sin(x):
return torch.sin(x)

0 comments on commit 14e98ec

Please sign in to comment.