-
Notifications
You must be signed in to change notification settings - Fork 595
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DDP-related improvements to datamodule and logging (#594)
* Dividing batch size by number of devices in MNISTDataModule's setup fn * .log file is now the same across devices when training in a DDP setting * Adding rank-aware pylogger
- Loading branch information
Showing
11 changed files
with
93 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
_target_: src.data.mnist_datamodule.MNISTDataModule | ||
data_dir: ${paths.data_dir} | ||
batch_size: 128 | ||
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) | ||
train_val_test_split: [55_000, 5_000, 10_000] | ||
num_workers: 0 | ||
pin_memory: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers | ||
from src.utils.logging_utils import log_hyperparameters | ||
from src.utils.pylogger import get_pylogger | ||
from src.utils.pylogger import RankedLogger | ||
from src.utils.rich_utils import enforce_tags, print_config_tree | ||
from src.utils.utils import extras, get_metric_value, task_wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,51 @@ | ||
import logging | ||
from typing import Mapping, Optional | ||
|
||
from lightning.pytorch.utilities import rank_zero_only | ||
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only | ||
|
||
|
||
def get_pylogger(name: str = __name__) -> logging.Logger: | ||
"""Initializes a multi-GPU-friendly python command line logger. | ||
class RankedLogger(logging.LoggerAdapter): | ||
"""A multi-GPU-friendly python command line logger.""" | ||
|
||
:param name: The name of the logger, defaults to ``__name__``. | ||
def __init__( | ||
self, | ||
name: str = __name__, | ||
rank_zero_only: bool = False, | ||
extra: Optional[Mapping[str, object]] = None, | ||
) -> None: | ||
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes | ||
with their rank prefixed in the log message. | ||
:return: A logger object. | ||
""" | ||
logger = logging.getLogger(name) | ||
:param name: The name of the logger. Default is ``__name__``. | ||
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. | ||
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. | ||
""" | ||
logger = logging.getLogger(name) | ||
super().__init__(logger=logger, extra=extra) | ||
self.rank_zero_only = rank_zero_only | ||
|
||
# this ensures all logging levels get marked with the rank zero decorator | ||
# otherwise logs would get multiplied for each GPU process in multi-GPU setup | ||
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") | ||
for level in logging_levels: | ||
setattr(logger, level, rank_zero_only(getattr(logger, level))) | ||
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: | ||
"""Delegate a log call to the underlying logger, after prefixing its message with the rank | ||
of the process it's being logged from. If `'rank'` is provided, then the log will only | ||
occur on that rank/process. | ||
return logger | ||
:param level: The level to log at. Look at `logging.__init__.py` for more information. | ||
:param msg: The message to log. | ||
:param rank: The rank to log at. | ||
:param args: Additional args to pass to the underlying logging function. | ||
:param kwargs: Any additional keyword args to pass to the underlying logging function. | ||
""" | ||
if self.isEnabledFor(level): | ||
msg, kwargs = self.process(msg, kwargs) | ||
current_rank = getattr(rank_zero_only, "rank", None) | ||
if current_rank is None: | ||
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") | ||
msg = rank_prefixed_message(msg, current_rank) | ||
if self.rank_zero_only: | ||
if current_rank == 0: | ||
self.logger.log(level, msg, *args, **kwargs) | ||
else: | ||
if rank is None: | ||
self.logger.log(level, msg, *args, **kwargs) | ||
elif current_rank == rank: | ||
self.logger.log(level, msg, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters