Skip to content

Commit

Permalink
fix ds z3 checkpointing when `stage3_gather_16bit_weights_on_model_sa…
Browse files Browse the repository at this point in the history
…ve=False` (huggingface#25817)

* fix ds z3 checkpointing when  `stage3_gather_16bit_weights_on_model_save=False`

* refactoring
  • Loading branch information
pacman100 authored and EduardoPach committed Nov 18, 2023
1 parent eaae29a commit 39b6132
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
nested_numpify,
nested_xla_mesh_reduce,
reissue_pt_warnings,
remove_dummy_checkpoint,
)
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
Expand Down Expand Up @@ -2780,12 +2781,8 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled:
# remove the dummy state_dict saved above
if self.args.should_save:
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)

elif self.is_deepspeed_enabled:
Expand All @@ -2801,6 +2798,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
self._save(output_dir, state_dict={})
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model_wrapped.save_checkpoint(output_dir)

elif self.args.should_save:
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,14 @@ def get_module_class_from_name(module, name):
return module_class


def remove_dummy_checkpoint(is_main_process, output_dir, filenames):
if is_main_process:
for filename in filenames:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)


if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp

Expand Down

0 comments on commit 39b6132

Please sign in to comment.