Skip to content

Commit

Permalink
Updates for Hydra and OmegaConf updates
Browse files Browse the repository at this point in the history
Signed-off-by: smajumdar <titu1994@gmail.com>
  • Loading branch information
titu1994 committed Jun 11, 2021
1 parent 6cba93c commit d21ef44
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 10 deletions.
6 changes: 4 additions & 2 deletions nemo/core/config/hydra_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@


def hydra_runner(
config_path: Optional[str] = None, config_name: Optional[str] = None, schema: Optional[Any] = None
config_path: Optional[str] = ".", config_name: Optional[str] = None, schema: Optional[Any] = None
) -> Callable[[TaskFunction], Any]:
"""
Decorator used for passing the Config paths to main function.
Optionally registers a schema used for validation/providing default values.
Args:
config_path: Optional path that will be added to config search directory.
NOTE: The default value of `config_path` has changed between Hydra 1.0 and Hydra 1.1+.
Please refer to https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path/
for details.
config_name: Pathname of the config file.
schema: Structured config type representing the schema used for validation/providing default values.
"""
Expand Down Expand Up @@ -100,7 +103,6 @@ def parse_args(self, args=None, namespace=None):
task_function=task_function,
config_path=config_path,
config_name=config_name,
strict=None,
)

return wrapper
Expand Down
1 change: 1 addition & 0 deletions nemo/core/config/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class WarmupSchedulerParams(SchedulerParams):
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
"""

max_steps: int = 0
warmup_steps: Optional[float] = None
warmup_ratio: Optional[float] = None

Expand Down
7 changes: 5 additions & 2 deletions nemo/core/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.optim.lr_scheduler import _LRScheduler

from nemo.core.config import SchedulerParams, get_scheduler_config, register_scheduler_params
from nemo.utils.model_utils import maybe_update_config_version
from nemo.utils import logging


Expand Down Expand Up @@ -453,6 +454,8 @@ def prepare_lr_scheduler(
A dictionary containing the LR Scheduler implementation if the config was successfully parsed
along with other parameters required by Pytorch Lightning, otherwise None.
"""
scheduler_config = maybe_update_config_version(scheduler_config)

# Build nested dictionary for convenience out of structured objects
if isinstance(scheduler_config, DictConfig):
scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True)
Expand Down Expand Up @@ -493,7 +496,7 @@ def prepare_lr_scheduler(
return None

# Try instantiation of scheduler params from config class path
try:
if '_target_' in scheduler_args:
scheduler_args_cfg = OmegaConf.create(scheduler_args)
scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg)
scheduler_args = vars(scheduler_conf)
Expand All @@ -504,7 +507,7 @@ def prepare_lr_scheduler(
if 'Params' in scheduler_name:
scheduler_name = scheduler_name.replace('Params', '')

except Exception:
else:
# Class path instantiation failed; try resolving "name" component

# Get name of the scheduler
Expand Down
9 changes: 5 additions & 4 deletions nemo/core/optim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from nemo.core.config import OptimizerParams, get_optimizer_config, register_optimizer_params
from nemo.core.optim.novograd import Novograd
from nemo.utils.model_utils import maybe_update_config_version
from nemo.utils import logging

AVAILABLE_OPTIMIZERS = {
Expand Down Expand Up @@ -74,20 +75,19 @@ def parse_optimizer_args(
return kwargs

optimizer_kwargs = copy.deepcopy(optimizer_kwargs)
optimizer_kwargs = maybe_update_config_version(optimizer_kwargs)

if isinstance(optimizer_kwargs, DictConfig):
optimizer_kwargs = OmegaConf.to_container(optimizer_kwargs, resolve=True)

# If it is a dictionary, perform stepwise resolution
if hasattr(optimizer_kwargs, 'keys'):
# Attempt class path resolution
try:
if '_target_' in optimizer_kwargs: # captures (target, _target_)
optimizer_kwargs_config = OmegaConf.create(optimizer_kwargs)
optimizer_instance = hydra.utils.instantiate(optimizer_kwargs_config) # type: DictConfig
optimizer_instance = vars(optimizer_instance)
return optimizer_instance
except Exception:
pass

# If class path was not provided, perhaps `name` is provided for resolution
if 'name' in optimizer_kwargs:
Expand All @@ -114,7 +114,8 @@ def parse_optimizer_args(

# If we are provided just a Config object, simply return the dictionary of that object
if optimizer_params_name is None:
optimizer_params = vars(optimizer_params_cls)
optimizer_params = optimizer_params_cls
optimizer_params = vars(optimizer_params)
return optimizer_params

else:
Expand Down
17 changes: 15 additions & 2 deletions nemo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,12 @@ def _convert_config(cfg: OmegaConf):
""" Recursive function convertint the configuration from old hydra format to the new one. """

# Get rid of cls -> _target_.
if 'cls' in cfg and "_target_" not in cfg:
cfg._target_ = cfg.pop("cls")
if 'cls' in cfg and '_target_' not in cfg:
cfg._target_ = cfg.pop('cls')

# Get rid of target -> _target_.
if 'target' in cfg and '_target_' not in cfg:
cfg._target_ = cfg.pop('target')

# Get rid of params.
if 'params' in cfg:
Expand All @@ -397,6 +401,7 @@ def maybe_update_config_version(cfg: DictConfig):
Changes include:
- `cls` -> `_target_`.
- `target` -> `_target_`
- `params` -> drop params and shift all arguments to parent.
Args:
Expand All @@ -405,6 +410,14 @@ def maybe_update_config_version(cfg: DictConfig):
Returns:
An updated DictConfig that conforms to Hydra 1.x format.
"""
if cfg is not None and not isinstance(cfg, DictConfig):
try:
temp_cfg = OmegaConf.create(cfg)
cfg = temp_cfg
except omegaconf_errors.OmegaConfBaseException:
# Cannot be cast to DictConfig, skip updating.
return cfg

# Make a copy of model config.
cfg = copy.deepcopy(cfg)
OmegaConf.set_struct(cfg, False)
Expand Down

0 comments on commit d21ef44

Please sign in to comment.