Skip to content

Commit

Permalink
Lddl bert (#6761)
Browse files Browse the repository at this point in the history
* initial POC for LDDL Bert

* Finish LDDL POC

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address comments

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge head

* resolving merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add support for  val/test loaders

* change to new LDDL class + add winding

* fix logging level

* fix winding

* test fix

* fixes to winding

* add file system

* add prepemption optimizations

* more logging

* more prints

* better logging

* asfsf

* add barrier

* removing prints

* working with mb lddl loader

* final changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update requirements file with LDDL

Signed-off-by: wdykas <wdykas@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert adding to requirements

---------

Signed-off-by: wdykas <wdykas@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
3 people committed Jun 1, 2023
1 parent 1486b12 commit aff5217
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 10 deletions.
Expand Up @@ -133,7 +133,7 @@ model:
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 0
dataloader_type: single # cyclic
dataloader_type: single # cyclic, LDDL
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
Expand Down
5 changes: 3 additions & 2 deletions examples/nlp/language_modeling/megatron_bert_pretraining.py
Expand Up @@ -29,11 +29,12 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)


@hydra_runner(config_path="conf", config_name="megatron_bert_config")
def main(cfg) -> None:
if cfg.model.data.dataloader_type != "LDDL":
mp.set_start_method("spawn", force=True)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

Expand Down
129 changes: 122 additions & 7 deletions nemo/collections/nlp/models/language_modeling/megatron_bert_model.py
Expand Up @@ -40,6 +40,7 @@
from nemo.core.neural_types import ChannelType, MaskType, NeuralType
from nemo.utils import AppState, logging


try:
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

Expand All @@ -49,6 +50,14 @@

HAVE_APEX = False

try:
import logging
from lddl.torch_mp import get_bert_pretrain_data_loader

HAVE_LDDL = True
except (ImportError, ModuleNotFoundError):
HAVE_LDDL = False

try:
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
Expand Down Expand Up @@ -300,7 +309,12 @@ def training_step(self, dataloader_iter, batch_idx):
for param in module.embedding.parameters():
param.data_ptr()

tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]
if self.cfg.data.dataloader_type == "LDDL":
# this is of type bert dataset
seq_length = dataloader_iter.iterator.loaders.get_seqlen()
tensor_shape = [seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]
else:
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
Expand All @@ -324,7 +338,10 @@ def training_step(self, dataloader_iter, batch_idx):
loss_tensor = torch.vstack(loss_tensors_list)
loss_mean = loss_tensor.mean(axis=0)
else:
loss_mean = torch.tensor([0.0, 0.0]).cuda()
if self.cfg.bert_binary_head == True:
loss_mean = torch.tensor([0.0, 0.0, 0.0]).cuda()
else:
loss_mean = torch.tensor([0.0, 0.0]).cuda()

# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False):
Expand Down Expand Up @@ -404,7 +421,12 @@ def allreduce_first_last_embeddings(self):
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())

def validation_step(self, dataloader_iter, batch_idx):
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

if self.cfg.data.dataloader_type == "LDDL":
seq_length = dataloader_iter.iterator.get_seqlen()
tensor_shape = [seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]
else:
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

fwd_bwd_function = get_forward_backward_func()

Expand Down Expand Up @@ -476,6 +498,95 @@ def loss_func(self, loss_mask, sentence_order, output_tensor):
# [lm_loss])
# return loss, {'lm loss': averaged_losses[0]}

