Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New api about checkpoint and models #10878

Merged
merged 39 commits into from
Jun 10, 2018
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b044724
update fluid Train API param_path to checkpoint_config
seiriosPlus May 22, 2018
dca0b6d
restore param_path
seiriosPlus May 23, 2018
73b6723
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_api…
seiriosPlus May 23, 2018
b2cb7c6
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_api…
seiriosPlus May 28, 2018
514b242
add save/load persist_vars_without_grad
seiriosPlus May 28, 2018
5eea5db
optimized checkpoint and save_model
seiriosPlus May 29, 2018
5f5d6a9
optimized checkpoint and save_model
seiriosPlus May 29, 2018
ad9dfeb
bug fix and optimize
seiriosPlus May 29, 2018
486e1e3
bug fix and optimize
seiriosPlus May 29, 2018
9086043
bug fix and optimize
seiriosPlus May 29, 2018
9357078
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_api…
seiriosPlus May 29, 2018
0211c5d
bug fix
seiriosPlus May 30, 2018
0deb6f9
annotation optimized and code style optimized
seiriosPlus May 30, 2018
d712af2
add distribute config
seiriosPlus May 30, 2018
b44ede8
bug fix
seiriosPlus May 30, 2018
94eaf94
bug fix about lru and save
seiriosPlus May 30, 2018
e44c278
bug fix about clean
seiriosPlus May 30, 2018
bca4da4
cancle only chief delete files
seiriosPlus May 30, 2018
46f2688
bug fix
seiriosPlus May 31, 2018
7973d9b
bug fix
seiriosPlus Jun 1, 2018
55d908c
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_api…
seiriosPlus Jun 1, 2018
7734034
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_api…
seiriosPlus Jun 1, 2018
c06f43b
add annotation about _is_checkpoint_var
seiriosPlus Jun 4, 2018
08e5f0a
rename need_load_checkpoint to get_latest_checkpoint_serial
seiriosPlus Jun 4, 2018
bfdcf18
grammar optimized.
seiriosPlus Jun 4, 2018
9735f25
optimized
seiriosPlus Jun 5, 2018
be16af3
delete pyc
seiriosPlus Jun 5, 2018
eea5762
add checkpoint unittest
seiriosPlus Jun 5, 2018
951fa74
add checkpoint unittest
seiriosPlus Jun 5, 2018
3b5e3f9
update checkpoint unittest
seiriosPlus Jun 5, 2018
6db240d
update trainer about epoch_id and step id
seiriosPlus Jun 5, 2018
f28f41d
update io.py annotations and codes
seiriosPlus Jun 5, 2018
53409a2
code optimized
seiriosPlus Jun 5, 2018
2f44585
code optimized
seiriosPlus Jun 6, 2018
cb7c124
code optimized
seiriosPlus Jun 6, 2018
7fbddaa
bug fix
seiriosPlus Jun 6, 2018
9e026a9
remove chief
seiriosPlus Jun 7, 2018
5c8397a
remove chief in test
seiriosPlus Jun 8, 2018
bf2c53a
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_api…
seiriosPlus Jun 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from trainer import EndEpochEvent
from trainer import BeginStepEvent
from trainer import EndStepEvent
from trainer import CheckpointConfig

import inferencer
from inferencer import Inferencer
Expand Down
230 changes: 168 additions & 62 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint'
'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad', 'get_latest_checkpoint_serial'
]


Expand Down Expand Up @@ -457,14 +458,18 @@ def get_parameter_value_by_name(name, executor, program=None):

SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__"
TRAINER_PREFIX = "trainer"
CHECKPOINT_SEPARATOR = "_"


def save_checkpoint(executor,
checkpoint_dir=None,
max_num_checkpoints=3,
save_interval_secs=600,
main_program=None):
checkpoint_dir,
trainer_id,
is_chief=False,
trainer_args=None,
main_program=None,
max_num_checkpoints=3):
"""
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
Expand All @@ -473,79 +478,143 @@ def save_checkpoint(executor,

:param executor
:param checkpoint_dir
:param max_num_checkpoints
:param save_interval_secs
:param trainer_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more details comments.

:param is_chief
:param main_program
:param max_num_checkpoints
"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
if checkpoint_dir.strip() is None:
raise ValueError("'checkpoint_dir' should not be None")

if trainer_args:
assert isinstance(trainer_args, dict)

if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)

serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial >= 0 and not _interval_secs_exceed(
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
return
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)

serial += 1
cur_dir = _get_serial_dir(serial, checkpoint_dir)
save_trainer_args(cur_dir, trainer_id, trainer_args)

if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks all gradient vars are all not persistent, so maybe the function name would shorter for save_persistent_vars? BTW, persist is a verb, we need the adjective one: persistent .

Copy link
Collaborator Author

@seiriosPlus seiriosPlus Jun 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, I find arguments named "X@GRAD" are persistent .
Second, save_persistent_vars do not filter RAW arguments.


save_vars(
executor,
dirname=cur_dir,
main_program=main_program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(checkpoint_dir, max_num_checkpoints)


def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
"""
Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto.

