Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async save for optimizer #8557

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 91 additions & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import collections
import contextlib
import copy
import inspect
import math
import multiprocessing
import os
import random
import re
Expand Down Expand Up @@ -184,6 +186,75 @@

__all__ = ["Trainer"]

async_save_queue = []
g_cpu_optimizer_state_dict = {}


def _save_func(obj, path, saved_signal_path, protocol):
paddle.save(obj, path, protocol)
# dump savd_siganl
with open(saved_signal_path, mode="w+") as f:
f.write("1")

Check warning on line 198 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L198

Added line #L198 was not covered by tests

def check_exitcode(task):
exitcode = task.exitcode

Check warning on line 201 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L200-L201

Added lines #L200 - L201 were not covered by tests
if exitcode != 0:
print(f"Error: save ckpt process failed with exitcode {exitcode}!!!")


def clear_async_save_task_queue():
"""

Check warning on line 207 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L205-L207

Added lines #L205 - L207 were not covered by tests
wait until all async save task to be done.
"""
while len(async_save_queue) > 0:
task = async_save_queue.pop()
if task and task.is_alive():
task.join(timeout=60)
if task.is_alive():
logger.error("Error: save ckpt process timeout!!!")
async_save_queue.append(task)
else:
check_exitcode(task)
else:
check_exitcode(task)

Check warning on line 220 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L214-L220

Added lines #L214 - L220 were not covered by tests


Check warning on line 222 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L222

Added line #L222 was not covered by tests
def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4):
global g_cpu_optimizer_state_dict

Check warning on line 224 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L224

Added line #L224 was not covered by tests
g_cpu_optimizer_state_dict.clear()
for k, v in optimizer_state_dict.items():
if k == "master_weights":
g_cpu_optimizer_state_dict[k] = {}
for kk, vv in v.items():
tensor_name = vv.name
g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
g_cpu_optimizer_state_dict[k][kk].name = tensor_name
elif k == "LR_Scheduler":
g_cpu_optimizer_state_dict[k] = copy.deepcopy(v)
else:
g_cpu_optimizer_state_dict[k] = v.pin_memory()
paddle.device.synchronize()
clear_async_save_task_queue()

Check warning on line 238 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L229-L238

Added lines #L229 - L238 were not covered by tests

attempt = 0
ctx = multiprocessing.get_context("spawn")

Check warning on line 242 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L240-L242

Added lines #L240 - L242 were not covered by tests
def start_process():
nonlocal attempt
try:

Check warning on line 245 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L244-L245

Added lines #L244 - L245 were not covered by tests
p = ctx.Process(target=_save_func, args=(g_cpu_optimizer_state_dict, path, saved_signal_path, protocol))
p.start()

Check warning on line 247 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L247

Added line #L247 was not covered by tests
return p
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
attempt += 1
time.sleep(1)
return start_process()

p = start_process()
async_save_queue.append(p)

Check warning on line 257 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L249-L257

Added lines #L249 - L257 were not covered by tests

class Trainer:
"""
Expand Down Expand Up @@ -1332,7 +1403,7 @@
metrics = None
if self.control.should_evaluate:
if isinstance(self.optimizer, GroupShardedOptimizerStage2) and self.optimizer._broadcast_overlap:
paddle.device.cuda.synchronize()
paddle.device.synchronize()

if isinstance(self.eval_dataset, dict):
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
Expand All @@ -1346,7 +1417,7 @@

if self.control.should_save:
if isinstance(self.optimizer, GroupShardedOptimizerStage2) and self.optimizer._broadcast_overlap:
paddle.device.cuda.synchronize()
paddle.device.synchronize()

Check warning on line 1420 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1420

Added line #L1420 was not covered by tests

self._save_checkpoint(model, metrics=metrics)
logger.info(f"{self.runtime_timer.log()}")
Expand Down Expand Up @@ -2208,6 +2279,10 @@
def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")

if self.args.use_async_save:
clear_async_save_task_queue()

Check warning on line 2285 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2285

Added line #L2285 was not covered by tests
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

Expand All @@ -2225,6 +2300,7 @@
# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")

if self.args.use_hybrid_parallel:
if self.dp_group.rank <= 0 or self.args.use_expert_parallel:
Expand All @@ -2245,10 +2321,19 @@
os.path.join(output_dir, optimizer_name),
)
else:
self._save_ckpt_func(
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
async_save_optimizer(
state_dict,
save_path,
saved_signal_path=saved_signal_path,
)

Check warning on line 2332 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2328-L2332

Added lines #L2328 - L2332 were not covered by tests
else:
self._save_ckpt_func(state_dict, save_path)
with open(saved_signal_path, mode="w+") as f:
f.write("1")

if self.args.should_save or self.args.use_expert_parallel:
if not self.args.use_hybrid_parallel:
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,10 @@ class TrainingArguments:
default=True,
metadata={"help": "Whether use lazy data processing."},
)
use_async_save: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use async_save instead of paddle.save."},
)
skip_profile_timer: Optional[bool] = field(
default=True,
metadata={"help": "enable framework timer, will output timeline informatoin in logging and visualdl."},
Expand Down
Loading