Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions serving/processor/serving/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ tf_cc_shared_object(
srcs = ["processor.cc",
"processor.h",],
deps = [
"model_serving"],
"model_serving",
],
)

cc_library(
Expand Down Expand Up @@ -130,6 +131,23 @@ cc_test(
"@com_google_googletest//:gtest_main",],
)

cc_library(
name = "message_coding",
srcs = ["message_coding.cc",],
hdrs = ["message_coding.h",],
deps = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//serving/processor/framework:model_version",
"//serving/processor/storage:model_store",
"model_message",
"predict_proto_cc",
"utils",
],
)

cc_library(
name = "model_instance",
srcs = ["model_instance.cc",],
Expand All @@ -144,6 +162,7 @@ cc_library(
"//serving/processor/framework:graph_optimizer",
"//serving/processor/framework:model_version",
"//serving/processor/storage:model_store",
":message_coding",
"model_config",
"model_partition",
"model_session",
Expand All @@ -156,10 +175,8 @@ cc_library(
cc_library(
name = "model_serving",
srcs = ["model_serving.cc",
"message_coding.cc",
"model_impl.cc",],
hdrs = ["model_serving.h",
"message_coding.h",
"model_impl.h",],
deps = [
"//tensorflow/core:protos_all_cc",
Expand All @@ -168,6 +185,7 @@ cc_library(
"//tensorflow/core:lib",
"//serving/processor/framework:model_version",
"//serving/processor/storage:model_store",
":message_coding",
"model_config",
"model_session",
"model_message",
Expand Down
28 changes: 17 additions & 11 deletions serving/processor/serving/message_coding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,9 @@ ProtoBufParser::ProtoBufParser(int thread_num) {
thread_num));
}

Status ProtoBufParser::ParseRequestFromBuf(
const void* input_data, int input_size, Call& call,
const SignatureInfo* signature_info) {
eas::PredictRequest request;
bool success = request.ParseFromArray(input_data, input_size);
if (!success) {
LOG(ERROR) << "Parse request from array failed, input_data: " << input_data
<< ", input_size: " << input_size;
return Status(errors::Code::INVALID_ARGUMENT, "Please check the input data.");
}

Status ProtoBufParser::ParseRequest(
const eas::PredictRequest& request,
const SignatureInfo* signature_info, Call& call) {
for (auto& input : request.inputs()) {
if (signature_info->input_key_idx.find(input.first) ==
signature_info->input_key_idx.end()) {
Expand Down Expand Up @@ -49,6 +41,20 @@ Status ProtoBufParser::ParseRequestFromBuf(
return Status::OK();
}

Status ProtoBufParser::ParseRequestFromBuf(
const void* input_data, int input_size, Call& call,
const SignatureInfo* signature_info) {
eas::PredictRequest request;
bool success = request.ParseFromArray(input_data, input_size);
if (!success) {
LOG(ERROR) << "Parse request from array failed, input_data: " << input_data
<< ", input_size: " << input_size;
return Status(errors::Code::INVALID_ARGUMENT, "Please check the input data.");
}

return ParseRequest(request, signature_info, call);
}

Status ProtoBufParser::ParseResponseToBuf(
const Call& call, void** output_data, int* output_size,
const SignatureInfo* signature_info) {
Expand Down
15 changes: 15 additions & 0 deletions serving/processor/serving/message_coding.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "serving/processor/serving/predict.pb.h"

namespace tensorflow {
namespace processor {
Expand All @@ -19,6 +20,10 @@ class IParser {
const void* input_data, int input_size, Call& call,
const SignatureInfo* info) = 0;

virtual Status ParseRequest(
const eas::PredictRequest& request,
const SignatureInfo* signature_info, Call& call) = 0;

virtual Status ParseResponseToBuf(
const Call& call, void** output_data,
int* output_size, const SignatureInfo* info) = 0;
Expand Down Expand Up @@ -52,6 +57,10 @@ class ProtoBufParser : public IParser {
const void* input_data, int input_size,
Call& call, const SignatureInfo* info) override;

Status ParseRequest(
const eas::PredictRequest& request,
const SignatureInfo* signature_info, Call& call) override;

Status ParseResponseToBuf(
const Call& call, void** output_data,
int* output_size, const SignatureInfo* info) override;
Expand Down Expand Up @@ -83,6 +92,12 @@ class FlatBufferParser : public IParser {
return Status::OK();
}

Status ParseRequest(
const eas::PredictRequest& request,
const SignatureInfo* signature_info, Call& call) override {
return Status::OK();
}

Status ParseResponseToBuf(
const Call& call, void** output_data,
int* output_size, const SignatureInfo* info) override {
Expand Down
107 changes: 67 additions & 40 deletions serving/processor/serving/model_instance.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <fstream>
#include "serving/processor/serving/message_coding.h"
#include "serving/processor/serving/model_instance.h"
#include "serving/processor/serving/model_partition.h"
#include "serving/processor/serving/model_session.h"
Expand All @@ -24,6 +25,7 @@ namespace processor {
namespace {
constexpr int _60_Seconds = 60;
constexpr int MAX_TRY_COUNT = 10;
constexpr int WARMUP_COUNT = 5;

Tensor CreateTensor(const TensorInfo& tensor_info) {
auto real_ts = tensor_info.tensor_shape();
Expand Down Expand Up @@ -71,44 +73,35 @@ Tensor CreateTensor(const TensorInfo& tensor_info) {
return tensor;
}

Call CreateWarmupParams(SignatureDef& sig_def) {
Call call;
Status CreateWarmupParams(SignatureDef& sig_def, Call* call) {
for (auto it : sig_def.inputs()) {
const auto& tensor = CreateTensor(it.second);
call.request.inputs.emplace_back(it.second.name(), tensor);
call->request.inputs.emplace_back(it.second.name(), tensor);
}

for (auto it : sig_def.outputs()) {
call.request.output_tensor_names.emplace_back(it.second.name());
call->request.output_tensor_names.emplace_back(it.second.name());
}

return call;
return Status::OK();
}

Call CreateWarmupParams(SignatureDef& sig_def,
const std::string& warmup_file_name) {
Status CreateWarmupParams(SignatureDef& sig_def,
const std::string& warmup_file_name,
Call* call, IParser* parser,
const SignatureInfo& signature_info) {
// Parse warmup file
eas::PredictRequest request;
std::fstream input(warmup_file_name, std::ios::in | std::ios::binary);
request.ParseFromIstream(&input);
input.close();

Call call;
for (auto& input : request.inputs()) {
call.request.inputs.emplace_back(input.first,
util::Proto2Tensor(input.second));
}

call.request.output_tensor_names =
std::vector<std::string>(request.output_filter().begin(),
request.output_filter().end());

// User need to set fetches
if (call.request.output_tensor_names.size() == 0) {
LOG(FATAL) << "warmup file must be contain fetches.";
bool success = request.ParseFromIstream(&input);
if (!success) {
LOG(ERROR) << "Read warmp file failed: " << warmup_file_name;
return Status(error::Code::INTERNAL,
"Read warmp file failed, please check warmp file path");
}
input.close();

return call;
return parser->ParseRequest(request, &signature_info, *call);
}

bool ShouldWarmup(SignatureDef& sig_def) {
Expand Down Expand Up @@ -264,6 +257,7 @@ Status LocalSessionInstance::Init(ModelConfig* config,
{kSavedModelTagServe}, &meta_graph_def_));

warmup_file_name_ = config->warmup_file_name;
parser_ = ParserFactory::GetInstance(config->serialize_protocol, 4);

GraphOptimizerOption option;
option.native_tf_mode = true;
Expand Down Expand Up @@ -356,21 +350,38 @@ Status LocalSessionInstance::Warmup(
return Status::OK();
}

LOG(INFO) << "Try to warmup model: " << warmup_file_name_;
Status s;
Call call;
if (warmup_file_name_.empty()) {
call = CreateWarmupParams(model_signature_.second);
s = CreateWarmupParams(model_signature_.second, &call);
} else {
call = CreateWarmupParams(model_signature_.second,
warmup_file_name_);
s = CreateWarmupParams(model_signature_.second,
warmup_file_name_, &call,
parser_, signature_info_);
}
if (!s.ok()) {
LOG(ERROR) << "Create warmup params failed, warmup will be canceled.";
return s;
}

if (warmup_session) {
return warmup_session->LocalPredict(
call.request, call.response);
int left_try_count = WARMUP_COUNT;
while (left_try_count > 0) {
if (warmup_session) {
s = warmup_session->LocalPredict(
call.request, call.response);
} else {
s = session_mgr_->LocalPredict(
call.request, call.response);
}
if (!s.ok()) return s;

--left_try_count;
call.response.outputs.clear();
}
LOG(INFO) << "Warmup model successful: " << warmup_file_name_;

return session_mgr_->LocalPredict(
call.request, call.response);
return Status::OK();
}

std::string LocalSessionInstance::DebugString() {
Expand Down Expand Up @@ -482,6 +493,7 @@ Status RemoteSessionInstance::Init(ModelConfig* model_config,
backup_storage_ = new FeatureStoreMgr(&backup_model_config);

warmup_file_name_ = model_config->warmup_file_name;
parser_ = ParserFactory::GetInstance(model_config->serialize_protocol, 4);

// set active flag
serving_storage_->SetStorageActiveStatus(active);
Expand Down Expand Up @@ -542,21 +554,36 @@ Status RemoteSessionInstance::Warmup(
return Status::OK();
}

Status s;
Call call;
if (warmup_file_name_.empty()) {
call = CreateWarmupParams(model_signature_.second);
s = CreateWarmupParams(model_signature_.second, &call);
} else {
call = CreateWarmupParams(model_signature_.second,
warmup_file_name_);
s = CreateWarmupParams(model_signature_.second,
warmup_file_name_, &call,
parser_, signature_info_);
}
if (!s.ok()) {
LOG(ERROR) << "Create warmup params failed, warmup will be canceled.";
return s;
}

if (warmup_session) {
return warmup_session->Predict(
call.request, call.response);
int left_try_count = WARMUP_COUNT;
while (left_try_count > 0) {
if (warmup_session) {
s = warmup_session->LocalPredict(
call.request, call.response);
} else {
s = session_mgr_->LocalPredict(
call.request, call.response);
}
if (!s.ok()) return s;

--left_try_count;
call.response.outputs.clear();
}

return session_mgr_->Predict(
call.request, call.response);
return Status::OK();
}

Status RemoteSessionInstance::FullModelUpdate(
Expand Down
3 changes: 3 additions & 0 deletions serving/processor/serving/model_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ModelStore;
class ModelSession;
class ModelSessionMgr;
class IFeatureStoreMgr;
class IParser;

class LocalSessionInstance {
public:
Expand Down Expand Up @@ -60,6 +61,7 @@ class LocalSessionInstance {
SignatureInfo signature_info_;

std::string warmup_file_name_;
IParser* parser_ = nullptr;

ModelSessionMgr* session_mgr_ = nullptr;
SessionOptions* session_options_ = nullptr;
Expand Down Expand Up @@ -109,6 +111,7 @@ class RemoteSessionInstance {
SignatureInfo signature_info_;

std::string warmup_file_name_;
IParser* parser_ = nullptr;

ModelSessionMgr* session_mgr_ = nullptr;
SessionOptions* session_options_ = nullptr;
Expand Down
3 changes: 1 addition & 2 deletions serving/processor/serving/model_serving.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ Status Model::Init(const char* model_config) {
}

if (!config->warmup_file_name.empty()) {
config->warmup_file_name =
model_entry_ + config->warmup_file_name;
LOG(INFO) << "User set warmup file: " << config->warmup_file_name;
}

parser_ = ParserFactory::GetInstance(config->serialize_protocol,
Expand Down