Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add checkpoint util class and implement #10532

Merged
merged 59 commits into from
May 23, 2018

Conversation

seiriosPlus
Copy link
Collaborator

Add new feature about #10376

@@ -0,0 +1,214 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see there's load_op and load_combine_op and corresponding saving ops, on python side, you can also use fluid.io.save_persistables to save all persistable variables.

In order to make save_persistables equal to save a checkpoint, make sure that the state variables are all "persistable" like step counters, learning rates, learning_rate moments etc.

So can you reuse those ops instead of writing some new one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_op and save_op are designed for LodTensor variable, But checkpoint will save variables not only LodTensor, and checkpoint has some arguments particular.
At present, checkpoint load/save op and load/save op have no clear-cut distinction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to reuse current operators maybe check the variable type will be fine.
So what the other variable types are saved in the checkpoint? "RAW" types and "feed" "fetch" may not need to be saved.

@@ -187,6 +187,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change may not be nessesary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry about it, I will revert it later.


if checkpoint_dir and self.is_chief:
program.global_block().create_var(
name=SERIAL_VAR_NAME,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a "serial"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serial is a serial number, like 0,1,2...100, each time when paddle needs to save checkpoint, the serial will auto incrementally.
If everything goes well, the biggest serial number will be used in load checkpoint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call it a checkpoint ID maybe good for understanding


save_vars = []
for var in self.origin_program.list_vars():
if self._is_persistable(var):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use fluid.io.save_persistables


serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)

s_prog.global_block().append_op(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do current parameter server know which parameter block to load?

# is_chief (no.0 triner) for checkpoint
# the no.0 trainer will save all variables and its own reader offset to checkpoint
# other trianers will save its own reader offset to checkpoint
self.is_chief = trainer_id == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_is_chief

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix it.

except ValueError:
return -1

success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is success_path used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a tag to indicate that: the checkpoint content is right. So, I define a tag named _SUCCESS, when the checkpoint_save_op save all need variables successfully, it will write an empty file named _SUCCESS at last.
Because of this, when pserver/trainer need to load checkpoint, it will check _SUCCESS at first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you only need to know when the saving ends, just wait for the executor returns, raise any error that may cause the saving fail is OK I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If some exceptions happened in executor running, the executor does not have any information return, how do we know saving is OK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can catch that exception I think, please give it a try, this will make the code more simple.

save_secs=600,
main_program=None):
"""
Save Variables to Checkpint Dir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpint => Checkpoint
Dir => Directory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def save_checkpoint(executor,
dirname,
keep_max=3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep_max => max_num_checkpoints

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def save_checkpoint(executor,
dirname,
keep_max=3,
save_secs=600,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_secs => interval

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


serial = _get_lastest_checkpoint_dir(dirname) + 1
cur_dir = os.path.join(dirname, str(serial))
# save_persistables(executor, cur_dir, main_program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No commented out codes please.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



def save_checkpoint(executor,
dirname,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use current working directory as default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -454,3 +452,149 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program()
var = program.global_block().var(name)
return get_parameter_value(var, executor)


SUCCESS = "_SUCCESS"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUCCESS = SUCCESS_MARK_FILENAME

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if not os.path.isdir(dirname):
os.makedirs(dirname)

global BEGIN_SECS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try not to use global please

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

serial = _get_lastest_checkpoint_dir(dirname) + 1
cur_dir = os.path.join(dirname, str(serial))
# save_persistables(executor, cur_dir, main_program)
save_vars(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why call save_vars instead of save_psersistables?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_psersistables can not filter out gradient variables.

'get_inference_program',
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'restore_checkpoint'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's better to be named as load_checkpoint or restore_from_checkpoint?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will use load_checkpoint , restore_from_checkpoint is too long I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return True


def _lru_delete(dirname, keep_max=3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep_max => max_num_checkpoints

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def _lru_delete(dirname, keep_max=3):
"""
retain checkpoint nums with keep_max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep_max => max_num_checkpoints

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def _write_success(dirname):
"""
write _SUCCESS to checkpoint dir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update _SUCCESS

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_SUCCESS is the file name ,not the var name


def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
get the biggest number in checkpoint_dir, which has _SUCCESS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update _SUCCESS

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_SUCCESS is the file name, not the var name

Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interval between two saved checkpoints must greater than save_interval_secs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def load_checkpoint(executor, dirname=None, main_program=None):
"""
Load checkpoint from directory by executor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directory => one directory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def load_checkpoint(executor, dirname=None, main_program=None):
"""
Load checkpoint from directory by executor,
it will find lastest checkpoint file and load it auto.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

latest => the most recent saved checkpoint

os.path.join(dirname, str(serial)), save_interval_secs):
return

serial = serial + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use +=

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return

serial = serial + 1
cur_dir = os.path.join(dirname, str(serial))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint directories will be saved in "1", "2", etc. which may not make sense to users, I think it's better to be like "checkpoint_1", "checkpoint_2", and the "_SUCCESS" file can save a timestamp when the checkpoint is saved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

typhoonzero
typhoonzero previously approved these changes May 23, 2018
Copy link
Contributor

@typhoonzero typhoonzero left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@typhoonzero typhoonzero left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM again.

@seiriosPlus seiriosPlus merged commit 397a69d into PaddlePaddle:develop May 23, 2018
@seiriosPlus seiriosPlus deleted the checkpoint branch May 23, 2018 11:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants