forked from Lightning-Universe/lightning-bolts
/
data_monitor.py
280 lines (223 loc) · 10.5 KB
/
data_monitor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np
import torch
import torch.nn as nn
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from torch import Tensor
from torch.utils.hooks import RemovableHandle
from pl_bolts.utils import _WANDB_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
if _WANDB_AVAILABLE:
import wandb
else: # pragma: no cover
warn_missing_pkg("wandb")
wandb = None
class DataMonitorBase(Callback):
supported_loggers = (
TensorBoardLogger,
WandbLogger,
)
def __init__(self, log_every_n_steps: int = None):
"""
Base class for monitoring data histograms in a LightningModule.
This requires a logger configured in the Trainer, otherwise no data is logged.
The specific class that inherits from this base defines what data gets collected.
Args:
log_every_n_steps: The interval at which histograms should be logged. This defaults to the
interval defined in the Trainer. Use this to override the Trainer default.
"""
super().__init__()
self._log_every_n_steps = log_every_n_steps
self._log = False
self._trainer = None
self._train_batch_idx = None
def on_train_start(self, trainer, pl_module):
self._log = self._is_logger_available(trainer.logger)
self._log_every_n_steps = self._log_every_n_steps or trainer.log_every_n_steps
self._trainer = trainer
def on_train_batch_start(
self, trainer, pl_module, batch, batch_idx, dataloader_idx
):
self._train_batch_idx = batch_idx
def log_histograms(self, batch, group="") -> None:
"""
Logs the histograms at the interval defined by `row_log_interval`, given a logger is available.
Args:
batch: torch or numpy arrays, or a collection of it (tuple, list, dict, ...), can be nested.
If the data appears in a dictionary, the keys are used as labels for the corresponding histogram.
Otherwise the histograms get labelled with an integer index.
Each label also has the tensors's shape as suffix.
group: Name under which the histograms will be grouped.
"""
if not self._log or (self._train_batch_idx + 1) % self._log_every_n_steps != 0:
return
batch = apply_to_collection(batch, dtype=np.ndarray, function=torch.from_numpy)
named_tensors = dict()
collect_and_name_tensors(batch, output=named_tensors, parent_name=group)
for name, tensor in named_tensors.items():
self.log_histogram(tensor, name)
def log_histogram(self, tensor: Tensor, name: str) -> None:
"""
Override this method to customize the logging of histograms.
Detaches the tensor from the graph and moves it to the CPU for logging.
Args:
tensor: The tensor for which to log a histogram
name: The name of the tensor as determined by the callback. Example: ``ìnput/0/[64, 1, 28, 28]``
"""
logger = self._trainer.logger
tensor = tensor.detach().cpu()
if isinstance(logger, TensorBoardLogger):
logger.experiment.add_histogram(
tag=name, values=tensor, global_step=self._trainer.global_step
)
if isinstance(logger, WandbLogger):
if not _WANDB_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `wandb` which is not installed yet."
)
logger.experiment.log(
data={name: wandb.Histogram(tensor)}, commit=False,
)
def _is_logger_available(self, logger) -> bool:
available = True
if not logger:
rank_zero_warn("Cannot log histograms because Trainer has no logger.")
available = False
if not isinstance(logger, self.supported_loggers):
rank_zero_warn(
f"{self.__class__.__name__} does not support logging with {logger.__class__.__name__}."
f" Supported loggers are: {', '.join(map(lambda x: str(x.__name__), self.supported_loggers))}"
)
available = False
return available
class ModuleDataMonitor(DataMonitorBase):
GROUP_NAME_INPUT = "input"
GROUP_NAME_OUTPUT = "output"
def __init__(
self,
submodules: Optional[Union[bool, List[str]]] = None,
log_every_n_steps: int = None,
):
"""
Args:
submodules: If `True`, logs the in- and output histograms of every submodule in the
LightningModule, including the root module itself.
This parameter can also take a list of names of specifc submodules (see example below).
Default: `None`, logs only the in- and output of the root module.
log_every_n_steps: The interval at which histograms should be logged. This defaults to the
interval defined in the Trainer. Use this to override the Trainer default.
Note:
A too low value for `log_every_n_steps` may have a significant performance impact
especially when many submodules are involved, since the logging occurs during the forward pass.
It should only be used for debugging purposes.
Example:
.. code-block:: python
# log the in- and output histograms of the `forward` in LightningModule
trainer = Trainer(callbacks=[ModuleDataMonitor()])
# all submodules in LightningModule
trainer = Trainer(callbacks=[ModuleDataMonitor(submodules=True)])
# specific submodules
trainer = Trainer(callbacks=[ModuleDataMonitor(submodules=["generator", "generator.conv1"])])
"""
super().__init__(log_every_n_steps=log_every_n_steps)
self._submodule_names = submodules
self._hook_handles = []
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
super().on_train_start(trainer, pl_module)
submodule_dict = dict(pl_module.named_modules())
self._hook_handles = []
for name in self._get_submodule_names(pl_module):
if name not in submodule_dict:
rank_zero_warn(
f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__},"
" skipping this key."
)
continue
handle = self._register_hook(name, submodule_dict[name])
self._hook_handles.append(handle)
def on_train_end(self, trainer, pl_module):
for handle in self._hook_handles:
handle.remove()
def _get_submodule_names(self, root_module: nn.Module) -> List[str]:
# default is the root module only
names = [""]
if isinstance(self._submodule_names, list):
names = self._submodule_names
if self._submodule_names is True:
names = [name for name, _ in root_module.named_modules()]
return names
def _register_hook(self, module_name: str, module: nn.Module) -> RemovableHandle:
input_group_name = (
f"{self.GROUP_NAME_INPUT}/{module_name}"
if module_name
else self.GROUP_NAME_INPUT
)
output_group_name = (
f"{self.GROUP_NAME_OUTPUT}/{module_name}"
if module_name
else self.GROUP_NAME_OUTPUT
)
def hook(_, inp, out):
inp = inp[0] if len(inp) == 1 else inp
self.log_histograms(inp, group=input_group_name)
self.log_histograms(out, group=output_group_name)
handle = module.register_forward_hook(hook)
return handle
class TrainingDataMonitor(DataMonitorBase):
GROUP_NAME = "training_step"
def __init__(self, log_every_n_steps: int = None):
"""
Callback that logs the histogram of values in the batched data passed to `training_step`.
Args:
log_every_n_steps: The interval at which histograms should be logged. This defaults to the
interval defined in the Trainer. Use this to override the Trainer default.
Example:
.. code-block:: python
# log histogram of training data passed to `LightningModule.training_step`
trainer = Trainer(callbacks=[TrainingDataMonitor()])
"""
super().__init__(log_every_n_steps=log_every_n_steps)
def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs):
super().on_train_batch_start(trainer, pl_module, batch, *args, **kwargs)
self.log_histograms(batch, group=self.GROUP_NAME)
def collect_and_name_tensors(
data: Any, output: Dict[str, Tensor], parent_name: str = "input"
) -> None:
"""
Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them.
Data in dictionaries get named by their corresponding keys and otherwise they get indexed by an
increasing integer. The shape of the tensor gets appended to the name as well.
Args:
data: A collection of data (potentially nested).
output: A dictionary in which the outputs will be stored.
parent_name: Used when called recursively on a nested input data.
Example:
>>> data = {"x": torch.zeros(2, 3), "y": {"z": torch.zeros(5)}, "w": 1}
>>> output = {}
>>> collect_and_name_tensors(data, output)
>>> output # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
{'input/x/[2, 3]': ..., 'input/y/z/[5]': ...}
"""
assert isinstance(output, dict)
if isinstance(data, Tensor):
name = f"{parent_name}/{shape2str(data)}"
output[name] = data
if isinstance(data, dict):
for k, v in data.items():
collect_and_name_tensors(v, output, parent_name=f"{parent_name}/{k}")
if isinstance(data, Sequence) and not isinstance(data, str):
for i, item in enumerate(data):
collect_and_name_tensors(item, output, parent_name=f"{parent_name}/{i:d}")
def shape2str(tensor: Tensor) -> str:
"""
Returns the shape of a tensor in bracket notation as a string.
Example:
>>> shape2str(torch.rand(1, 2, 3))
'[1, 2, 3]'
>>> shape2str(torch.rand(4))
'[4]'
"""
return "[" + ", ".join(map(str, tensor.shape)) + "]"