Skip to content

Commit

Permalink
IVGCVSW-5559 Add int8_t to tflite delegate on ExecuteNetwork
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: I56afc73d48848bc40842692831c05316484757a4
  • Loading branch information
FinnWilliamsArm authored and FrancisMurtagh-arm committed Nov 20, 2020
1 parent 66da751 commit 4f55a25
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 74 deletions.
98 changes: 55 additions & 43 deletions tests/ExecuteNetwork/ExecuteNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,57 +88,50 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
if (params.m_InputTypes[inputIndex].compare("float") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<float>(input);
TContainer tensorData;
PopulateTensorWithData(tensorData,
params.m_InputTensorShapes[inputIndex]->GetNumElements(),
params.m_InputTypes[inputIndex],
armnn::EmptyOptional(),
dataFile);

mapbox::util::apply_visitor([&](auto&& value)
{
for (unsigned int i = 0; i < inputSize; ++i)
{
inputData[i] = value.data()[i];
}
},
tensorData);
std::vector<float> tensorData;
PopulateTensorWithDataGeneric<float>(tensorData,
params.m_InputTensorShapes[inputIndex]->GetNumElements(),
dataFile,
[](const std::string& s)
{ return std::stof(s); });

std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else if (params.m_InputTypes[inputIndex].compare("int8") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<int8_t>(input);
std::vector<int8_t> tensorData;
PopulateTensorWithDataGeneric<int8_t>(tensorData,
params.m_InputTensorShapes[inputIndex]->GetNumElements(),
dataFile,
[](const std::string& s)
{ return armnn::numeric_cast<int8_t>(std::stoi(s)); });

std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else if (params.m_InputTypes[inputIndex].compare("int") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<int32_t>(input);
TContainer tensorData;
PopulateTensorWithData(tensorData,
params.m_InputTensorShapes[inputIndex]->GetNumElements(),
params.m_InputTypes[inputIndex],
armnn::EmptyOptional(),
dataFile);
mapbox::util::apply_visitor([&](auto&& value)
{
for (unsigned int i = 0; i < inputSize; ++i)
{
inputData[i] = value.data()[i];
}
},
tensorData);
std::vector<int32_t> tensorData;
PopulateTensorWithDataGeneric<int32_t>(tensorData,
params.m_InputTensorShapes[inputIndex]->GetNumElements(),
dataFile,
[](const std::string& s)
{ return std::stoi(s); });

std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<uint8_t>(input);
TContainer tensorData;
PopulateTensorWithData(tensorData,
params.m_InputTensorShapes[inputIndex]->GetNumElements(),
params.m_InputTypes[inputIndex],
armnn::EmptyOptional(),
dataFile);
mapbox::util::apply_visitor([&](auto&& value)
{
for (unsigned int i = 0; i < inputSize; ++i)
{
inputData[i] = value.data()[i];
}
},
tensorData);
std::vector<uint8_t> tensorData;
PopulateTensorWithDataGeneric<uint8_t>(tensorData,
params.m_InputTensorShapes[inputIndex]->GetNumElements(),
dataFile,
[](const std::string& s)
{ return armnn::numeric_cast<uint8_t>(std::stoi(s)); });

std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else
{
Expand Down Expand Up @@ -203,6 +196,25 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
}
}
}
else if (params.m_OutputTypes[outputIndex].compare("int8") == 0)
{
auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<int8_t>(tfLiteDelegateOutputId);
if(tfLiteDelageOutputData == NULL)
{
ARMNN_LOG(fatal) << "Output tensor is null, output type: "
"\"" << params.m_OutputTypes[outputIndex] << "\" may be incorrect.";
return EXIT_FAILURE;
}

for (int i = 0; i < outputSize; ++i)
{
std::cout << signed(tfLiteDelageOutputData[i]) << ", ";
if (i % 60 == 0)
{
std::cout << std::endl;
}
}
}
else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0)
{
auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<uint8_t>(tfLiteDelegateOutputId);
Expand Down
30 changes: 0 additions & 30 deletions tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,6 @@
#include "armnnOnnxParser/IOnnxParser.hpp"
#endif


template<typename T, typename TParseElementFunc>
std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
{
std::vector<T> result;
// Processes line-by-line.
std::string line;
while (std::getline(stream, line))
{
std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
for (const std::string& token : tokens)
{
if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
{
try
{
result.push_back(parseElementFunc(token));
}
catch (const std::exception&)
{
ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
}
}
}
}

return result;
}


template<armnn::DataType NonQuantizedType>
auto ParseDataArray(std::istream& stream);

Expand Down
52 changes: 51 additions & 1 deletion tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

#include <armnn/IRuntime.hpp>
#include <armnn/Types.hpp>
#include <armnn/Logging.hpp>
#include <armnn/utility/StringUtils.hpp>

#include <mapbox/variant.hpp>

#include <iostream>
#include <fstream>


std::vector<unsigned int> ParseArray(std::istream& stream);
Expand Down Expand Up @@ -68,4 +71,51 @@ bool ValidatePath(const std::string& file, const bool expectFile);
* @param expectFile bool - If true, checks for a regular file.
* @return bool - True if all given strings are valid paths., false otherwise.
* */
bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);

template<typename T, typename TParseElementFunc>
std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
{
std::vector<T> result;
// Processes line-by-line.
std::string line;
while (std::getline(stream, line))
{
std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
for (const std::string& token : tokens)
{
if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
{
try
{
result.push_back(parseElementFunc(token));
}
catch (const std::exception&)
{
ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
}
}
}
}

return result;
}

template <typename T, typename TParseElementFunc>
void PopulateTensorWithDataGeneric(std::vector<T>& tensorData,
unsigned int numElements,
const armnn::Optional<std::string>& dataFile,
TParseElementFunc parseFunction)
{
const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();

std::ifstream inputTensorFile;
if (readFromFile)
{
inputTensorFile = std::ifstream(dataFile.value());
}

tensorData = readFromFile ?
ParseArrayImpl<T>(inputTensorFile, parseFunction) :
std::vector<T>(numElements, static_cast<T>(0));
}

0 comments on commit 4f55a25

Please sign in to comment.