-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RETRO model for pretraining (#4121)
* add model Signed-off-by: Yi Dong <yidong@nvidia.com> * added model and config Signed-off-by: Yi Dong <yidong@nvidia.com> * added mock dataset Signed-off-by: Yi Dong <yidong@nvidia.com> * added mock dataset Signed-off-by: Yi Dong <yidong@nvidia.com> * added model changes Signed-off-by: Yi Dong <yidong@nvidia.com> * working training Signed-off-by: Yi Dong <yidong@nvidia.com> * fix style and some issue Signed-off-by: Yi Dong <yidong@nvidia.com> * make sure mock dataset behave Signed-off-by: Yi Dong <yidong@nvidia.com> * calculate the embedding from retrieved ids Signed-off-by: Yi Dong <yidong@nvidia.com> * added eos attention mask Signed-off-by: Yi Dong <yidong@nvidia.com> * added unit test for attention mask Signed-off-by: Yi Dong <yidong@nvidia.com> * fix unittest error Signed-off-by: Yi Dong <yidong@nvidia.com> * added requirments Signed-off-by: Yi Dong <yidong@nvidia.com> * added apex guard Signed-off-by: Yi Dong <yidong@nvidia.com> * added jenkins test Signed-off-by: Yi Dong <yidong@nvidia.com> * fix Jenkins Signed-off-by: Yi Dong <yidong@nvidia.com> * fix lgtm Signed-off-by: Yi Dong <yidong@nvidia.com> * fix typo Signed-off-by: Yi Dong <yidong@nvidia.com> * fix the filename Signed-off-by: Yi Dong <yidong@nvidia.com> * addressed the comments Signed-off-by: Yi Dong <yidong@nvidia.com> * remove unused imports Signed-off-by: Yi Dong <yidong@nvidia.com> * fix bf16 support Signed-off-by: Yi Dong <yidong@nvidia.com> * add pad if the id is negative Signed-off-by: Yi Dong <yidong@nvidia.com> * fix dataset Signed-off-by: Yi Dong <yidong@nvidia.com> * use MegatronBase model for all models Signed-off-by: Yi Dong <yidong@nvidia.com> * moved tokenzizer and vocab_size to the base model Signed-off-by: Yi Dong <yidong@nvidia.com> * move more shared things to the base Signed-off-by: Yi Dong <yidong@nvidia.com> * changed the name Signed-off-by: Yi Dong <yidong@nvidia.com> * move the special tokens out of baseclass Signed-off-by: Yi Dong <yidong@nvidia.com> * fix unit test error Signed-off-by: Yi Dong <yidong@nvidia.com> * move the t5 specific check out of base Signed-off-by: Yi Dong <yidong@nvidia.com> * add guard for tokenizer Signed-off-by: Yi Dong <yidong@nvidia.com> * remove bad import Signed-off-by: Yi Dong <yidong@nvidia.com> * make sure graident clipping works for prompt learning Signed-off-by: Yi Dong <yidong@nvidia.com> * make sure setup_optimizer_param groups works for t5 Signed-off-by: Yi Dong <yidong@nvidia.com> * added doc string Signed-off-by: Yi Dong <yidong@nvidia.com> * add missing property Signed-off-by: Yi Dong <yidong@nvidia.com> * added default gradient behavior Signed-off-by: Yi Dong <yidong@nvidia.com> * gradient clip for all parameters Signed-off-by: Yi Dong <yidong@nvidia.com> * use pl default gradient clip Signed-off-by: Yi Dong <yidong@nvidia.com> * fix configure optimizer param Signed-off-by: Yi Dong <yidong@nvidia.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
- Loading branch information
Showing
15 changed files
with
1,072 additions
and
369 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
examples/nlp/language_modeling/conf/megatron_retro_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
name: test_retro | ||
restore_from_path: null # used when starting from a .nemo file | ||
|
||
trainer: | ||
devices: 2 | ||
num_nodes: 1 | ||
accelerator: gpu | ||
precision: 16 | ||
logger: False # logger provided by exp_manager | ||
enable_checkpointing: False | ||
replace_sampler_ddp: False | ||
max_epochs: 1000 # PTL default. In practice we don't usually train for more than 1 epoch. | ||
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches | ||
log_every_n_steps: 10 | ||
val_check_interval: 100 | ||
limit_val_batches: 50 | ||
limit_test_batches: 500 | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: 1.0 | ||
|
||
exp_manager: | ||
explicit_log_dir: null | ||
exp_dir: null | ||
name: megatron_retro | ||
create_wandb_logger: False | ||
wandb_logger_kwargs: | ||
project: null | ||
name: null | ||
resume_if_exists: True | ||
resume_ignore_no_checkpoint: True | ||
create_checkpoint_callback: True | ||
checkpoint_callback_params: | ||
monitor: val_loss | ||
save_top_k: 10 | ||
mode: min | ||
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel | ||
filename: 'megatron_retro--{val_loss:.2f}-{step}-{consumed_samples}' | ||
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} | ||
|
||
|
||
model: | ||
# model parallelism | ||
micro_batch_size: 4 | ||
tensor_model_parallel_size: 1 | ||
pipeline_model_parallel_size: 1 # has to be one. not supporting pipeline parallel yet | ||
|
||
# model architecture | ||
encoder_seq_length: 2048 | ||
max_position_embeddings: ${.encoder_seq_length} | ||
num_layers: 12 | ||
hidden_size: 768 | ||
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size. | ||
num_attention_heads: 12 | ||
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') | ||
hidden_dropout: 0.1 # Dropout probability for hidden state transformer. | ||
attention_dropout: 0.1 # drop out probability for attention matrix | ||
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null | ||
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. | ||
layernorm_epsilon: 1e-5 | ||
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) | ||
persist_layer_norm: False | ||
bias_gelu_fusion: True | ||
bias_dropout_add_fusion: True | ||
masked_softmax_fusion: True | ||
activation: 'gelu' | ||
bias: True | ||
|
||
# retro architecture | ||
chunk_size: 64 # the chunk size used to retrive | ||
enc_num_layers: 4 # total number of encoder layers | ||
dec_num_layers: 6 # total number of decoder layers | ||
enc_cross_attention: [3] # layer numbers for cross attention in encoder | ||
dec_cross_attention: [3, 5] # layer numbers for chunked cross attention in decoder | ||
add_position_embedding: False # whether use the absolute position encoding | ||
|
||
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. | ||
pre_process: True # add embedding | ||
post_process: True # add pooler | ||
bert_binary_head: True # BERT binary head | ||
|
||
tokenizer: | ||
library: 'megatron' | ||
type: 'GPT2BPETokenizer' | ||
model: null | ||
vocab_file: null | ||
merge_file: null | ||
delimiter: null # only used for tabular tokenizer | ||
|
||
# precision | ||
native_amp_init_scale: 4294967296 # 2 ** 32 | ||
native_amp_growth_interval: 1000 | ||
fp32_residual_connection: False # Move residual connections to fp32 | ||
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 | ||
|
||
|
||
# miscellaneous | ||
seed: 1234 | ||
use_cpu_initialization: False # Init weights on the CPU (slow for large models) | ||
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. | ||
|
||
# not implemented in NeMo yet | ||
activations_checkpoint_method: null # 'uniform', 'block' | ||
activations_checkpoint_num_layers: 1 | ||
|
||
data: | ||
# Path to data must be specified by the user. | ||
# can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", | ||
# Or see example below: | ||
# data_prefix: | ||
# - .5 | ||
# - /raid/data/pile/my-gpt3_00_text_document | ||
# - .5 | ||
# - /raid/data/pile/my-gpt3_01_text_document | ||
data_prefix: ??? | ||
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix | ||
data_impl: mmap | ||
splits_string: 900,50,50 | ||
seq_length: ${model.encoder_seq_length} | ||
skip_warmup: True | ||
num_workers: 0 | ||
dataloader_type: single # cyclic | ||
reset_position_ids: False # Reset position ids after end-of-document token | ||
reset_attention_mask: False # Reset attention mask after end-of-document token | ||
eod_mask_loss: False # Mask loss for the end of document tokens | ||
masked_lm_prob: 0.15 # Probability of replacing a token with mask. | ||
short_seq_prob: 0.1 # Probability of producing a short sequence. | ||
neighbors: 2 # number of retrieved neighbors | ||
mock: True # whether use mock dataset | ||
mock_data_size: 10000 # the mock dataset size | ||
|
||
optim: | ||
name: fused_adam | ||
lr: 2e-4 | ||
weight_decay: 0.01 | ||
betas: | ||
- 0.9 | ||
- 0.98 | ||
sched: | ||
name: CosineAnnealing | ||
warmup_steps: 500 | ||
constant_steps: 50000 | ||
min_lr: 2e-5 |
80 changes: 80 additions & 0 deletions
80
examples/nlp/language_modeling/megatron_retro_pretraining.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from omegaconf.omegaconf import OmegaConf, open_dict | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks.timer import Timer | ||
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment | ||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin | ||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector | ||
|
||
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel | ||
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, NLPDDPPlugin | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
from nemo.utils.exp_manager import StatelessTimer, exp_manager | ||
|
||
|
||
@hydra_runner(config_path="conf", config_name="megatron_retro_config") | ||
def main(cfg) -> None: | ||
logging.info("\n\n************** Experiment configuration ***********") | ||
logging.info(f'\n{OmegaConf.to_yaml(cfg)}') | ||
|
||
plugins = [ | ||
NLPDDPPlugin( | ||
no_ddp_communication_hook=False, | ||
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, | ||
find_unused_parameters=False, | ||
) | ||
] | ||
|
||
if cfg.trainer.precision in [16, 'bf16']: | ||
scaler = None | ||
if cfg.trainer.precision == 16: | ||
scaler = GradScaler( | ||
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), | ||
growth_interval=cfg.model.get('native_amp_growth_interval', 1000), | ||
hysteresis=cfg.model.get('hysteresis', 2), | ||
) | ||
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) | ||
|
||
if cfg.get('cluster_type', None) == 'BCP': | ||
plugins.append(TorchElasticEnvironment()) | ||
|
||
trainer = Trainer(plugins=plugins, **cfg.trainer) | ||
|
||
exp_manager(trainer, cfg.exp_manager) | ||
|
||
# update resume from checkpoint found by exp_manager | ||
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path | ||
# resume_from_checkpoint = uninject_model_parallel_rank(resume_from_checkpoint) | ||
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') | ||
|
||
trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) | ||
# Override timer callback to a stateless one | ||
for idx, callback in enumerate(trainer.callbacks): | ||
if isinstance(callback, Timer): | ||
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) | ||
|
||
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams | ||
with open_dict(cfg): | ||
cfg.model.precision = cfg.trainer.precision | ||
|
||
model = MegatronRetrievalModel(cfg.model, trainer) | ||
|
||
trainer.fit(model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.