Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top-k error #1337

Merged
merged 8 commits into from
Feb 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions paddle/cuda/include/hl_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d,
const int* index,
int numSequence);

/**
* @brief Matrix classification error.
*
* @param[in] A_d input matrix (M x N).
* @param[in] B_d input vector (M x 1).
* @param[out] C_d output vector (M x 1).
* @param[in] dimM matrix height.
* @param[in] dimN matrix width.
*
*/
extern void hl_matrix_classification_error(
real* A_d, int* B_d, real* C_d, int dimM, int dimN);

/**
* @brief Matrix cross entropy.
*
Expand Down
28 changes: 27 additions & 1 deletion paddle/cuda/include/hl_top_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal,
int beamSize,
int numSamples);

#endif /* HL_TOP_K_H_ */
/**
* @brief Matrix classification error.
*
* @param[out] topVal top k element.
* @param[in] ldv leading dimension of topVal.
* @param[out] topIds top k index.
* @param[in] src input value.
* @param[in] lds leading dimension of src.
* @param[in] dim width of input value.
* @param[in] topkSize size of top k element.
* @param[in] numSamples height of input value.
* @param[in] label ground truth label.
* @param[out] recResult top-k classification error.
*
*/
extern void hl_matrix_classification_error(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult);

#endif // HL_TOP_K_H_
12 changes: 10 additions & 2 deletions paddle/cuda/include/stub/hl_matrix_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d,
inline void hl_matrix_softmax_derivative(
real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {}

inline void hl_matrix_classification_error(
real* A_d, int* B_d, real* C_d, int dimM, int dimN) {}
inline void hl_matrix_classification_error(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {}

inline void hl_matrix_cross_entropy(
real* A_d, real* C_d, int* label_d, int dimM, int dimN) {}
Expand Down
53 changes: 0 additions & 53 deletions paddle/cuda/src/hl_cuda_matrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d,
CHECK_SYNC("hl_matrix_softmax_derivative failed");
}

template<int blockSize>
__global__ void KeMatrixClassificationError(real* in_A,
int* in_B,
real* out_C,
int dimN) {
__shared__ real max_s[blockSize];
__shared__ int max_l[blockSize];
const int tid = threadIdx.x;
const int rowId = blockIdx.x;

max_s[tid] = -1e30f;
in_A += rowId * dimN;
real tmp;
for (int colId = tid; colId < dimN; colId += blockSize) {
tmp = in_A[colId];
if (max_s[tid] < tmp) {
max_s[tid] = tmp;
max_l[tid] = colId;
}
}
__syncthreads();

for (int stride = blockSize/2; stride > 0; stride = stride/2) {
if (tid < stride) {
if (max_s[tid] < max_s[tid + stride]) {
max_s[tid] = max_s[tid + stride];
max_l[tid] = max_l[tid + stride];
}
}
__syncthreads();
}
__syncthreads();

if (tid == 0) {
out_C[rowId] = (max_l[0] == in_B[rowId] ? 0 : 1.0f);
}
}

void hl_matrix_classification_error(real* A_d,
int* B_d,
real* C_d,
int dimM,
int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(B_d);
CHECK_NOTNULL(C_d);

// each sample is calculated by one block
KeMatrixClassificationError<1024><<< dimM, 1024, 0, STREAM_DEFAULT >>>
(A_d, B_d, C_d, dimN);
CHECK_SYNC("hl_matrix_classification_error");
}

__global__ void KeMatrixMultiBinaryCrossEntropy(real* output,
real* entropy,
int* row,
Expand Down
78 changes: 78 additions & 0 deletions paddle/cuda/src/hl_top_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
CHECK_SYNC("hl_sparse_matrix_top_k failed");
}

/**
* Each block compute one sample.
* In a block:
* 1. every thread get top maxLength value;
* 2. merge to shTopK, block reduce and get max value;
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
template<int maxLength, int blockSize>
__global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
int * topIds,
real* src, int lds,
int dim,
int beamSize,
int* label,
real* recResult) {
__shared__ Pair shTopK[blockSize];
__shared__ int maxId[blockSize / 2];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
src += blockIdx.x * lds;
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;

Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
bool firstStep = true;
int topkSize = beamSize;

for (int k = 0; k < maxLength; k++) {
topK[k].set(-HL_FLOAT_MAX, -1);
}

while (beamSize) {
threadGetTopK<maxLength, blockSize>
(topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);

shTopK[tid] = topK[0];
blockReduce<maxLength, blockSize>
(shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}

__syncthreads();
if (tid == 0) {
for (int i = 0; i < topkSize; i++) {
if (*--topIds == label[blockIdx.x]) {
recResult[blockIdx.x] = 0;
break;
}
recResult[blockIdx.x] = 1.0f;
}
}
}

void hl_matrix_classification_error(real* topVal, int ldv,
int* topIds,
real* src, int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {
CHECK_NOTNULL(topVal);
CHECK_NOTNULL(topIds);
CHECK_NOTNULL(src);

if (topkSize > dim) topkSize = dim;

dim3 threads(256, 1);
dim3 grid(numSamples, 1);
KeMatrixTopKClassificationError<5, 256>
<<< grid, threads, 0, STREAM_DEFAULT >>>
(topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);

CHECK_SYNC("hl_matrix_top_k classification error failed");
}
22 changes: 21 additions & 1 deletion paddle/gserver/evaluators/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) {
*/
class ClassificationErrorEvaluator : public Evaluator {
public:
/*
ClassificationErrorEvaluator() : totalScore2_(0) {}

virtual void start() {
Evaluator::start();
totalScore2_ = 0;
} */

virtual void updateSamplesNum(const std::vector<Argument>& arguments) {
if (3 == arguments.size()) {
numSamples_ += arguments[2].value->getSum();
Expand Down Expand Up @@ -76,9 +84,11 @@ class ClassificationErrorEvaluator : public Evaluator {
1,
/* trans= */ false,
useGpu(arguments[0].deviceId));

errorMat->zeroMem();

if (label != nullptr) {
errorMat->classificationError(*output, *label);
errorMat->classificationError(*output, *label, config_.top_k());
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
errorMat->classificationErrorMulti(
Expand All @@ -94,6 +104,16 @@ class ClassificationErrorEvaluator : public Evaluator {
return errorMat;
}

void printStats(std::ostream& os) const {
if (config_.top_k() == 1) {
os << config_.name() << "="
<< (numSamples_ ? totalScore_ / numSamples_ : 0);
} else {
os << " top_" << config_.top_k()
<< "_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0);
}
}

virtual real evalImp(std::vector<Argument>& arguments) {
MatrixPtr errorMat = calcError(arguments);
return errorMat->getSum();
Expand Down
1 change: 1 addition & 0 deletions paddle/gserver/layers/Layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ class Layer {
return *output->second;
} else {
LOG(FATAL) << "No specific output " << str;
return *((Argument*)nullptr);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions paddle/gserver/tests/test_Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf,
TEST(Evaluator, classification_error) {
TestConfig config;
config.evaluatorConfig.set_type("classification_error");
config.evaluatorConfig.set_top_k(5);

config.inputDefs.push_back({INPUT_DATA, "output", 50});
config.inputDefs.push_back({INPUT_LABEL, "label", 50});
Expand Down
Loading