From e30763869df2e7e494fbbb87611776b3da6dfbb2 Mon Sep 17 00:00:00 2001 From: Shihao Yin Date: Mon, 26 Jun 2023 09:35:44 +0800 Subject: [PATCH 1/6] fixes no graph recorded bug --- src/lightning/fabric/loggers/tensorboard.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 3a5b468f3a6b7..07724dfa17a74 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -265,6 +265,11 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None input_array = model._on_before_batch_transfer(input_array) # type: ignore[operator] input_array = model._apply_batch_transfer_handler(input_array) # type: ignore[operator] self.experiment.add_graph(model, input_array) + elif hasattr(model, "module") and callable(getattr(model, "_redirection_through_forward", None)): + # this is probably is a _FabricModule + self.experiment.add_graph(model.module, input_array) + else: + self.experiment.add_graph(model, input_array) @rank_zero_only def save(self) -> None: From f9c1af4c84b93d5f701ce17b9e6f4a4c0deb1e02 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 17 Jul 2023 15:37:05 +0200 Subject: [PATCH 2/6] unwrap --- src/lightning/fabric/loggers/tensorboard.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index b6a46650651c0..c69b6f037bb46 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -21,6 +21,7 @@ from torch import Tensor from torch.nn import Module +from lightning.fabric.wrappers import _unwrap_objects from lightning.fabric.loggers.logger import Logger, rank_zero_experiment from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict @@ -246,6 +247,7 @@ def log_hyperparams( # type: ignore[override] def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: model_example_input = getattr(model, "example_input_array", None) input_array = model_example_input if input_array is None else input_array + model = _unwrap_objects(model) if input_array is None: rank_zero_warn( @@ -265,9 +267,6 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None input_array = model._on_before_batch_transfer(input_array) # type: ignore[operator] input_array = model._apply_batch_transfer_handler(input_array) # type: ignore[operator] self.experiment.add_graph(model, input_array) - elif hasattr(model, "module") and callable(getattr(model, "_redirection_through_forward", None)): - # this is probably is a _FabricModule - self.experiment.add_graph(model.module, input_array) else: self.experiment.add_graph(model, input_array) From c44ced77e7977f27f54fceeb80a7ca1b95c2a6db Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 17 Jul 2023 15:47:48 +0200 Subject: [PATCH 3/6] add test --- tests/tests_fabric/loggers/test_tensorboard.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/tests_fabric/loggers/test_tensorboard.py b/tests/tests_fabric/loggers/test_tensorboard.py index 05889452ceeb2..243f08cd144bf 100644 --- a/tests/tests_fabric/loggers/test_tensorboard.py +++ b/tests/tests_fabric/loggers/test_tensorboard.py @@ -23,6 +23,7 @@ from lightning.fabric.loggers import TensorBoardLogger from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE +from lightning.fabric.wrappers import _FabricModule from tests_fabric.test_fabric import BoringModel @@ -153,8 +154,18 @@ def test_tensorboard_log_graph(tmpdir, example_input_array): if example_input_array is not None: model.example_input_array = None - logger = TensorBoardLogger(tmpdir, log_graph=True) + logger = TensorBoardLogger(tmpdir) + logger._experiment = Mock() logger.log_graph(model, example_input_array) + if example_input_array is not None: + logger.experiment.add_graph.assert_called_with(model, example_input_array) + logger._experiment.reset_mock() + + # model wrapped in `FabricModule` + wrapped = _FabricModule(model, precision=Mock()) + logger.log_graph(wrapped, example_input_array) + if example_input_array is not None: + logger.experiment.add_graph.assert_called_with(model, example_input_array) @pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE)) From ef31bf54de7d4f82b6b1bec4279eeb677a15ce55 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 17 Jul 2023 15:49:27 +0200 Subject: [PATCH 4/6] chlog --- src/lightning/fabric/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 750001f527954..a1b491528b7cd 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -150,6 +150,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed FSDP re-applying activation checkpointing when the user had manually applied it already ([#18006](https://github.com/Lightning-AI/lightning/pull/18006)) +- Fixed `TensorBoardLogger.log_graph` not unwrapping the `_FabricModule` ([#17844](https://github.com/Lightning-AI/lightning/pull/17844)) + + ## [2.0.5] - 2023-07-07 ### Added From ca95cd7aa94f10f4fea49265aefe427226c7524c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Jul 2023 13:51:26 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/loggers/tensorboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index c69b6f037bb46..d6b8981e02c22 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -21,13 +21,13 @@ from torch import Tensor from torch.nn import Module -from lightning.fabric.wrappers import _unwrap_objects from lightning.fabric.loggers.logger import Logger, rank_zero_experiment from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict from lightning.fabric.utilities.logger import _sanitize_params as _utils_sanitize_params from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn from lightning.fabric.utilities.types import _PATH +from lightning.fabric.wrappers import _unwrap_objects log = logging.getLogger(__name__) From ed19120f857ed81b4d027f10f7725fbcc2debe90 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 17 Jul 2023 18:37:00 +0200 Subject: [PATCH 6/6] mypy --- src/lightning/fabric/loggers/tensorboard.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index d6b8981e02c22..a3f84315c9185 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -264,8 +264,8 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None getattr(model, "_apply_batch_transfer_handler", None) ): # this is probably is a LightningModule - input_array = model._on_before_batch_transfer(input_array) # type: ignore[operator] - input_array = model._apply_batch_transfer_handler(input_array) # type: ignore[operator] + input_array = model._on_before_batch_transfer(input_array) + input_array = model._apply_batch_transfer_handler(input_array) self.experiment.add_graph(model, input_array) else: self.experiment.add_graph(model, input_array)