This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
tensorboard_writer.py
341 lines (301 loc) · 15.7 KB
/
tensorboard_writer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
from typing import Any, Callable, Dict, List, Optional, Set
import logging
import os
from tensorboardX import SummaryWriter
import torch
from allennlp.common.from_params import FromParams
from allennlp.data.dataloader import TensorDict
from allennlp.nn import util as nn_util
from allennlp.training.optimizers import Optimizer
from allennlp.training import util as training_util
from allennlp.models.model import Model
logger = logging.getLogger(__name__)
class TensorboardWriter(FromParams):
"""
Class that handles Tensorboard (and other) logging.
# Parameters
serialization_dir : `str`, optional (default = `None`)
If provided, this is where the Tensorboard logs will be written.
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"tensorboard_writer", it gets passed in separately.
summary_interval : `int`, optional (default = `100`)
Most statistics will be written out only every this many batches.
histogram_interval : `int`, optional (default = `None`)
If provided, activation histograms will be written out every this many batches.
If None, activation histograms will not be written out.
When this parameter is specified, the following additional logging is enabled:
* Histograms of model parameters
* The ratio of parameter update norm to parameter norm
* Histogram of layer activations
We log histograms of the parameters returned by
`model.get_parameters_for_histogram_tensorboard_logging`.
The layer activations are logged for any modules in the `Model` that have
the attribute `should_log_activations` set to `True`. Logging
histograms requires a number of GPU-CPU copies during training and is typically
slow, so we recommend logging histograms relatively infrequently.
Note: only Modules that return tensors, tuples of tensors or dicts
with tensors as values currently support activation logging.
batch_size_interval : `int`, optional, (default = `None`)
If defined, how often to log the average batch size.
should_log_parameter_statistics : `bool`, optional (default = `True`)
Whether to log parameter statistics (mean and standard deviation of parameters and
gradients).
should_log_learning_rate : `bool`, optional (default = `False`)
Whether to log (parameter-specific) learning rate.
get_batch_num_total : `Callable[[], int]`, optional (default = `None`)
A thunk that returns the number of batches so far. Most likely this will
be a closure around an instance variable in your `Trainer` class. Because of circular
dependencies in constructing this object and the `Trainer`, this is typically `None` when
you construct the object, but it gets set inside the constructor of our `Trainer`.
"""
def __init__(
self,
serialization_dir: Optional[str] = None,
summary_interval: int = 100,
histogram_interval: int = None,
batch_size_interval: Optional[int] = None,
should_log_parameter_statistics: bool = True,
should_log_learning_rate: bool = False,
get_batch_num_total: Callable[[], int] = None,
) -> None:
if serialization_dir is not None:
# Create log directories prior to creating SummaryWriter objects
# in order to avoid race conditions during distributed training.
train_ser_dir = os.path.join(serialization_dir, "log", "train")
os.makedirs(train_ser_dir, exist_ok=True)
self._train_log = SummaryWriter(train_ser_dir)
val_ser_dir = os.path.join(serialization_dir, "log", "validation")
os.makedirs(val_ser_dir, exist_ok=True)
self._validation_log = SummaryWriter(val_ser_dir)
else:
self._train_log = self._validation_log = None
self._summary_interval = summary_interval
self._histogram_interval = histogram_interval
self._batch_size_interval = batch_size_interval
self._should_log_parameter_statistics = should_log_parameter_statistics
self._should_log_learning_rate = should_log_learning_rate
self.get_batch_num_total = get_batch_num_total
self._cumulative_batch_group_size = 0
self._batches_this_epoch = 0
self._histogram_parameters: Optional[Set[str]] = None
@staticmethod
def _item(value: Any):
if hasattr(value, "item"):
val = value.item()
else:
val = value
return val
def log_memory_usage(self, cpu_memory_usage: Dict[int, int], gpu_memory_usage: Dict[int, int]):
cpu_memory_usage_total = 0.0
for worker, mem_bytes in cpu_memory_usage.items():
memory = mem_bytes / (1024 * 1024)
self.add_train_scalar(f"memory_usage/worker_{worker}_cpu", memory)
cpu_memory_usage_total += memory
self.add_train_scalar("memory_usage/cpu", cpu_memory_usage_total)
for gpu, mem_bytes in gpu_memory_usage.items():
memory = mem_bytes / (1024 * 1024)
self.add_train_scalar(f"memory_usage/gpu_{gpu}", memory)
def log_batch(
self,
model: Model,
optimizer: Optimizer,
batch_grad_norm: Optional[float],
metrics: Dict[str, float],
batch_group: List[List[TensorDict]],
param_updates: Optional[Dict[str, torch.Tensor]],
) -> None:
if self.should_log_this_batch():
self.log_parameter_and_gradient_statistics(model, batch_grad_norm)
self.log_learning_rates(model, optimizer)
self.add_train_scalar("loss/loss_train", metrics["loss"])
self.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()})
if self.should_log_histograms_this_batch():
assert param_updates is not None
self.log_histograms(model)
self.log_gradient_updates(model, param_updates)
if self._batch_size_interval:
# We're assuming here that `log_batch` will get called every batch, and only every
# batch. This is true with our current usage of this code (version 1.0); if that
# assumption becomes wrong, this code will break.
batch_group_size = sum(training_util.get_batch_size(batch) for batch in batch_group) # type: ignore
self._batches_this_epoch += 1
self._cumulative_batch_group_size += batch_group_size
if (self._batches_this_epoch - 1) % self._batch_size_interval == 0:
average = self._cumulative_batch_group_size / self._batches_this_epoch
logger.info(f"current batch size: {batch_group_size} mean batch size: {average}")
self.add_train_scalar("current_batch_size", batch_group_size)
self.add_train_scalar("mean_batch_size", average)
def reset_epoch(self) -> None:
self._cumulative_batch_group_size = 0
self._batches_this_epoch = 0
def should_log_this_batch(self) -> bool:
assert self.get_batch_num_total is not None
return self.get_batch_num_total() % self._summary_interval == 0
def should_log_histograms_this_batch(self) -> bool:
assert self.get_batch_num_total is not None
return (
self._histogram_interval is not None
and self.get_batch_num_total() % self._histogram_interval == 0
)
def add_train_scalar(self, name: str, value: float, timestep: int = None) -> None:
assert self.get_batch_num_total is not None
timestep = timestep or self.get_batch_num_total()
# get the scalar
if self._train_log is not None:
self._train_log.add_scalar(name, self._item(value), timestep)
def add_train_histogram(self, name: str, values: torch.Tensor) -> None:
assert self.get_batch_num_total is not None
if self._train_log is not None:
if isinstance(values, torch.Tensor):
values_to_write = values.cpu().data.numpy().flatten()
self._train_log.add_histogram(name, values_to_write, self.get_batch_num_total())
def add_validation_scalar(self, name: str, value: float, timestep: int = None) -> None:
assert self.get_batch_num_total is not None
timestep = timestep or self.get_batch_num_total()
if self._validation_log is not None:
self._validation_log.add_scalar(name, self._item(value), timestep)
def log_parameter_and_gradient_statistics(
self, model: Model, batch_grad_norm: float = None
) -> None:
"""
Send the mean and std of all parameters and gradients to tensorboard, as well
as logging the average gradient norm.
"""
if self._should_log_parameter_statistics:
# Log parameter values to Tensorboard
for name, param in model.named_parameters():
if param.data.numel() > 0:
self.add_train_scalar("parameter_mean/" + name, param.data.mean().item())
if param.data.numel() > 1:
self.add_train_scalar("parameter_std/" + name, param.data.std().item())
if param.grad is not None:
if param.grad.is_sparse:
grad_data = param.grad.data._values()
else:
grad_data = param.grad.data
# skip empty gradients
if torch.prod(torch.tensor(grad_data.shape)).item() > 0:
self.add_train_scalar("gradient_mean/" + name, grad_data.mean())
if grad_data.numel() > 1:
self.add_train_scalar("gradient_std/" + name, grad_data.std())
else:
# no gradient for a parameter with sparse gradients
logger.info("No gradient for %s, skipping tensorboard logging.", name)
# norm of gradients
if batch_grad_norm is not None:
self.add_train_scalar("gradient_norm", batch_grad_norm)
def log_learning_rates(self, model: Model, optimizer: Optimizer):
"""
Send current parameter specific learning rates to tensorboard
"""
if self._should_log_learning_rate:
# optimizer stores lr info keyed by parameter tensor
# we want to log with parameter name
names = {param: name for name, param in model.named_parameters()}
for group in optimizer.param_groups:
if "lr" not in group:
continue
rate = group["lr"]
for param in group["params"]:
# check whether params has requires grad or not
effective_rate = rate * float(param.requires_grad)
self.add_train_scalar("learning_rate/" + names[param], effective_rate)
def log_histograms(self, model: Model) -> None:
"""
Send histograms of parameters to tensorboard.
"""
if not self._histogram_parameters:
# Avoiding calling this every batch. If we ever use two separate models with a single
# writer, this is wrong, but I doubt that will ever happen.
self._histogram_parameters = set(
model.get_parameters_for_histogram_tensorboard_logging()
)
for name, param in model.named_parameters():
if name in self._histogram_parameters:
self.add_train_histogram("parameter_histogram/" + name, param)
def log_gradient_updates(self, model: Model, param_updates: Dict[str, torch.Tensor]) -> None:
for name, param in model.named_parameters():
update_norm = torch.norm(param_updates[name].view(-1))
param_norm = torch.norm(param.view(-1)).cpu()
self.add_train_scalar(
"gradient_update/" + name,
update_norm / (param_norm + nn_util.tiny_value_of_dtype(param_norm.dtype)),
)
def log_metrics(
self,
train_metrics: dict,
val_metrics: dict = None,
epoch: int = None,
log_to_console: bool = False,
) -> None:
"""
Sends all of the train metrics (and validation metrics, if provided) to tensorboard.
"""
metric_names = set(train_metrics.keys())
if val_metrics is not None:
metric_names.update(val_metrics.keys())
val_metrics = val_metrics or {}
# For logging to the console
if log_to_console:
dual_message_template = "%s | %8.3f | %8.3f"
no_val_message_template = "%s | %8.3f | %8s"
no_train_message_template = "%s | %8s | %8.3f"
header_template = "%s | %-10s"
name_length = max(len(x) for x in metric_names)
logger.info(header_template, "Training".rjust(name_length + 13), "Validation")
for name in sorted(metric_names):
# Log to tensorboard
train_metric = train_metrics.get(name)
if train_metric is not None:
self.add_train_scalar(name, train_metric, timestep=epoch)
val_metric = val_metrics.get(name)
if val_metric is not None:
self.add_validation_scalar(name, val_metric, timestep=epoch)
# And maybe log to console
if log_to_console and val_metric is not None and train_metric is not None:
logger.info(
dual_message_template, name.ljust(name_length), train_metric, val_metric
)
elif log_to_console and val_metric is not None:
logger.info(no_train_message_template, name.ljust(name_length), "N/A", val_metric)
elif log_to_console and train_metric is not None:
logger.info(no_val_message_template, name.ljust(name_length), train_metric, "N/A")
def enable_activation_logging(self, model: Model) -> None:
if self._histogram_interval is not None:
# To log activation histograms to the forward pass, we register
# a hook on forward to capture the output tensors.
# This uses a closure to determine whether to log the activations,
# since we don't want them on every call.
for _, module in model.named_modules():
if not getattr(module, "should_log_activations", False):
# skip it
continue
def hook(module_, inputs, outputs):
log_prefix = "activation_histogram/{0}".format(module_.__class__)
if self.should_log_histograms_this_batch():
self.log_activation_histogram(outputs, log_prefix)
module.register_forward_hook(hook)
def log_activation_histogram(self, outputs, log_prefix: str) -> None:
if isinstance(outputs, torch.Tensor):
log_name = log_prefix
self.add_train_histogram(log_name, outputs)
elif isinstance(outputs, (list, tuple)):
for i, output in enumerate(outputs):
log_name = "{0}_{1}".format(log_prefix, i)
self.add_train_histogram(log_name, output)
elif isinstance(outputs, dict):
for k, tensor in outputs.items():
log_name = "{0}_{1}".format(log_prefix, k)
self.add_train_histogram(log_name, tensor)
else:
# skip it
pass
def close(self) -> None:
"""
Calls the `close` method of the `SummaryWriter` s which makes sure that pending
scalars are flushed to disk and the tensorboard event files are closed properly.
"""
if self._train_log is not None:
self._train_log.close()
if self._validation_log is not None:
self._validation_log.close()