diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 6f21b82afdc6c..eb454c59c1e58 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -69,19 +69,6 @@ extern void hl_sequence_softmax_forward(real* A_d, const int* index, int numSequence); -/** - * @brief Matrix classification error. - * - * @param[in] A_d input matrix (M x N). - * @param[in] B_d input vector (M x 1). - * @param[out] C_d output vector (M x 1). - * @param[in] dimM matrix height. - * @param[in] dimN matrix width. - * - */ -extern void hl_matrix_classification_error( - real* A_d, int* B_d, real* C_d, int dimM, int dimN); - /** * @brief Matrix cross entropy. * diff --git a/paddle/cuda/include/hl_top_k.h b/paddle/cuda/include/hl_top_k.h index 77949ed295a6e..79ae0d0e741de 100644 --- a/paddle/cuda/include/hl_top_k.h +++ b/paddle/cuda/include/hl_top_k.h @@ -58,4 +58,30 @@ extern void hl_sparse_matrix_top_k(real* topVal, int beamSize, int numSamples); -#endif /* HL_TOP_K_H_ */ +/** + * @brief Matrix classification error. + * + * @param[out] topVal top k element. + * @param[in] ldv leading dimension of topVal. + * @param[out] topIds top k index. + * @param[in] src input value. + * @param[in] lds leading dimension of src. + * @param[in] dim width of input value. + * @param[in] topkSize size of top k element. + * @param[in] numSamples height of input value. + * @param[in] label ground truth label. + * @param[out] recResult top-k classification error. + * + */ +extern void hl_matrix_classification_error(real* topVal, + int ldv, + int* topIds, + real* src, + int lds, + int dim, + int topkSize, + int numSamples, + int* label, + real* recResult); + +#endif // HL_TOP_K_H_ diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index f4e6461cdcf19..127cb7e27983e 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -35,8 +35,16 @@ inline void hl_sequence_softmax_forward(real* A_d, inline void hl_matrix_softmax_derivative( real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {} -inline void hl_matrix_classification_error( - real* A_d, int* B_d, real* C_d, int dimM, int dimN) {} +inline void hl_matrix_classification_error(real* topVal, + int ldv, + int* topIds, + real* src, + int lds, + int dim, + int topkSize, + int numSamples, + int* label, + real* recResult) {} inline void hl_matrix_cross_entropy( real* A_d, real* C_d, int* label_d, int dimM, int dimN) {} diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 96c07d9c3b7a3..9bcc7fb7de44b 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -265,59 +265,6 @@ void hl_matrix_softmax_derivative(real *grad_d, CHECK_SYNC("hl_matrix_softmax_derivative failed"); } -template -__global__ void KeMatrixClassificationError(real* in_A, - int* in_B, - real* out_C, - int dimN) { - __shared__ real max_s[blockSize]; - __shared__ int max_l[blockSize]; - const int tid = threadIdx.x; - const int rowId = blockIdx.x; - - max_s[tid] = -1e30f; - in_A += rowId * dimN; - real tmp; - for (int colId = tid; colId < dimN; colId += blockSize) { - tmp = in_A[colId]; - if (max_s[tid] < tmp) { - max_s[tid] = tmp; - max_l[tid] = colId; - } - } - __syncthreads(); - - for (int stride = blockSize/2; stride > 0; stride = stride/2) { - if (tid < stride) { - if (max_s[tid] < max_s[tid + stride]) { - max_s[tid] = max_s[tid + stride]; - max_l[tid] = max_l[tid + stride]; - } - } - __syncthreads(); - } - __syncthreads(); - - if (tid == 0) { - out_C[rowId] = (max_l[0] == in_B[rowId] ? 0 : 1.0f); - } -} - -void hl_matrix_classification_error(real* A_d, - int* B_d, - real* C_d, - int dimM, - int dimN) { - CHECK_NOTNULL(A_d); - CHECK_NOTNULL(B_d); - CHECK_NOTNULL(C_d); - - // each sample is calculated by one block - KeMatrixClassificationError<1024><<< dimM, 1024, 0, STREAM_DEFAULT >>> - (A_d, B_d, C_d, dimN); - CHECK_SYNC("hl_matrix_classification_error"); -} - __global__ void KeMatrixMultiBinaryCrossEntropy(real* output, real* entropy, int* row, diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index f0ef0cc3c51f9..4f0bbfcf4e3aa 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -384,3 +384,81 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv, CHECK_SYNC("hl_sparse_matrix_top_k failed"); } +/** + * Each block compute one sample. + * In a block: + * 1. every thread get top maxLength value; + * 2. merge to shTopK, block reduce and get max value; + * 3. go to the second setp, until one thread's topK value is null; + * 4. go to the first setp, until get the topK value. + */ +template +__global__ void KeMatrixTopKClassificationError(real* topVal, int ldv, + int * topIds, + real* src, int lds, + int dim, + int beamSize, + int* label, + real* recResult) { + __shared__ Pair shTopK[blockSize]; + __shared__ int maxId[blockSize / 2]; + const int tid = threadIdx.x; + const int warp = threadIdx.x / 32; + src += blockIdx.x * lds; + topVal += blockIdx.x * ldv; + topIds += blockIdx.x * beamSize; + + Pair topK[maxLength]; // NOLINT + int beam = maxLength; + Pair max; + bool isEmpty = false; + bool firstStep = true; + int topkSize = beamSize; + + for (int k = 0; k < maxLength; k++) { + topK[k].set(-HL_FLOAT_MAX, -1); + } + + while (beamSize) { + threadGetTopK + (topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid); + + shTopK[tid] = topK[0]; + blockReduce + (shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp); + } + + __syncthreads(); + if (tid == 0) { + for (int i = 0; i < topkSize; i++) { + if (*--topIds == label[blockIdx.x]) { + recResult[blockIdx.x] = 0; + break; + } + recResult[blockIdx.x] = 1.0f; + } + } +} + +void hl_matrix_classification_error(real* topVal, int ldv, + int* topIds, + real* src, int lds, + int dim, + int topkSize, + int numSamples, + int* label, + real* recResult) { + CHECK_NOTNULL(topVal); + CHECK_NOTNULL(topIds); + CHECK_NOTNULL(src); + + if (topkSize > dim) topkSize = dim; + + dim3 threads(256, 1); + dim3 grid(numSamples, 1); + KeMatrixTopKClassificationError<5, 256> + <<< grid, threads, 0, STREAM_DEFAULT >>> + (topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult); + + CHECK_SYNC("hl_matrix_top_k classification error failed"); +} diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 2bf6ead0dc382..cd8d1e9ecbfd8 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -39,6 +39,14 @@ void Evaluator::eval(const NeuralNetwork& nn) { */ class ClassificationErrorEvaluator : public Evaluator { public: + /* + ClassificationErrorEvaluator() : totalScore2_(0) {} + + virtual void start() { + Evaluator::start(); + totalScore2_ = 0; + } */ + virtual void updateSamplesNum(const std::vector& arguments) { if (3 == arguments.size()) { numSamples_ += arguments[2].value->getSum(); @@ -76,9 +84,11 @@ class ClassificationErrorEvaluator : public Evaluator { 1, /* trans= */ false, useGpu(arguments[0].deviceId)); + errorMat->zeroMem(); + if (label != nullptr) { - errorMat->classificationError(*output, *label); + errorMat->classificationError(*output, *label, config_.top_k()); } else if (dynamic_cast(multiBinaryLabel.get()) || dynamic_cast(multiBinaryLabel.get())) { errorMat->classificationErrorMulti( @@ -94,6 +104,16 @@ class ClassificationErrorEvaluator : public Evaluator { return errorMat; } + void printStats(std::ostream& os) const { + if (config_.top_k() == 1) { + os << config_.name() << "=" + << (numSamples_ ? totalScore_ / numSamples_ : 0); + } else { + os << " top_" << config_.top_k() + << "_error=" << (numSamples_ ? totalScore_ / numSamples_ : 0); + } + } + virtual real evalImp(std::vector& arguments) { MatrixPtr errorMat = calcError(arguments); return errorMat->getSum(); diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 6dfd48fb96618..7c4bea072157a 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -311,6 +311,7 @@ class Layer { return *output->second; } else { LOG(FATAL) << "No specific output " << str; + return *((Argument*)nullptr); } } } diff --git a/paddle/gserver/tests/test_Evaluator.cpp b/paddle/gserver/tests/test_Evaluator.cpp index 8165eb8269336..d39c042271567 100644 --- a/paddle/gserver/tests/test_Evaluator.cpp +++ b/paddle/gserver/tests/test_Evaluator.cpp @@ -129,6 +129,7 @@ void testEvaluatorAll(TestConfig testConf, TEST(Evaluator, classification_error) { TestConfig config; config.evaluatorConfig.set_type("classification_error"); + config.evaluatorConfig.set_top_k(5); config.inputDefs.push_back({INPUT_DATA, "output", 50}); config.inputDefs.push_back({INPUT_LABEL, "label", 50}); diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 1964b2f8bfaeb..07450bfb0ef70 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -732,6 +732,7 @@ void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { size_t beam = maxVal.getWidth(); CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxVal.getHeight(), numSamples); + CHECK_EQ(maxVal.getWidth(), beam); hl_matrix_top_k(maxVal.getData(), maxVal.getStride(), @@ -792,19 +793,32 @@ void GpuMatrix::maxoutBackward(Matrix& a, } /*calulate the error of classification */ -void GpuMatrix::classificationError(Matrix& output, IVector& label) { - auto output_ptr = dynamic_cast(&output); - auto label_ptr = dynamic_cast(&label); - CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; - - CHECK(height_ == output_ptr->height_ && width_ == 1) +void GpuMatrix::classificationError(Matrix& output, + IVector& label, + size_t topkSize) { + auto gpuOutput = dynamic_cast(&output); + auto gpuLabel = dynamic_cast(&label); + size_t numSamples = this->getHeight(); + GpuMatrixPtr gpuTopVal = std::make_shared(numSamples, topkSize); + GpuIVectorPtr gpuTopIds = std::make_shared(numSamples * topkSize); + + CHECK(gpuOutput && gpuLabel) << "Invalid argument pointer"; + CHECK(gpuTopVal && gpuTopIds) << "Allocate GPU memory failed"; + CHECK(gpuLabel->getSize() == numSamples) << "Vector size is not equal"; + CHECK(numSamples == gpuOutput->getHeight() && this->getWidth() == 1) << "Matrix dimensions are not equal"; - hl_matrix_classification_error((real*)output_ptr->data_, - (int*)label_ptr->getData(), - data_, - height_, - output_ptr->width_); + size_t dim = gpuOutput->getWidth(); + hl_matrix_classification_error(gpuTopVal->getData(), + gpuTopVal->getStride(), + gpuTopIds->getData(), + gpuOutput->getData(), + gpuOutput->getStride(), + dim, + topkSize, + numSamples, + gpuLabel->getData(), + this->getData()); } /* copy -log(output[i * width + label]) to this->data[i] */ @@ -3039,7 +3053,7 @@ void CpuMatrix::rowMax(Matrix& max) { max.maxRows(*this); } -/* get beam size of max ids and values */ +/* Get the top k elements of each row of this matrix */ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { CHECK(isContiguous()); CHECK(!maxIds.useGpu() && !maxVal.useGpu()) << "Matrix type are not equal"; @@ -3047,6 +3061,7 @@ void CpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { size_t beam = maxVal.getWidth(); CHECK_EQ(maxIds.getSize(), numSamples * beam); CHECK_EQ(maxVal.getHeight(), numSamples); + CHECK_EQ(maxVal.getWidth(), beam); real* a = getData(); int* s = maxIds.getData(); @@ -3198,32 +3213,39 @@ void CpuMatrix::rowNormalizeL1(Matrix& out) { } /* calulate classification error */ -void CpuMatrix::classificationError(Matrix& output, IVector& label) { - CHECK(dynamic_cast(&output)); - CHECK(dynamic_cast(&label)); +void CpuMatrix::classificationError(Matrix& output, + IVector& label, + size_t topkSize) { + size_t numSamples = this->getHeight(); + auto cpuOutput = dynamic_cast(&output); + auto cpuLabel = dynamic_cast(&label); + IVectorPtr cpuTopIds = std::make_shared(numSamples * topkSize); + MatrixPtr cpuTopVal = std::make_shared(numSamples, topkSize); + + CHECK(cpuOutput && cpuLabel) << "Invalid argument pointer"; + CHECK(cpuTopIds && cpuTopVal) << "Allocate cpu memory failed"; + CHECK(cpuLabel->getSize() == numSamples) << "Vector size is not equal"; + CHECK(cpuOutput->getHeight() == numSamples && this->getWidth() == 1) + << "Matrix dimensions are not equal"; - CHECK_EQ(getWidth(), (size_t)1); - size_t numSamples = getHeight(); - CHECK_EQ(label.getSize(), numSamples); - CHECK_EQ(output.getHeight(), numSamples); + // top k matrix classification + cpuOutput->rowMax(*cpuTopIds, *cpuTopVal); - size_t dim = output.getWidth(); - real* out = output.getData(); - int* lbl = label.getData(); - real maxData = 0.0; - int maxIndex = -1; + size_t dim = cpuOutput->getWidth(); + real* result = this->getData(); + int* ids = cpuTopIds->getData(); + int* lbl = cpuLabel->getData(); for (size_t i = 0; i < numSamples; ++i) { CHECK_GE(lbl[i], 0); CHECK_LT((size_t)lbl[i], dim); - maxData = out[i * dim]; - maxIndex = 0; - for (size_t j = 0; j < dim; ++j) { - if (maxData < out[i * dim + j]) { - maxIndex = j; - maxData = out[i * dim + j]; + + for (size_t j = 0; j < topkSize; ++j) { + if (ids[j + i * topkSize] == lbl[i]) { + result[i] = 0; + break; } + result[i] = 1.0f; } - getData()[i] = (maxIndex != lbl[i]); } } diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index ea4bbb86b057b..d0ba2e93feabf 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -836,8 +836,11 @@ class Matrix : public BaseMatrix { * output[i] = 1 if row i is an error. * * output[i] = 0 if row i is correct. + * */ - virtual void classificationError(Matrix& output, IVector& label) { + virtual void classificationError(Matrix& output, + IVector& label, + size_t topkSize = 1) { LOG(FATAL) << "Not implemented"; } @@ -1314,7 +1317,7 @@ class GpuMatrix : public Matrix { void check(std::ostream& os, Matrix& refMat, bool printDiff = true); void randomizeUniform(); - void classificationError(Matrix& output, IVector& label); + void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); void convExpand(Matrix& feature, int feaImgHeight, @@ -1739,7 +1742,7 @@ class CpuMatrix : public Matrix { void randomizeUniform(); - void classificationError(Matrix& output, IVector& label); + void classificationError(Matrix& output, IVector& label, size_t topkSize = 1); void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 6caaea443c1df..08b64c1bb6f5d 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -764,7 +764,7 @@ TEST(Matrix, paramReluBackwardDiff) { } } -void testClassificationError(int numSamples, int dim) { +void testClassificationError(int numSamples, int dim, int topkSize) { MatrixPtr cpuError = std::make_shared(numSamples, 1); MatrixPtr gpuError = std::make_shared(numSamples, 1); MatrixPtr cpuOutput = std::make_shared(numSamples, dim); @@ -777,17 +777,22 @@ void testClassificationError(int numSamples, int dim) { gpuOutput->copyFrom(*cpuOutput); gpuLabel->copyFrom(*cpuLabel); - cpuError->classificationError(*cpuOutput, *cpuLabel); - gpuError->classificationError(*gpuOutput, *gpuLabel); + cpuError->classificationError(*cpuOutput, *cpuLabel, topkSize); + gpuError->classificationError(*gpuOutput, *gpuLabel, topkSize); TensorCheckEqual(*cpuError, *gpuError); } TEST(Matrix, classificationError) { - for (auto numSamples : {1, 10, 100, 1000, 70000}) { - for (auto dim : {1, 10, 100, 1000}) { - VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; - testClassificationError(numSamples, dim); + for (auto numSamples : {1, 5, 31, 90, 150, 300}) { + for (auto dim : + {1, 5, 8, 10, 15, 64, 80, 120, 256, 300, 1280, 5120, 50000}) { + for (auto topkSize : {1, 5, 10, 20, 40, (int)rand() % dim + 1}) { + if (topkSize > dim) continue; + VLOG(3) << " sample= " << numSamples << " topkSize= " << topkSize + << " dim= " << dim; + testClassificationError(numSamples, dim, topkSize); + } } } } diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index 29d6e20dc1696..1ccded8187967 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -375,10 +375,6 @@ bool Parameter::load(const std::string& filename) { std::ifstream fs(filename, std::ios_base::binary); if (!fs) { LOG(INFO) << "missing parameters [" << filename << "] while loading model."; - if (isStatic()) { - LOG(FATAL) << getName() << " is static but missing, not allowed."; - return false; - } if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { LOG(FATAL) << getName() << " missing, not allowed."; return false; diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index be4634d5103c0..65d5d50277b66 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -475,6 +475,10 @@ message EvaluatorConfig { // Used by ChunkEvaluator // chunk of these types are not counted repeated int32 excluded_chunk_types = 12; + + // Used by ClassificationErrorEvaluator + // top # classification error + optional int32 top_k = 13 [default = 1]; } message LinkConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index d403a6029a3e9..da937152ee0ce 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1253,6 +1253,7 @@ def Evaluator( dict_file=None, result_file=None, num_results=None, + top_k=None, delimited=None, excluded_chunk_types=None, ): evaluator = g_config.model_config.evaluators.add() @@ -1280,6 +1281,8 @@ def Evaluator( evaluator.result_file = result_file if num_results is not None: evaluator.num_results = num_results + if top_k is not None: + evaluator.top_k = top_k if delimited is not None: evaluator.delimited = delimited diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index bd247ea9af9d8..567521ee9dbad 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -71,6 +71,7 @@ def evaluator_base( result_file=None, num_results=None, delimited=None, + top_k=None, excluded_chunk_types=None, ): """ Evaluator will evaluate the network status while training/testing. @@ -104,12 +105,15 @@ def evaluator_base( :param weight: An input layer which is a weight for each sample. Each evaluator may calculate differently to use this weight. :type weight: LayerOutput. + :param top_k: number k in top-k error rate + :type top_k: int """ # inputs type assertions. assert classification_threshold is None or isinstance( classification_threshold, float) assert positive_label is None or isinstance(positive_label, int) assert num_results is None or isinstance(num_results, int) + assert top_k is None or isinstance(top_k, int) if not isinstance(input, list): input = [input] @@ -130,6 +134,8 @@ def evaluator_base( dict_file=dict_file, result_file=result_file, delimited=delimited, + num_results=num_results, + top_k=top_k, excluded_chunk_types=excluded_chunk_types, ) @@ -139,6 +145,7 @@ def classification_error_evaluator(input, label, name=None, weight=None, + top_k=None, threshold=None): """ Classification Error Evaluator. It will print error rate for classification. @@ -167,6 +174,8 @@ def classification_error_evaluator(input, then means not set weight. The larger weight it is, the more important this sample is. :type weight: LayerOutput + :param top_k: number k in top-k error rate + :type top_k: int :param threshold: The classification threshold. :type threshold: float :return: None. @@ -178,6 +187,7 @@ def classification_error_evaluator(input, input=input, label=label, weight=weight, + top_k=top_k, classification_threshold=threshold, ) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 0d3a31b9d39a2..00aef80691fba 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2870,8 +2870,8 @@ def gru_step_layer(input, :param name: :param gate_act: :param bias_attr: - :param param_attr: the parameter_attribute for transforming the output_mem - from previous step. + :param param_attr: the parameter_attribute for transforming the output_mem + from previous step. :param layer_attr: :return: LayerOutput object. :rtype: LayerOutput @@ -2882,10 +2882,10 @@ def gru_step_layer(input, Layer( name=name, type=LayerType.GRU_STEP_LAYER, - # The parameter here is for transforming the output_mem. The input has - # already been transformed outside this module so it does not need - # parameter associated with it. - # The parameter here is instead grouped with input is due to + # The parameter here is for transforming the output_mem. The input has + # already been transformed outside this module so it does not need + # parameter associated with it. + # The parameter here is instead grouped with input is due to # backward model compatibility. inputs=[Input(input.name, **param_attr.attr), output_mem.name], bias=ParamAttr.to_bias(bias_attr), @@ -3536,6 +3536,7 @@ def classification_cost(input, label, weight=None, name=None, + top_k=None, evaluator=classification_error_evaluator, layer_attr=None): """ @@ -3550,6 +3551,8 @@ def classification_cost(input, :param weight: The weight affects the cost, namely the scale of cost. It is an optional argument. :type weight: LayerOutput + :param top_k: number k in top-k error rate + :type top_k: int :param evaluator: Evaluator method. :param layer_attr: layer's extra attribute. :type layer_attr: ExtraLayerAttribute @@ -3577,7 +3580,7 @@ def __add_evaluator__(e): assert isinstance(e.for_classification, bool) assert e.for_classification - e(name=e.__name__, input=input, label=label, weight=weight) + e(name=e.__name__, input=input, label=label, weight=weight, top_k=top_k) if not isinstance(evaluator, collections.Sequence): evaluator = [evaluator]