diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 3159026e6b923..8ef5e9d0c116d 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -166,11 +166,21 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config, outArgStream_ = HPPL_STREAM_1; + start(); +} + +void MultiGradientMachine::start() { for (auto& thread : threads_) { thread->start(); } } +void MultiGradientMachine::finish() { + for (auto& thread : threads_) { + thread->stop(); + } +} + std::vector*> MultiGradientMachine::getSlaveParameters() { std::vector*> vec; @@ -326,12 +336,6 @@ void MultiGradientMachine::onPassEnd() { } } -void MultiGradientMachine::finish() { - for (auto& thread : threads_) { - thread->stop(); - } -} - Evaluator* MultiGradientMachine::makeEvaluator() const { return threads_[0]->getGradientMachine()->makeEvaluator(); } @@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config, gradStream_ = HPPL_STREAM_2; valueStream_ = HPPL_STREAM_3; - stopping_ = false; + stopping_ = true; updateCounter_ = 0; parameterUpdated_ = false; } @@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config, TrainerThread::~TrainerThread() { stop(); } void TrainerThread::start() { + if (!stopping_) return; + + stopping_ = false; + gradientMachine_->start(); computeThread_.reset(new std::thread([this]() { computeThread(); })); diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.h b/paddle/gserver/gradientmachines/MultiGradientMachine.h index 70203bbb97fe7..5e7622f929fd5 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.h +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.h @@ -176,6 +176,10 @@ class MultiGradientMachine : public GradientMachine { explicit MultiGradientMachine(const ModelConfig& config, bool useGpu); + virtual void start(); + + virtual void finish(); + virtual void prefetch(const std::vector& inArgs); virtual void forward(const std::vector& inArgs, @@ -193,8 +197,6 @@ class MultiGradientMachine : public GradientMachine { virtual void onPassEnd(); - virtual void finish(); - virtual Evaluator* makeEvaluator() const; virtual void eval(Evaluator* evaluator) const;