Skip to content

Commit

Permalink
Add Module-level Adapters, Save-Restore and tests (#4114)
Browse files Browse the repository at this point in the history
* First draft of model level tests and support for multiple types adapters in same model

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add save restore tests for adapters

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add save restore tests for adapters

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add adapter only save and restore

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update base adapter config

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add tests

Signed-off-by: smajumdar <titu1994@gmail.com>

* Fix collection of get enabled adapters, limiting to each module's scope

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update docs and add support for resolution of module adapter names

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update ASR adapters to only support module adapters

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add state dict match test

Signed-off-by: smajumdar <titu1994@gmail.com>

* Fix name resolution for set_enabled_adapters

Signed-off-by: smajumdar <titu1994@gmail.com>

* Correct case where name is none for set adapter

Signed-off-by: smajumdar <titu1994@gmail.com>

* Correct case where there are no adapters to save

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update config for training

Signed-off-by: smajumdar <titu1994@gmail.com>

* Force update to internal config upon get or set

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add spec augment update support to adapters

Signed-off-by: smajumdar <titu1994@gmail.com>

* Correct config update

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add dropout support to linear adapters

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add type to config

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add stochastic depth regularization to adapter merge strategy and related tests

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add support for dynamic strategy change

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add support for dynamic strategy change

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add more tests

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add more tests

Signed-off-by: smajumdar <titu1994@gmail.com>

* Remove logging of adapter name

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update changes for reviews

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Refactor the utility methods

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Refactor the utility methods

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fixed configs for optim and spec augment

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fixed configs for optim and spec augment

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Rename method to subclassable private

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Add support for adapter module names to be pre-specified in config

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix imports

Signed-off-by: smajumdar <smajumdar@nvidia.com>

* Fix typos

Signed-off-by: smajumdar <smajumdar@nvidia.com>
  • Loading branch information
titu1994 committed May 17, 2022
1 parent 8318980 commit 89994de
Show file tree
Hide file tree
Showing 15 changed files with 1,939 additions and 130 deletions.
33 changes: 30 additions & 3 deletions examples/asr/asr_adapters/train_asr_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"""

import os
from dataclasses import is_dataclass

import pytorch_lightning as pl
Expand Down Expand Up @@ -117,10 +118,10 @@ def main(cfg):
if cfg.model.pretrained_model is None and cfg.model.nemo_model is None:
raise ValueError("Either set `cfg.model.nemo_model` or `cfg.model.pretrained_model`")
if cfg.model.pretrained_model is not None and cfg.model.nemo_model is not None:
raise ValueError("Cannot set `cfg.model.nemo_model` and `cfg.model.pretrained_model`. Select one only.")
raise ValueError("Cannot set both `cfg.model.nemo_model` and `cfg.model.pretrained_model`. Select one only.")

trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
exp_log_dir = exp_manager(trainer, cfg.get("exp_manager", None))

if cfg.model.pretrained_model is not None:
model_cfg = ASRModel.from_pretrained(cfg.model.pretrained_model, return_config=True)
Expand All @@ -141,13 +142,25 @@ def main(cfg):
model.setup_multiple_validation_data(cfg.model.validation_ds)

# Setup optimizer
cfg.model.optim = update_model_cfg(model.cfg.optim, cfg.model.optim)
model.setup_optimization(cfg.model.optim)

# Setup spec augmentation
if 'spec_augment' in cfg.model:
model.spec_augmentation = model.from_config_dict(cfg.model.spec_augment)
else:
model.spec_augmentation = None
del model.cfg.spec_augment

# Setup adapters
with open_dict(cfg.model.adapter):
# Extract the name of the adapter (must be give for training)
adapter_name = cfg.model.adapter.pop("adapter_name")
adapter_module_name = cfg.model.adapter.pop("adapter_module_name", None)
adapter_state_dict_name = cfg.model.adapter.pop("adapter_state_dict_name", None)

# augment adapter name with module name, if not provided by user
if adapter_module_name is not None and ':' not in adapter_name:
adapter_name = f'{adapter_module_name}:{adapter_name}'

# Extract the global adapter config, if provided
adapter_global_cfg = cfg.model.adapter.pop(model.adapter_global_cfg_key, None)
Expand All @@ -168,9 +181,23 @@ def main(cfg):
# Then, Unfreeze just the adapter weights that were enabled above (no part of encoder/decoder/joint/etc)
model.unfreeze_enabled_adapters()

# Update model config prior to training
model.cfg = model.cfg

# Finally, train model
trainer.fit(model)

# Save the adapter state dict
if adapter_state_dict_name is not None:
state_path = exp_log_dir if exp_log_dir is not None else os.getcwd()
ckpt_path = os.path.join(state_path, "checkpoints")
if os.path.exists(ckpt_path):
state_path = ckpt_path
state_path = os.path.join(state_path, adapter_state_dict_name)

# Save the adapter modules in a seperate file
model.save_adapters(str(state_path))


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
34 changes: 27 additions & 7 deletions examples/asr/conf/asr_adapters/asr_adaptation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,23 @@ model:
nemo_model: null # path to a ASR model file (.nemo)

adapter:
# Config of the adapter training/eval script.
adapter_name: ??? # Name of the adapter, used by the script
adapter_module_name: null # Name of the preprocessing module, if required. Defaults to None, equivalent to ''.
adapter_state_dict_name: "adapters.pt" # If the individual adapters must be saved, a file name can be provided here. null disables this.

# Config of the adapter module itself
_target_: nemo.collections.common.parts.adapter_modules.LinearAdapter
in_features: ??? # User must provide the output dimension of the layers of the model, which is the input dimension of this adapter.
dim: 32 # The hidden dimension of the adapter, as chosen by user, but small values are preferred to reduce param count.
activation: swish
norm_position: 'post' # Can be `pre` or `post`
dropout: 0.0 # float, dropout for the adapter

# Adapter strategy config
adapter_strategy:
_target_: nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy
stochastic_depth: 0.0 # float, setting to > 0 will enable stochastic depth for each adapter block.

# Optional global config available to all adapters at a global level.
# A global config is shared across every layer of the adapters, defining global properties rather
Expand All @@ -73,6 +84,14 @@ model:
global_cfg:
encoder_adapter: True # ASR adapter key, determines whether to add encoder adapter modules

# Overrides the model's internal spec augment configuration
spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 0
time_masks: 0
freq_width: 27
time_width: 0.05

train_ds:
# train dataset + dataloader config
# sample_rate will be merged with model config
Expand Down Expand Up @@ -113,19 +132,20 @@ model:

optim:
# optimizer arguments
# name will be merged with model config
# betas will be merged with model config
lr: 0.01 # LR depends on the scheduler used by the base model. Noam prefers 0.5, Cosine Annealing prefers 0.02
name: adamw
betas: [0.9, 0.98]
lr: 0.001 # LR depends on the scheduler used by the base model. Noam prefers 0.5, Cosine Annealing prefers 0.02
weight_decay: 0 # During adaptation, since training run is short, WD is not required. Can be set if needed.

# scheduler setup
sched:
# name will be merged with model config
# d_model will be merged with model config
name: CosineAnnealing

# scheduler config override
# min_lr will be merged with model config
warmup_steps: 1000 # Warmup steps should be set, and smaller than the trainer.max_steps set below.
warmup_steps: 100 # Warmup steps should be set, and smaller than the trainer.max_steps set below.
warmup_ratio: null
min_lr: 1e-5
last_epoch: -1

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/modules/conv_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def _update_adapter_cfg_input_dim(self, block: JasperBlock, cfg):
in_planes = cfg['in_features']

if in_planes != block.planes:
logging.info(f"Updating Adapter input dim from {in_planes} to {block.planes}")
logging.info(f"Updating ConvASR Encoder Adapter input dim from {in_planes} to {block.planes}")
in_planes = block.planes

cfg['in_features'] = in_planes
Expand Down
90 changes: 58 additions & 32 deletions nemo/collections/asr/parts/mixins/asr_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,19 @@ def add_adapter(self, name: str, cfg: DictConfig):
name: A globally unique name for the adapter. Will be used to access, enable and disable adapters.
cfg: A DictConfig that contains at the bare minimum `__target__` to instantiate a new Adapter module.
"""
# setup the config for adapters
super().add_adapter(name=name, cfg=cfg)

# Try to retrieve global adapter config
global_config = DictConfig({})
if self.adapter_global_cfg_key in self.cfg.adapters:
global_config = self.adapter_cfg[self.adapter_global_cfg_key]
# Resolve module name and adapter name
module_name, _ = self.resolve_adapter_module_name_(name)

# Update the model.cfg with information about the new adapter from cfg
with open_dict(self.cfg):
# Check if encoder adapters should be added
use_encoder_adapters = global_config.get('encoder_adapter', True)
if use_encoder_adapters:

if module_name in ('', 'encoder'):
# Dispatch the call to the encoder.
self.encoder.add_adapter(name=name, cfg=self.cfg.adapters[name])
self.encoder.add_adapter(name=name, cfg=cfg)

def is_adapter_available(self) -> bool:
"""
Expand All @@ -92,7 +91,12 @@ def is_adapter_available(self) -> bool:
enabled or disabled, false only if no adapters exist.
"""
config_contains_adapter = super().is_adapter_available()
return self.encoder.is_adapter_available() and config_contains_adapter

# Forward the method call to the individual modules
if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin):
config_contains_adapter |= self.encoder.is_adapter_available()

