Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MTEncDec Finetune support #4540

Merged
merged 17 commits into from
Jul 13, 2022
36 changes: 36 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,7 @@ pipeline {
+exp_manager.create_checkpoint_callback=true \
+exp_manager.resume_if_exists=True \
'
sh 'rm -rf examples/nlp/machine_translation/nmt_results'
}
}

Expand Down Expand Up @@ -1978,6 +1979,41 @@ pipeline {
}
}
}

stage('L2: NMT Attention is All You Need Finetuning') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh "cd examples/nlp/machine_translation && \
python enc_dec_nmt_finetune.py \
model_path=/home/TestData/nlp/nmt/toy_data/en_de_24x6_preln.nemo \
trainer.devices=[0] \
~trainer.max_epochs \
model.train_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.train_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref \
model.validation_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.validation_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.test_ds.src_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
model.test_ds.tgt_file_name=/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src \
+trainer.val_check_interval=10 \
+trainer.limit_val_batches=1 \
+trainer.limit_test_batches=1 \
+trainer.max_steps=10 \
+exp_manager.exp_dir=examples/nlp/machine_translation/nmt_finetune \
+exp_manager.create_checkpoint_callback=True \
+exp_manager.checkpoint_callback_params.monitor=val_sacreBLEU \
+exp_manager.checkpoint_callback_params.mode=max \
+exp_manager.checkpoint_callback_params.save_best_model=true \
"
sh "rm -rf examples/nlp/machine_translation/nmt_finetune"
}
}

