Skip to content

Commit

Permalink
Check model signature (#176)
Browse files Browse the repository at this point in the history
Fixes #136.

Signed-off-by: Zhao Lufan <zhao.lufan30@zte.com.cn>

Co-authored-by: Zhao Lufan <zhao.lufan30@zte.com.cn>
  • Loading branch information
EFanZh and Zhao Lufan committed Jun 4, 2020
1 parent 26600b0 commit 9e01655
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ using adlik::serving::Batch;
using adlik::serving::BatchingMessageTask;
using adlik::serving::getTFDataTypeName;
using adlik::serving::InputContext;
using adlik::serving::ModelConfigProto;
using adlik::serving::ModelInput;
using adlik::serving::ModelOutput;
using adlik::serving::OutputContext;
using adlik::serving::PredictRequestProvider;
using adlik::serving::TensorShapeDims;
using adlik::serving::tfLiteTypeToTfType;
using google::protobuf::RepeatedPtrField;
using std::make_unique;
using std::shared_ptr;
using std::string;
Expand Down Expand Up @@ -108,6 +112,28 @@ variant<tuple<InputSignature, size_t>, Status> getSignature(const Interpreter& i
return make_tuple(std::move(result), batchSize);
}

template <class T>
Status checkSignature(const InputSignature& modelSignature, const RepeatedPtrField<T>& declaredSignature) {
InputSignature signature2;

for (const auto& item : declaredSignature) {
signature2.emplace(
std::piecewise_construct,
std::forward_as_tuple(item.name()),
std::forward_as_tuple(item.data_type(), TensorShapeDims::owned(item.dims().begin(), item.dims().end())));
}

if (modelSignature == signature2) {
return Status::OK();
} else {
return Internal("Model signature does not match declared one. Model signature: ",
displaySignature(modelSignature),
". Declared signature: ",
displaySignature(modelSignature),
".");
}
}

StringViewMap<InputContext> getInputContextMap(const Interpreter& interpreter) {
StringViewMap<InputContext> result;

Expand Down Expand Up @@ -260,7 +286,8 @@ Status TensorFlowLiteBatchProcessor::processBatch(Batch<BatchingMessageTask>& ba

variant<unique_ptr<TensorFlowLiteBatchProcessor>, Status> TensorFlowLiteBatchProcessor::create(
shared_ptr<FlatBufferModel> model,
const OpResolver& opResolver) {
const OpResolver& opResolver,
const ModelConfigProto& modelConfigProto) {
unique_ptr<Interpreter> interpreter;

if (InterpreterBuilder{*model, opResolver}(&interpreter, 1) != TfLiteStatus::kTfLiteOk) {
Expand All @@ -271,23 +298,42 @@ variant<unique_ptr<TensorFlowLiteBatchProcessor>, Status> TensorFlowLiteBatchPro
return Internal("Unable to allocate tensors");
}

auto maybeSignature = getSignature(*interpreter, interpreter->inputs());
// Check input signature.

if (absl::holds_alternative<tuple<InputSignature, size_t>>(maybeSignature)) {
auto signature = std::move(absl::get<tuple<InputSignature, size_t>>(maybeSignature));
auto inputContextMap = getInputContextMap(*interpreter);
auto outputContexts = getOutputContexts(*interpreter);
auto maybeInputSignature = getSignature(*interpreter, interpreter->inputs());

return make_unique<TensorFlowLiteBatchProcessor>(ConstructCredential{},
std::move(model),
std::move(interpreter),
std::move(std::get<InputSignature>(signature)),
std::move(std::get<size_t>(signature)),
std::move(inputContextMap),
std::move(outputContexts));
} else {
return std::move(absl::get<Status>(maybeSignature));
if (!absl::holds_alternative<tuple<InputSignature, size_t>>(maybeInputSignature)) {
return std::move(absl::get<Status>(maybeInputSignature));
}

auto inputSignature = std::move(absl::get<tuple<InputSignature, size_t>>(maybeInputSignature));

TF_RETURN_IF_ERROR(checkSignature(std::get<InputSignature>(inputSignature), modelConfigProto.input()));

// Check output signature.

auto maybeOutputSignature = getSignature(*interpreter, interpreter->outputs());

if (!absl::holds_alternative<tuple<InputSignature, size_t>>(maybeOutputSignature)) {
return std::move(absl::get<Status>(maybeOutputSignature));
}

TF_RETURN_IF_ERROR(
checkSignature(std::get<InputSignature>(absl::get<tuple<InputSignature, size_t>>(maybeOutputSignature)),
modelConfigProto.output()));

// Get IO contexts.

auto inputContextMap = getInputContextMap(*interpreter);
auto outputContexts = getOutputContexts(*interpreter);

return make_unique<TensorFlowLiteBatchProcessor>(ConstructCredential{},
std::move(model),
std::move(interpreter),
std::move(std::get<InputSignature>(inputSignature)),
std::move(std::get<size_t>(inputSignature)),
std::move(inputContextMap),
std::move(outputContexts));
}
} // namespace serving
} // namespace adlik
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#define ADLIK_SERVING_RUNTIME_TENSORFLOW_LITE_TENSORFLOW_LITE_BATCH_PROCESSOR_H

#include "absl/hash/hash.h"
#include "adlik_serving/framework/domain/model_config.pb.h"
#include "adlik_serving/runtime/batching/batch_processor.h"
#include "adlik_serving/runtime/tensorflow_lite/input_context.h"
#include "adlik_serving/runtime/tensorflow_lite/output_context.h"
Expand Down Expand Up @@ -44,7 +45,8 @@ class TensorFlowLiteBatchProcessor : public BatchProcessor {

static absl::variant<std::unique_ptr<TensorFlowLiteBatchProcessor>, tensorflow::Status> create(
std::shared_ptr<tflite::FlatBufferModel> model,
const tflite::OpResolver& opResolver);
const tflite::OpResolver& opResolver,
const ModelConfigProto& modelConfigProto);
};
} // namespace serving
} // namespace adlik
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ variant<unique_ptr<TensorFlowLiteModel>, Status> internalCreate(const ModelConfi

for (const auto& instanceGroup : normalizedModelConfig.instance_group()) {
for (int i = 0; i != instanceGroup.count(); ++i) {
auto processor = TensorFlowLiteBatchProcessor::create(flatBufferModel, opResolver);
auto processor = TensorFlowLiteBatchProcessor::create(flatBufferModel, opResolver, modelConfig);

if (absl::holds_alternative<unique_ptr<TensorFlowLiteBatchProcessor>>(processor)) {
result->add(std::move(absl::get<unique_ptr<TensorFlowLiteBatchProcessor>>(processor)));
Expand Down

0 comments on commit 9e01655

Please sign in to comment.