Skip to content

Commit

Permalink
DDP-related improvements to datamodule and logging (#594)
Browse files Browse the repository at this point in the history
* Dividing batch size by number of devices in MNISTDataModule's setup fn
* .log file is now the same across devices when training in a DDP setting
* Adding rank-aware pylogger
  • Loading branch information
tesfaldet committed Sep 18, 2023
1 parent 07ce4b7 commit 1fb5405
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 39 deletions.
2 changes: 1 addition & 1 deletion configs/data/mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: src.data.mnist_datamodule.MNISTDataModule
data_dir: ${paths.data_dir}
batch_size: 128
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
train_val_test_split: [55_000, 5_000, 10_000]
num_workers: 0
pin_memory: False
2 changes: 1 addition & 1 deletion configs/hydra/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
filename: ${hydra.runtime.output_dir}/${task_name}.log
16 changes: 13 additions & 3 deletions src/data/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None

self.batch_size_per_device = batch_size

@property
def num_classes(self) -> int:
"""Get the number of classes.
Expand Down Expand Up @@ -112,6 +114,14 @@ def setup(self, stage: Optional[str] = None) -> None:
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
"""
# Divide batch size by the number of devices.
if self.trainer is not None:
if self.hparams.batch_size % self.trainer.world_size != 0:
raise RuntimeError(
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
)
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
Expand All @@ -130,7 +140,7 @@ def train_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
Expand All @@ -143,7 +153,7 @@ def val_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
Expand All @@ -156,7 +166,7 @@ def test_dataloader(self) -> DataLoader[Any]:
"""
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
Expand Down
18 changes: 12 additions & 6 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src import utils
from src.utils import (
RankedLogger,
extras,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)

log = utils.get_pylogger(__name__)
log = RankedLogger(__name__, rank_zero_only=True)


@utils.task_wrapper
@task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Evaluates given checkpoint on a datamodule testset.
Expand All @@ -48,7 +54,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
model: LightningModule = hydra.utils.instantiate(cfg.model)

log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
Expand All @@ -63,7 +69,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
log_hyperparameters(object_dict)

log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
Expand All @@ -84,7 +90,7 @@ def main(cfg: DictConfig) -> None:
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
extras(cfg)

evaluate(cfg)

Expand Down
24 changes: 16 additions & 8 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,20 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src import utils
from src.utils import (
RankedLogger,
extras,
get_metric_value,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)

log = utils.get_pylogger(__name__)
log = RankedLogger(__name__, rank_zero_only=True)


@utils.task_wrapper
@task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
Expand All @@ -53,10 +61,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
model: LightningModule = hydra.utils.instantiate(cfg.model)

log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))

log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
Expand All @@ -72,7 +80,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
log_hyperparameters(object_dict)

if cfg.get("train"):
log.info("Starting training!")
Expand Down Expand Up @@ -106,13 +114,13 @@ def main(cfg: DictConfig) -> Optional[float]:
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
extras(cfg)

# train the model
metric_dict, _ = train(cfg)

# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = utils.get_metric_value(
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)

Expand Down
2 changes: 1 addition & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.logging_utils import log_hyperparameters
from src.utils.pylogger import get_pylogger
from src.utils.pylogger import RankedLogger
from src.utils.rich_utils import enforce_tags, print_config_tree
from src.utils.utils import extras, get_metric_value, task_wrapper
2 changes: 1 addition & 1 deletion src/utils/instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.RankedLogger(__name__, rank_zero_only=True)


def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
Expand Down
4 changes: 2 additions & 2 deletions src/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict

from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import OmegaConf

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.RankedLogger(__name__, rank_zero_only=True)


@rank_zero_only
Expand Down
56 changes: 43 additions & 13 deletions src/utils/pylogger.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,51 @@
import logging
from typing import Mapping, Optional

from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only


def get_pylogger(name: str = __name__) -> logging.Logger:
"""Initializes a multi-GPU-friendly python command line logger.
class RankedLogger(logging.LoggerAdapter):
"""A multi-GPU-friendly python command line logger."""

:param name: The name of the logger, defaults to ``__name__``.
def __init__(
self,
name: str = __name__,
rank_zero_only: bool = False,
extra: Optional[Mapping[str, object]] = None,
) -> None:
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
with their rank prefixed in the log message.
:return: A logger object.
"""
logger = logging.getLogger(name)
:param name: The name of the logger. Default is ``__name__``.
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
"""
logger = logging.getLogger(name)
super().__init__(logger=logger, extra=extra)
self.rank_zero_only = rank_zero_only

# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
of the process it's being logged from. If `'rank'` is provided, then the log will only
occur on that rank/process.
return logger
:param level: The level to log at. Look at `logging.__init__.py` for more information.
:param msg: The message to log.
:param rank: The rank to log at.
:param args: Additional args to pass to the underlying logging function.
:param kwargs: Any additional keyword args to pass to the underlying logging function.
"""
if self.isEnabledFor(level):
msg, kwargs = self.process(msg, kwargs)
current_rank = getattr(rank_zero_only, "rank", None)
if current_rank is None:
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
msg = rank_prefixed_message(msg, current_rank)
if self.rank_zero_only:
if current_rank == 0:
self.logger.log(level, msg, *args, **kwargs)
else:
if rank is None:
self.logger.log(level, msg, *args, **kwargs)
elif current_rank == rank:
self.logger.log(level, msg, *args, **kwargs)
4 changes: 2 additions & 2 deletions src/utils/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
log = pylogger.RankedLogger(__name__, rank_zero_only=True)


@rank_zero_only
Expand Down
2 changes: 1 addition & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from src.utils import pylogger, rich_utils

log = pylogger.get_pylogger(__name__)
log = pylogger.RankedLogger(__name__, rank_zero_only=True)


def extras(cfg: DictConfig) -> None:
Expand Down

0 comments on commit 1fb5405

Please sign in to comment.