Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lddl bert #6761

Merged
merged 36 commits into from Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
24ef67f
initial POC for LDDL Bert
wdykas Jan 26, 2023
4c2aa3d
Finish LDDL POC
wdykas Feb 8, 2023
33ed786
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2023
debc4ac
address comments
wdykas Feb 13, 2023
405525e
resolve mergre
wdykas Feb 13, 2023
1fd7b7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2023
41183dc
fix merge head
wdykas Feb 13, 2023
d0ceb82
Merge branch 'LDDL-Bert' of github.com:wdykas/NeMo into LDDL-Bert
wdykas Feb 13, 2023
7914d9e
resolving merge
wdykas Feb 13, 2023
555518c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2023
bc8555b
Merge branch 'main' of github.com:wdykas/NeMo into LDDL-Bert
wdykas Mar 1, 2023
6813273
Merge branch 'NVIDIA:main' into LDDL-Bert
wdykas Mar 8, 2023
7465a75
add support for val/test loaders
wdykas Mar 13, 2023
f9684a9
Merge branch 'LDDL-Bert' of github.com:wdykas/NeMo into LDDL-Bert
wdykas Mar 13, 2023
34f67b9
Merge branch 'NVIDIA:main' into LDDL-Bert
wdykas Mar 15, 2023
e0b5d54
change to new LDDL class + add winding
wdykas Mar 31, 2023
60dd17d
fix logging level
wdykas Mar 31, 2023
d13e34a
fix winding
wdykas Apr 6, 2023
b909b1e
test fix
wdykas Apr 24, 2023
a7efa66
fixes to winding
wdykas Apr 25, 2023
edbd657
add file system
wdykas Apr 26, 2023
95f671f
add prepemption optimizations
wdykas Apr 28, 2023
90cb15f
more logging
wdykas May 1, 2023
c0735d8
more prints
wdykas May 1, 2023
05f2a68
better logging
wdykas May 1, 2023
bd10702
asfsf
wdykas May 1, 2023
046b0f4
add barrier
wdykas May 1, 2023
2950be4
removing prints
wdykas May 2, 2023
0dbe532
merge and broken code
wdykas May 4, 2023
7ed30a1
working with mb lddl loader
wdykas May 10, 2023
fc6e499
final changes
wdykas May 30, 2023
c071789
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2023
7a4e381
update requirements file with LDDL
wdykas Jun 1, 2023
461cf6c
Merge branch 'r1.19.0' into LDDL-Bert
ericharper Jun 1, 2023
033dc88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
e08140e
revert adding to requirements
wdykas Jun 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -133,7 +133,7 @@ model:
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 0
dataloader_type: single # cyclic
dataloader_type: single # cyclic, LDDL
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
Expand Down
5 changes: 3 additions & 2 deletions examples/nlp/language_modeling/megatron_bert_pretraining.py
Expand Up @@ -29,11 +29,12 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)


@hydra_runner(config_path="conf", config_name="megatron_bert_config")
def main(cfg) -> None:
if cfg.model.data.dataloader_type != "LDDL":
mp.set_start_method("spawn", force=True)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

Expand Down
129 changes: 122 additions & 7 deletions nemo/collections/nlp/models/language_modeling/megatron_bert_model.py
Expand Up @@ -40,6 +40,7 @@
from nemo.core.neural_types import ChannelType, MaskType, NeuralType
from nemo.utils import AppState, logging


try:
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

Expand All @@ -49,6 +50,14 @@

HAVE_APEX = False

try:
import logging
from lddl.torch_mp import get_bert_pretrain_data_loader
ericharper marked this conversation as resolved.
Show resolved Hide resolved

HAVE_LDDL = True
except (ImportError, ModuleNotFoundError):
HAVE_LDDL = False

try:
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
Expand Down Expand Up @@ -295,7 +304,12 @@ def training_step(self, dataloader_iter, batch_idx):
for param in module.embedding.parameters():
param.data_ptr()

tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]
if self.cfg.data.dataloader_type == "LDDL":
# this is of type bert dataset
seq_length = dataloader_iter.iterator.loaders.get_seqlen()
tensor_shape = [seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]
else:
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
Expand All @@ -319,7 +333,10 @@ def training_step(self, dataloader_iter, batch_idx):
loss_tensor = torch.vstack(loss_tensors_list)
loss_mean = loss_tensor.mean(axis=0)
else:
loss_mean = torch.tensor([0.0, 0.0]).cuda()
if self.cfg.bert_binary_head == True:
loss_mean = torch.tensor([0.0, 0.0, 0.0]).cuda()
else:
loss_mean = torch.tensor([0.0, 0.0]).cuda()

# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False):
Expand Down Expand Up @@ -399,7 +416,12 @@ def allreduce_first_last_embeddings(self):
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())

def validation_step(self, dataloader_iter, batch_idx):
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

