Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy op #3907

Merged
merged 23 commits into from
Sep 14, 2017
Merged

Accuracy op #3907

merged 23 commits into from
Sep 14, 2017

Conversation

typhoonzero
Copy link
Contributor

Fix #3840

Copy link
Contributor

@dzhwinter dzhwinter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to split into small PRs. : )

auto *label = ctx.Input<framework::Tensor>("Label");

// label must be a vector
PADDLE_ENFORCE_EQ(label->dims().size(), 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put the comment into the value assertation.

 PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// label must be a vector
PADDLE_ENFORCE_EQ(label->dims().size(), 1);
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One assertation should contain some info if it cores dump.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto* label = ctx.Input<Tensor>("Label");
auto* accuracy = ctx.Output<Tensor>("Accuracy");
const size_t topk = 1;
// static_cast<AttrType>(ctx.op_.GetAttr<AttrType>("topk"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the comment if it is unused code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

// FIXME(typhoonzero): we don't accumulate the accuracy for now.
*accuracy_data = static_cast<T>(num_correct) / static_cast<T>(num_samples);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here need to cast to the type of T? I think accuracy_data always is float or double according to the precision.
Suppose that we are serving online, then the T will be fp16 for acceleration, the accuracy_data will get the wrong type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto* inference = ctx.Input<Tensor>("Inference");
auto* label = ctx.Input<Tensor>("Label");
auto* accuracy = ctx.Output<Tensor>("Accuracy");
const size_t topk = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems topk is unused in Accuracy operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

// input must have >= 1d shape.
PADDLE_ENFORCE_GE(input->dims().size(), 1);
// input must have >= k columns.
PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with above. The user can read the enforce information without trace into the source code.

}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need to put __forceinline__ explicitly? According to its document,
forceinline will be added by the nvcc compiler. To the best of my knowledge, function contains while, ifelse control flow loop, it can never be an inline function. Neither in c++ or cuda code.


// reshape input to a flattern matrix(like flat_inner_dims)
framework::DDim inputdims = input->dims();
const size_t row = framework::product(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be a flatten_to_2d interface in ddim, maybe should put a TODO here to replace in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe we need different ways to flatten: flatten_to_2d_inner and flatten_to_2d_outter

X.reshape(flat2dims);

for (size_t i = 0; i < row; i++) {
// TODO(typhoonzero): make this more efficient
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have a better choice since the partial_sort is heap sort.
Topk is such a classic question, so I don't think the efficiency is a problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm think try not to copy the memory, since the input is const, we cannot do it currently.

AddOutput("Accuracy", "The accuracy of current batch");

AddComment(
R"DOC(Accuracy. It will print accuracy rate for classification.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad comments here.

R"DOC(...)DOC" is just like python """...""", which R"DOC( is the left """ and )DOC" is the right """.

Copy link
Contributor

@lcy-seso lcy-seso Sep 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I had ever been confused about what does R"abc(....)abc" mean (abc can be anything, it is just a delimiter ).
I found these two docs solve my questions.

Just to share.

@typhoonzero typhoonzero changed the title Accuracy op [WIP] Accuracy op Sep 7, 2017
@typhoonzero typhoonzero changed the title [WIP] Accuracy op Accuracy op Sep 13, 2017
auto* accuracy = ctx.Output<Tensor>("Accuracy");
const int* inference_data = inference->data<int>();
const int* label_data = label->data<int>();
T* accuracy_data = accuracy->mutable_data<T>(ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the accuracy_data should also be changed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what do you mean by "change", accuracy_data is passed to cuda kernel and assign the result in the cuda kernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the same CPU code below.

float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());

Here the accuracy_data type is T, which may lose precision, maybe its better to write as mutable_data<float>. And I noticed that the AccuracyDivideKernel received a parameter of float.


// FIXME(typhoonzero): we don't accumulate the accuracy for now.
*accuracy_data =
static_cast<float>(num_correct) / static_cast<float>(num_samples);
Copy link
Contributor

@dzhwinter dzhwinter Sep 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need an ENFORCE check num_samples is not equal to zero? when user misuse this operator, the num_samples may be zero. I'm not sure it's useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, this return 0 if num_sample==0

auto* accuracy = ctx.Output<Tensor>("Accuracy");
const int* inference_data = inference->data<int>();
const int* label_data = label->data<int>();
T* accuracy_data = accuracy->mutable_data<T>(ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the same CPU code below.

float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());

Here the accuracy_data type is T, which may lose precision, maybe its better to write as mutable_data<float>. And I noticed that the AccuracyDivideKernel received a parameter of float.

size_t num_samples = inference->dims()[0];
size_t infer_width = inference->dims()[1];

AccuracyDivideKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar to cuda. It seems launched kernel in <<<1,1>>> can not utilize the capacity of cuda. This operator writes the similar one crossEntropy, but I don't know how the block=512 comes from.
@qingqing01

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I'm still trying to let the kernel use more threads, I'll enhance the kernel in next PR. I've got some problem writing kernel with atocimAdd.

Copy link
Contributor

@dzhwinter dzhwinter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM++

@typhoonzero typhoonzero merged commit 2d62336 into PaddlePaddle:develop Sep 14, 2017
heavengate pushed a commit to heavengate/Paddle that referenced this pull request Aug 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants