Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[in progress] Expose API for parameter server #1039

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions demo/quick_start/cluster/pserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from py_paddle import swig_paddle as api

#import pudb;pudb.set_trace()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the unused line other than commenting it out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,in progress



def main():
api.initPaddle("--nics=lo0", "--port=7164", "--ports_num=1",
"--num_gradient_servers=1", "--comment=paddle_pserver")
pserver = api.ParameterServer.createParameterServer()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我没有太看明白这里的API设计。我以为这里是想通过调用Python API启动一个parameter server process?

如果是,那么是不是应该把 L21到L26简化为,比如:

psid = api.pserver.start(nics="lo0", port=7164, ports_num=1, num_gradient_server=1, comment="paddle_pserver")
api.pserver.wait(psid)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对,这个后面正在修改,目前长这个样子是因为历史遗留问题,initPaddle实际上是去初始化各种gflags,后面的版本已经改成proto配置了,见#1051

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow Python code convention and rename module/package name ParameterServer to be pserver or parameter_server.

pserver.init()
pserver.start()
pserver.join()


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion paddle/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ set(API_SOURCES
SequenceGenerator.cpp
Trainer.cpp
Util.cpp
Vector.cpp)
Vector.cpp
ParameterServer.cpp)
set(API_HEADER
PaddleAPI.h
Internal.h)
Expand Down
3 changes: 3 additions & 0 deletions paddle/api/Paddle.swig
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ namespace std {
%newobject ParameterOptimizer::create;
%newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater;
%newobject ParameterUpdater::createRemoteUpdater;
%newobject ParameterServer::createParameterServer;

%feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide
Expand All @@ -196,5 +198,6 @@ namespace std {
%ignore ParameterConfigPrivate;
%ignore OptimizationConfigPrivate;
%ignore ParameterTraverseCallbackPrivate;
%ignore ParameterServerPrivate;
%include "utils/GlobalConstants.h"
%include "api/PaddleAPI.h"
24 changes: 24 additions & 0 deletions paddle/api/PaddleAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,8 @@ class ParameterUpdater {

public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount);
~ParameterUpdater();

/**
Expand Down Expand Up @@ -866,6 +868,28 @@ class ParameterUpdater {
ParameterUpdaterPrivate* m;
};

struct ParameterServerPrivate;
class ParameterServer {
private:
ParameterServer();

public:
static ParameterServer* createParameterServer();

~ParameterServer();

/**
* @brief initialize Parameter Server.
* @param gm
*/
void init();
void start();
void join();

private:
ParameterServerPrivate* m;
};

struct EvaluatorPrivate;
class Evaluator {
private:
Expand Down
5 changes: 5 additions & 0 deletions paddle/api/PaddleAPIPrivate.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/parameter/ParameterUpdaterBase.h"
#include "paddle/pserver/PServerUtil.h"
#include "paddle/trainer/TrainerConfigHelper.h"

struct GradientMachinePrivate {
Expand Down Expand Up @@ -72,6 +73,10 @@ struct ParameterUpdaterPrivate {
std::unique_ptr<paddle::ParameterUpdater> updater;
};

struct ParameterServerPrivate {
std::unique_ptr<paddle::PServerUtil> pServerUtil;
};

struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater,
Expand Down
33 changes: 33 additions & 0 deletions paddle/api/ParameterServer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to source file naming convention and rename this to be parameter_server.cpp.


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "PaddleAPI.h"

#include "PaddleAPIPrivate.h"

ParameterServer::ParameterServer() : m(new ParameterServerPrivate()) {}

ParameterServer* ParameterServer::createParameterServer() {
auto pServer = new ParameterServer();
pServer->m->pServerUtil.reset(new paddle::PServerUtil());
return pServer;
}

ParameterServer::~ParameterServer() { delete m; }

void ParameterServer::init() { m->pServerUtil->init(); }

void ParameterServer::start() { m->pServerUtil->start(); }

void ParameterServer::join() { m->pServerUtil->join(); }
16 changes: 13 additions & 3 deletions paddle/api/ParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,25 @@ limitations under the License. */
#include "PaddleAPI.h"

#include "PaddleAPIPrivate.h"
#include "paddle/trainer/RemoteParameterUpdater.h"
#include "paddle/trainer/ThreadParameterUpdater.h"

ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}

ParameterUpdater *ParameterUpdater::createLocalUpdater(
OptimizationConfig *config) {
auto param = new ParameterUpdater();
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
return param;
auto updater = new ParameterUpdater();
updater->m->updater.reset(
new paddle::SgdThreadUpdater(config->m->getConfig()));
return updater;
}

ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount) {
auto updater = new ParameterUpdater();
updater->m->updater.reset(new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr));
return updater;
}