def build_LDDL_data(self, cfg):
if not HAVE_LDDL:
raise ImportError(
"LDDL was not found. Please see the LDDL README for installation instructions: https://github.com/NVIDIA/LDDL#installation."
)
logging.info(f'Starting building LDDL Dataloaders')
self._train_ds = None
self._validation_ds = None
self._test_ds = None
data_parallel_size = parallel_state.get_data_parallel_world_size()
num_micro_batches = self.cfg.global_batch_size // (self.cfg.micro_batch_size * data_parallel_size)
global_batch_size_on_this_data_parallel_rank = num_micro_batches * self.cfg.micro_batch_size
samples_consumed_dploader = self.compute_consumed_samples(0) // data_parallel_size
# We run under the assumption that the datapath is the prefix if LDDL dataloader
train_lddl_data_path = self.cfg.data.data_prefix[0]
self._train_dl = get_bert_pretrain_data_loader(
train_lddl_data_path,
dp_rank=parallel_state.get_data_parallel_rank(),
local_rank=self.local_rank,
shuffle_buffer_size=16384,
shuffle_buffer_warmup_factor=16,
vocab_file=self.cfg.tokenizer.vocab_file,
data_loader_kwargs={
'batch_size': global_batch_size_on_this_data_parallel_rank,
'num_workers': self.cfg.data.num_workers,
'prefetch_factor': 2,
},
mlm_probability=0.15,
base_seed=self.cfg.seed,
log_level=logging.CRITICAL,
log_dir="/tmp/log",
return_raw_samples=False,
start_epoch=0,
sequence_length_alignment=8,
ignore_index=-1,
samples_seen=samples_consumed_dploader,
micro_batch_size=self.cfg.micro_batch_size,
)
logging.info(f'Completed build train LDDL Dataloader')
if len(self.cfg.data.data_prefix) > 1:
val_lddl_data_path = self.cfg.data.data_prefix[1]
self._validation_dl = get_bert_pretrain_data_loader(
val_lddl_data_path,
dp_rank=parallel_state.get_data_parallel_rank(),
local_rank=self.local_rank,
shuffle_buffer_size=16384,
shuffle_buffer_warmup_factor=16,
vocab_file=self.cfg.tokenizer.vocab_file,
data_loader_kwargs={
'batch_size': global_batch_size_on_this_data_parallel_rank,
'num_workers': self.cfg.data.num_workers,
'prefetch_factor': 2,
},
mlm_probability=0.15,
base_seed=self.cfg.seed,
log_level=logging.CRITICAL,
log_dir="/tmp/log",
return_raw_samples=False,
start_epoch=0,
sequence_length_alignment=8,
ignore_index=-1,
micro_batch_size=self.cfg.micro_batch_size,
)
if len(self.cfg.data.data_prefix) > 2:
test_lddl_data_path = self.cfg.data.data_prefix[2]
self._test_dl = get_bert_pretrain_data_loader(
test_lddl_data_path,
dp_rank=parallel_state.get_data_parallel_rank(),
local_rank=self.local_rank,
shuffle_buffer_size=16384,
shuffle_buffer_warmup_factor=16,
vocab_file=self.cfg.tokenizer.vocab_file,
data_loader_kwargs={
'batch_size': global_batch_size_on_this_data_parallel_rank,
'num_workers': self.cfg.data.num_workers,
'prefetch_factor': 2,
},
mlm_probability=0.15,
base_seed=self.cfg.seed,
log_level=logging.CRITICAL,
log_dir="/tmp/log",
return_raw_samples=False,
start_epoch=0,
sequence_length_alignment=8,
ignore_index=-1,
micro_batch_size=self.cfg.micro_batch_size,
)
logging.info(f'Finished building LDDL Dataloaders')

def build_train_valid_test_datasets(self):
logging.info('Building Bert datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
Expand Down Expand Up @@ -581,10 +692,14 @@ def setup(self, stage=None):
else:
# TODO: consider adding a ModelPT guard to check if model is being restored.
# allowing restored models to optionally setup datasets
self.build_train_valid_test_datasets()
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)
if self.cfg.data.dataloader_type == "LDDL":
self.build_LDDL_data(self.cfg.data)
torch.distributed.barrier()
else:
self.build_train_valid_test_datasets()
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)

# when using pipeline model parallel the final stage need to initialize word embeddings
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
Expand Down

0 comments on commit aff5217

Please sign in to comment.