Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support distribute training in python v2 API #1782

Merged
merged 15 commits into from
Apr 24, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions demo/word2vec/api_train_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def main():
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
trainer.save_parameter("output", "batch-" + str(event.batch_id))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重构之后接口可以不变,实现起来可以考虑save parameter由parameter server来做。(trainer.save_parameter告诉parameter server存parameter.)

result = trainer.test(
paddle.batch(
paddle.dataset.imikolov.test(word_dict, N), 32))
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/v2/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import collections
import gzip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

似乎gzip不需要了?

import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os也不需要了吧?


import py_paddle.swig_paddle as api

Expand Down Expand Up @@ -96,6 +98,18 @@ def __prepare_parameter__(self, in_args):
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()

def save_parameter(self, dir_name, file_name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数不要传dirname和filename,直接传一个fp进来。

这样我们不只可以保存到本地文件,也可以保存到二进制流中。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jacquesqiao 请修复这个问题。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if not os.path.exists(dir_name):
os.makedirs(dir_name)
param_file_name = dir_name + "/" + file_name + '.tar.gz'
assert not os.path.exists(param_file_name)
self.__parameter_updater__.catchUpWith()
Copy link
Contributor

@helinwang helinwang Apr 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能否解释下

self.__parameter_updater__.catchUpWith()
self.__parameter_updater__.apply()
self.__parameter_updater__.getParametersRemote(True, True)
self.__parameter_updater__.restore()

这个序列的对__parameter_updater__的操作是干啥的?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是支持正则化和ModelAverage的操作。。。

正则化是Lazy的计算,而ModelAverage,当前训练用的模型和实际上预测或者保存的模型并不是一个模型。

self.__parameter_updater__.apply()
self.__parameter_updater__.getParametersRemote(True, True)
with gzip.open(param_file_name, 'w') as f:
self.__parameters__.to_tar(f)
self.__parameter_updater__.restore()

def train(self, reader, num_passes=1, event_handler=None, feeding=None):
"""
Training method. Will train num_passes of input data.
Expand Down