Skip to content

Commit

Permalink
Support for pipeline parallelism in T5 model
Browse files Browse the repository at this point in the history
- Accumulate encoder hidden state gradient to handle skip connection
- Correctly compute the number of layers in encoder / decoder for T5 model
- Ensure e weights are initialized the same way in embeddings
- Synchronize embedding gradients across encoder and decoder for T5 model
- Support for checkpoint loading and saving
  • Loading branch information
deepakn94 committed Jul 30, 2021
1 parent 6a68098 commit 46c74b4
Show file tree
Hide file tree
Showing 20 changed files with 603 additions and 194 deletions.
9 changes: 9 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={},
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)

# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
Expand Down Expand Up @@ -567,6 +573,9 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
Expand Down
3 changes: 2 additions & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def _initialize_distributed():
else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size)
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)


def _init_autoresume():
Expand Down
1 change: 1 addition & 0 deletions megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .t5_model import T5Model
from .language_model import get_language_model
from .module import Float16Module
from .enums import ModelType
4 changes: 4 additions & 0 deletions megatron/model/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

import enum

class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2

class LayerType(enum.Enum):
encoder = 1
decoder = 2
Expand Down
160 changes: 103 additions & 57 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,

def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
scaled_init_method=None, add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True):
"""Build language model and return along with the key to save."""
Expand All @@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
scaled_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
Expand Down Expand Up @@ -159,6 +161,13 @@ def __init__(self,
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.data.fill_(0)
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)

def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
Expand Down Expand Up @@ -273,6 +282,7 @@ def __init__(self,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
Expand All @@ -286,10 +296,12 @@ def __init__(self,
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None

# Embeddings.
if self.pre_process:
Expand All @@ -302,25 +314,33 @@ def __init__(self,
self._embedding_key = 'embedding'

# Transformer.
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process
)
self._encoder_key = 'encoder'

# Decoder
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process
)
self._encoder_key = 'encoder'
else:
self.encoder = None

# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder:
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type)
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process)
self._decoder_key = 'decoder'
else:
self.decoder = None

if self.post_process:
# Pooler.
Expand All @@ -330,28 +350,48 @@ def __init__(self,

def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')

def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):

# Embeddings.
# Encoder embedding.
if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids,
tokentype_ids=tokentype_ids)
encoder_input = embedding_output
encoder_input = self.embedding(enc_input_ids, enc_position_ids,
tokentype_ids=tokentype_ids)
else:
encoder_input = None

# encoder.
# Run encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input,
enc_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if self.encoder is not None:
encoder_output = self.encoder(encoder_input,
enc_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value)
else:
encoder_output = self.encoder_hidden_state
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)

Expand All @@ -369,11 +409,15 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
else:
return encoder_output

# Decoder Embedding
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(dec_input_ids,
dec_position_ids)
else:
decoder_input = None

# Run decoder.
decoder_output = self.decoder(decoder_input,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
Expand All @@ -394,9 +438,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_encoder:
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] \
Expand Down Expand Up @@ -425,38 +470,39 @@ def load_state_dict(self, state_dict, strict=True):
self.embedding.load_state_dict(state_dict_, strict=strict)

# Encoder.
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]

# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
if self.add_encoder:
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# For backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention

self.encoder.load_state_dict(state_dict_, strict=strict)

# For backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]

# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention

self.encoder.load_state_dict(state_dict_, strict=strict)

# Pooler.
if self.post_process:
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
# Decoder.
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
Expand Down
37 changes: 29 additions & 8 deletions megatron/model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,14 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',


def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \
mpu.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(ignore_virtual=True):
else:
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false')
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')


def initialize_word_embeddings(self, init_method_normal):
Expand All @@ -69,12 +68,12 @@ def initialize_word_embeddings(self, init_method_normal):
'share_word_embeddings is false')

# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return

# Parameters are shared between the word embeddings layer, and the
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
Expand All @@ -97,12 +96,34 @@ def initialize_word_embeddings(self, init_method_normal):
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True

# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
not mpu.is_pipeline_last_stage(ignore_virtual=True) and \
mpu.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()

# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
Expand Down
Loading

0 comments on commit 46c74b4

Please sign in to comment.