return config_contains_adapter

def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):
"""
Expand All @@ -114,17 +118,17 @@ def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True)
"""
super().set_enabled_adapters(name=name, enabled=enabled)

# Try to retrieve global adapter config
global_config = DictConfig({})
if self.adapter_global_cfg_key in self.cfg.adapters:
global_config = self.cfg.adapters[self.adapter_global_cfg_key]
# Resolve the module name and adapter name
if name is not None:
module_name, _ = self.resolve_adapter_module_name_(name)
else:
module_name = None

# Check if encoder adapters should be used
use_encoder_adapters = global_config.get('encoder_adapter', True)

# Dispatch the call to the encoder.
if use_encoder_adapters:
self.encoder.set_enabled_adapters(name=name, enabled=enabled)
if name is None or module_name in ('', 'encoder'):
if self.encoder.is_adapter_available():
self.encoder.set_enabled_adapters(name=name, enabled=enabled)

def get_enabled_adapters(self) -> List[str]:
"""
Expand All @@ -135,30 +139,18 @@ def get_enabled_adapters(self) -> List[str]:
"""
enabled_adapters = super().get_enabled_adapters()

# Try to retrieve global adapter config
global_config = DictConfig({})
if self.adapter_global_cfg_key in self.cfg.adapters:
global_config = self.cfg.adapters[self.adapter_global_cfg_key]

