Skip to content

Commit

Permalink
Rewrite code according to reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Liang Zhao committed Feb 22, 2017
1 parent e768721 commit d256512
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 28 deletions.
34 changes: 10 additions & 24 deletions paddle/gserver/evaluators/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,36 +102,22 @@ class ClassificationErrorEvaluator : public Evaluator {
maxValues_, height, width, false, useGpu(arguments[0].deviceId));
output->rowMax(*maxIds_, *maxValues_); // top-k values

int* ids = nullptr;
int* lbl = nullptr;
IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
IVectorPtr dest2 = IVector::create(label->getSize(), false);
if (useGpu(arguments[0].deviceId)) {
hl_memcpy_device2host((void*)dest->getData(),
(void*)maxIds_->getData(),
sizeof(int) * maxIds_->getSize());
ids = dest->getData();

hl_memcpy_device2host((void*)dest2->getData(),
(void*)label->getData(),
sizeof(int) * label->getSize());
lbl = dest2->getData();
} else {
ids = maxIds_->getData();
lbl = label->getData();
}
dest->copyFrom(*maxIds_);
dest2->copyFrom(*label);
int* ids = dest->getData();
int* lbl = dest2->getData();

real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) {
result2[i] = (ids[i * width] != lbl[i]); // initialize top-k error
for (size_t j = 1; j < width; ++j) {
if (result2[i] == 0.0) {
break;
}
result2[i] = (ids[i * width + j] != lbl[i]); // top-k error
bool contain = false;
for (size_t j = 0; j < width && !contain; ++j) {
contain = (ids[i * width + j] == lbl[i]);
}
if (!contain) {
totalScore2_ += 1.0; // update top-k error
}
}
totalScore2_ += errorMat2->getSum();
}
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
Expand Down
4 changes: 0 additions & 4 deletions paddle/parameter/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary);
if (!fs) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
/*if (isStatic()) {
LOG(FATAL) << getName() << " is static but missing, not allowed.";
return false;
}*/
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed.";
return false;
Expand Down

0 comments on commit d256512

Please sign in to comment.