Skip to content

Commit

Permalink
Merge pull request #1147 from HexToString/fix_grpc_bug
Browse files Browse the repository at this point in the history
fix a grpc bug
  • Loading branch information
TeslaZhao committed Apr 19, 2021
2 parents fbbce8a + c4e5969 commit 49b18a5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
34 changes: 25 additions & 9 deletions python/paddle_serving_server/rpc_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os

import numpy as np
import google.protobuf.text_format

from .proto import general_model_config_pb2 as m_config
Expand All @@ -9,6 +23,7 @@
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc


class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path_list, is_multi_model, endpoints):
Expand All @@ -31,7 +46,7 @@ def _parse_model_config(self, model_config_path_list):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass

file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
Expand All @@ -57,7 +72,7 @@ def _parse_model_config(self, model_config_path_list):
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)

self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var):
Expand Down Expand Up @@ -86,11 +101,11 @@ def _unpack_inference_request(self, request):
v_type = self.feed_types_[name]
data = None
if is_python:
if v_type == 0:# int64
if v_type == 0: # int64
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:# float32
elif v_type == 1: # float32
data = np.frombuffer(var.data, dtype="float32")
elif v_type == 2:# int32
elif v_type == 2: # int32
data = np.frombuffer(var.data, dtype="int32")
else:
raise Exception("error type.")
Expand All @@ -99,7 +114,7 @@ def _unpack_inference_request(self, request):
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2:# int32
elif v_type == 2: # int32
data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
Expand Down Expand Up @@ -155,7 +170,8 @@ def SetTimeout(self, request, context):
# This porcess and Inference process cannot be operate at the same time.
# For performance reasons, do not add thread lock temporarily.
timeout_ms = request.timeout_ms
self._init_bclient(self.model_config_path_list, self.endpoints_, timeout_ms)
self._init_bclient(self.model_config_path_list, self.endpoints_,
timeout_ms)
resp = multi_lang_general_model_service_pb2.SimpleResponse()
resp.err_code = 0
return resp
Expand All @@ -176,4 +192,4 @@ def GetClientConfig(self, request, context):
#dict should be added when graphMaker is used.
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str_list[:] = self.model_config_path_list
return resp
return resp
2 changes: 1 addition & 1 deletion python/paddle_serving_server/web_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def load_model_config(self,

self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
if client_config_path == None:
self.client_config_path = self.server_config_dir_paths
self.client_config_path = file_path_list

def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it")
Expand Down

0 comments on commit 49b18a5

Please sign in to comment.