ParameterUpdater::~ParameterUpdater() { delete m; }
Expand Down
6 changes: 4 additions & 2 deletions paddle/pserver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ set(PSERVER_SOURCES
BaseClient.cpp
ParameterClient2.cpp
ParameterServer2.cpp
SparseParameterDistribution.cpp)
SparseParameterDistribution.cpp
PServerUtil.cpp)

set(PSERVER_HEADERS
BaseClient.h
ParameterClient2.h
ParameterServer2.h
SparseParameterDistribution.h)
SparseParameterDistribution.h
PServerUtil.h)

add_library(paddle_pserver STATIC
${PSERVER_SOURCES})
Expand Down
77 changes: 77 additions & 0 deletions paddle/pserver/PServerUtil.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to source file naming convention and rename this to be pserver_util.cpp.


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "PServerUtil.h"

namespace paddle {

void PServerUtil::init() {
// round robin to load balance RDMA server ENGINE
std::vector<std::string> devices;
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;

if (FLAGS_nics.empty()) {
pservers_.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (FLAGS_rdma_tcp == "rdma") {
pservers_[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers_[i].reset(new ParameterServer2(std::string(), FLAGS_port + i));
}
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server"
<< FLAGS_port + i;
}
} else {
str::split(FLAGS_nics, ',', &devices);
pservers_.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (FLAGS_rdma_tcp == "rdma") {
pservers_[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers_[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
}
CHECK(pservers_[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< FLAGS_port + i;
}
}
}
}

void PServerUtil::start() {
LOG(INFO) << "pserver sizes : " << pservers_.size();
int i = 0;
for (const auto &pserver : pservers_) {
LOG(INFO) << "pserver started : " << i;
pserver->start();
i++;
}
}

void PServerUtil::join() {
LOG(INFO) << "pserver sizes : " << pservers_.size();
for (const auto &pserver : pservers_) {
pserver->join();
}
}

} // namespace paddle
33 changes: 33 additions & 0 deletions paddle/pserver/PServerUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to source file naming convention and rename this to be pserver_util.h.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,代码规范问题,会在分开的两个pr中进行,这个pr先关掉。


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "ParameterServer2.h"
#include "RDMANetwork.h"
#include "paddle/utils/StringUtil.h"

namespace paddle {

class PServerUtil {
public:
void init();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init => 构造函数.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果方便的话,可以考虑把GFLAGS提取出来,变成函数的参数。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这里是这么想的,变成参数

void start();
void join();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DISABLE_COPY(PServerUtil);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

析构的时候调用join.

private:
std::vector<std::shared_ptr<ParameterServer2>> pservers_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看起来,std::vector<std::unique_ptr< ParameterServer2 >> pservers_;
就够了

};

} // namespace paddle
57 changes: 5 additions & 52 deletions paddle/pserver/ParameterServer2Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,63 +16,16 @@ limitations under the License. */
#include "paddle/utils/StringUtil.h"
#include "paddle/utils/Util.h"

#include "ParameterServer2.h"
#include "RDMANetwork.h"
#include "paddle/utils/Flags.h"
#include "PServerUtil.h"

using namespace paddle; // NOLINT

int main(int argc, char** argv) {
initMain(argc, argv);

std::vector<std::string> devices;
std::vector<std::shared_ptr<ParameterServer2>> pservers;

// round robin to loadbalance RDMA server ENGINE
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
if (FLAGS_nics.empty()) {
pservers.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i].reset(new ParameterServer2(std::string(), FLAGS_port + i));
}
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << FLAGS_port + i;
pservers[i]->start();
}
} else {
str::split(FLAGS_nics, ',', &devices);
pservers.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
}
CHECK(pservers[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << devices[j] << ":"
<< FLAGS_port + i;
pservers[i * devices.size() + j]->start();
}
}
}

for (auto& pserver : pservers) {
pserver->join();
}
PServerUtil* pserverUtil = new PServerUtil();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::unique_ptr pservers(new PServerUtil());

把init实现在构造函数里,把join放到析构函数里的。

pservers->start();

pserverUtil->init();
pserverUtil->start();
pserverUtil->join();

return 0;
}
4 changes: 2 additions & 2 deletions paddle/trainer/RemoteParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
public:
RemoteParameterUpdater(
const OptimizationConfig& config,
int expectedPpassCount,
int expectedPassCount,
std::unique_ptr<ParameterUpdater>&& localUpdater = nullptr);
~RemoteParameterUpdater() {
if (controllerThread_) {
Expand Down Expand Up @@ -146,7 +146,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
BatchStatus batchStatus_;
/// controller thread for sync-sgd
std::unique_ptr<std::thread> controllerThread_;
/// passed alread finished
/// passed already finished
int64_t passCount_;
/// expected passes to finished
int64_t expectedPassCount_;
Expand Down
Loading