# Check if encoder adapters should be used
use_encoder_adapters = global_config.get('encoder_adapter', True)

if use_encoder_adapters:
# Check if encoder adapters should be used or are enabled
if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin):
enabled_adapters.extend(self.encoder.get_enabled_adapters())

return enabled_adapters

def _check_valid_model_with_adapter_support(self):
def check_valid_model_with_adapter_support_(self):
"""
Utility method to test if the subclass of this mixin is an appropriate subclass of ModelPT itself.
"""
# Obtain the global adapter config if possible, otherwise use sensible defaults.
global_cfg = DictConfig({})
if hasattr(self, 'adapter_cfg'):
global_cfg = self.adapter_cfg

if self.adapter_global_cfg_key in global_cfg:
global_cfg = global_cfg[self.adapter_global_cfg_key]
global_cfg = self._get_global_cfg()

# Test whether the encoder supports adapters
use_encoder_adapter = global_cfg.get('encoder_adapter', True)
Expand All @@ -167,3 +159,37 @@ def _check_valid_model_with_adapter_support(self):

if use_encoder_adapter and not isinstance(self.encoder, AdapterModuleMixin):
raise ValueError(f'{self.encoder.__class__.__name__} does not implement `AdapterModuleMixin`')

def resolve_adapter_module_name_(self, name: str) -> (str, str):
"""
Utility method to resolve a given global/module adapter name to its components.
Always returns a tuple representing (module_name, adapter_name). ":" is used as the
delimiter for denoting the module name vs the adapter name.
Will attempt to also resolve a given adapter_name alone back to (module_name, adapter_name)
if the metadata config exists for access.
Args:
name: A global adapter, or a module adapter name (with structure module_name:adapter_name).
Returns:
A tuple representing (module_name, adapter_name). If a global adapter is provided,
module_name is set to ''.
"""
module_name, adapter_name = super().resolve_adapter_module_name_(name)

# resolve name and module only for valid modules
valid_module_names = ['', 'encoder']
if module_name not in valid_module_names:
raise ValueError(f"Provided module name `{module_name}` is not in valid list : {valid_module_names}")

return (module_name, adapter_name)

def _get_global_cfg(self):
"""
Utility method, to either extract or construct the global config inside adapters config.
"""
global_config = DictConfig({})
if 'adapters' in self.cfg and self.adapter_global_cfg_key in self.cfg.adapters:
global_config = self.adapter_cfg[self.adapter_global_cfg_key]
return global_config
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import Callable, List, Optional, Tuple
from typing import Callable, Iterable, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -708,10 +708,10 @@ def __init__(
raise ValueError("currently only 'same' padding is supported")

kernel_size_factor = float(kernel_size_factor)
if type(kernel_size) in (list, tuple):
if isinstance(kernel_size, Iterable):
kernel_size = [compute_new_kernel_size(k, kernel_size_factor) for k in kernel_size]
else:
kernel_size = compute_new_kernel_size(kernel_size, kernel_size_factor)
kernel_size = [compute_new_kernel_size(kernel_size, kernel_size_factor)]

if future_context < 0:
padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0])
Expand Down
Loading

0 comments on commit 89994de

Please sign in to comment.