Skip to content

Commit

Permalink
1. Fixed PP, TP, PP+TP still fails.
Browse files Browse the repository at this point in the history
Signed-off-by: Micha Livne <mlivne@cs.toronto.edu>
  • Loading branch information
michalivne committed Jul 28, 2022
1 parent ca45e6d commit 61acb6e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
Expand Up @@ -900,13 +900,14 @@ def dummy():
num_micro_batches_before_decode = get_num_microbatches()
# Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc".
# reconfigure back to how things were before encode
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding.
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
if reconfigure_microbatch:
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding.
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size]

# build input arguments description
Expand Down Expand Up @@ -948,11 +949,11 @@ def dummy():
dtype=self.autocast_dtype,
)

# get output tensor of encoder [batch, seq_len, hidden]
# get output tensor of encoder [seq_len, batch, hidden]
if parallel_state.is_pipeline_last_stage():
output_tensor = output_tensor[0]['hiddens']
else:
output_tensor = torch.zeros(tensor_shape, dtype=self.autocast_dtype).cuda().transpose(0, 1).contiguous()
output_tensor = torch.zeros(tensor_shape, dtype=self.autocast_dtype).cuda()

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# Broadcast from the last pipeline stage to all other model-parallel ranks.
Expand All @@ -972,7 +973,8 @@ def dummy():
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)

return output_tensor
# Return the output tensor of encoder and transpose from [batch, seq_len, hidden] to [seq_len, batch, hidden]
return output_tensor.transpose(1, 0)

def decode(
self,
Expand Down Expand Up @@ -1032,6 +1034,7 @@ def dummy():

# get encoder hiddens (output)
if enc_output is None:
# Encode returns a tensr of shape [batch, seq_len, hidden]
enc_output = self.encode(
tokens_enc=tokens_enc, enc_mask=enc_mask, encoder_input=encoder_input, reconfigure_microbatch=False
)
Expand Down
Expand Up @@ -379,6 +379,8 @@ def forward(
decoder_cross_attention_relative_position_bias,
) = (None, None, None)

enc_output_provided = enc_output is not None

if (enc_input is None) and (enc_input_ids is not None):
if self.pre_process and self.add_encoder:
# We don't need position ids for RPE, because the embedding layer does not have position embeddings.
Expand Down Expand Up @@ -426,6 +428,10 @@ def forward(
# Note: This is when the decoder itself is split across PP ranks.
dec_input = None

# If enc_output is provided in `batch_for_pipeline`, we need to transpose it from [B x S x H] -> [S x B x H].
if enc_output_provided:
enc_output = enc_output.transpose(0, 1)

output = self.enc_dec_model(
enc_input=enc_input,
enc_attn_mask=enc_attn_mask,
Expand Down

0 comments on commit 61acb6e

Please sign in to comment.