Skip to content

Commit

Permalink
Update seq par reset/restore
Browse files Browse the repository at this point in the history
Signed-off-by: Markel Sanz Ausin <markelsanz14@gmail.com>
  • Loading branch information
markelsanz14 committed May 30, 2023
1 parent eb89c4a commit 4bb046f
Showing 1 changed file with 6 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,9 @@ def _reset_sequence_parallelism_args(self):
self.cfg.sequence_parallel = False

# Reset model parameters.

for module in self.get_gpt_module_list():
module.language_model.encoder.sequence_parallel = False
for module in self.model.modules():
if hasattr(module, "sequence_parallel"):
module.sequence_parallel = False

def _restore_sequence_parallelism_args(self):
""" Restores the sequence parallelism parameters using the values saved by
Expand All @@ -1186,5 +1186,6 @@ def _restore_sequence_parallelism_args(self):
self.cfg.sequence_parallel = self.last_sequence_parallel

# Restore model parameters.
for module in self.get_gpt_module_list():
module.language_model.encoder.sequence_parallel = self.last_sequence_parallel
for module in self.model.modules():
if hasattr(module, "sequence_parallel"):
module.sequence_parallel = self.last_sequence_parallel

0 comments on commit 4bb046f

Please sign in to comment.