Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Logging and metrics changes for distributed training (#3372)
Browse files Browse the repository at this point in the history
* Refactor logging setup to support distributed attrs

* `cleanup_logging()` is replaced with stdlib's `logging.shutdown()`
* Remove `TeeLogger` and use standard log handlers
* Remove `replace_cr_with_newline` and use the standard logging practice of using
`logging.Filter`
* Introduce `rank` and `world_size` optional attributes to support
distributed workers

* Support for distributed training in `get_metrics`

* Remove bad import

* Fix duplicate log messages in stdout

* Remove preemptive `logging.shutdown`

`logging.shutdown` is called by the logging module
by default during exit which makes it unnecessary to
be called from `train_model`

* Fix black formatting issues

* Remove `tee_logger` references in API doc

* Set log level from `ALLENNLP_DEBUG` env
  • Loading branch information
scarecrow1123 authored and DeNeutoy committed Oct 23, 2019
1 parent 2850579 commit 2d7a51b
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 135 deletions.
11 changes: 2 additions & 9 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import check_for_gpu
from allennlp.common import Params
from allennlp.common.util import (
prepare_environment,
prepare_global_logging,
cleanup_global_logging,
dump_metrics,
)
from allennlp.common.util import prepare_environment, prepare_global_logging, dump_metrics
from allennlp.models.archival import archive_model, CONFIG_NAME
from allennlp.models.model import Model, _DEFAULT_WEIGHTS
from allennlp.training.trainer import Trainer
Expand Down Expand Up @@ -241,7 +236,7 @@ def train_model(
create_serialization_dir(params, serialization_dir, recover, force)
params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

stdout_handler = prepare_global_logging(serialization_dir, file_friendly_logging)
prepare_global_logging(serialization_dir, file_friendly_logging)
prepare_environment(params)

cuda_device = params.params.get("trainer").get("cuda_device", -1)
Expand Down Expand Up @@ -318,8 +313,6 @@ def train_model(
"'evaluate_on_test' flag, or use the 'allennlp evaluate' command."
)

cleanup_global_logging(stdout_handler)

# Now tar up results
archive_model(serialization_dir, files_to_archive=params.files_to_archive)
dump_metrics(os.path.join(serialization_dir, "metrics.json"), metrics, log=True)
Expand Down
1 change: 0 additions & 1 deletion allennlp/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from allennlp.common.from_params import FromParams
from allennlp.common.params import Params
from allennlp.common.registrable import Registrable
from allennlp.common.tee_logger import TeeLogger
from allennlp.common.tqdm import Tqdm
from allennlp.common.util import JsonDict
63 changes: 0 additions & 63 deletions allennlp/common/tee_logger.py

This file was deleted.

150 changes: 98 additions & 52 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""
Various utilities that don't fit anwhere else.
"""
from itertools import zip_longest, islice
from typing import Any, Callable, Dict, List, Tuple, TypeVar, Iterable, Iterator
import importlib
import json
import logging
import os
import pkgutil
import random
import subprocess
import sys
import os
from itertools import zip_longest, islice
from logging import Filter
from typing import Any, Callable, Dict, List, Tuple, TypeVar, Iterable, Iterator

try:
import resource
Expand All @@ -29,7 +30,6 @@
from allennlp.common.checks import log_pytorch_version_info
from allennlp.common.params import Params
from allennlp.common.tqdm import Tqdm
from allennlp.common.tee_logger import TeeLogger

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -219,71 +219,117 @@ def prepare_environment(params: Params):
log_pytorch_version_info()


def prepare_global_logging(
serialization_dir: str, file_friendly_logging: bool
) -> logging.FileHandler:
class FileFriendlyLogFilter(Filter):
"""
TQDM and requests use carriage returns to get the training line to update for each batch
without adding more lines to the terminal output. Displaying those in a file won't work
correctly, so we'll just make sure that each batch shows up on its one line.
"""
This function configures 3 global logging attributes - streaming stdout and stderr
to a file as well as the terminal, setting the formatting for the python logging
library and setting the interval frequency for the Tqdm progress bar.

Note that this function does not set the logging level, which is set in ``allennlp/run.py``.
def filter(self, record):
if "\r" in record.msg:
record.msg = record.msg.replace("\r", "")
if not record.msg or record.msg[-1] != "\n":
record.msg += "\n"
return True

Parameters
----------
serialization_dir : ``str``, required.
The directory to stream logs to.
file_friendly_logging : ``bool``, required.
Whether logs should clean the output to prevent carriage returns
(used to update progress bars on a single terminal line). This
option is typically only used if you are running in an environment
without a terminal.

Returns
-------
``logging.FileHandler``
A logging file handler that can later be closed and removed from the global logger.
"""
class WorkerLogFilter(Filter):
def __init__(self, rank=-1):
super().__init__()
self._rank = rank

def filter(self, record):
if self._rank != -1:
record.msg = f"Rank {self._rank} | {record.msg}"
return True


def prepare_global_logging(
serialization_dir: str, file_friendly_logging: bool, rank: int = 0, world_size: int = 1
) -> None:
# If we don't have a terminal as stdout,
# force tqdm to be nicer.
if not sys.stdout.isatty():
file_friendly_logging = True

Tqdm.set_slower_interval(file_friendly_logging)
std_out_file = os.path.join(serialization_dir, "stdout.log")
sys.stdout = TeeLogger( # type: ignore
std_out_file, sys.stdout, file_friendly_logging
)
sys.stderr = TeeLogger( # type: ignore
os.path.join(serialization_dir, "stderr.log"), sys.stderr, file_friendly_logging
)

stdout_handler = logging.FileHandler(std_out_file)
stdout_handler.setFormatter(
logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
)
logging.getLogger().addHandler(stdout_handler)
# Handlers for stdout/err logging
output_stream_log_handler = logging.StreamHandler(sys.stdout)
error_stream_log_handler = logging.StreamHandler(sys.stderr)

return stdout_handler
if world_size == 1:
# This case is not distributed training and hence will stick to the older
# log file names
output_file_log_handler = logging.FileHandler(
filename=os.path.join(serialization_dir, "stdout.log")
)
error_file_log_handler = logging.FileHandler(
filename=os.path.join(serialization_dir, "stderr.log")
)
else:
# Create log files with worker ids
output_file_log_handler = logging.FileHandler(
filename=os.path.join(serialization_dir, f"stdout_worker{rank}.log")
)
error_file_log_handler = logging.FileHandler(
filename=os.path.join(serialization_dir, f"stderr_worker{rank}.log")
)

# This adds the worker's rank to messages being logged to files.
# This will help when combining multiple worker log files using `less` command.
worker_filter = WorkerLogFilter(rank)
output_file_log_handler.addFilter(worker_filter)
error_file_log_handler.addFilter(worker_filter)

def cleanup_global_logging(stdout_handler: logging.FileHandler) -> None:
"""
This function closes any open file handles and logs set up by `prepare_global_logging`.
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")

Parameters
----------
stdout_handler : ``logging.FileHandler``, required.
The file handler returned from `prepare_global_logging`, attached to the global logger.
"""
stdout_handler.close()
logging.getLogger().removeHandler(stdout_handler)
root_logger = logging.getLogger()

# Remove the already set stream handler in root logger.
# Not doing this will result in duplicate log messages
# printed in the console
root_logger.removeHandler(root_logger.handlers[0])

# file handlers need to be handled for tqdm's \r char
file_friendly_log_filter = FileFriendlyLogFilter()

if os.environ.get("ALLENNLP_DEBUG"):
LEVEL = logging.DEBUG
else:
LEVEL = logging.INFO

if rank == 0:
# stdout/stderr handlers are added only for the
# master worker. This is to avoid cluttering the console
# screen with too many log messages from all workers.
output_stream_log_handler.setFormatter(formatter)
error_stream_log_handler.setFormatter(formatter)

output_stream_log_handler.setLevel(LEVEL)
error_stream_log_handler.setLevel(logging.ERROR)

if file_friendly_logging:
output_stream_log_handler.addFilter(file_friendly_log_filter)
error_stream_log_handler.addFilter(file_friendly_log_filter)

root_logger.addHandler(output_stream_log_handler)
root_logger.addHandler(error_stream_log_handler)

output_file_log_handler.addFilter(file_friendly_log_filter)
error_file_log_handler.addFilter(file_friendly_log_filter)

output_file_log_handler.setFormatter(formatter)
error_file_log_handler.setFormatter(formatter)

output_file_log_handler.setLevel(LEVEL)
error_file_log_handler.setLevel(logging.ERROR)

root_logger.addHandler(output_file_log_handler)
root_logger.addHandler(error_file_log_handler)

if isinstance(sys.stdout, TeeLogger):
sys.stdout = sys.stdout.cleanup()
if isinstance(sys.stderr, TeeLogger):
sys.stderr = sys.stderr.cleanup()
root_logger.setLevel(LEVEL)


LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool], SpacyModelType] = {}
Expand Down
21 changes: 19 additions & 2 deletions allennlp/training/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Helper functions for Trainers
"""
import torch.distributed as dist
from typing import Any, Union, Dict, Iterable, List, Optional, Tuple
import datetime
import json
Expand Down Expand Up @@ -379,7 +380,12 @@ def rescale_gradients(model: Model, grad_norm: Optional[float] = None) -> Option


def get_metrics(
model: Model, total_loss: float, num_batches: int, reset: bool = False
model: Model,
total_loss: float,
num_batches: int,
reset: bool = False,
world_size: int = 1,
rank: int = 0,
) -> Dict[str, float]:
"""
Gets the metrics but sets ``"loss"`` to
Expand All @@ -388,7 +394,18 @@ def get_metrics(
"""
metrics = model.get_metrics(reset=reset)
metrics["loss"] = float(total_loss / num_batches) if num_batches > 0 else 0.0
return metrics

if world_size > 1:
# In distributed mode, average out all metrics across GPUs
aggregated_metrics = {}
for metric_name, metric_val in metrics.items():
metric_tensor = torch.tensor(metric_val).to(torch.device(rank))
dist.all_reduce(metric_tensor, op=dist.ReduceOp.SUM)
reduced_metric = metric_tensor.item() / world_size
aggregated_metrics[metric_name] = reduced_metric
return aggregated_metrics
else:
return metrics


def evaluate(
Expand Down
1 change: 0 additions & 1 deletion doc/api/allennlp.common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ that's used by datasets, models, trainers, and so on.
allennlp.common.from_params
allennlp.common.params
allennlp.common.registrable
allennlp.common.tee_logger
allennlp.common.testing
allennlp.common.tqdm
allennlp.common.util
Expand Down
7 changes: 0 additions & 7 deletions doc/api/allennlp.common.tee_logger.rst

This file was deleted.

0 comments on commit 2d7a51b

Please sign in to comment.