-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support distribute training in python v2 API #1782
Changes from 1 commit
6802b65
bad503f
64bfd81
8210350
f6c5b6f
ea25eef
6295f2d
cfff946
cf86ca0
9562178
68c1efd
35f1dfd
9e9d456
6a2776e
cb84cba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
import collections | ||
import gzip | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 似乎gzip不需要了? |
||
import os | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. os也不需要了吧? |
||
|
||
import py_paddle.swig_paddle as api | ||
|
||
|
@@ -96,6 +98,18 @@ def __prepare_parameter__(self, in_args): | |
self.__gradient_machine__.prefetch(in_args) | ||
self.__parameter_updater__.getParametersRemote() | ||
|
||
def save_parameter(self, dir_name, file_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参数不要传dirname和filename,直接传一个fp进来。 这样我们不只可以保存到本地文件,也可以保存到二进制流中。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jacquesqiao 请修复这个问题。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if not os.path.exists(dir_name): | ||
os.makedirs(dir_name) | ||
param_file_name = dir_name + "/" + file_name + '.tar.gz' | ||
assert not os.path.exists(param_file_name) | ||
self.__parameter_updater__.catchUpWith() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 能否解释下 self.__parameter_updater__.catchUpWith()
self.__parameter_updater__.apply()
self.__parameter_updater__.getParametersRemote(True, True)
self.__parameter_updater__.restore() 这个序列的对 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是支持正则化和ModelAverage的操作。。。 正则化是Lazy的计算,而ModelAverage,当前训练用的模型和实际上预测或者保存的模型并不是一个模型。 |
||
self.__parameter_updater__.apply() | ||
self.__parameter_updater__.getParametersRemote(True, True) | ||
with gzip.open(param_file_name, 'w') as f: | ||
self.__parameters__.to_tar(f) | ||
self.__parameter_updater__.restore() | ||
|
||
def train(self, reader, num_passes=1, event_handler=None, feeding=None): | ||
""" | ||
Training method. Will train num_passes of input data. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重构之后接口可以不变,实现起来可以考虑save parameter由parameter server来做。(trainer.save_parameter告诉parameter server存parameter.)