Skip to content

Commit

Permalink
Add additional bounds checks to NNAPI FMQ deserialize utility functions
Browse files Browse the repository at this point in the history
This CL adds the following additional bounds checks:
* Adds additional checks of the index of the std::vector before
  accessing the element at the index
* Changes the array index operator [] to the checked std::vector::at
  method

Bug: 256589724
Test: mma
Merged-In: I6bfb02a5cd76258284cc4d797a4508b21e672c4b
Change-Id: I6bfb02a5cd76258284cc4d797a4508b21e672c4b
(cherry picked from commit 67d9ebe)
Merged-In: I6bfb02a5cd76258284cc4d797a4508b21e672c4b
  • Loading branch information
Michael Butler authored and Android Build Coastguard Worker committed Dec 8, 2022
1 parent 7af0956 commit e7355d6
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions neuralnetworks/1.2/utils/src/BurstUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
size_t index = 0;

// validate packet information
if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::packetInformation) {
return NN_ERROR() << "FMQ Request packet ill-formed";
}

// unpackage packet information
const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
const FmqRequestDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
index++;
const uint32_t packetSize = packetInfo.packetSize;
const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
Expand All @@ -212,13 +213,14 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
inputs.reserve(numberOfInputOperands);
for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
// validate input operand information
if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::inputOperandInformation) {
return NN_ERROR() << "FMQ Request packet ill-formed";
}

// unpackage operand information
const FmqRequestDatum::OperandInformation& operandInfo =
data[index].inputOperandInformation();
data.at(index).inputOperandInformation();
index++;
const bool hasNoValue = operandInfo.hasNoValue;
const V1_0::DataLocation location = operandInfo.location;
Expand All @@ -229,12 +231,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::inputOperandDimensionValue) {
return NN_ERROR() << "FMQ Request packet ill-formed";
}

// unpackage dimension
const uint32_t dimension = data[index].inputOperandDimensionValue();
const uint32_t dimension = data.at(index).inputOperandDimensionValue();
index++;

// store result
Expand All @@ -251,13 +254,14 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
outputs.reserve(numberOfOutputOperands);
for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
// validate output operand information
if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::outputOperandInformation) {
return NN_ERROR() << "FMQ Request packet ill-formed";
}

// unpackage operand information
const FmqRequestDatum::OperandInformation& operandInfo =
data[index].outputOperandInformation();
data.at(index).outputOperandInformation();
index++;
const bool hasNoValue = operandInfo.hasNoValue;
const V1_0::DataLocation location = operandInfo.location;
Expand All @@ -268,12 +272,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::outputOperandDimensionValue) {
return NN_ERROR() << "FMQ Request packet ill-formed";
}

// unpackage dimension
const uint32_t dimension = data[index].outputOperandDimensionValue();
const uint32_t dimension = data.at(index).outputOperandDimensionValue();
index++;

// store result
Expand All @@ -290,30 +295,31 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
slots.reserve(numberOfPools);
for (size_t pool = 0; pool < numberOfPools; ++pool) {
// validate input operand information
if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::poolIdentifier) {
return NN_ERROR() << "FMQ Request packet ill-formed";
}

// unpackage operand information
const int32_t poolId = data[index].poolIdentifier();
const int32_t poolId = data.at(index).poolIdentifier();
index++;

// store result
slots.push_back(poolId);
}

// validate measureTiming
if (data[index].getDiscriminator() != discriminator::measureTiming) {
if (index >= data.size() || data.at(index).getDiscriminator() != discriminator::measureTiming) {
return NN_ERROR() << "FMQ Request packet ill-formed";
}

// unpackage measureTiming
const V1_2::MeasureTiming measure = data[index].measureTiming();
const V1_2::MeasureTiming measure = data.at(index).measureTiming();
index++;

// validate packet information
if (index != packetSize) {
return NN_ERROR() << "FMQ Result packet ill-formed";
return NN_ERROR() << "FMQ Request packet ill-formed";
}

// return request
Expand All @@ -328,12 +334,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
size_t index = 0;

// validate packet information
if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::packetInformation) {
return NN_ERROR() << "FMQ Result packet ill-formed";
}

// unpackage packet information
const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
const FmqResultDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
index++;
const uint32_t packetSize = packetInfo.packetSize;
const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
Expand All @@ -349,12 +356,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
outputShapes.reserve(numberOfOperands);
for (size_t operand = 0; operand < numberOfOperands; ++operand) {
// validate operand information
if (data[index].getDiscriminator() != discriminator::operandInformation) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::operandInformation) {
return NN_ERROR() << "FMQ Result packet ill-formed";
}

// unpackage operand information
const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
const FmqResultDatum::OperandInformation& operandInfo = data.at(index).operandInformation();
index++;
const bool isSufficient = operandInfo.isSufficient;
const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
Expand All @@ -364,12 +372,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::operandDimensionValue) {
return NN_ERROR() << "FMQ Result packet ill-formed";
}

// unpackage dimension
const uint32_t dimension = data[index].operandDimensionValue();
const uint32_t dimension = data.at(index).operandDimensionValue();
index++;

// store result
Expand All @@ -381,12 +390,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
}

// validate execution timing
if (data[index].getDiscriminator() != discriminator::executionTiming) {
if (index >= data.size() ||
data.at(index).getDiscriminator() != discriminator::executionTiming) {
return NN_ERROR() << "FMQ Result packet ill-formed";
}

// unpackage execution timing
const V1_2::Timing timing = data[index].executionTiming();
const V1_2::Timing timing = data.at(index).executionTiming();
index++;

// validate packet information
Expand Down

0 comments on commit e7355d6

Please sign in to comment.