Skip to content

Commit

Permalink
Set sync_batch_comm in other places (#5448)
Browse files Browse the repository at this point in the history
Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
MaximumEntropy committed Nov 17, 2022
1 parent 738e37d commit c170e03
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel_enabled=self.cfg.get("sequence_parallel", False),
sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False),
)
else:
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
Expand All @@ -595,6 +596,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False),
)

# only the last stages of the pipeline return losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel_enabled=False,
sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False),
)
else:
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
Expand All @@ -214,6 +215,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
decoder_sequence_length=dec_seq_length,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False),
)

# only the last stages of the pipeline return losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def forward_step(self, batch, tensor_shape):
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.model.autocast_dtype,
sync_batch_comm=self.model.cfg.get('sync_batch_comm', False),
)
else:
output_tensor = forward_backward_no_pipelining(
Expand All @@ -70,6 +71,7 @@ def forward_step(self, batch, tensor_shape):
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.model.autocast_dtype,
sync_batch_comm=self.model.cfg.get('sync_batch_comm', False),
)
return output_tensor

Expand Down

0 comments on commit c170e03

Please sign in to comment.