if self.cfg.data.dataloader_type == "LDDL":
seq_length = dataloader_iter.iterator.get_seqlen()
tensor_shape = [seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]
else:
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

fwd_bwd_function = get_forward_backward_func()

Expand Down Expand Up @@ -471,6 +493,95 @@ def loss_func(self, loss_mask, sentence_order, output_tensor):
# [lm_loss])
# return loss, {'lm loss': averaged_losses[0]}

def build_LDDL_data(self, cfg):
if not HAVE_LDDL:
raise ImportError(
"LDDL was not found. Please see the LDDL README for installation instructions: https://github.com/NVIDIA/LDDL#installation."
)
logging.info(f'Starting building LDDL Dataloaders')
self._train_ds = None
self._validation_ds = None
self._test_ds = None
data_parallel_size = parallel_state.get_data_parallel_world_size()
num_micro_batches = self.cfg.global_batch_size // (self.cfg.micro_batch_size * data_parallel_size)
global_batch_size_on_this_data_parallel_rank = num_micro_batches * self.cfg.micro_batch_size
samples_consumed_dploader = self.compute_consumed_samples(0) // data_parallel_size
# We run under the assumption that the datapath is the prefix if LDDL dataloader
train_lddl_data_path = self.cfg.data.data_prefix[0]
self._train_dl = get_bert_pretrain_data_loader(
train_lddl_data_path,
dp_rank=parallel_state.get_data_parallel_rank(),
local_rank=self.local_rank,
shuffle_buffer_size=16384,
shuffle_buffer_warmup_factor=16,
vocab_file=self.cfg.tokenizer.vocab_file,
data_loader_kwargs={
'batch_size': global_batch_size_on_this_data_parallel_rank,
'num_workers': self.cfg.data.num_workers,
'prefetch_factor': 2,
},
mlm_probability=0.15,
base_seed=self.cfg.seed,
log_level=logging.CRITICAL,
log_dir="/tmp/log",
return_raw_samples=False,
start_epoch=0,
sequence_length_alignment=8,
ignore_index=-1,
samples_seen=samples_consumed_dploader,
micro_batch_size=self.cfg.micro_batch_size,
)
logging.info(f'Completed build train LDDL Dataloader')
if len(self.cfg.data.data_prefix) > 1:
val_lddl_data_path = self.cfg.data.data_prefix[1]
self._validation_dl = get_bert_pretrain_data_loader(
val_lddl_data_path,
dp_rank=parallel_state.get_data_parallel_rank(),
local_rank=self.local_rank,
shuffle_buffer_size=16384,
shuffle_buffer_warmup_factor=16,
vocab_file=self.cfg.tokenizer.vocab_file,
data_loader_kwargs={
'batch_size': global_batch_size_on_this_data_parallel_rank,
'num_workers': self.cfg.data.num_workers,
'prefetch_factor': 2,
},
mlm_probability=0.15,
base_seed=self.cfg.seed,
log_level=logging.CRITICAL,
log_dir="/tmp/log",
return_raw_samples=False,
start_epoch=0,
sequence_length_alignment=8,
ignore_index=-1,
micro_batch_size=self.cfg.micro_batch_size,
)
if len(self.cfg.data.data_prefix) > 2:
test_lddl_data_path = self.cfg.data.data_prefix[2]
self._test_dl = get_bert_pretrain_data_loader(
test_lddl_data_path,
dp_rank=parallel_state.get_data_parallel_rank(),
local_rank=self.local_rank,
shuffle_buffer_size=16384,
shuffle_buffer_warmup_factor=16,
vocab_file=self.cfg.tokenizer.vocab_file,
data_loader_kwargs={
'batch_size': global_batch_size_on_this_data_parallel_rank,
'num_workers': self.cfg.data.num_workers,
'prefetch_factor': 2,
},
mlm_probability=0.15,
base_seed=self.cfg.seed,
log_level=logging.CRITICAL,
log_dir="/tmp/log",
return_raw_samples=False,
start_epoch=0,
sequence_length_alignment=8,
ignore_index=-1,
micro_batch_size=self.cfg.micro_batch_size,
)
logging.info(f'Finished building LDDL Dataloaders')

def build_train_valid_test_datasets(self):
logging.info('Building Bert datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
Expand Down Expand Up @@ -576,10 +687,14 @@ def setup(self, stage=None):
else:
# TODO: consider adding a ModelPT guard to check if model is being restored.
# allowing restored models to optionally setup datasets
self.build_train_valid_test_datasets()
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)
if self.cfg.data.dataloader_type == "LDDL":
self.build_LDDL_data(self.cfg.data)
torch.distributed.barrier()
else:
self.build_train_valid_test_datasets()
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)

# when using pipeline model parallel the final stage need to initialize word embeddings
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
Expand Down