Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DETAIL logs for batch use cases #11008

Merged
merged 8 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs


- Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601))

Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
"""Root package info."""

import logging
from typing import Any

from pytorch_lightning.__about__ import * # noqa: F401, F403

_DETAIL = 15 # between logging.INFO and logging.DEBUG, used for logging in production use cases


def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
if self.isEnabledFor(_DETAIL):
# logger takes its '*args' as 'args'
self._log(_DETAIL, message, args, **kwargs)


logging.addLevelName(_DETAIL, "DETAIL")
logging.detail = _detail
logging.Logger.detail = _detail

_root_logger = logging.getLogger()
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def on_advance_start(self) -> None: # type: ignore[override]

# reset train dataloader
if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl:
log.detail(f"{self.__class__.__name__}: resetting train dataloader")
self.trainer.reset_train_dataloader(model)
self._is_fresh_start_epoch = False

Expand Down Expand Up @@ -240,6 +241,7 @@ def on_advance_start(self) -> None: # type: ignore[override]

def advance(self) -> None: # type: ignore[override]
"""Runs one whole epoch."""
log.detail(f"{self.__class__.__name__}: advancing loop")
assert self.trainer.train_dataloader is not None
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)
Expand Down Expand Up @@ -300,6 +302,7 @@ def on_advance_end(self) -> None:

def on_run_end(self) -> None:
"""Calls the ``on_train_end`` hook."""
log.detail(f"{self.__class__.__name__}: train run ended")
# NOTE: the current_epoch is already incremented
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
# To simulate that current behavior, we decrement here.
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
log.detail(f"{self.__class__.__name__}: initializing DDP plugin")
self.interactive_ddp_procs = []
self._num_nodes = 1
self.sync_batchnorm = False
Expand Down Expand Up @@ -171,7 +172,9 @@ def setup(self, trainer: "pl.Trainer") -> None:

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
device_ids = self.determine_ddp_device_ids()
log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

def _call_children_scripts(self):
# bookkeeping of spawned processes
Expand Down Expand Up @@ -243,6 +246,7 @@ def _call_children_scripts(self):
self._rank_0_has_called_call_children_scripts = True

