-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto pruning #2603
Auto pruning #2603
Changes from 5 commits
8317882
9a8d498
6dbc941
07cbc9e
7526315
44c4060
54a5577
e681196
b960be8
6849985
1e44249
cbd5afb
d875def
3813928
a4862cd
3251527
12cf82f
3747d19
6636a21
9d98dd1
1368682
59c8c43
34b7b90
d6b04ec
e184a41
ab3d10b
e8e5a67
8b74b72
8966db2
b896b32
3a24d7e
5c822fe
ecf25b1
6e5d805
44aec34
322fcbe
cc139b2
a5407fa
bd6749a
bb3038d
7c9d5e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -573,7 +573,7 @@ class Parameter { | |
|
||
ParameterConfig* getConfig(); | ||
void setValueUpdated(); | ||
|
||
void handleBeforeSave(); | ||
bool save(const std::string& filename) const; | ||
|
||
bool load(const std::string& filename) const; | ||
|
@@ -880,6 +880,7 @@ class ParameterUpdater { | |
* @param param | ||
*/ | ||
void update(Parameter* param); | ||
void preprocess(Parameter* param, size_t currentPass, size_t currentBatch); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参数名,我觉得还是叫 |
||
|
||
/** | ||
* @breif only get required sparse rows by default. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ limitations under the License. */ | |
#include <algorithm> | ||
#include <atomic> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <mutex> | ||
#include <thread> | ||
#include <unordered_map> | ||
|
@@ -29,42 +30,22 @@ limitations under the License. */ | |
|
||
namespace paddle { | ||
|
||
/** | ||
* The static pruning hook | ||
* Static means user specify a sparsity_ratio before training started, and the | ||
* network will prune the parameters based on the sparsity_ratio. More details | ||
* can be found https://arxiv.org/pdf/1506.02626.pdf. | ||
*/ | ||
|
||
class StaticPruningHook : public IParameterUpdaterHook { | ||
class ParameterPruningHook : public IParameterUpdaterHook { | ||
public: | ||
explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig) | ||
: initCount_(0) { | ||
sparsityRatio_ = hookConfig.sparsity_ratio(); | ||
} | ||
explicit ParameterPruningHook() : initCount_(0) {} | ||
|
||
static bool sortPairAscend(const std::pair<real, size_t> &pair1, | ||
const std::pair<real, size_t> &pair2) { | ||
return pair1.first > pair2.first; | ||
} | ||
|
||
void update(Parameter *para) { | ||
updateThreadChecker_.check(); | ||
auto &vec = para->getBuf(PARAMETER_GRADIENT); | ||
if (vec) { | ||
vec->dotMul(*maskVec_); | ||
} | ||
} | ||
virtual void update(Parameter *para) {/*do nothing*/} | ||
virtual void handleBeforeSave(Parameter *para) {/*do nothing*/} | ||
virtual void preprocess(Parameter *para, size_t currentPass, size_t currentBatch) {} | ||
|
||
void generateMask(Parameter *para) { | ||
virtual void generateMask(Parameter *para, size_t nonZeroNum) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议第二个参数传 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 嗯, 好的 |
||
VectorPtr maskTemp = Vector::create(para->getSize(), false); | ||
maskTemp->zeroMem(); | ||
real *maskTempData = maskTemp->getData(); | ||
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); | ||
|
||
VectorPtr paraVec = para->getBuf(PARAMETER_VALUE); | ||
VectorPtr paraCpuCopy = Vector::create(para->getSize(), false); | ||
|
||
paraCpuCopy->copyFrom(*paraVec); | ||
std::vector<std::pair<real, size_t>> param; | ||
|
||
|
@@ -73,38 +54,138 @@ class StaticPruningHook : public IParameterUpdaterHook { | |
|
||
std::partial_sort( | ||
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); | ||
|
||
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0; | ||
|
||
// Currently just use a mask vector for hack. | ||
if (para->useGpu()) { | ||
maskVec_ = Vector::create(para->getSize(), para->useGpu()); | ||
maskVec_->copyFrom(*maskTemp); | ||
this-> maskVec_ = Vector::create(para->getSize(), para->useGpu()); | ||
this-> maskVec_->copyFrom(*maskTemp); | ||
} else { | ||
maskVec_ = maskTemp; | ||
this-> maskVec_ = maskTemp; | ||
} | ||
} | ||
|
||
void init(Parameter *para) { | ||
generateMask(para); | ||
static bool sortPairAscend(const std::pair<real, size_t> &pair1, | ||
const std::pair<real, size_t> &pair2) { | ||
return pair1.first > pair2.first; | ||
} | ||
|
||
|
||
protected: | ||
std::atomic<size_t> initCount_; | ||
SameThreadChecker updateThreadChecker_; | ||
VectorPtr maskVec_; | ||
}; | ||
|
||
/** | ||
* The static pruning hook | ||
* Static means user specify a sparsity_ratio before training started, and the | ||
* network will prune the parameters based on the sparsity_ratio. More details | ||
* can be found https://arxiv.org/pdf/1506.02626.pdf. | ||
*/ | ||
|
||
class StaticPruningHook : public ParameterPruningHook { | ||
public: | ||
explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig) | ||
: ParameterPruningHook() { | ||
this->sparsityRatio_ = hookConfig.sparsity_ratio(); | ||
} | ||
|
||
void update(Parameter *para) override{ | ||
updateThreadChecker_.check(); | ||
auto &vec = para->getBuf(PARAMETER_GRADIENT); | ||
if (vec) { | ||
vec->dotMul(*maskVec_); | ||
} | ||
} | ||
|
||
void init(Parameter *para) override { | ||
size_t initCount = this->initCount_.fetch_add(1); | ||
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " | ||
"in same ParamterUpdater"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
VLOG(3) << "Initialize Parameter " << para; | ||
SetDevice device(para->getDeviceId()); | ||
|
||
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); | ||
this->generateMask(para, nonZeroNum); | ||
|
||
auto ¶Vec = para->getBuf(PARAMETER_VALUE); | ||
paraVec->dotMul(*maskVec_); | ||
paraVec->dotMul(*this->maskVec_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里只是有个疑问, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this 可以去掉... |
||
} | ||
|
||
private: | ||
SameThreadChecker updateThreadChecker_; | ||
std::atomic<size_t> initCount_; | ||
VectorPtr maskVec_; | ||
real sparsityRatio_; | ||
}; | ||
|
||
IParameterUpdaterHook::IParameterUpdaterHook() {} | ||
class DynamicPruningHook : public ParameterPruningHook { | ||
public: | ||
explicit DynamicPruningHook(const ParameterUpdaterHookConfig &hookConfig) | ||
: ParameterPruningHook() { | ||
this->upperBound_ = hookConfig.upper_bound(); | ||
this->interPass_ = hookConfig.inter_pass(); | ||
this->endPass_ = hookConfig.end_pass(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这几个参数命名都很不直观,需要去猜测所代表的意思,建议:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. upper_bound 改成了 sparsity_upper_bound inter_pass 改成 interval_pass ,属性名字太长的话用户体验不是太好,end_pass 的话我认为是有必要有的,num_passes 是整个训练过程经过的pass, 而end_pass 是我们sparsity_ratio 变化期间经过的pass |
||
} | ||
|
||
void init(Parameter *para) override { | ||
// init mask | ||
size_t initCount = this->initCount_.fetch_add(1); | ||
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " | ||
"in same ParamterUpdater"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
VLOG(3) << "Initialize Parameter " << para; | ||
this->maskVec_ = Vector::create(para->getSize(), para->useGpu()); | ||
this->maskVec_->reset(1.0); | ||
|
||
/* | ||
real *data = this->maskVec_->getData(); | ||
for (size_t i = 0; i < para->getSize(); i++){ | ||
std::cout << data[i] << " " ; | ||
} | ||
*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 138-143行注释的代码请删掉。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
|
||
void handleBeforeSave(Parameter *para) override{ | ||
updateThreadChecker_.check(); | ||
auto &vec = para->getBuf(PARAMETER_VALUE); | ||
if (vec) { | ||
vec->dotMul(*maskVec_); | ||
} | ||
} | ||
|
||
void preprocess(Parameter *para, size_t currentPass, size_t currentBatch) override { | ||
if (currentPass % interPass_ == 0 && currentPass <= endPass_ && currentBatch == 0) { | ||
real boundWeight = | ||
this->upperBound_ / std::log(this->endPass_ / (real)this->interPass_); | ||
real sparsityRatio = | ||
boundWeight * std::log(2 + currentPass / (real)interPass_); | ||
|
||
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio); | ||
this->generateMask(para, nonZeroNum); | ||
std::cout << para->getName() << " Current sparsity ratio: " << | ||
sparsityRatio <<" " << nonZeroNum<<std::endl; | ||
} | ||
//add the the temp | ||
auto ¶Vec = para->getBuf(PARAMETER_VALUE); | ||
paraVec->dotMul(*this->maskVec_); | ||
/* | ||
VectorPtr paraCopyCpu = Vector::create(para->getSize(), false); | ||
paraCopyCpu->copyFrom(*paraVec); | ||
real *data = paraCopyCpu->getData(); | ||
size_t sum_non = 0; | ||
for(size_t i = 0; i < para->getSize(); i++){ | ||
if(data[i] != 0.0) | ||
sum_non += 1; | ||
} | ||
std::cout<<"sum_non: " <<sum_non << " " << para->getSize()<< std::endl; | ||
*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 169-178注释的代码请删掉。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
|
||
private: | ||
real upperBound_; | ||
size_t interPass_; | ||
size_t endPass_; | ||
}; | ||
|
||
IParameterUpdaterHook::IParameterUpdaterHook() {} | ||
IParameterUpdaterHook::~IParameterUpdaterHook() {} | ||
|
||
/** | ||
|
@@ -139,6 +220,8 @@ static IParameterUpdaterHook *createImpl( | |
auto &type = config.type(); | ||
if (type == "pruning") { | ||
return new StaticPruningHook(config); | ||
} else if (type == "dpruning") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dpruning需要写全称dynamic_pruning,以便用户更好理解呢? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
return new DynamicPruningHook(config); | ||
} | ||
|
||
LOG(FATAL) << "Unknown Hook type: " << type; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.h文件中新加的函数都请添加注释,下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done