Skip to content

Commit cefc7f7

Browse files
authored
Feature/log computational graph (Lightning-AI#3003)
* add methods * log in trainer * add tests * changelog * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * text * added argument * update tests * fix styling * improve testing
1 parent 7b917de commit cefc7f7

File tree

7 files changed

+98
-5
lines changed

7 files changed

+98
-5
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5252

5353
- Added saving test predictions on multiple GPUs ([#2926](https://github.com/PyTorchLightning/pytorch-lightning/pull/2926))
5454

55+
- Auto log the computational graph for loggers that support this ([#3003](https://github.com/PyTorchLightning/pytorch-lightning/pull/3003))
56+
5557
### Changed
5658

5759
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
@@ -110,7 +112,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
110112

111113
- Fixed test metrics not being logged with `LoggerCollection` ([#2723](https://github.com/PyTorchLightning/pytorch-lightning/pull/2723))
112114

113-
- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))
115+
- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))
114116

115117
- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789))
116118

pytorch_lightning/loggers/base.py

+15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111

1212
from pytorch_lightning.utilities import rank_zero_only
13+
from pytorch_lightning.core.lightning import LightningModule
1314

1415

1516
class LightningLoggerBase(ABC):
@@ -220,6 +221,16 @@ def log_hyperparams(self, params: argparse.Namespace):
220221
params: :class:`~argparse.Namespace` containing the hyperparameters
221222
"""
222223

224+
def log_graph(self, model: LightningModule, input_array=None) -> None:
225+
"""
226+
Record model graph
227+
228+
Args:
229+
model: lightning model
230+
input_array: input passes to `model.forward`
231+
"""
232+
pass
233+
223234
def save(self) -> None:
224235
"""Save log data."""
225236
self._finalize_agg_metrics()
@@ -296,6 +307,10 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
296307
for logger in self._logger_iterable:
297308
logger.log_hyperparams(params)
298309

310+
def log_graph(self, model: LightningModule, input_array=None) -> None:
311+
for logger in self._logger_iterable:
312+
logger.log_graph(model, input_array)
313+
299314
def save(self) -> None:
300315
for logger in self._logger_iterable:
301316
logger.save()

pytorch_lightning/loggers/tensorboard.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from pytorch_lightning import _logger as log
1616
from pytorch_lightning.core.saving import save_hparams_to_yaml
1717
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
18-
from pytorch_lightning.utilities import rank_zero_only
18+
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
1919
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
20+
from pytorch_lightning.core.lightning import LightningModule
2021

2122
try:
2223
from omegaconf import Container, OmegaConf
@@ -47,6 +48,9 @@ class TensorBoardLogger(LightningLoggerBase):
4748
directory for existing versions, then automatically assigns the next available version.
4849
If it is a string then it is used as the run-specific subdirectory name,
4950
otherwise ``'version_${version}'`` is used.
51+
log_graph: Adds the computational graph to tensorboard. This requires that
52+
the user has defined the `self.example_input_array` attribute in their
53+
model.
5054
\**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor.
5155
5256
"""
@@ -56,11 +60,13 @@ def __init__(self,
5660
save_dir: str,
5761
name: Optional[str] = "default",
5862
version: Optional[Union[int, str]] = None,
63+
log_graph: bool = True,
5964
**kwargs):
6065
super().__init__()
6166
self._save_dir = save_dir
6267
self._name = name or ''
6368
self._version = version
69+
self._log_graph = log_graph
6470

6571
self._experiment = None
6672
self.hparams = {}
@@ -160,6 +166,24 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
160166
v = v.item()
161167
self.experiment.add_scalar(k, v, step)
162168

169+
@rank_zero_only
170+
def log_graph(self, model: LightningModule, input_array=None):
171+
if self._log_graph:
172+
if input_array is None:
173+
input_array = model.example_input_array
174+
175+
if input_array is not None:
176+
self.experiment.add_graph(
177+
model,
178+
model.transfer_batch_to_device(
179+
model.example_input_array, model.device)
180+
)
181+
else:
182+
rank_zero_warn('Could not log computational graph since the'
183+
' `model.example_input_array` attribute is not set'
184+
' or `input_array` was not given',
185+
UserWarning)
186+
163187
@rank_zero_only
164188
def save(self) -> None:
165189
super().save()

pytorch_lightning/loggers/test_tube.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
_TEST_TUBE_AVAILABLE = False
1414

1515
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
16-
from pytorch_lightning.utilities.distributed import rank_zero_only
16+
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
17+
from pytorch_lightning.core.lightning import LightningModule
1718

1819

1920
class TestTubeLogger(LightningLoggerBase):
@@ -51,7 +52,9 @@ class TestTubeLogger(LightningLoggerBase):
5152
version: Experiment version. If version is not specified the logger inspects the save
5253
directory for existing versions, then automatically assigns the next available version.
5354
create_git_tag: If ``True`` creates a git tag to save the code used in this experiment.
54-
55+
log_graph: Adds the computational graph to tensorboard. This requires that
56+
the user has defined the `self.example_input_array` attribute in their
57+
model.
5558
"""
5659

5760
__test__ = False
@@ -62,7 +65,8 @@ def __init__(self,
6265
description: Optional[str] = None,
6366
debug: bool = False,
6467
version: Optional[int] = None,
65-
create_git_tag: bool = False):
68+
create_git_tag: bool = False,
69+
log_graph=True):
6670

6771
if not _TEST_TUBE_AVAILABLE:
6872
raise ImportError('You want to use `test_tube` logger which is not installed yet,'
@@ -74,6 +78,7 @@ def __init__(self,
7478
self.debug = debug
7579
self._version = version
7680
self.create_git_tag = create_git_tag
81+
self._log_graph = log_graph
7782
self._experiment = None
7883

7984
@property
@@ -117,6 +122,24 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
117122
self.experiment.debug = self.debug
118123
self.experiment.log(metrics, global_step=step)
119124

125+
@rank_zero_only
126+
def log_graph(self, model: LightningModule, input_array=None):
127+
if self._log_graph:
128+
if input_array is None:
129+
input_array = model.example_input_array
130+
131+
if input_array is not None:
132+
self.experiment.add_graph(
133+
model,
134+
model.transfer_batch_to_device(
135+
model.example_input_array, model.device)
136+
)
137+
else:
138+
rank_zero_warn('Could not log computational graph since the'
139+
' `model.example_input_array` attribute is not set'
140+
' or `input_array` was not given',
141+
UserWarning)
142+
120143
@rank_zero_only
121144
def save(self) -> None:
122145
super().save()

pytorch_lightning/trainer/trainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,7 @@ def run_pretrain_routine(self, model: LightningModule):
11511151
if self.logger is not None:
11521152
# save exp to get started
11531153
self.logger.log_hyperparams(ref_model.hparams)
1154+
self.logger.log_graph(ref_model)
11541155
self.logger.save()
11551156

11561157
if self.use_ddp or self.use_ddp2:

tests/loggers/test_tensorboard.py

+25
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,28 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
156156

157157
metrics = {"abc": torch.tensor([0.54])}
158158
logger.log_hyperparams(hparams, metrics)
159+
160+
161+
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 28 * 28)])
162+
def test_tensorboard_log_graph(tmpdir, example_input_array):
163+
""" test that log graph works with both model.example_input_array and
164+
if array is passed externaly
165+
"""
166+
model = EvalModelTemplate()
167+
if example_input_array is None:
168+
model.example_input_array = None
169+
logger = TensorBoardLogger(tmpdir)
170+
logger.log_graph(model, example_input_array)
171+
172+
173+
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
174+
""" test that log graph throws warning if model.example_input_array is None """
175+
model = EvalModelTemplate()
176+
model.example_input_array = None
177+
logger = TensorBoardLogger(tmpdir)
178+
with pytest.warns(
179+
UserWarning,
180+
match='Could not log computational graph since the `model.example_input_array`'
181+
' attribute is not set or `input_array` was not given'
182+
):
183+
logger.log_graph(model)

tests/models/test_cpu.py

+3
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def train_dataloader(self):
346346
)
347347

348348
model = BpttTestModel(**hparams)
349+
model.example_input_array = torch.randn(5, truncated_bptt_steps)
349350

350351
# fit model
351352
trainer = Trainer(
@@ -424,6 +425,7 @@ def train_dataloader(self):
424425
)
425426

426427
model = BpttTestModel(**hparams)
428+
model.example_input_array = torch.randn(5, truncated_bptt_steps)
427429

428430
# fit model
429431
trainer = Trainer(
@@ -494,6 +496,7 @@ def train_dataloader(self):
494496
)
495497

496498
model = BpttTestModel(**hparams)
499+
model.example_input_array = torch.randn(5, truncated_bptt_steps)
497500

498501
# fit model
499502
trainer = Trainer(

0 commit comments

Comments
 (0)