def setup_distributed(self):
log.detail(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()

# determine which process we are and world size
Expand Down Expand Up @@ -288,6 +292,7 @@ def pre_configure_ddp(self):
self._ddp_kwargs["find_unused_parameters"] = True

def _register_ddp_hooks(self) -> None:
log.detail(f"{self.__class__.__name__}: registering ddp hooks")
# In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# Since 1.9, DDP communication hooks can work on all backends.
if _TORCH_GREATER_EQUAL_1_9 or (
Expand All @@ -307,6 +312,7 @@ def _register_ddp_hooks(self) -> None:
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)

def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
optimizers = self.lightning_module.trainer.optimizers
if self._model_averaging_period is None:
raise ValueError(
Expand Down Expand Up @@ -350,6 +356,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
_convert_to_lightning_optimizers(trainer)

def configure_ddp(self) -> None:
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
self.pre_configure_ddp()
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
Expand Down Expand Up @@ -380,6 +387,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
prepare_for_backward(self.model, closure_loss)

def model_to_device(self):
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
self.model.to(self.root_device)

def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
Expand Down Expand Up @@ -500,6 +508,7 @@ def reconciliate_processes(self, trace: str) -> None:
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")

def teardown(self) -> None:
log.detail(f"{self.__class__.__name__}: tearing down DDP plugin")
super().teardown()
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module
Expand All @@ -509,6 +518,7 @@ def teardown(self) -> None:

if self.on_gpu:
# GPU teardown
log.detail(f"{self.__class__.__name__}: moving model to CPU")
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
8 changes: 8 additions & 0 deletions pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
from typing import Dict, Generator, List, Optional

import torch
Expand All @@ -30,6 +31,8 @@
from fairscale.nn import default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel

log = logging.getLogger(__name__)


class DDPFullyShardedStrategy(DDPStrategy):

Expand Down Expand Up @@ -144,6 +147,7 @@ def setup(self, trainer: "pl.Trainer") -> None:

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
precision = self.precision_plugin.precision

def wrap_policy(*args, **kwargs):
Expand All @@ -165,7 +169,10 @@ def wrap_policy(*args, **kwargs):
):
yield

log.detail(f"{self.__class__.__name__}: exiting model_sharded_context.")

def configure_ddp(self) -> None:
log.detail(f"{self.__class__.__name__}: configuring DDP... (cpu_offload: [{self.cpu_offload}])")
if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us.
# Note: this would be problematic for large model (which could not fit in one GPU)
Expand All @@ -177,6 +184,7 @@ def configure_ddp(self) -> None:
self.setup_optimizers(self.lightning_module.trainer)

def model_to_device(self) -> None:
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import re
from typing import Any, Dict, Optional
Expand All @@ -35,6 +36,9 @@
from omegaconf import Container


log: logging.Logger = logging.getLogger(__name__)


class CheckpointConnector:
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
self.trainer = trainer
Expand Down Expand Up @@ -74,6 +78,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
self.resume_checkpoint_path = self._hpc_resume_path or self._fault_tolerant_auto_resume_path or checkpoint_path
checkpoint_path = self.resume_checkpoint_path
if not checkpoint_path:
log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
return

rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def __init__(
"""
super().__init__()
Trainer._log_api_event("init")
log.detail(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")
edward-io marked this conversation as resolved.
Show resolved Hide resolved
self.state = TrainerState()

gpu_ids, tpu_cores = self._parse_devices(gpus, auto_select_gpus, tpu_cores)
Expand Down Expand Up @@ -744,6 +745,7 @@ def _fit_impl(
ckpt_path: Optional[str] = None,
) -> None:
Trainer._log_api_event("fit")
log.detail(f"{self.__class__.__name__}: trainer fit stage")

self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
Expand Down Expand Up @@ -821,6 +823,7 @@ def _validate_impl(
# SETUP HOOK
# --------------------
Trainer._log_api_event("validate")
log.detail(f"{self.__class__.__name__}: trainer validate stage")

self.state.fn = TrainerFn.VALIDATING
self.state.status = TrainerStatus.RUNNING
Expand Down Expand Up @@ -906,6 +909,7 @@ def _test_impl(
# SETUP HOOK
# --------------------
Trainer._log_api_event("test")
log.detail(f"{self.__class__.__name__}: trainer test stage")

self.state.fn = TrainerFn.TESTING
self.state.status = TrainerStatus.RUNNING
Expand Down Expand Up @@ -992,6 +996,7 @@ def _predict_impl(
# SETUP HOOK
# --------------------
Trainer._log_api_event("predict")
log.detail(f"{self.__class__.__name__}: trainer predict stage")

self.state.fn = TrainerFn.PREDICTING
self.state.status = TrainerStatus.RUNNING
Expand Down Expand Up @@ -1107,19 +1112,23 @@ def _run(
verify_loop_configurations(self)

# hook
log.detail(f"{self.__class__.__name__}: preparing data")
self._data_connector.prepare_data()

# ----------------------------
# SET UP TRAINING
# ----------------------------
self._call_callback_hooks("on_before_accelerator_backend_setup")
log.detail(f"{self.__class__.__name__}: setting up strategy environment")
self.strategy.setup_environment()
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment

# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._restore_modules_and_callbacks(ckpt_path)

log.detail(f"{self.__class__.__name__}: configuring sharded model")
self._call_configure_sharded_model() # allow user to setup in model sharded environment

# ----------------------------
Expand Down Expand Up @@ -1165,14 +1174,17 @@ def _run(
self._log_hyperparams()

if self.strategy.restore_checkpoint_after_setup:
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._restore_modules_and_callbacks(ckpt_path)

# restore optimizers, etc.
log.detail(f"{self.__class__.__name__}: restoring training state")
self.checkpoint_connector.restore_training_state()
Comment on lines +1181 to 1182
Copy link
Contributor

@awaelchli awaelchli Jan 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with most changes in this PR. Instead of moving the log calls into the actual methods, this is now littering the trainer with more code that is irrelevant to the reader. Our goal is to make the Trainer more readable and easier to trace steps for contributors and researchers who are curious. This PR is going in the wrong direction IMO.
Please reconsider this, I feel very strongly about this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @PyTorchLightning/core-lightning

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The choice of logging in the caller vs the callee is as old as logging has existed and it is a tradeoff between flexibility vs convenience.

I don't think these log calls make the Trainer unreadable, as they are just glorified prints. You could see their value as a replacement for comments.

Please reconsider this, I feel very strongly about this.

However, I don't feel very strongly about it, so if you do, I'm okay with changing these.


self.checkpoint_connector.resume_end()

results = self._run_stage()
log.detail(f"{self.__class__.__name__}: trainer tearing down")
self._teardown()

# ----------------------------
Expand All @@ -1183,6 +1195,7 @@ def _run(
self._call_callback_hooks("on_fit_end")
self._call_lightning_module_hook("on_fit_end")

log.detail(f"{self.__class__.__name__}: calling teardown hooks")
self._call_teardown_hook()

if self.state.status != TrainerStatus.INTERRUPTED:
Expand Down Expand Up @@ -1554,6 +1567,7 @@ def _call_callback_hooks(
*args: Any,
**kwargs: Any,
) -> None:
log.detail(f"{self.__class__.__name__}: calling callback hook: {hook_name}")
# TODO: remove if block in v1.8
if hook_name in ("on_init_start", "on_init_end"):
# these `Callback` hooks are the only ones that do not take a lightning module.
Expand Down
14 changes: 14 additions & 0 deletions tests/utilities/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
# check that logging is properly configured
import logging

from pytorch_lightning import _DETAIL

root_logger = logging.getLogger()
lightning_logger = logging.getLogger("pytorch_lightning")
# should have a `StreamHandler`
Expand All @@ -77,3 +79,15 @@

output = stderr.getvalue()
assert output == "test2\n", repr(output)

stderr = StringIO()
lightning_logger.handlers[0].stream = stderr
with redirect_stderr(stderr):
# Lightning should not output DETAIL level logging by default
lightning_logger.detail("test1")
lightning_logger.setLevel(_DETAIL)
lightning_logger.detail("test2")
# logger should not output anything for DEBUG statements if set to DETAIL
lightning_logger.debug("test3")
output = stderr.getvalue()
assert output == "test2\n", repr(output)