Skip to content

Commit

Permalink
fix code style to pass CI.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Aug 31, 2017
1 parent 3d1b871 commit 36f0aa7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
11 changes: 7 additions & 4 deletions paddle/gserver/layers/CrossEntropyOverBeam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ void CostForOneSequence::calValidExpandStep() {
start,
start + goldRowIds_[i - 1] * beamSize_ + goldColIds_[i - 1],
[](const real& val) { return val != -1.; });
} else
} else {
goldRowIds_[i] = 0;
}

real* start =
beams_->candidateIds[i]->getData() + goldRowIds_[i] * beamSize_;
Expand Down Expand Up @@ -288,7 +289,7 @@ void CrossEntropyOverBeam::copyInputsToCpu() {

void CrossEntropyOverBeam::splitBatchBeams() {
beamCosts_.resize(batchSize_);
beamPerSeq_.resize(batchSize_, beamExpanCount_);
beamPerSeq_.resize(batchSize_, BeamExpansion(beamExpanCount_));

for (size_t i = 0; i < beamExpanCount_; ++i) {
int* seqStarts =
Expand All @@ -300,8 +301,9 @@ void CrossEntropyOverBeam::splitBatchBeams() {
subSeqStarts =
getInput(i * 3).subSequenceStartPositions->getMutableData(false);
maxLen = getInput(i * 3).subSequenceStartPositions->getSize() - 1;
} else
} else {
maxLen = getInput(i).sequenceStartPositions->getSize() - 1;
}

for (size_t j = 0; j < batchSize_; ++j) {
beamPerSeq_[j].scores[i] =
Expand Down Expand Up @@ -348,8 +350,9 @@ void CrossEntropyOverBeam::resizeOutput() {
inGrad->getWidth(),
false,
false);
} else
} else {
candidateScoreGrad_[i] = std::move(inGrad);
}
candidateScoreGrad_[i]->zeroMem();
}
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/gserver/layers/CrossEntropyOverBeam.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ struct BeamExpansion {

size_t expansionCount;

BeamExpansion(int n) {
explicit BeamExpansion(int n) {
expansionCount = n;
scores.resize(expansionCount);
seqInfo.resize(expansionCount);
candidateIds.resize(expansionCount);
scoreGrad.resize(expansionCount);

gold.resize(expansionCount);
};
}
};
typedef std::shared_ptr<BeamExpansion> BeamExpansionPtr;

Expand Down Expand Up @@ -74,7 +74,7 @@ class CostForOneSequence {
CHECK_GT(beams_->seqInfo[beamId]->getSize() - 1, rowId);
int* starts = beams_->seqInfo[beamId]->getData();
return starts[rowId] - starts[0];
};
}

size_t beamSize_;
size_t validExpansionCount_;
Expand Down

0 comments on commit 36f0aa7

Please sign in to comment.