Skip to content

Commit

Permalink
Alternative mechanism to detect missing Fabric.backward() call (#19493
Browse files Browse the repository at this point in the history
)
  • Loading branch information
awaelchli committed Feb 27, 2024
1 parent ea89133 commit 7880c11
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 56 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))


- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447))
- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447), [#19493](https://github.com/Lightning-AI/lightning/pull/19493))


-
Expand Down
40 changes: 9 additions & 31 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

import lightning.fabric
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.connector import _PLUGIN_INPUT, _PRECISION_INPUT, _Connector, _is_using_cli
from lightning.fabric.loggers import Logger
Expand Down Expand Up @@ -142,7 +143,6 @@ def __init__(
self._loggers = loggers if isinstance(loggers, list) else [loggers]
self._models_setup: int = 0
self._launched: bool = False
self._backward_called: bool = False

self._prepare_run_method()
if _is_using_cli():
Expand Down Expand Up @@ -253,19 +253,15 @@ def setup(

if compile_kwargs is not None:
module = _to_compiled(module, compile_kwargs)
module = _FabricModule(module, self._precision, original_module=original_module)
self._require_fabric_backward(module)
module = _FabricModule(module, self._strategy, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
# NOTE: for sharded strategies or manual device placement, there's no single root device
_update_properties(
module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device
)

optimizers = [
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
for optimizer in optimizers
]
optimizers = [_FabricOptimizer(optimizer, self._strategy, self._callbacks) for optimizer in optimizers]

self._models_setup += 1

Expand Down Expand Up @@ -318,8 +314,7 @@ def setup_module(

if compile_kwargs is not None:
module = _to_compiled(module, compile_kwargs)
module = _FabricModule(module, self._precision, original_module=original_module)
self._require_fabric_backward(module)
module = _FabricModule(module, self._strategy, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
# NOTE: for sharded strategies or manual device placement, there's no single root device
Expand Down Expand Up @@ -448,9 +443,11 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] =
# requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call.
self._strategy._deepspeed_engine = module

self._backward_called = True
self._strategy.backward(tensor, module, *args, **kwargs)
self._backward_called = False
lightning.fabric.wrappers._in_fabric_backward = True
try:
self._strategy.backward(tensor, module, *args, **kwargs)
finally:
lightning.fabric.wrappers._in_fabric_backward = False

def clip_gradients(
self,
Expand Down Expand Up @@ -1092,25 +1089,6 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")

def _require_fabric_backward(self, module: _FabricModule) -> None:
strategy_requires = is_overridden("backward", self._strategy, parent=Strategy)
precision_requires = any(
is_overridden(method, self._precision, parent=Precision)
for method in ("pre_backward", "backward", "post_backward")
)

def _backward_hook(*_: Any, **__: Any) -> None:
if (strategy_requires or precision_requires) and not self._backward_called:
raise RuntimeError(
"The current strategy and precision selection requires you to call `fabric.backward(loss)`"
" instead of `loss.backward()`."
)

if _TORCH_GREATER_EQUAL_2_0:
module.register_full_backward_pre_hook(_backward_hook, prepend=True)
else:
module.register_full_backward_hook(_backward_hook)

@staticmethod
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
callbacks = callbacks if callbacks is not None else []
Expand Down
41 changes: 34 additions & 7 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from copy import deepcopy
from functools import wraps
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -31,6 +31,7 @@
)

import torch
from lightning_utilities import is_overridden
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch import nn as nn
Expand All @@ -53,6 +54,8 @@
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")

_in_fabric_backward: bool = False


class _FabricOptimizer:
def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None:
Expand Down Expand Up @@ -105,7 +108,7 @@ def __getattr__(self, item: Any) -> Any:

class _FabricModule(_DeviceDtypeModuleMixin):
def __init__(
self, forward_module: nn.Module, precision: Precision, original_module: Optional[nn.Module] = None
self, forward_module: nn.Module, strategy: Strategy, original_module: Optional[nn.Module] = None
) -> None:
"""The FabricModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
Expand All @@ -114,7 +117,7 @@ def __init__(
Args:
forward_module: The module to wrap the ``forward`` method on.
precision: Reference to the precision plugin for handling precision context
strategy: Reference to the strategy for handling precision etc.
original_module: The original, unmodified module as passed into the
:meth:`lightning.fabric.fabric.Fabric.setup` method. This is needed when attribute lookup
on this wrapper should pass through to the original module.
Expand All @@ -123,7 +126,7 @@ def __init__(
super().__init__()
self._forward_module = forward_module
self._original_module = original_module or forward_module
self._precision = precision
self._strategy = strategy
self._fabric_module_initialized = True

@property
Expand All @@ -133,12 +136,15 @@ def module(self) -> nn.Module:
@override
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
args, kwargs = self._precision.convert_input((args, kwargs))
precision = self._strategy.precision
args, kwargs = precision.convert_input((args, kwargs))

with self._precision.forward_context():
with precision.forward_context():
output = self._forward_module(*args, **kwargs)

output = self._precision.convert_output(output)
output = precision.convert_output(output)

apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook)
return output

@overload
Expand Down Expand Up @@ -214,6 +220,19 @@ def _wrapped_method(*args: Any, **kwargs: Any) -> Any:

return _wrapped_method

def _register_backward_hook(self, tensor: Tensor) -> Tensor:
if not tensor.requires_grad:
return tensor

strategy_requires = is_overridden("backward", self._strategy, parent=Strategy)
precision_requires = any(
is_overridden(method, self._strategy.precision, parent=Precision)
for method in ("pre_backward", "backward", "post_backward")
)
hook = partial(_backward_hook, (strategy_requires or precision_requires))
tensor.register_hook(hook)
return tensor

@override
def __getattr__(self, item: Any) -> Any:
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
Expand Down Expand Up @@ -347,6 +366,14 @@ def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> "Optimize
return torch.compile(module, **compile_kwargs) # type: ignore[return-value]


def _backward_hook(requires_backward: bool, *_: Any) -> None:
if requires_backward and not _in_fabric_backward:
raise RuntimeError(
"The current strategy and precision selection requires you to call `fabric.backward(loss)`"
" instead of `loss.backward()`."
)


def is_wrapped(obj: object) -> bool:
"""Checks if an object was set up by Fabric.
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ def thread_police_duuu_daaa_duuu_daaa():
raise AssertionError(f"Test left zombie thread: {thread}")


@pytest.fixture(autouse=True)
def reset_in_fabric_backward():
"""Ensures that the wrappers.in_fabric_backward global variable gets reset after each test."""
import lightning.fabric.wrappers as wrappers

assert hasattr(wrappers, "_in_fabric_backward")
yield
wrappers._in_fabric_backward = False


@pytest.fixture()
def reset_deterministic_algorithm():
"""Ensures that torch determinism settings are reset before the next test runs."""
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_fabric/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_tensorboard_log_graph(tmp_path, example_input_array):
logger._experiment.reset_mock()

# model wrapped in `FabricModule`
wrapped = _FabricModule(model, precision=Mock())
wrapped = _FabricModule(model, strategy=Mock())
logger.log_graph(wrapped, example_input_array)
if example_input_array is not None:
logger.experiment.add_graph.assert_called_with(model, example_input_array)
Expand Down
38 changes: 29 additions & 9 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
)
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
Expand Down Expand Up @@ -646,25 +645,27 @@ def test_backward_required(_, strategy, precision, error_expected, setup_method)

# One model
model1 = nn.Linear(2, 2)
assert not (model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks)
model1 = getattr(fabric, setup_method)(model1)
assert model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks
loss = model1(batch).sum()
output = model1(batch)
assert output._backward_hooks is not None
loss = output.sum()
with error_context:
loss.backward()
loss = model1(batch).sum()
assert not lightning.fabric.wrappers._in_fabric_backward
fabric.backward(loss) # no error
assert not fabric._backward_called
assert not lightning.fabric.wrappers._in_fabric_backward

# Two models chained
model2 = torch.nn.Linear(2, 2)
model2 = getattr(fabric, setup_method)(model2)
loss = model2(model1(batch)).sum()
output = model2(model1(batch))
assert output._backward_hooks is not None
loss = output.sum()
with error_context:
loss.backward()
loss = model2(model1(batch)).sum()
fabric.backward(loss) # no error
assert not fabric._backward_called

# Two independent models
loss1 = model1(batch).sum()
Expand All @@ -676,9 +677,28 @@ def test_backward_required(_, strategy, precision, error_expected, setup_method)
loss1 = model1(batch).sum()
loss2 = model2(batch).sum()
fabric.backward(loss1) # no error
assert not fabric._backward_called
fabric.backward(loss2) # no error
assert not fabric._backward_called

# Model that returns a datastructure of tensors
class DictReturnModel(nn.Linear):
def forward(self, x):
return {
"loss": super().forward(x).sum(),
"other": torch.rand(2, 2), # does not require grad
}

model3 = DictReturnModel(2, 2)
model3 = getattr(fabric, setup_method)(model3)
output = model3(batch)
loss = output["loss"]
other = output["other"]
assert loss._backward_hooks is not None
assert other._backward_hooks is None

with error_context:
(loss * 2).backward()
loss = model3(batch)["loss"]
fabric.backward(loss * 2) # no error


@RunIf(deepspeed=True, mps=False)
Expand Down
15 changes: 8 additions & 7 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def __init__(self, module):

# Regular case: forward_module == original_module -> no warnings
original_module = OriginalModule()
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
assert fabric_module.method_without_module_invocation() == 100

# Special case: original module wrapped by forward module: -> warn if method accepts args
original_module = OriginalModule()
wrapped_module = ModuleWrapper(original_module)
fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module)
fabric_module = _FabricModule(forward_module=wrapped_module, strategy=Mock(), original_module=original_module)
assert fabric_module.method_without_module_invocation() == 100
with pytest.raises(
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
Expand Down Expand Up @@ -254,7 +254,7 @@ def check_autocast(forward_input):
return forward_input

module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast)
fabric_module = _FabricModule(module, fabric._precision).to(device)
fabric_module = _FabricModule(module, fabric._strategy).to(device)
out = fabric_module(torch.tensor([1, 2, 3], dtype=input_type, device=device))
assert module.call_args[0][0].dtype == expected_type
assert out.dtype == input_type or out.dtype == torch.get_default_dtype()
Expand Down Expand Up @@ -560,10 +560,11 @@ def validation_step(self, arg, kwarg=None):
def normal_method(self):
pass

precision = Mock(wraps=Precision())
strategy = Mock()
strategy.precision = Mock(wraps=Precision())
original_module = LightningModule()
forward_module = DDP(original_module)
fabric_module = _FabricModule(forward_module=forward_module, precision=precision, original_module=original_module)
fabric_module = _FabricModule(forward_module=forward_module, strategy=strategy, original_module=original_module)

# Regular methods on the original_module are visible and identical on the fabric_module ...
assert fabric_module.normal_method.__wrapped__ == original_module.normal_method
Expand All @@ -585,13 +586,13 @@ def normal_method(self):
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
precision.forward_context.assert_called()
strategy.precision.forward_context.assert_called()

# The forward method remains untouched/unpatched after the special methods have been called
assert original_module.forward.__name__ == "forward"

# Special case: forward_module == original_module -> no special treatment applied
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
assert fabric_module.training_step == original_module.training_step
assert fabric_module.validation_step == original_module.validation_step

Expand Down

0 comments on commit 7880c11

Please sign in to comment.