Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
127 lines (104 sloc) 5.25 KB
# These should probably all live in separate files
from typing import Set, Dict, TYPE_CHECKING
import logging
import torch
from allennlp.common.params import Params
from import util as training_util
from import Callback, handle_event
from import Events
from import TensorboardWriter
from import CallbackTrainer
logger = logging.getLogger(__name__)
class LogToTensorboard(Callback):
Callback that handles all Tensorboard logging.
tensorboard : ``TensorboardWriter``
The TensorboardWriter instance to write to.
log_batch_size_period : int, optional (default: None)
If provided, we'll log the average batch sizes to Tensorboard
every this-many batches.
def __init__(self, tensorboard: TensorboardWriter, log_batch_size_period: int = None) -> None:
self.log_batch_size_period = log_batch_size_period
self.tensorboard = tensorboard
self.cumulative_batch_size = 0
# For logging histograms
self.histogram_parameters: Set[str] = set()
self.param_updates: Dict[str, torch.Tensor] = {}
def training_start(self, trainer: "CallbackTrainer"):
# This is an ugly hack to get the tensorboard instance to know about the trainer, because
# the callbacks are defined before the trainer.
# TODO: figure out a better way to handle this.
self.tensorboard._get_batch_num_total = lambda: trainer.batch_num_total
# Get histogram parameters
self.histogram_parameters = set(
# Enable activation logging.
if self.tensorboard._histogram_interval is not None:
def copy_current_parameters(self, trainer: "CallbackTrainer"):
if self.tensorboard.should_log_histograms_this_batch():
# Get the magnitude of parameter updates for logging
# We need a copy of current parameters to compute magnitude of updates,
# and copy them to CPU so large models won't go OOM on the GPU.
self.param_updates = {
name: param.detach().cpu().clone()
for name, param in trainer.model.named_parameters()
def batch_end_logging(self, trainer: "CallbackTrainer"):
# Log parameter values to tensorboard
if self.tensorboard.should_log_this_batch():
trainer.model, trainer.batch_grad_norm
self.tensorboard.log_learning_rates(trainer.model, trainer.optimizer)
self.tensorboard.add_train_scalar("loss/loss_train", trainer.train_metrics["loss"])
{"epoch_metrics/" + k: v for k, v in trainer.train_metrics.items()}
if self.log_batch_size_period:
cur_batch = sum([training_util.get_batch_size(batch) for batch in trainer.batch_group])
self.cumulative_batch_size += cur_batch
if (trainer.batches_this_epoch - 1) % self.log_batch_size_period == 0:
average = self.cumulative_batch_size / trainer.batches_this_epoch
logger.debug(f"current batch size: {cur_batch} mean batch size: {average}")
self.tensorboard.add_train_scalar("current_batch_size", cur_batch)
self.tensorboard.add_train_scalar("mean_batch_size", average)
if self.tensorboard.should_log_histograms_this_batch():
for name, param in trainer.model.named_parameters():
update_norm = torch.norm(self.param_updates[name].view(-1))
param_norm = torch.norm(param.view(-1)).cpu()
"gradient_update/" + name, update_norm / (param_norm + 1e-7)
self.tensorboard.log_histograms(trainer.model, self.histogram_parameters)
def epoch_end_logging(self, trainer: "CallbackTrainer"):
epoch=trainer.epoch_number + 1,
def training_end(self, trainer: "CallbackTrainer"):
def from_params( # type: ignore
cls, serialization_dir: str, params: Params
) -> "LogToTensorboard":
log_batch_size_period = params.pop_int("log_batch_size_period", None)
tensorboard = TensorboardWriter.from_params(
params=params, serialization_dir=serialization_dir, get_batch_num_total=lambda: None
return LogToTensorboard(tensorboard, log_batch_size_period)
You can’t perform that action at this time.