-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New api about checkpoint and models #10878
Changes from 36 commits
b044724
dca0b6d
73b6723
b2cb7c6
514b242
5eea5db
5f5d6a9
ad9dfeb
486e1e3
9086043
9357078
0211c5d
0deb6f9
d712af2
b44ede8
94eaf94
e44c278
bca4da4
46f2688
7973d9b
55d908c
7734034
c06f43b
08e5f0a
bfdcf18
9735f25
be16af3
eea5762
951fa74
3b5e3f9
6db240d
f28f41d
53409a2
2f44585
cb7c124
7fbddaa
9e026a9
5c8397a
bf2c53a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,8 @@ | |
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', | ||
'load_persistables', 'save_inference_model', 'load_inference_model', | ||
'get_inference_program', 'save_checkpoint', 'load_checkpoint', | ||
'clean_checkpoint' | ||
'clean_checkpoint', 'load_persist_vars_without_grad', | ||
'save_persist_vars_without_grad', 'get_latest_checkpoint_serial' | ||
] | ||
|
||
|
||
|
@@ -457,95 +458,163 @@ def get_parameter_value_by_name(name, executor, program=None): | |
|
||
SUCCESS_MARK_FILENAME = "_SUCCESS" | ||
CHECKPOINT_PREFIX = "checkpoint" | ||
MODEL_DIR = "__model__" | ||
TRAINER_PREFIX = "trainer" | ||
CHECKPOINT_SEPARATOR = "_" | ||
|
||
|
||
def save_checkpoint(executor, | ||
checkpoint_dir=None, | ||
max_num_checkpoints=3, | ||
save_interval_secs=600, | ||
main_program=None): | ||
checkpoint_dir, | ||
trainer_id, | ||
is_chief=False, | ||
trainer_args=None, | ||
main_program=None, | ||
max_num_checkpoints=3): | ||
""" | ||
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, | ||
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy | ||
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, | ||
The interval between two saved checkpoints must greater than save_interval_secs. | ||
|
||
:param executor | ||
:param checkpoint_dir | ||
:param max_num_checkpoints | ||
:param save_interval_secs | ||
:param main_program | ||
:param executor executor for save the value | ||
:param checkpoint_dir the checkpoint directory | ||
:param trainer_id currect trainer id | ||
:param is_chief if the trainer id equals 0, the is_chief will be true | ||
:param main_program will save all variables in program | ||
:param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints | ||
""" | ||
if checkpoint_dir is None: | ||
checkpoint_dir = os.getcwd() | ||
raise ValueError("'checkpoint_dir' should not be None") | ||
|
||
if trainer_args: | ||
assert isinstance(trainer_args, dict) | ||
|
||
if not os.path.isdir(checkpoint_dir): | ||
os.makedirs(checkpoint_dir) | ||
|
||
serial = _get_lastest_checkpoint_dir(checkpoint_dir) | ||
if serial >= 0 and not _interval_secs_exceed( | ||
_get_serial_dir(serial, checkpoint_dir), save_interval_secs): | ||
return | ||
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 | ||
cur_dir = _get_serial_dir(checkpoint_dir, serial) | ||
|
||
serial += 1 | ||
cur_dir = _get_serial_dir(serial, checkpoint_dir) | ||
save_trainer_args(cur_dir, trainer_id, trainer_args) | ||
|
||
save_vars( | ||
executor, | ||
dirname=cur_dir, | ||
main_program=main_program, | ||
vars=None, | ||
predicate=_is_checkpoint_var, | ||
filename=None) | ||
_write_success(cur_dir) | ||
_lru_delete(checkpoint_dir, max_num_checkpoints) | ||
if is_chief: | ||
save_persist_vars_without_grad(executor, cur_dir, main_program) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks all gradient vars are all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First, I find arguments named "X@GRAD" are |
||
|
||
_scroll_delete(checkpoint_dir, max_num_checkpoints) | ||
|
||
|
||
def load_checkpoint(executor, checkpoint_dir=None, main_program=None): | ||
def load_checkpoint(executor, checkpoint_dir, serial, main_program): | ||
""" | ||
Load checkpoint from a directory by executor, | ||
it will find the most recent saved checkpoint file and load it auto. | ||
|
||
:param executor | ||
:param checkpoint_dir | ||
:param main_program | ||
:param executor executor for load the value | ||
:param checkpoint_dir the checkpoint directory | ||
:param serial the serial folder in checkpoint directory will be load | ||
:param main_program will load all variables in program | ||
""" | ||
|
||
if checkpoint_dir is None: | ||
checkpoint_dir = os.getcwd() | ||
raise ValueError("'checkpoint_dir' should not be None") | ||
|
||
serial = _get_lastest_checkpoint_dir(checkpoint_dir) | ||
if serial is None or serial < 0: | ||
raise ValueError("'serial' should not be None or <0 ") | ||
|
||
if serial < 0: | ||
return | ||
if main_program is None: | ||
raise ValueError('main_program should not be None.') | ||
|
||
cur_dir = _get_serial_dir(serial, checkpoint_dir) | ||
|
||
load_vars( | ||
executor, | ||
dirname=cur_dir, | ||
main_program=main_program, | ||
predicate=_is_checkpoint_var, | ||
filename=None) | ||
cur_dir = _get_serial_dir(checkpoint_dir, serial) | ||
load_persist_vars_without_grad(executor, cur_dir, main_program, True) | ||
|
||
|
||
def clean_checkpoint(checkpoint_dir, delete_dir=False): | ||
""" | ||
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. | ||
delete_dir only works when the directory is empty, otherwise, OSError is raised. | ||
|
||
:param checkpoint_dir | ||
:param delete_dir | ||
""" | ||
|
||
if checkpoint_dir is None: | ||
checkpoint_dir = os.getcwd() | ||
_lru_delete(checkpoint_dir, max_num_checkpoints=0) | ||
raise ValueError("'checkpoint_dir' should not be None") | ||
_scroll_delete(checkpoint_dir, max_num_checkpoints=0) | ||
|
||
if delete_dir and not os.listdir(checkpoint_dir): | ||
os.rmdir(checkpoint_dir) | ||
|
||
|
||
def _get_serial_dir(serial, checkpoint_dir): | ||
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) | ||
return os.path.join(checkpoint_dir, serial_folder) | ||
def load_persist_vars_without_grad(executor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. write |
||
dirname, | ||
program, | ||
has_model_dir=False): | ||
""" | ||
load_persist_vars_without_grad will load variables from a directory by an executor, | ||
the variable named end with "@GRAD" will not be loaded. | ||
|
||
:param executor executor for load the value | ||
:param dirname the checkpoint directory | ||
:param program will load all variables in program | ||
:param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__ | ||
""" | ||
|
||
if has_model_dir: | ||
dirname = _get_model_dir(dirname) | ||
|
||
load_vars( | ||
executor, | ||
dirname=dirname, | ||
main_program=program, | ||
predicate=_is_checkpoint_var, | ||
filename=None) | ||
|
||
|
||
def save_persist_vars_without_grad(executor, dirname, program): | ||
""" | ||
save_persist_vars_without_grad will save variables to a directory by an executor, | ||
the variable named end with "@GRAD" will not be saved. | ||
|
||
:param executor executor for load the value | ||
:param dirname the checkpoint directory | ||
:param program will load all variables in program | ||
""" | ||
cur_dir = _get_model_dir(dirname) | ||
save_vars( | ||
executor, | ||
dirname=cur_dir, | ||
main_program=program, | ||
vars=None, | ||
predicate=_is_checkpoint_var, | ||
filename=None) | ||
_write_success(cur_dir) | ||
|
||
|
||
def save_trainer_args(dirname, trainer_id, trainer_args): | ||
assert isinstance(trainer_args, dict) | ||
|
||
cur_dir = _get_trainer_dir(dirname, trainer_id) | ||
|
||
for name, value in trainer_args.iteritems(): | ||
args_file = os.path.join(cur_dir, name) | ||
with open(args_file, 'w') as f: | ||
f.write(str(value)) | ||
_write_success(cur_dir) | ||
|
||
|
||
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): | ||
assert isinstance(trainer_args, list) | ||
|
||
cur_dir = _get_serial_dir(checkpoint_dir, serial) | ||
cur_dir = _get_trainer_dir(cur_dir, trainer_id) | ||
|
||
ret_values = [] | ||
|
||
for arg in trainer_args: | ||
cur_file = os.path.join(cur_dir, arg) | ||
with open(cur_file, 'r') as f: | ||
contents = f.read() | ||
ret_values.append(contents.strip()) | ||
return ret_values | ||
|
||
|
||
def _is_checkpoint_var(var): | ||
|
@@ -559,36 +628,74 @@ def _is_checkpoint_var(var): | |
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ | ||
var.desc.type() == core.VarDesc.VarType.RAW: | ||
return False | ||
# @GRAD are named for gradient variables, checkpoint will not save it. | ||
if "@GRAD" in var.name: | ||
return False | ||
# .trainer_ are named for distribute train variables, checkpoint will not save it. | ||
if ".trainer_" in var.name: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add some comments to explain what's the meaning of the hard code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return False | ||
|
||
if var.name.endswith("@GRAD"): | ||
# .block is named for distribute train variables, checkpoint will not save it. | ||
if ".block" in var.name: | ||
return False | ||
|
||
return var.persistable | ||
|
||
|
||
def _interval_secs_exceed(dirname, save_interval_secs): | ||
dir_time = os.path.getmtime(dirname) | ||
if save_interval_secs > (time.time() - dir_time): | ||
return False | ||
return True | ||
def _get_dir_serial(dirname): | ||
_, serial = dirname.split(CHECKPOINT_SEPARATOR) | ||
|
||
try: | ||
serial_num = int(serial) | ||
except ValueError: | ||
serial_num = -1 | ||
return serial_num | ||
|
||
|
||
def _get_serial_dir(dirname, serial): | ||
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) | ||
serial_dir = os.path.join(dirname, serial_folder) | ||
|
||
if not os.path.isdir(serial_dir): | ||
os.makedirs(serial_dir) | ||
|
||
return serial_dir | ||
|
||
|
||
def _get_model_dir(dirname): | ||
model_dir = os.path.join(dirname, MODEL_DIR) | ||
|
||
def _lru_delete(dirname, max_num_checkpoints=3): | ||
if not os.path.isdir(model_dir): | ||
os.makedirs(model_dir) | ||
|
||
return model_dir | ||
|
||
|
||
def _get_trainer_dir(dirname, trainer_id): | ||
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) | ||
trainer_dir = os.path.join(dirname, trainer_folder) | ||
|
||
if not os.path.isdir(trainer_dir): | ||
os.makedirs(trainer_dir) | ||
|
||
return trainer_dir | ||
|
||
|
||
def _scroll_delete(dirname, max_num_checkpoints=3): | ||
dirs = os.listdir(dirname) | ||
serials = [] | ||
serial_map = {} | ||
for serial in dirs: | ||
try: | ||
serials.append(int(serial)) | ||
except ValueError: | ||
continue | ||
serial_num = _get_dir_serial(serial) | ||
serial_map[serial_num] = serial | ||
|
||
if len(serials) <= max_num_checkpoints: | ||
if len(serial_map.keys()) <= max_num_checkpoints: | ||
return | ||
|
||
serials = serial_map.keys() | ||
serials.sort(reverse=True) | ||
serials = serials[max_num_checkpoints:] | ||
for serial in serials: | ||
cur_dir = os.path.join(dirname, str(serial)) | ||
cur_dir = _get_serial_dir(dirname, serial) | ||
shutil.rmtree(cur_dir) | ||
|
||
|
||
|
@@ -604,33 +711,30 @@ def _write_success(dirname): | |
f.write(now) | ||
|
||
|
||
def _get_lastest_checkpoint_dir(checkpoint_dir): | ||
def get_latest_checkpoint_serial(checkpoint_dir): | ||
""" | ||
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory | ||
|
||
:param checkpoint_dir | ||
""" | ||
if not checkpoint_dir.strip(): | ||
if not checkpoint_dir: | ||
return -1 | ||
|
||
def has_success(checkpoint_dir, cur_dir): | ||
""" | ||
is _SUCCESS in this dir | ||
""" | ||
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR) | ||
|
||
try: | ||
int(serial) | ||
except ValueError: | ||
return -1 | ||
|
||
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): | ||
serial = _get_dir_serial(cur_dir) | ||
if serial == -1 or not os.path.isdir( | ||
os.path.join(checkpoint_dir, cur_dir)): | ||
return -1 | ||
|
||
success_path = os.path.join( | ||
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME) | ||
_get_serial_dir(checkpoint_dir, serial), MODEL_DIR, | ||
SUCCESS_MARK_FILENAME) | ||
if os.path.isfile(success_path): | ||
return int(serial) | ||
return serial | ||
|
||
if not os.path.isdir(checkpoint_dir): | ||
return -1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If have
is_chief
why still need to passtrainer_id
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each
trainer
need to save its arguments practicality.Only
chief
need to savevariables
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have deleted code about
chief