Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ParameterServerController for parameter server python api #1051

Merged
merged 11 commits into from
Jan 11, 2017
Merged
1 change: 1 addition & 0 deletions demo/quick_start/cluster/cluster_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ log_file="$bin_dir/train.log"
pushd "$home_dir"
cfg=trainer_config.lr.py
paddle train \
--start_pserver=false \
--config=$cfg \
--save_dir=${model_dir} \
--trainer_count=4 \
Expand Down
6 changes: 4 additions & 2 deletions paddle/pserver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ set(PSERVER_SOURCES
BaseClient.cpp
ParameterClient2.cpp
ParameterServer2.cpp
SparseParameterDistribution.cpp)
SparseParameterDistribution.cpp
ParameterServerController.cpp)

set(PSERVER_HEADERS
BaseClient.h
ParameterClient2.h
ParameterServer2.h
SparseParameterDistribution.h)
SparseParameterDistribution.h
ParameterServerController.h)

add_library(paddle_pserver STATIC
${PSERVER_SOURCES})
Expand Down
59 changes: 5 additions & 54 deletions paddle/pserver/ParameterServer2Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,66 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <fstream>
#include "paddle/utils/StringUtil.h"
#include "paddle/utils/Util.h"

#include "ParameterServer2.h"
#include "RDMANetwork.h"
#include "paddle/utils/Flags.h"
#include "ParameterServerController.h"

using namespace paddle; // NOLINT

int main(int argc, char** argv) {
initMain(argc, argv);

std::vector<std::string> devices;
std::vector<std::shared_ptr<ParameterServer2>> pservers;

// round robin to loadbalance RDMA server ENGINE
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
if (FLAGS_nics.empty()) {
pservers.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i].reset(new ParameterServer2(std::string(), FLAGS_port + i));
}
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << FLAGS_port + i;
pservers[i]->start();
}
} else {
str::split(FLAGS_nics, ',', &devices);
pservers.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
}
CHECK(pservers[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << devices[j] << ":"
<< FLAGS_port + i;
pservers[i * devices.size() + j]->start();
}
}
}

for (auto& pserver : pservers) {
pserver->join();
}
std::unique_ptr<ParameterServerController> pServerPtr(
paddle::ParameterServerController::createByGflags());
pServerPtr->start();
pServerPtr->join();

return 0;
}
101 changes: 101 additions & 0 deletions paddle/pserver/ParameterServerController.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "ParameterServerController.h"

namespace paddle {

ParameterServerController::ParameterServerController(
const ParameterServerConfig& config) {
// round robin to load balance RDMA server ENGINE
std::vector<std::string> devices;
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = config.ports_num() + config.ports_num_for_sparse();

if (config.nics().empty()) {
pservers_.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (config.rdma_tcp() == "rdma") {
pservers_[i].reset(
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers_[i].reset(
new ParameterServer2(std::string(), config.port() + i));
}
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server"
<< config.port() + i;
}
} else {
str::split(config.nics(), ',', &devices);
pservers_.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (config.rdma_tcp() == "rdma") {
pservers_[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers_[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), config.port() + i));
}
CHECK(pservers_[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< config.port() + i;
}
}
}
}

ParameterServerController::~ParameterServerController() { this->join(); }

ParameterServerController* ParameterServerController::createByGflags() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

createFromGflags

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,这个地方之前确实纠结过

ParameterServerConfig config;

config.set_nics(FLAGS_nics);
config.set_rdma_tcp(FLAGS_rdma_tcp);
config.set_port(FLAGS_port);
config.set_ports_num(FLAGS_ports_num);
config.set_ports_num_for_sparse(FLAGS_ports_num_for_sparse);

return create(config);
}

ParameterServerController* ParameterServerController::create(
const ParameterServerConfig& config) {
return new ParameterServerController(config);
}

void ParameterServerController::start() {
LOG(INFO) << "pserver sizes : " << pservers_.size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"pserver sizes" ==> "number of pserver instances"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int i = 0;
for (const auto& pserver : pservers_) {
LOG(INFO) << "pserver started : " << i;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(INFO) << "Staring pserver " << i;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

pserver->start();
i++;
}
}

void ParameterServerController::join() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

join 对应的是 fork, start 对应的是 wait。这里的join 应该改名为 wait 才对。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,好主意

LOG(INFO) << "pserver sizes : " << pservers_.size();
int i = 0;
for (const auto& pserver : pservers_) {
LOG(INFO) << "pserver join : " << i;
pserver->join();
i++;
}
}

} // namespace paddle
66 changes: 66 additions & 0 deletions paddle/pserver/ParameterServerController.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "ParameterServer2.h"
#include "ParameterServerConfig.pb.h"
#include "RDMANetwork.h"
#include "paddle/utils/StringUtil.h"