:param executor
:param checkpoint_dir
:param serial
:param main_program
"""

if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
if checkpoint_dir.strip() is None:
raise ValueError("'checkpoint_dir' should not be None")

serial = _get_lastest_checkpoint_dir(checkpoint_dir)

if serial < 0:
return
if serial is None or serial < 0:
raise ValueError("'serial' should not be None or <0 ")

cur_dir = _get_serial_dir(serial, checkpoint_dir)
if main_program is None:
raise ValueError('main_program should not be None.')

load_vars(
executor,
dirname=cur_dir,
main_program=main_program,
predicate=_is_checkpoint_var,
filename=None)
cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program, True)


def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.

:param checkpoint_dir
:param delete_dir
"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()

if checkpoint_dir.strip() is None:
raise ValueError("'checkpoint_dir' should not be None")
_lru_delete(checkpoint_dir, max_num_checkpoints=0)

if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir)


def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)
def load_persist_vars_without_grad(executor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write load_persist_vars_without_grad just because of the filter.
I need to write a new filter to filter variables.

dirname,
program,
has_model_dir=False):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.

:param executor executor for load the value
:param dirname the checkpoint directory
:param program will load all variables in program
:param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
"""

if has_model_dir:
dirname = _get_model_dir(dirname)

load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)


def save_persist_vars_without_grad(executor, dirname, program):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.

:param executor executor for load the value
:param dirname the checkpoint directory
:param program will load all variables in program
"""
cur_dir = _get_model_dir(dirname)
save_vars(
executor,
dirname=cur_dir,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)


def save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict)

cur_dir = _get_trainer_dir(dirname, trainer_id)

for name, value in trainer_args.iteritems():
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
_write_success(cur_dir)


def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
assert isinstance(trainer_args, list)

cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id)

ret_values = []

for arg in trainer_args:
cur_file = os.path.join(cur_dir, arg)
with open(cur_file, 'r') as f:
contents = f.read()
ret_values.append(contents.strip())
return ret_values


def _is_checkpoint_var(var):
Expand All @@ -559,36 +628,74 @@ def _is_checkpoint_var(var):
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
return False
# @GRAD are named for gradient variables, checkpoint will not save it.
if "@GRAD" in var.name:
return False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if ".trainer_" in var.name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments to explain what's the meaning of the hard code .blcok and .trainer_ ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return False

if var.name.endswith("@GRAD"):
# .block is named for distribute train variables, checkpoint will not save it.
if ".block" in var.name:
return False

return var.persistable


def _interval_secs_exceed(dirname, save_interval_secs):
dir_time = os.path.getmtime(dirname)
if save_interval_secs > (time.time() - dir_time):
return False
return True
def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR)

try:
serial_num = int(serial)
except ValueError:
serial_num = -1
return serial_num


def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder)

if not os.path.isdir(serial_dir):
os.makedirs(serial_dir)

return serial_dir


def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR)

if not os.path.isdir(model_dir):
os.makedirs(model_dir)

return model_dir


def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)

if not os.path.isdir(trainer_dir):
os.makedirs(trainer_dir)

return trainer_dir


def _lru_delete(dirname, max_num_checkpoints=3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this function does not do implement a real LRU algorithms, scroll_delete would be better.

dirs = os.listdir(dirname)
serials = []
serial_map = {}
for serial in dirs:
try:
serials.append(int(serial))
except ValueError:
continue
serial_num = _get_dir_serial(serial)
serial_map[serial_num] = serial

if len(serials) <= max_num_checkpoints:
if len(serial_map.keys()) <= max_num_checkpoints:
return

serials = serial_map.keys()
serials.sort(reverse=True)
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
cur_dir = _get_serial_dir(dirname, serial)
shutil.rmtree(cur_dir)


Expand All @@ -604,7 +711,7 @@ def _write_success(dirname):
f.write(now)


def _get_lastest_checkpoint_dir(checkpoint_dir):
def get_latest_checkpoint_serial(checkpoint_dir):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

Expand All @@ -617,20 +724,19 @@ def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)

try:
int(serial)
except ValueError:
serial = _get_dir_serial(cur_dir)
if serial == -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please merge the two condition statement.

return -1

if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1

success_path = os.path.join(
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
_get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return int(serial)
return serial

if not os.path.isdir(checkpoint_dir):
return -1
Expand Down
Loading