Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto pruning #2603

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
8317882
add new auto pruning module
NHZlX Jun 26, 2017
9a8d498
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jun 26, 2017
6dbc941
the log function in preprocess must plus 2, or it will be a bug
NHZlX Jun 30, 2017
07cbc9e
fixed some bug of auto pruning
NHZlX Jul 1, 2017
7526315
delete fault file
NHZlX Jul 1, 2017
44c4060
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 1, 2017
54a5577
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 1, 2017
e681196
there is the conflict with the momentum and the mask
NHZlX Jul 3, 2017
b960be8
auto pruning modify bug
NHZlX Jul 3, 2017
6849985
fix bug in auto pruning, before save the model, multiple the paramete…
NHZlX Jul 4, 2017
1e44249
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 4, 2017
cbd5afb
Merge branch 'auto_pruning' of https://github.com/NHZlX/Paddle into a…
NHZlX Jul 4, 2017
d875def
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 4, 2017
3813928
add the handleBeforeFetch which the parameter operate with the hook b…
NHZlX Jul 7, 2017
a4862cd
dynamic pruning prameter config
NHZlX Jul 7, 2017
3251527
refactor the dynamic pruning and fixed some bug
NHZlX Jul 7, 2017
12cf82f
add dynamic pruning interface of the python
NHZlX Jul 7, 2017
3747d19
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 7, 2017
6636a21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 7, 2017
9d98dd1
delete the handleBeforeFetch function
NHZlX Jul 17, 2017
1368682
set the updateHook after the updateImpl
NHZlX Jul 17, 2017
59c8c43
modity the parameter config in PrameterAttribute
NHZlX Jul 17, 2017
34b7b90
modity the related interface in python of pruning
NHZlX Jul 17, 2017
d6b04ec
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 17, 2017
e184a41
delete the explicit keywords of PrameterPruningHook constructors
NHZlX Jul 17, 2017
ab3d10b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 18, 2017
e8e5a67
Update ParameterUpdaterBase.h
NHZlX Jul 20, 2017
8b74b72
Update trainer.py
NHZlX Jul 20, 2017
8966db2
fix the format
NHZlX Jul 20, 2017
b896b32
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Jul 20, 2017
3a24d7e
Merge branch 'auto_pruning' of https://github.com/NHZlX/Paddle into a…
NHZlX Jul 20, 2017
5c822fe
fix format
NHZlX Jul 20, 2017
ecf25b1
tiny modify
NHZlX Jul 30, 2017
6e5d805
Update ParameterUpdaterHook.cpp
NHZlX Aug 2, 2017
44aec34
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Sep 21, 2017
322fcbe
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Sep 21, 2017
cc139b2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Sep 22, 2017
a5407fa
modify code style
NHZlX Sep 22, 2017
bd6749a
fix error in cpu
NHZlX Oct 25, 2017
bb3038d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NHZlX Oct 25, 2017
7c9d5e5
add comments of preprocess func
NHZlX Oct 25, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/api/PaddleAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ class Parameter {

ParameterConfig* getConfig();
void setValueUpdated();

void handleBeforeSave();
bool save(const std::string& filename) const;

bool load(const std::string& filename) const;
Expand Down Expand Up @@ -880,6 +880,7 @@ class ParameterUpdater {
* @param param
*/
void update(Parameter* param);
void preprocess(Parameter* param, size_t currentPass, size_t currentBatch);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.h文件中新加的函数都请添加注释,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数名,我觉得还是叫passIdbatchId比较好。


/**
* @breif only get required sparse rows by default.
Expand Down
4 changes: 4 additions & 0 deletions paddle/api/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ bool Parameter::save(const std::string& filename) const {
return m->getPtr()->save(filename);
}

void Parameter::handleBeforeSave() {
return m->getPtr()->handleBeforeSave();
}

bool Parameter::load(const std::string& filename) const {
return m->getPtr()->load(filename);
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/api/ParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ void ParameterUpdater::update(Parameter *param) {
m->updater->update(paddleParam);
}

void ParameterUpdater::preprocess(Parameter *param, size_t currentPass, size_t currentBatch) {
auto paddleParam = param->m->getPtr();
m->updater->preprocess(paddleParam, currentPass, currentBatch);
}

void ParameterUpdater::getParametersRemote(bool fullSize, bool apply) {
m->updater->getParametersRemote(fullSize, apply);
}
Expand Down
12 changes: 12 additions & 0 deletions paddle/parameter/Parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,18 @@ class Parameter {
* It could modify gradient/momentum/etc here. Such as drop some gradient,
* etc.
*/
void handleBeforeSave(){
for (auto& hook : updaterHooks_) {
hook->handleBeforeSave(this);
}
}

void preProcessHook(size_t currentPass, size_t currentBatch) {
for (auto& hook : updaterHooks_) {
hook->preprocess(this, currentPass, currentBatch);
}
}

void updateHook() {
for (auto& hook : updaterHooks_) {
hook->update(this);
Expand Down
4 changes: 4 additions & 0 deletions paddle/parameter/ParameterUpdaterBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class ParameterUpdater {
this->updateImpl(para);
}

void preprocess(Parameter* para, size_t currentPass, size_t currentBatch) {
SetDevice setDevice(para->getDeviceId());
para->preProcessHook(currentPass, currentBatch);
}
// only get required sparse rows by default,
// get full matrix parameter if *fullSize* set
// get PARAMETER_APPLY on pserver if *apply* set
Expand Down
157 changes: 120 additions & 37 deletions paddle/parameter/ParameterUpdaterHook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>
#include <fstream>
#include <iostream>
#include <mutex>
#include <thread>
#include <unordered_map>
Expand All @@ -29,42 +30,22 @@ limitations under the License. */

namespace paddle {

/**
* The static pruning hook
* Static means user specify a sparsity_ratio before training started, and the
* network will prune the parameters based on the sparsity_ratio. More details
* can be found https://arxiv.org/pdf/1506.02626.pdf.
*/

class StaticPruningHook : public IParameterUpdaterHook {
class ParameterPruningHook : public IParameterUpdaterHook {
public:
explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig)
: initCount_(0) {
sparsityRatio_ = hookConfig.sparsity_ratio();
}
explicit ParameterPruningHook() : initCount_(0) {}

static bool sortPairAscend(const std::pair<real, size_t> &pair1,
const std::pair<real, size_t> &pair2) {
return pair1.first > pair2.first;
}

void update(Parameter *para) {
updateThreadChecker_.check();
auto &vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) {
vec->dotMul(*maskVec_);
}
}
virtual void update(Parameter *para) {/*do nothing*/}
virtual void handleBeforeSave(Parameter *para) {/*do nothing*/}
virtual void preprocess(Parameter *para, size_t currentPass, size_t currentBatch) {}

void generateMask(Parameter *para) {
virtual void generateMask(Parameter *para, size_t nonZeroNum) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议第二个参数传real sparsityRatio

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯, 好的

VectorPtr maskTemp = Vector::create(para->getSize(), false);
maskTemp->zeroMem();
real *maskTempData = maskTemp->getData();
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);

VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);

paraCpuCopy->copyFrom(*paraVec);
std::vector<std::pair<real, size_t>> param;

Expand All @@ -73,38 +54,138 @@ class StaticPruningHook : public IParameterUpdaterHook {

std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);

for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;

// Currently just use a mask vector for hack.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp);
this-> maskVec_ = Vector::create(para->getSize(), para->useGpu());
this-> maskVec_->copyFrom(*maskTemp);
} else {
maskVec_ = maskTemp;
this-> maskVec_ = maskTemp;
}
}

void init(Parameter *para) {
generateMask(para);
static bool sortPairAscend(const std::pair<real, size_t> &pair1,
const std::pair<real, size_t> &pair2) {
return pair1.first > pair2.first;
}


protected:
std::atomic<size_t> initCount_;
SameThreadChecker updateThreadChecker_;
VectorPtr maskVec_;
};

/**
* The static pruning hook
* Static means user specify a sparsity_ratio before training started, and the
* network will prune the parameters based on the sparsity_ratio. More details
* can be found https://arxiv.org/pdf/1506.02626.pdf.
*/

class StaticPruningHook : public ParameterPruningHook {
public:
explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig)
: ParameterPruningHook() {
this->sparsityRatio_ = hookConfig.sparsity_ratio();
}

void update(Parameter *para) override{
updateThreadChecker_.check();
auto &vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) {
vec->dotMul(*maskVec_);
}
}

void init(Parameter *para) override {
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
"in same ParamterUpdater";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invoke -> be invoked
same -> the same

VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());

size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
this->generateMask(para, nonZeroNum);

auto &paraVec = para->getBuf(PARAMETER_VALUE);
paraVec->dotMul(*maskVec_);
paraVec->dotMul(*this->maskVec_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只是有个疑问,*this->maskVec_是等价于*(this->maskVec_)的?为何要加this呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this 可以去掉...

}

private:
SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
real sparsityRatio_;
};

