Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>
  • Loading branch information
marcromeyn committed Jun 17, 2024
1 parent cff1ac5 commit d620391
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from nemo.lightning.base import get_vocab_size, teardown
from nemo.lightning.nemo_logger import NeMoLogger
from nemo.lightning.pytorch.callbacks.megatron_model_checkpoint import ModelCheckpoint
from nemo.lightning.pytorch.opt import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule
from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision
from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler
from nemo.lightning.pytorch.callbacks.megatron_model_checkpoint import ModelCheckpoint
from nemo.lightning.pytorch.strategies import MegatronStrategy
from nemo.lightning.pytorch.trainer import Trainer
from nemo.lightning.resume import AutoResume
Expand Down
34 changes: 17 additions & 17 deletions nemo/lightning/experiment.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from dataclasses import dataclass
from typing import Optional, List, Union
import sys
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union

import lightning_fabric as fl
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint
import lightning_fabric as fl

from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.env_var_parsing import get_envbool
from nemo.utils.exp_manager import check_explicit_log_dir
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.app_state import AppState
from nemo.utils.mcore_logger import add_handlers_to_mcore_logger
from nemo.lightning.pytorch.callbacks import ModelCheckpoint


@dataclass
Expand All @@ -29,13 +30,13 @@ class Experiment:
log_global_rank_0_only: bool = False
files_to_copy: Optional[List[str]] = None
update_logger_directory: bool = True

def __post_init__(self):
if self.log_local_rank_0_only is True and self.log_global_rank_0_only is True:
raise ValueError(
f"Cannot set both log_local_rank_0_only and log_global_rank_0_only to True. Please set either one or neither."
)

def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = False):
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = trainer.node_rank * trainer.world_size + local_rank
Expand All @@ -51,15 +52,15 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =

if not self.name:
self.name = "default"

if isinstance(trainer, pl.Trainer) and trainer.logger is not None:
if self.update_logger_directory:
logging.warning(
f'"update_logger_directory" is True. Overwriting logger "save_dir" to {_dir} and "name" to {self.name}'
)
trainer.logger._root_dir = _dir
trainer.logger._name = self.name

version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None)
if is_global_rank_zero():
if self.use_datetime_version:
Expand All @@ -68,7 +69,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =
logging.warning(
"No version folders would be created under the log folder as 'resume_if_exists' is enabled."
)
version = None
version = None
if version:
if is_global_rank_zero():
os.environ[NEMO_ENV_VARNAME_VERSION] = version
Expand All @@ -80,11 +81,11 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =
app_state.exp_dir = _dir
app_state.name = self.name
app_state.version = version

os.makedirs(log_dir, exist_ok=True) # Cannot limit creation to global zero as all ranks write to own log file
logging.info(f'Experiments will be logged at {log_dir}')
if isinstance(trainer, pl.Trainer):

if isinstance(trainer, pl.Trainer):
for callback in trainer.callbacks:
if isinstance(callback, PTLModelCheckpoint):
## TODO: make configurable
Expand All @@ -95,7 +96,6 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =
callback.prefix = name
ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + '-last'


# This is set if the env var NEMO_TESTING is set to True.
nemo_testing = get_envbool(NEMO_ENV_VARNAME_TESTING, False)

Expand All @@ -115,8 +115,8 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =

app_state.files_to_copy = self.files_to_copy
app_state.cmd_args = sys.argv

return app_state

def teardown(self):
pass
pass

0 comments on commit d620391

Please sign in to comment.