Skip to content

Commit

Permalink
Add RETRO model for pretraining (#4121)
Browse files Browse the repository at this point in the history
* add model

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added model and config

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added mock dataset

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added mock dataset

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added model changes

Signed-off-by: Yi Dong <yidong@nvidia.com>

* working training

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix style and some issue

Signed-off-by: Yi Dong <yidong@nvidia.com>

* make sure mock dataset behave

Signed-off-by: Yi Dong <yidong@nvidia.com>

* calculate the embedding from retrieved ids

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added eos attention mask

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added unit test for attention mask

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix unittest error

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added requirments

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added apex guard

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added jenkins test

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix Jenkins

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix lgtm

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix typo

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix the filename

Signed-off-by: Yi Dong <yidong@nvidia.com>

* addressed the comments

Signed-off-by: Yi Dong <yidong@nvidia.com>

* remove unused imports

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix bf16 support

Signed-off-by: Yi Dong <yidong@nvidia.com>

* add pad if the id is negative

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix dataset

Signed-off-by: Yi Dong <yidong@nvidia.com>

* use MegatronBase model for all models

Signed-off-by: Yi Dong <yidong@nvidia.com>

* moved tokenzizer and vocab_size to the base model

Signed-off-by: Yi Dong <yidong@nvidia.com>

* move more shared things to the base

Signed-off-by: Yi Dong <yidong@nvidia.com>

* changed the name

Signed-off-by: Yi Dong <yidong@nvidia.com>

* move the special tokens out of baseclass

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix unit test error

Signed-off-by: Yi Dong <yidong@nvidia.com>

* move the t5 specific check out of base

Signed-off-by: Yi Dong <yidong@nvidia.com>

* add guard for tokenizer

Signed-off-by: Yi Dong <yidong@nvidia.com>

* remove bad import

Signed-off-by: Yi Dong <yidong@nvidia.com>

* make sure graident clipping works for prompt learning

Signed-off-by: Yi Dong <yidong@nvidia.com>

* make sure setup_optimizer_param groups works for t5

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added doc string

Signed-off-by: Yi Dong <yidong@nvidia.com>

* add missing property

Signed-off-by: Yi Dong <yidong@nvidia.com>

* added default gradient behavior

Signed-off-by: Yi Dong <yidong@nvidia.com>

* gradient clip for all parameters

Signed-off-by: Yi Dong <yidong@nvidia.com>

* use pl default gradient clip

Signed-off-by: Yi Dong <yidong@nvidia.com>

* fix configure optimizer param

Signed-off-by: Yi Dong <yidong@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
  • Loading branch information
yidong72 and okuchaiev committed May 14, 2022
1 parent 08df199 commit 27129ab
Show file tree
Hide file tree
Showing 15 changed files with 1,072 additions and 369 deletions.
68 changes: 68 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2343,6 +2343,74 @@ pipeline {
sh "rm -rf examples/nlp/language_modeling/bert_index_mappings"
}
}
stage('L2: Megatron RETRO Pretraining and Resume Training') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh "python examples/nlp/language_modeling/megatron_retro_pretraining.py \
trainer.devices=2 \
trainer.num_nodes=1 \
trainer.accelerator=gpu \
trainer.accumulate_grad_batches=1 \
trainer.limit_val_batches=2 \
exp_manager.resume_if_exists=True \
trainer.max_steps=10 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
trainer.val_check_interval=20 \
exp_manager.exp_dir=examples/nlp/language_modeling/retro_results \
model.data.data_prefix='' \
model.tensor_model_parallel_size=2 \
model.micro_batch_size=4 \
model.optim.name=fused_adam \
model.optim.lr=2e-4 \
model.optim.sched.warmup_steps=2 \
model.optim.sched.constant_steps=2 \
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.chunk_size=32 \
model.enc_num_layers=2 \
model.dec_num_layers=2 \
model.enc_cross_attention=[1] \
model.dec_cross_attention=[1] \
model.data.mock=True"
sh "python examples/nlp/language_modeling/megatron_retro_pretraining.py \
trainer.devices=2 \
trainer.num_nodes=1 \
trainer.accelerator=gpu \
trainer.accumulate_grad_batches=1 \
trainer.limit_val_batches=2 \
exp_manager.resume_if_exists=True \
trainer.max_steps=30 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
trainer.val_check_interval=20 \
exp_manager.exp_dir=examples/nlp/language_modeling/retro_results \
model.data.data_prefix='' \
model.tensor_model_parallel_size=2 \
model.micro_batch_size=4 \
model.optim.name=fused_adam \
model.optim.lr=2e-4 \
model.optim.sched.warmup_steps=2 \
model.optim.sched.constant_steps=2 \
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.chunk_size=32 \
model.enc_num_layers=2 \
model.dec_num_layers=2 \
model.enc_cross_attention=[1] \
model.dec_cross_attention=[1] \
model.data.mock=True"
sh "rm -rf examples/nlp/language_modeling/retro_results"
}
}
stage('L2: BioMegatron Bert NER Task') {
when {
anyOf {
Expand Down
142 changes: 142 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_retro_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
name: test_retro
restore_from_path: null # used when starting from a .nemo file

trainer:
devices: 2
num_nodes: 1
accelerator: gpu
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
max_epochs: 1000 # PTL default. In practice we don't usually train for more than 1 epoch.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: 50
limit_test_batches: 500
accumulate_grad_batches: 1
gradient_clip_val: 1.0

exp_manager:
explicit_log_dir: null
exp_dir: null
name: megatron_retro
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
filename: 'megatron_retro--{val_loss:.2f}-{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}


model:
# model parallelism
micro_batch_size: 4
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1 # has to be one. not supporting pipeline parallel yet

# model architecture
encoder_seq_length: 2048
max_position_embeddings: ${.encoder_seq_length}
num_layers: 12
hidden_size: 768
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 12
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # drop out probability for attention matrix
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
persist_layer_norm: False
bias_gelu_fusion: True
bias_dropout_add_fusion: True
masked_softmax_fusion: True
activation: 'gelu'
bias: True

# retro architecture
chunk_size: 64 # the chunk size used to retrive
enc_num_layers: 4 # total number of encoder layers
dec_num_layers: 6 # total number of decoder layers
enc_cross_attention: [3] # layer numbers for cross attention in encoder
dec_cross_attention: [3, 5] # layer numbers for chunked cross attention in decoder
add_position_embedding: False # whether use the absolute position encoding

make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
bert_binary_head: True # BERT binary head

tokenizer:
library: 'megatron'
type: 'GPT2BPETokenizer'
model: null
vocab_file: null
merge_file: null
delimiter: null # only used for tabular tokenizer

# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16


# miscellaneous
seed: 1234
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.

# not implemented in NeMo yet
activations_checkpoint_method: null # 'uniform', 'block'
activations_checkpoint_num_layers: 1

data:
# Path to data must be specified by the user.
# can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]",
# Or see example below:
# data_prefix:
# - .5
# - /raid/data/pile/my-gpt3_00_text_document
# - .5
# - /raid/data/pile/my-gpt3_01_text_document
data_prefix: ???
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_impl: mmap
splits_string: 900,50,50
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 0
dataloader_type: single # cyclic
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
masked_lm_prob: 0.15 # Probability of replacing a token with mask.
short_seq_prob: 0.1 # Probability of producing a short sequence.
neighbors: 2 # number of retrieved neighbors
mock: True # whether use mock dataset
mock_data_size: 10000 # the mock dataset size

optim:
name: fused_adam
lr: 2e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 2e-5
80 changes: 80 additions & 0 deletions examples/nlp/language_modeling/megatron_retro_pretraining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, NLPDDPPlugin
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager


@hydra_runner(config_path="conf", config_name="megatron_retro_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

plugins = [
NLPDDPPlugin(
no_ddp_communication_hook=False,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
]

if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, **cfg.trainer)

exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
# resume_from_checkpoint = uninject_model_parallel_rank(resume_from_checkpoint)
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
# Override timer callback to a stateless one
for idx, callback in enumerate(trainer.callbacks):
if isinstance(callback, Timer):
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)

# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision

model = MegatronRetrievalModel(cfg.model, trainer)

trainer.fit(model)


if __name__ == '__main__':
main()
Loading

0 comments on commit 27129ab

Please sign in to comment.