stage('L2: NMT with HuggingFace') {
when {
anyOf {
Expand Down
77 changes: 77 additions & 0 deletions examples/nlp/machine_translation/conf/aayn_finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
name: AttentionIsAllYouNeedFinetune
do_training: True # set to False if only preprocessing data
do_testing: False # set to True to run evaluation on test data after training
model_path: ???

model:
train_ds:
src_file_name: null
tgt_file_name: null
use_tarred_dataset: False # if true tar_file_name and meta_file_name will be used (or created automatically)
# config for preprocessing training data and creating a tarred datset automatically
tar_file_prefix: parallel # prefix for tar file names
tar_files: null # if data has already been preprocessed (rest of config ignored)
metadata_file: null # metadata for tarred dataset
lines_per_dataset_fragment: 1000000 # Number of lines to consider for bucketing and padding
num_batches_per_tarfile: 100 # Number of batches (pickle files) within each tarfile
tar_shuffle_n: 100 # How many samples to look ahead and load to be shuffled
shard_strategy: scatter # tarred dataset shard distribution strategy
n_preproc_jobs: -2 # number of processes to use for data preprocessing (-2 means all but 2)
tokens_in_batch: 512
clean: true
max_seq_length: 512
shuffle: true
num_samples: -1
drop_last: false
pin_memory: false
num_workers: 8
concat_sampling_technique: temperature # only used with ConcatTranslationDataset
concat_sampling_temperature: 5 # only used with ConcatTranslationDataset
concat_sampling_probabilities: null # only used with ConcatTranslationDataset

validation_ds:
src_file_name: ???
tgt_file_name: ???
tokens_in_batch: 512
clean: false
max_seq_length: 512
shuffle: false
num_samples: -1
drop_last: false
pin_memory: false
num_workers: 8

test_ds:
src_file_name: ???
tgt_file_name: ???
tokens_in_batch: 512
clean: false
max_seq_length: 512
shuffle: false
num_samples: -1
drop_last: false
pin_memory: false
num_workers: 8

optim:
name: adam
lr: 0.00002
betas:
- 0.9
- 0.98
weight_decay: 0.0

trainer:
devices: 4
num_nodes: 1
max_epochs: 200
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0
accelerator: gpu
enable_checkpointing: False
logger: False
log_every_n_steps: 50 # Interval of logging.
check_val_every_n_epoch: 1

exp_manager:
name: AAYNBaseFineTune
files_to_copy: []
107 changes: 107 additions & 0 deletions examples/nlp/machine_translation/enc_dec_nmt_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional

from omegaconf import OmegaConf
from omegaconf.omegaconf import MISSING
from pytorch_lightning import Trainer

from nemo.collections.nlp.data.machine_translation.preproc_mt_data import MTDataPreproc
from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import MTEncDecModelConfig
from nemo.collections.nlp.models.machine_translation.mt_enc_dec_model import MTEncDecModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin
from nemo.core.config import hydra_runner
from nemo.core.config.modelPT import NemoConfig
from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import logging
from nemo.utils.config_utils import update_model_config
from nemo.utils.exp_manager import ExpManagerConfig, exp_manager


"""
Usage:
python enc_dec_nmt_finetune.py \
model_path=/raid/models/de_en_24x6.nemo \
trainer.devices=2 \
~trainer.max_epochs \
+trainer.max_steps=4500 \
+trainer.val_check_interval=500 \
model.train_ds.tgt_file_name=/raid/data/train_lang_filtered.en \
model.train_ds.src_file_name=/raid/data/train_lang_filtered.de \
model.train_ds.tokens_in_batch=6000 \
model.validation_ds.tgt_file_name=/raid/data/2015.norm.tok.en \
model.validation_ds.src_file_name=/raid/data/2015.norm.tok.de \
model.validation_ds.tokens_in_batch=4000 \
model.test_ds.tgt_file_name=/raid/data/2015.en \
model.test_ds.src_file_name=/raid/data/2015.de \
+exp_manager.exp_dir=/raid/results/finetune-test \
+exp_manager.create_checkpoint_callback=True \
+exp_manager.checkpoint_callback_params.monitor=val_sacreBLEU \
+exp_manager.checkpoint_callback_params.mode=max \
+exp_manager.checkpoint_callback_params.save_best_model=true
"""


@dataclass
class MTFineTuneConfig(NemoConfig):
name: Optional[str] = 'MTEncDec'
model_path: str = MISSING
do_training: bool = True
do_testing: bool = False
model: MTEncDecModelConfig = MTEncDecModelConfig()
trainer: Optional[TrainerConfig] = TrainerConfig()
exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(name='MTEncDec', files_to_copy=[])


@hydra_runner(config_path="conf", config_name="aayn_finetune")
def main(cfg: MTFineTuneConfig) -> None:
# merge default config with user specified config
default_cfg = MTFineTuneConfig()
default_cfg.model = MTEncDecModel.restore_from(restore_path=cfg.model_path, return_config=True)
del default_cfg.model.optim, default_cfg.model.train_ds, default_cfg.model.validation_ds, default_cfg.model.test_ds
aklife97 marked this conversation as resolved.
Show resolved Hide resolved
cfg = update_model_config(default_cfg, cfg, drop_missing_subconfigs=False)
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'Config: {OmegaConf.to_yaml(cfg)}')

# training is managed by PyTorch Lightning
trainer_cfg = OmegaConf.to_container(cfg.trainer)
trainer_cfg.pop('plugins', None)
trainer = Trainer(plugins=[NLPDDPPlugin()], **trainer_cfg)

# experiment logs, checkpoints, and auto-resume are managed by exp_manager and PyTorch Lightning
exp_manager(trainer, cfg.exp_manager)

# everything needed to train translation models is encapsulated in the NeMo MTEncdDecModel
mt_model = MTEncDecModel.restore_from(restore_path=cfg.model_path, override_config_path=cfg.model, trainer=trainer)

mt_model.setup_training_data(cfg.model.train_ds)
mt_model.setup_multiple_validation_data(val_data_config=cfg.model.validation_ds)

logging.info("\n\n************** Model parameters and their sizes ***********")
for name, param in mt_model.named_parameters():
print(name, param.size())
logging.info("***********************************************************\n\n")

if cfg.do_training:
trainer.fit(mt_model)

if cfg.do_testing:
mt_model.setup_multiple_test_data(test_data_config=cfg.model.test_ds)
trainer.test(mt_model)


if __name__ == '__main__':
main()