diff --git a/serving/processor/serving/BUILD b/serving/processor/serving/BUILD index 6226fd1acbc..9a7b69e1e57 100644 --- a/serving/processor/serving/BUILD +++ b/serving/processor/serving/BUILD @@ -27,7 +27,8 @@ tf_cc_shared_object( srcs = ["processor.cc", "processor.h",], deps = [ - "model_serving"], + "model_serving", + ], ) cc_library( @@ -130,6 +131,23 @@ cc_test( "@com_google_googletest//:gtest_main",], ) +cc_library( + name = "message_coding", + srcs = ["message_coding.cc",], + hdrs = ["message_coding.h",], + deps = [ + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//serving/processor/framework:model_version", + "//serving/processor/storage:model_store", + "model_message", + "predict_proto_cc", + "utils", + ], +) + cc_library( name = "model_instance", srcs = ["model_instance.cc",], @@ -144,6 +162,7 @@ cc_library( "//serving/processor/framework:graph_optimizer", "//serving/processor/framework:model_version", "//serving/processor/storage:model_store", + ":message_coding", "model_config", "model_partition", "model_session", @@ -156,10 +175,8 @@ cc_library( cc_library( name = "model_serving", srcs = ["model_serving.cc", - "message_coding.cc", "model_impl.cc",], hdrs = ["model_serving.h", - "message_coding.h", "model_impl.h",], deps = [ "//tensorflow/core:protos_all_cc", @@ -168,6 +185,7 @@ cc_library( "//tensorflow/core:lib", "//serving/processor/framework:model_version", "//serving/processor/storage:model_store", + ":message_coding", "model_config", "model_session", "model_message", diff --git a/serving/processor/serving/message_coding.cc b/serving/processor/serving/message_coding.cc index aa3b18926f5..2073904897d 100644 --- a/serving/processor/serving/message_coding.cc +++ b/serving/processor/serving/message_coding.cc @@ -9,17 +9,9 @@ ProtoBufParser::ProtoBufParser(int thread_num) { thread_num)); } -Status ProtoBufParser::ParseRequestFromBuf( - const void* input_data, int input_size, Call& call, - const SignatureInfo* signature_info) { - eas::PredictRequest request; - bool success = request.ParseFromArray(input_data, input_size); - if (!success) { - LOG(ERROR) << "Parse request from array failed, input_data: " << input_data - << ", input_size: " << input_size; - return Status(errors::Code::INVALID_ARGUMENT, "Please check the input data."); - } - +Status ProtoBufParser::ParseRequest( + const eas::PredictRequest& request, + const SignatureInfo* signature_info, Call& call) { for (auto& input : request.inputs()) { if (signature_info->input_key_idx.find(input.first) == signature_info->input_key_idx.end()) { @@ -49,6 +41,20 @@ Status ProtoBufParser::ParseRequestFromBuf( return Status::OK(); } +Status ProtoBufParser::ParseRequestFromBuf( + const void* input_data, int input_size, Call& call, + const SignatureInfo* signature_info) { + eas::PredictRequest request; + bool success = request.ParseFromArray(input_data, input_size); + if (!success) { + LOG(ERROR) << "Parse request from array failed, input_data: " << input_data + << ", input_size: " << input_size; + return Status(errors::Code::INVALID_ARGUMENT, "Please check the input data."); + } + + return ParseRequest(request, signature_info, call); +} + Status ProtoBufParser::ParseResponseToBuf( const Call& call, void** output_data, int* output_size, const SignatureInfo* signature_info) { diff --git a/serving/processor/serving/message_coding.h b/serving/processor/serving/message_coding.h index d021c69c794..d5e4b7c6ffe 100644 --- a/serving/processor/serving/message_coding.h +++ b/serving/processor/serving/message_coding.h @@ -3,6 +3,7 @@ #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "serving/processor/serving/predict.pb.h" namespace tensorflow { namespace processor { @@ -19,6 +20,10 @@ class IParser { const void* input_data, int input_size, Call& call, const SignatureInfo* info) = 0; + virtual Status ParseRequest( + const eas::PredictRequest& request, + const SignatureInfo* signature_info, Call& call) = 0; + virtual Status ParseResponseToBuf( const Call& call, void** output_data, int* output_size, const SignatureInfo* info) = 0; @@ -52,6 +57,10 @@ class ProtoBufParser : public IParser { const void* input_data, int input_size, Call& call, const SignatureInfo* info) override; + Status ParseRequest( + const eas::PredictRequest& request, + const SignatureInfo* signature_info, Call& call) override; + Status ParseResponseToBuf( const Call& call, void** output_data, int* output_size, const SignatureInfo* info) override; @@ -83,6 +92,12 @@ class FlatBufferParser : public IParser { return Status::OK(); } + Status ParseRequest( + const eas::PredictRequest& request, + const SignatureInfo* signature_info, Call& call) override { + return Status::OK(); + } + Status ParseResponseToBuf( const Call& call, void** output_data, int* output_size, const SignatureInfo* info) override { diff --git a/serving/processor/serving/model_instance.cc b/serving/processor/serving/model_instance.cc index 6da859c3c4a..146b7b0bc82 100644 --- a/serving/processor/serving/model_instance.cc +++ b/serving/processor/serving/model_instance.cc @@ -1,4 +1,5 @@ #include +#include "serving/processor/serving/message_coding.h" #include "serving/processor/serving/model_instance.h" #include "serving/processor/serving/model_partition.h" #include "serving/processor/serving/model_session.h" @@ -24,6 +25,7 @@ namespace processor { namespace { constexpr int _60_Seconds = 60; constexpr int MAX_TRY_COUNT = 10; +constexpr int WARMUP_COUNT = 5; Tensor CreateTensor(const TensorInfo& tensor_info) { auto real_ts = tensor_info.tensor_shape(); @@ -71,44 +73,35 @@ Tensor CreateTensor(const TensorInfo& tensor_info) { return tensor; } -Call CreateWarmupParams(SignatureDef& sig_def) { - Call call; +Status CreateWarmupParams(SignatureDef& sig_def, Call* call) { for (auto it : sig_def.inputs()) { const auto& tensor = CreateTensor(it.second); - call.request.inputs.emplace_back(it.second.name(), tensor); + call->request.inputs.emplace_back(it.second.name(), tensor); } for (auto it : sig_def.outputs()) { - call.request.output_tensor_names.emplace_back(it.second.name()); + call->request.output_tensor_names.emplace_back(it.second.name()); } - return call; + return Status::OK(); } -Call CreateWarmupParams(SignatureDef& sig_def, - const std::string& warmup_file_name) { +Status CreateWarmupParams(SignatureDef& sig_def, + const std::string& warmup_file_name, + Call* call, IParser* parser, + const SignatureInfo& signature_info) { // Parse warmup file eas::PredictRequest request; std::fstream input(warmup_file_name, std::ios::in | std::ios::binary); - request.ParseFromIstream(&input); - input.close(); - - Call call; - for (auto& input : request.inputs()) { - call.request.inputs.emplace_back(input.first, - util::Proto2Tensor(input.second)); - } - - call.request.output_tensor_names = - std::vector(request.output_filter().begin(), - request.output_filter().end()); - - // User need to set fetches - if (call.request.output_tensor_names.size() == 0) { - LOG(FATAL) << "warmup file must be contain fetches."; + bool success = request.ParseFromIstream(&input); + if (!success) { + LOG(ERROR) << "Read warmp file failed: " << warmup_file_name; + return Status(error::Code::INTERNAL, + "Read warmp file failed, please check warmp file path"); } + input.close(); - return call; + return parser->ParseRequest(request, &signature_info, *call); } bool ShouldWarmup(SignatureDef& sig_def) { @@ -264,6 +257,7 @@ Status LocalSessionInstance::Init(ModelConfig* config, {kSavedModelTagServe}, &meta_graph_def_)); warmup_file_name_ = config->warmup_file_name; + parser_ = ParserFactory::GetInstance(config->serialize_protocol, 4); GraphOptimizerOption option; option.native_tf_mode = true; @@ -356,21 +350,38 @@ Status LocalSessionInstance::Warmup( return Status::OK(); } + LOG(INFO) << "Try to warmup model: " << warmup_file_name_; + Status s; Call call; if (warmup_file_name_.empty()) { - call = CreateWarmupParams(model_signature_.second); + s = CreateWarmupParams(model_signature_.second, &call); } else { - call = CreateWarmupParams(model_signature_.second, - warmup_file_name_); + s = CreateWarmupParams(model_signature_.second, + warmup_file_name_, &call, + parser_, signature_info_); + } + if (!s.ok()) { + LOG(ERROR) << "Create warmup params failed, warmup will be canceled."; + return s; } - if (warmup_session) { - return warmup_session->LocalPredict( - call.request, call.response); + int left_try_count = WARMUP_COUNT; + while (left_try_count > 0) { + if (warmup_session) { + s = warmup_session->LocalPredict( + call.request, call.response); + } else { + s = session_mgr_->LocalPredict( + call.request, call.response); + } + if (!s.ok()) return s; + + --left_try_count; + call.response.outputs.clear(); } + LOG(INFO) << "Warmup model successful: " << warmup_file_name_; - return session_mgr_->LocalPredict( - call.request, call.response); + return Status::OK(); } std::string LocalSessionInstance::DebugString() { @@ -482,6 +493,7 @@ Status RemoteSessionInstance::Init(ModelConfig* model_config, backup_storage_ = new FeatureStoreMgr(&backup_model_config); warmup_file_name_ = model_config->warmup_file_name; + parser_ = ParserFactory::GetInstance(model_config->serialize_protocol, 4); // set active flag serving_storage_->SetStorageActiveStatus(active); @@ -542,21 +554,36 @@ Status RemoteSessionInstance::Warmup( return Status::OK(); } + Status s; Call call; if (warmup_file_name_.empty()) { - call = CreateWarmupParams(model_signature_.second); + s = CreateWarmupParams(model_signature_.second, &call); } else { - call = CreateWarmupParams(model_signature_.second, - warmup_file_name_); + s = CreateWarmupParams(model_signature_.second, + warmup_file_name_, &call, + parser_, signature_info_); + } + if (!s.ok()) { + LOG(ERROR) << "Create warmup params failed, warmup will be canceled."; + return s; } - if (warmup_session) { - return warmup_session->Predict( - call.request, call.response); + int left_try_count = WARMUP_COUNT; + while (left_try_count > 0) { + if (warmup_session) { + s = warmup_session->LocalPredict( + call.request, call.response); + } else { + s = session_mgr_->LocalPredict( + call.request, call.response); + } + if (!s.ok()) return s; + + --left_try_count; + call.response.outputs.clear(); } - return session_mgr_->Predict( - call.request, call.response); + return Status::OK(); } Status RemoteSessionInstance::FullModelUpdate( diff --git a/serving/processor/serving/model_instance.h b/serving/processor/serving/model_instance.h index 05780154067..428b50e43a9 100644 --- a/serving/processor/serving/model_instance.h +++ b/serving/processor/serving/model_instance.h @@ -22,6 +22,7 @@ class ModelStore; class ModelSession; class ModelSessionMgr; class IFeatureStoreMgr; +class IParser; class LocalSessionInstance { public: @@ -60,6 +61,7 @@ class LocalSessionInstance { SignatureInfo signature_info_; std::string warmup_file_name_; + IParser* parser_ = nullptr; ModelSessionMgr* session_mgr_ = nullptr; SessionOptions* session_options_ = nullptr; @@ -109,6 +111,7 @@ class RemoteSessionInstance { SignatureInfo signature_info_; std::string warmup_file_name_; + IParser* parser_ = nullptr; ModelSessionMgr* session_mgr_ = nullptr; SessionOptions* session_options_ = nullptr; diff --git a/serving/processor/serving/model_serving.cc b/serving/processor/serving/model_serving.cc index 1675819a2c0..3b653e69fca 100644 --- a/serving/processor/serving/model_serving.cc +++ b/serving/processor/serving/model_serving.cc @@ -24,8 +24,7 @@ Status Model::Init(const char* model_config) { } if (!config->warmup_file_name.empty()) { - config->warmup_file_name = - model_entry_ + config->warmup_file_name; + LOG(INFO) << "User set warmup file: " << config->warmup_file_name; } parser_ = ParserFactory::GetInstance(config->serialize_protocol,