diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 750001f527954..a1b491528b7cd 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -150,6 +150,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed FSDP re-applying activation checkpointing when the user had manually applied it already ([#18006](https://github.com/Lightning-AI/lightning/pull/18006)) +- Fixed `TensorBoardLogger.log_graph` not unwrapping the `_FabricModule` ([#17844](https://github.com/Lightning-AI/lightning/pull/17844)) + + ## [2.0.5] - 2023-07-07 ### Added diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 7cd15042635c0..a3f84315c9185 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -27,6 +27,7 @@ from lightning.fabric.utilities.logger import _sanitize_params as _utils_sanitize_params from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn from lightning.fabric.utilities.types import _PATH +from lightning.fabric.wrappers import _unwrap_objects log = logging.getLogger(__name__) @@ -246,6 +247,7 @@ def log_hyperparams( # type: ignore[override] def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: model_example_input = getattr(model, "example_input_array", None) input_array = model_example_input if input_array is None else input_array + model = _unwrap_objects(model) if input_array is None: rank_zero_warn( @@ -262,8 +264,10 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None getattr(model, "_apply_batch_transfer_handler", None) ): # this is probably is a LightningModule - input_array = model._on_before_batch_transfer(input_array) # type: ignore[operator] - input_array = model._apply_batch_transfer_handler(input_array) # type: ignore[operator] + input_array = model._on_before_batch_transfer(input_array) + input_array = model._apply_batch_transfer_handler(input_array) + self.experiment.add_graph(model, input_array) + else: self.experiment.add_graph(model, input_array) @rank_zero_only diff --git a/tests/tests_fabric/loggers/test_tensorboard.py b/tests/tests_fabric/loggers/test_tensorboard.py index 05889452ceeb2..243f08cd144bf 100644 --- a/tests/tests_fabric/loggers/test_tensorboard.py +++ b/tests/tests_fabric/loggers/test_tensorboard.py @@ -23,6 +23,7 @@ from lightning.fabric.loggers import TensorBoardLogger from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE +from lightning.fabric.wrappers import _FabricModule from tests_fabric.test_fabric import BoringModel @@ -153,8 +154,18 @@ def test_tensorboard_log_graph(tmpdir, example_input_array): if example_input_array is not None: model.example_input_array = None - logger = TensorBoardLogger(tmpdir, log_graph=True) + logger = TensorBoardLogger(tmpdir) + logger._experiment = Mock() logger.log_graph(model, example_input_array) + if example_input_array is not None: + logger.experiment.add_graph.assert_called_with(model, example_input_array) + logger._experiment.reset_mock() + + # model wrapped in `FabricModule` + wrapped = _FabricModule(model, precision=Mock()) + logger.log_graph(wrapped, example_input_array) + if example_input_array is not None: + logger.experiment.add_graph.assert_called_with(model, example_input_array) @pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))