Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top-k error #1337

Merged
merged 8 commits into from
Feb 22, 2017
Merged

Add top-k error #1337

merged 8 commits into from
Feb 22, 2017

Conversation

lzhao4ever
Copy link
Contributor

No description provided.

@@ -375,10 +375,10 @@ bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary);
if (!fs) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
if (isStatic()) {
/*if (isStatic()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why comment this lines? Either uncomment them or remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It causes error when I tried to set static for some batch norm layers when their parameters are loaded from pre-trained model. I am trying to find out who added these lines and why.


int* ids = nullptr;
int* lbl = nullptr;
IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IVectorPtr dest = nullptr;
IVectorPtr dest2 = nullptr;

int* lbl = nullptr;
IVectorPtr dest = IVector::create(maxIds_->getSize(), false);
IVectorPtr dest2 = IVector::create(label->getSize(), false);
if (useGpu(arguments[0].deviceId)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dest = IVector::create(label->getSize(), false);
dest2 = IVector::create(label->getSize(), false);
dest->copyFrom(*maxIds_);
dest2->copyFrom(*label);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite these lines.

ids = maxIds_->getData();
lbl = label->getData();
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

errorMat2 is not necessary because we only need a counter(totalScore_) return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Remove errorMat2.

}

real* result2 = errorMat2->getData();
for (size_t i = 0; i < height; ++i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool contain = false;
for (size_t j = 0; j<width && !contain; ++j) {
   error = (ids[i * width + j] == lbl[i]);  // label is in prediction ids 
}
if (!contain) {
    totalScore2_ += 1.0;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable "error" should be "contain".

if (result2[i] == 0.0) {
break;
}
result2[i] = (ids[i * width + j] != lbl[i]); // top-k error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的代码逻辑有点乱吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite this part.

if (label != nullptr) {
errorMat->classificationError(*output, *label);
errorMat->classificationError(*output, *label); // top-1 error
Copy link
Contributor

@gangliao gangliao Feb 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有点问题:
其实不需要classificationError计算完毕,再copy回来,计算topk, 更直接的方法是直接在CPU和GPU上做top k classification error.

具体做法:给 Matrix.h里面的classificationError添加一个默认参数

virtual void classificationError(MatrixPtr output, IVectorPtr label,
size_t topkSize = 1)

Copy link
Contributor

@gangliao gangliao Feb 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matrix.cc classificationError on GPU

/*calulate the error of classification */
void GpuMatrix::classificationError(MatrixPtr output,
                                    IVectorPtr label,
                                    size_t topkSize) {
  size_t numSamples = this->getHeight();
  GpuMatrixPtr gpuOutput = std::dynamic_pointer_cast<GpuMatrix>(output);
  GpuIVectorPtr gpuLabel = std::dynamic_pointer_cast<GpuIVector>(label);
  GpuMatrixPtr gpuTopVal = std::make_shared<GpuMatrix>(numSamples, topkSize);
  GpuIVectorPtr gpuTopIds = std::make_shared<GpuIVector>(numSamples * topkSize);

  CHECK(gpuOutput && gpuLabel) << "Invalid argument pointer";
  CHECK(gpuTopVal && gpuTopIds) << "Allocate GPU memory failed";
  CHECK(gpuLabel->getSize() == numSamples) << "Vector size is not equal";
  CHECK(numSamples == gpuOutput->getHeight() && this->getWidth() == 1)
    << "Matrix dimensions are not equal";

  size_t dim = gpuOutput->getWidth();
  int ret = hl_matrix_classification_error(gpuTopVal->getData(),
                                           gpuTopVal->getStride(),
                                           gpuTopIds->getData(),
                                           gpuOutput->getData(),
                                           gpuOutput->getStride(),
                                           dim,
                                           topkSize,
                                           numSamples,
                                           gpuLabel->getData(),
                                           this->getData());

  CHECK_EQ(ret, HPPL_RET_OK) << "Error matrix classification ";
}

GPU 具体实现:

/**
 * Each block compute one sample.
 * In a block:
 * 1. every thread get top maxLength value;
 * 2. merge to shTopK, block reduce and get max value;
 * 3. go to the second setp, until one thread's topK value is null;
 * 4. go to the first setp, until get the topK value.
 */
template<int maxLength, int blockSize>
__global__ void KeMatrixTopKClassificationError(Elem_t* topVal, int ldv,
                                                int * topIds,
                                                Elem_t* src, int lds,
                                                int dim,
                                                int beamSize,
                                                int* label,
                                                Elem_t* recResult) {
  __shared__ Pair shTopK[blockSize];
  __shared__ int maxId[blockSize / 2];
  const int tid = threadIdx.x;
  const int warp = threadIdx.x / 32;
  src += blockIdx.x * lds;
  topVal += blockIdx.x * ldv;
  topIds += blockIdx.x * beamSize;

  Pair topK[maxLength]; // NOLINT
  int beam = maxLength;
  Pair max;
  bool isEmpty = false;
  bool firstStep = true;
  int topkSize = beamSize;

  for (int k = 0; k < maxLength; k++) {
    topK[k].set(-HL_FLOAT_MAX, -1);
  }

  while (beamSize) {
    threadGetTopK<maxLength, blockSize>
      (topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);

    shTopK[tid] = topK[0];
    blockReduce<maxLength, blockSize>
      (shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
  }

  __syncthreads();
  if (tid == 0) {
    for (int i = 0; i < topkSize; i++) {
        if (*--topIds == label[blockIdx.x]) {
            recResult[blockIdx.x] = 0;
            break;
        }
        recResult[blockIdx.x] = 1.0f;
    }
  }
}

int hl_matrix_classification_error(Elem_t* topVal, int ldv,
                                   int* topIds,
                                   Elem_t* src, int lds,
                                   int dim,
                                   int topkSize,
                                   int numSamples,
                                   int* label,
                                   Elem_t* recResult) {
  if (NULL == topVal || NULL == topIds || NULL == src) {
    return HPPL_RET_ERROR;
  }

  if (topkSize > dim) topkSize = dim;

  dim3 threads(256, 1);
  dim3 grid(numSamples, 1);
  KeMatrixTopKClassificationError<5, 256>
  <<< grid, threads, 0, STREAM_DEFAULT >>>
  (topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);

  if (sync_error("hl_matrix_top_k classification error failed")) {
    return HPPL_RET_ERROR;
  }

  return HPPL_RET_OK;
}

Copy link
Contributor

@gangliao gangliao Feb 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matrix.cc classificationError on CPU

/* calulate classification error */
void CpuMatrix::classificationError(MatrixPtr output,
                                    IVectorPtr label,
                                    size_t topkSize) {
  size_t numSamples = this->getHeight();
  MatrixPtr cpuOutput = std::dynamic_pointer_cast<CpuMatrix>(output);
  IVectorPtr cpuLabel = std::dynamic_pointer_cast<CpuIVector>(label);
  IVectorPtr cpuTopIds = std::make_shared<CpuIVector>(numSamples * topkSize);
  MatrixPtr cpuTopVal = std::make_shared<CpuMatrix>(numSamples, topkSize);

  CHECK(cpuOutput && cpuLabel) << "Invalid argument pointer";
  CHECK(cpuTopIds && cpuTopVal) << "Allocate cpu memory failed";
  CHECK(cpuLabel->getSize() == numSamples) << "Vector size is not equal";
  CHECK(cpuOutput->getHeight() == numSamples && this->getWidth() == 1)
    << "Matrix dimensions are not equal";

  // top k matrix classification
  cpuOutput->rowMax(*cpuTopIds, *cpuTopVal);

  size_t dim = cpuOutput->getWidth();
  real* result = this->getData();
  int* ids = cpuTopIds->getData();
  int* lbl = cpuLabel->getData();
  for (size_t i = 0; i < numSamples; ++i) {
    CHECK_GE(lbl[i], 0);
    CHECK_LT((size_t)lbl[i], dim);

    for (size_t j = 0; j < topkSize; ++j) {
      if (ids[j + i * topkSize] == lbl[i]) {
        result[i] = 0;
        break;
      }
      result[i] = 1.0f;
    }
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我想既计算top-1 error 又计算top-k error. 你的实现只给出了top-k error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

void classificationError(MatrixPtr output, IVectorPtr label,
size_t topkSize = 1)

默认就是top-1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

廖刚的这种方式可以同时接两个Evaluator,一个是top-1,另一个是top-k。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

知道了。谢谢。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将按建议修改。

Copy link
Contributor

@gangliao gangliao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一点小问题

Copy link
Contributor

@gangliao gangliao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM++

@lzhao4ever
Copy link
Contributor Author

谢谢提供详尽代码和说明:)

@wangkuiyi wangkuiyi merged commit f2c7c9b into PaddlePaddle:develop Feb 22, 2017
zhhsplendid pushed a commit to zhhsplendid/Paddle that referenced this pull request Sep 25, 2019
* update cn doc for floor op

* refine the description of floor op
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants