Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove not used params in GradientMachine::start #969

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions paddle/gserver/gradientmachines/GradientMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,7 @@ class GradientMachine {
* @note This function will only been implemented and used in a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议解释一下start的作用

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start之后,GradientMachine才可以做forward和backward。
对于某些GradientMachine来说,start过程会启动线程,而finish的时候会退出线程

* multithreaded environment.
*/
virtual void start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
}
virtual void start() {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个参数之前有用过么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个参数之前有用过。在最开始Paddle支持很多GradientMachine,有一部分GradientMachine不会关注forward的时候的数据,而会在start的时候,直接获得DataProvider,进而获得数据。

目前大部分GradientMachine都已经统一成一个MultiGradientMachine了,并且数据的输入和输出也全在forward/backward进行了,这个start的参数就用不上了。


/**
* @brief check each work-thread whether is failed/error/finish,
Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/gradientmachines/MultiGradientMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread::~TrainerThread() { stop(); }

void TrainerThread::start() {
gradientMachine_->start(*(TrainerConfig*)nullptr, (DataProviderPtr) nullptr);
gradientMachine_->start();

computeThread_.reset(new std::thread([this]() { computeThread(); }));

Expand Down
5 changes: 2 additions & 3 deletions paddle/gserver/gradientmachines/MultiNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() {
}
}

void MultiNetwork::start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
void MultiNetwork::start() {
for (auto& subNetwork : subNetworks_) {
subNetwork->start(config, dataProvider);
subNetwork->start();
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/gradientmachines/MultiNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MultiNetwork : public NeuralNetwork {
return subNetworks_;
}

virtual void start(const TrainerConfig& config, DataProviderPtr dataProvider);
virtual void start();

virtual void finish();

Expand Down
6 changes: 1 addition & 5 deletions paddle/gserver/gradientmachines/ParallelNeuralNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs,
backward(callback);
}

void ParallelNeuralNetwork::start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;

void ParallelNeuralNetwork::start() {
for (auto& thread : threads_) {
thread->start();
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/gradientmachines/ParallelNeuralNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ParallelNeuralNetwork : public NeuralNetwork {
PassType passType,
const UpdateCallback &callback = NULL);

virtual void start(const TrainerConfig &config, DataProviderPtr dataProvider);
virtual void start();

void addComputeThread(int deviceId);

Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/tests/test_NetworkCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) {
parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]);
}
}
gradientMachine->start(trainer.getConfig(), nullptr);
gradientMachine->start();
gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN);
for (size_t i = 0; i < in.outGrads.size(); i++) {
// If the all the layers in the config have no parameters, also
Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/tests/test_RecurrentGradientMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer {
public:
void startTrain() {
GradientMachine& gm = *this->trainerInternal_.getGradientMachine();
gm.start(this->getConfig(), dataProvider_);
gm.start();
}

void finishTrain() {
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/Tester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ void Tester::test() {
CHECK(testDataProvider_) << "TestData is not specified";
testDataProvider_->setSkipShuffle();
testDataProvider_->reset();
gradientMachine_->start(*config_, testDataProvider_);
gradientMachine_->start();

// For evaluation
std::vector<std::string> modelList;
Expand Down
4 changes: 2 additions & 2 deletions paddle/trainer/Trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) {
}

real Trainer::checkGradient() {
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
trainerInternal_.getGradientMachine()->start();
std::vector<ParameterPtr>& parameters =
trainerInternal_.getGradientMachine()->getNonStaticParameters();
DataBatch dataBatch;
Expand Down Expand Up @@ -390,7 +390,7 @@ void Trainer::startTrain() {
dataProvider_->reset();
}

trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
trainerInternal_.getGradientMachine()->start();
}

void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); }
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/tests/test_Compare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) {
trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch);
CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr);
trainer.getGradientMachine()->start();
for (int i = 0; i < 2; ++i) {
trainer.getGradientMachine()->forwardBackward(
inArgs, &Data.outArgs, PASS_TRAIN);
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/tests/test_CompareTwoNets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) {
CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams();

trainer.getGradientMachine()->start(trainer.getConfig(), nullptr);
trainer.getGradientMachine()->start();
trainer.getGradientMachine()->forwardBackward(
inArgs, &data.outArgs, PASS_TRAIN);

Expand Down