namespace paddle {

class ParameterServerController final {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的class应该有class comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

public:
DISABLE_COPY(ParameterServerController);

/**
* @brief Ctor, Create a ParameterServerController from ParameterServerConfig.
*/
explicit ParameterServerController(const ParameterServerConfig& config);

/**
* @brief Dtor.
*/
~ParameterServerController();

/**
* @brief create ParameterServerController from gflags, this is used for
* compatibility with the old usage of configuration by gflags.
*/
static ParameterServerController* createByGflags();

/**
* @brief create ParameterServerController with ParameterServerConfig, remove
* gflags from ParameterServer. Init all pservers thread according to the
* config.
*/
static ParameterServerController* create(const ParameterServerConfig& config);

/**
* @brief start all pserver thread in this ParameterServerController.
*/
void start();

/**
* @brief join and wait for all pserver thread in this
* ParameterServerController.
*/
void join();

private:
std::vector<std::unique_ptr<ParameterServer2>> pservers_;
};

} // namespace paddle
56 changes: 5 additions & 51 deletions paddle/trainer/TrainerMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pserver/ParameterServer2.h"
#include "paddle/utils/Common.h"
#include <fenv.h>
#include "paddle/pserver/ParameterServerController.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/StringUtil.h"

#include "ParamUtil.h"
#include "Trainer.h"
#include "paddle/pserver/RDMANetwork.h"

DEFINE_bool(start_pserver, false, "Whether to start pserver");
DECLARE_int32(gpu_id);
Expand All @@ -38,54 +36,10 @@ int main(int argc, char** argv) {
initMain(argc, argv);
initPython(argc, argv);

std::vector<std::unique_ptr<ParameterServer2>> pservers;
std::vector<std::string> devices;

std::unique_ptr<ParameterServerController> pServerPtr(nullptr);
if (FLAGS_start_pserver) {
// round robin to loadbalance RDMA server ENGINE
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
if (FLAGS_nics.empty()) {
pservers.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i));
}

CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << FLAGS_port + i;
pservers[i]->start();
}
} else {
str::split(FLAGS_nics, ',', &devices);
pservers.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
}

CHECK(pservers[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << devices[j] << ":"
<< FLAGS_port + i;
pservers[i * devices.size() + j]->start();
}
}
}
pServerPtr.reset(paddle::ParameterServerController::createByGflags());
pServerPtr->start();
}
Trainer trainer;
auto config = TrainerConfigHelper::createFromFlags();
Expand Down
3 changes: 2 additions & 1 deletion proto/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ set(proto_filenames
ModelConfig.proto
ParameterConfig.proto
ParameterService.proto
TrainerConfig.proto)
TrainerConfig.proto
ParameterServerConfig.proto)

set(PROTO_GEN)
set(PROTO_GEN_PY)
Expand Down
43 changes: 43 additions & 0 deletions proto/ParameterServerConfig.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面有文件叫 PServerUtils.*,这里叫ParameterServer,显然不一致呀。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个配置文件确实是用来配置parameter server的,目前的pserverutil封装了几个parameter server线程,根据config来创建这些线程。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我的意思是到底应该叫 pserver 还是 parameter server 呢?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经按照命名规范修改为ParameterServerController


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";

package paddle;

message ParameterClientConfig {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该有个注释说明这个proto message的用意。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,写的比较简单

required int32 trainer_id = 1;
}

message ParameterServerConfig {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该有个注释说明这个proto message的用意。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// The ports number for parameter send,
// increment based on default port number
required int32 ports_num = 1 [default = 1];
// The ports number for parameter send,
// increment based on default (port + ports_num
required int32 ports_num_for_sparse = 2 [default = 0];
// network device name for pservers
required string nics = 3 [default = "xgbe0,xgbe1"];
required string rdma_tcp = 4 [default = "tcp"];
// Listening port for pserver
required int32 port = 5 [default = 20134];
// number of gradient servers
required int32 num_gradient_servers = 6 [default = 1];
// number of threads for sync op exec
required int32 pserver_num_threads = 7 [default = 1];
// control config_.async_lagged_grad_discard_ratio() min value
required double async_lagged_ratio_min = 8 [default = 1.0];
// if async_lagged_grad_discard_ratio is not set in trainer_config.conf
// use it as defalut value
required double async_lagged_ratio_default = 9 [default = 1.5];
}