Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1416] Fix inception inference example for potential index out …
Browse files Browse the repository at this point in the history
…of range error. (#15179)

* Adding support to get the data from 1D NDArray.

* Added the error handling for index out of range.
  • Loading branch information
leleamol authored and lanking520 committed Jun 7, 2019
1 parent bcff498 commit 745a41c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
4 changes: 1 addition & 3 deletions cpp-package/example/inference/inception_inference.cpp
Expand Up @@ -302,13 +302,11 @@ void Predictor::PredictImage(const std::string& image_file) {

// The output is available in executor->outputs.
auto array = executor->outputs[0].Copy(Context::cpu());

/*
* Find out the maximum accuracy and the index associated with that accuracy.
* This is done by using the argmax operator on NDArray.
*/
auto predicted = array.ArgmaxChannel();

/*
* Wait until all the previous write operations on the 'predicted'
* NDArray to be complete before we read it.
Expand All @@ -317,7 +315,7 @@ void Predictor::PredictImage(const std::string& image_file) {
*/
predicted.WaitToRead();

int best_idx = predicted.At(0, 0);
int best_idx = predicted.At(0);
float best_accuracy = array.At(0, best_idx);

if (output_labels.empty()) {
Expand Down
6 changes: 6 additions & 0 deletions cpp-package/include/mxnet-cpp/ndarray.h
Expand Up @@ -321,6 +321,12 @@ class NDArray {
*/
size_t Offset(size_t c, size_t h, size_t w) const;
/*!
* \brief return value of the element at (index)
* \param index position
* \return value of one dimensions array
*/
mx_float At(size_t index) const;
/*!
* \brief return value of the element at (h, w)
* \param h height position
* \param w width position
Expand Down
13 changes: 12 additions & 1 deletion cpp-package/include/mxnet-cpp/ndarray.hpp
Expand Up @@ -375,11 +375,15 @@ inline void NDArray::Save(const std::string &file_name,
}

inline size_t NDArray::Offset(size_t h, size_t w) const {
return (h * GetShape()[1]) + w;
auto const shape = GetShape();
CHECK_EQ(shape.size(), 2) << "The NDArray needs to be 2 dimensional.";

return (h * shape[1]) + w;
}

inline size_t NDArray::Offset(size_t c, size_t h, size_t w) const {
auto const shape = GetShape();
CHECK_EQ(shape.size(), 3) << "The NDArray needs to be 3 dimensional.";
return h * shape[0] * shape[2] + w * shape[0] + c;
}

Expand All @@ -391,6 +395,13 @@ inline mx_float NDArray::At(size_t c, size_t h, size_t w) const {
return GetData()[Offset(c, h, w)];
}

inline mx_float NDArray::At(size_t index) const {
auto shape = GetShape();
CHECK_EQ(shape.size(), 1) << "The NDArray needs to be 1 dimensional.";
CHECK_LT(index, shape[0]) << "Specified index is out of range.";
return GetData()[index];
}

inline size_t NDArray::Size() const {
size_t ret = 1;
for (auto &i : GetShape()) ret *= i;
Expand Down

0 comments on commit 745a41c

Please sign in to comment.