IParameterUpdaterHook::IParameterUpdaterHook() {}
class DynamicPruningHook : public ParameterPruningHook {
public:
explicit DynamicPruningHook(const ParameterUpdaterHookConfig &hookConfig)
: ParameterPruningHook() {
this->upperBound_ = hookConfig.upper_bound();
this->interPass_ = hookConfig.inter_pass();
this->endPass_ = hookConfig.end_pass();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个参数命名都很不直观,需要去猜测所代表的意思,建议:

  1. upper_bound -> sparsity_upper_bound
  2. inter_pass中的inter应该代表的是interval吧,缩写成inter意思就不直观了。是否可以改成sparsity_increasing_interval,或者其他
  3. 我理解endPass的目的是用来计算sparsityRatio每次的增量。但是在这里设置endPass这个参数不太合适:
  • 用户在train的时候会设置一次num_passes,这里又设置一次会很繁琐。
  • 我理解也不能直接用外围设置的num_passes,因为用户会习惯将num_passes设置成一个很大的值,然后等收敛了再将作业kill掉。
    所以增量的设置,再斟酌一下。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upper_bound 改成了 sparsity_upper_bound inter_pass 改成 interval_pass ,属性名字太长的话用户体验不是太好,end_pass 的话我认为是有必要有的,num_passes 是整个训练过程经过的pass, 而end_pass 是我们sparsity_ratio 变化期间经过的pass

}

void init(Parameter *para) override {
// init mask
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
"in same ParamterUpdater";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

VLOG(3) << "Initialize Parameter " << para;
this->maskVec_ = Vector::create(para->getSize(), para->useGpu());
this->maskVec_->reset(1.0);

/*
real *data = this->maskVec_->getData();
for (size_t i = 0; i < para->getSize(); i++){
std::cout << data[i] << " " ;
}
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

138-143行注释的代码请删掉。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

void handleBeforeSave(Parameter *para) override{
updateThreadChecker_.check();
auto &vec = para->getBuf(PARAMETER_VALUE);
if (vec) {
vec->dotMul(*maskVec_);
}
}

void preprocess(Parameter *para, size_t currentPass, size_t currentBatch) override {
if (currentPass % interPass_ == 0 && currentPass <= endPass_ && currentBatch == 0) {
real boundWeight =
this->upperBound_ / std::log(this->endPass_ / (real)this->interPass_);
real sparsityRatio =
boundWeight * std::log(2 + currentPass / (real)interPass_);

size_t nonZeroNum = para->getSize() * (1 - sparsityRatio);
this->generateMask(para, nonZeroNum);
std::cout << para->getName() << " Current sparsity ratio: " <<
sparsityRatio <<" " << nonZeroNum<<std::endl;
}
//add the the temp
auto &paraVec = para->getBuf(PARAMETER_VALUE);
paraVec->dotMul(*this->maskVec_);
/*
VectorPtr paraCopyCpu = Vector::create(para->getSize(), false);
paraCopyCpu->copyFrom(*paraVec);
real *data = paraCopyCpu->getData();
size_t sum_non = 0;
for(size_t i = 0; i < para->getSize(); i++){
if(data[i] != 0.0)
sum_non += 1;
}
std::cout<<"sum_non: " <<sum_non << " " << para->getSize()<< std::endl;
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

169-178注释的代码请删掉。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

private:
real upperBound_;
size_t interPass_;
size_t endPass_;
};

IParameterUpdaterHook::IParameterUpdaterHook() {}
IParameterUpdaterHook::~IParameterUpdaterHook() {}

/**
Expand Down Expand Up @@ -139,6 +220,8 @@ static IParameterUpdaterHook *createImpl(
auto &type = config.type();
if (type == "pruning") {
return new StaticPruningHook(config);
} else if (type == "dpruning") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dpruning需要写全称dynamic_pruning,以便用户更好理解呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return new DynamicPruningHook(config);
}

LOG(FATAL) << "Unknown Hook type: " << type;
Expand Down
2 changes: 2 additions & 0 deletions paddle/parameter/ParameterUpdaterHook.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class IParameterUpdaterHook {
* The init hook method. Invoke in ParameterUpdater::init
*/
virtual void init(Parameter* para) = 0;
virtual void preprocess(Parameter* para, size_t currentPass, size_t currentBatch) = 0;
virtual void handleBeforeSave(Parameter* para) = 0;

protected:
/**
Expand Down
3 changes: 3 additions & 0 deletions proto/ParameterConfig.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ message ParameterUpdaterHookConfig {
required string type = 1;
// this represents the ratio of zero element to be set by the Parameter
optional double sparsity_ratio = 2 [default = 0.6];
optional double upper_bound = 3 [default = 0.8];
optional int32 inter_pass = 4 [default = 1];
optional int32 end_pass = 5 [default = 20];
}

message ParameterConfig {
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/trainer/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3145,6 +3145,10 @@ def ParameterHook(type, **kwargs):
if sparsity_ratio is not None:
hook.sparsity_ratio = sparsity_ratio
return hook
elif type == 'dpruning':
hook = ParameterUpdaterHookConfig()
hook.type = type
return hook
else:
return None

Expand Down
1 change: 1 addition & 0 deletions python/paddle/v2/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __getitem__(self, key):
for each_gradient_machine in self.__gradient_machines__:
param = __get_parameter_in_gradient_machine__(
each_gradient_machine, key)
param.handleBeforeSave()
# for simplify implementation now, we always copy from C++
assert isinstance(param, api.Parameter)
val = param.getBuf(api.PARAMETER_VALUE)
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def train(self, reader, num_passes=1, event_handler=None, feeding=None):
out_args = api.Arguments.createArguments(0)
feeder = DataFeeder(self.__data_types__, feeding)
for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id))
pass_evaluator.start()
self.__parameter_updater__.startPass()
pass_evaluator.start()
event_handler(v2_event.BeginPass(pass_id))
for batch_id, data_batch in enumerate(reader()):
batch_evaluator.start()
event_handler(
Expand All @@ -152,6 +152,9 @@ def train(self, reader, num_passes=1, event_handler=None, feeding=None):
len(data_batch))
in_args = feeder(data_batch)
self.__prepare_parameter__(in_args)
for each_param in self.__gradient_machine__.getNonStaticParameters(
):
self.__parameter_updater__.preprocess(each_param, pass_id, batch_id)
self.__gradient_machine__.forwardBackward(in_args, out_args,
pass_type)
self.__gradient_machine__.eval(pass_evaluator)
Expand Down