Skip to content

Commit

Permalink
[Cherry pick] Sharding reshard function enhancement (#8544)
Browse files Browse the repository at this point in the history
* fix bug of sharding format (#8483)

* Optimize the speed of set_state_dict (#8532)

* fix sharding reshard save (#8535)

* Fix ignore_data_skip bug when timer is enabled (#8536)

* Save parameter shape and dtype when using sharding reshard (#8543)

* save parameter shape and dtype

* refactor

* format pre-commit

---------

Co-authored-by: ShenLiang <2282912238@qq.com>
  • Loading branch information
sneaxiy and ForFishes committed Jun 11, 2024
1 parent 162d8d3 commit bcacc6a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 14 deletions.
31 changes: 23 additions & 8 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,12 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
base_weight_name=weight_name,
model_wrapped=self.model_wrapped,
)
self.model.set_state_dict(state_dict)
old_state_dict = self.model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
if k not in old_state_dict or id(v) != id(old_state_dict[k]):
new_state_dict[k] = v
self.model.set_state_dict(new_state_dict)
else:
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel):

Expand Down Expand Up @@ -891,7 +896,8 @@ def _inner_training_loop(

npu_accelerate_plugin(self.optimizer)

self.timers and self.timers("read-data").start()
if self.args.ignore_data_skip:
self.timers and self.timers("read-data").start()

for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
Expand All @@ -907,7 +913,9 @@ def _inner_training_loop(
inputs = split_inputs_sequence_dim(inputs)
if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1:
inputs = split_inputs_sequence_dim_load_balance(inputs)
self.timers and self.timers("read-data").stop()
if self.args.ignore_data_skip:
self.timers and self.timers("read-data").stop()

os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step)
self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs)

Expand Down Expand Up @@ -1098,7 +1106,9 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):

if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.timers and self.timers("read-data").start()

if self.args.ignore_data_skip:
self.timers and self.timers("read-data").start()

if step < 0:
logger.warning(
Expand Down Expand Up @@ -2462,10 +2472,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
if state_dict is None:
state_dict = self.model.state_dict()

self._save_ckpt_func(
state_dict,
os.path.join(output_dir, _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)),
)
if self.args.should_save_sharding_stage1_model:
state_dict, _, _ = self.sharding_io.manipulate_state_dict_and_config(
unwrap_model(self.model), merge_tensor_parallel=False, state_dict=state_dict
)
variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.sharded_name_suffix())
else:
variant = _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)

self._save_ckpt_func(state_dict, os.path.join(output_dir, variant))
else:
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
config_to_save = None
Expand Down
18 changes: 12 additions & 6 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,13 @@ def reshard_sharding(node_model_state):
node_model_state = reshard_pp(node_model_state)
return reshard_sharding(node_model_state)

def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False):
def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False, state_dict=None):
weight_name_suffix = self.args.sharded_name_suffix()

state_dict = model_to_save.state_dict()
if self.args.should_save_sharding_stage1_model:
state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group)
if state_dict is None:
state_dict = model_to_save.state_dict()
if self.args.should_save_sharding_stage1_model:
state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group)

config_to_save = None
merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel
Expand Down Expand Up @@ -384,7 +385,7 @@ def save_distributed_model_meta(self, dir):

path = os.path.join(dir, MODEL_META_NAME)
with open(path, "w") as f:
json.dump(model_meta, f, indent=4)
json.dump(model_meta, f)

def _get_distributed_strategy(self):
pp_degree = 1
Expand Down Expand Up @@ -544,13 +545,18 @@ def _gather_sharding_metas(self):
pp_overlap = unwrap_optimizer(self.optimizer, DygraphShardingOptimizerV2).pp_overlap

model = self.model
structure_name_mapping = {k: v.name for (k, v) in model.state_dict().items()}
structure_name_mapping = {}
param_meta = {}
for k, v in model.state_dict().items():
structure_name_mapping[k] = v.name
param_meta[k] = (v.shape, int(v.dtype))

sharding_metas = {}
sharding_meta = {}

sharding_meta["param2rank"] = param2rank
sharding_meta["structure_name_mapping"] = structure_name_mapping
sharding_meta["param_meta"] = param_meta
sharding_meta["sharding_strategy"] = sharding_strategy
sharding_meta["enable_overlap"] = pp_overlap
suffix = f"tp{self.args.tensor_parallel_rank:0>2d}_pp{self.args.pipeline_parallel_rank:0>2d}"
Expand Down

0 comments on commit bcacc6a

Please sign in to comment.