Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add unit tests
- Loading branch information
Showing
12 changed files
with
473 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[run] | ||
omit = | ||
# omit autogenerated files | ||
*tensorflow_serving_api* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,5 @@ flake8 | |
pytest | ||
pytest-cov | ||
pytest-mock | ||
requests | ||
requests | ||
grpcio-testing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from ie_serving.tensorflow_serving_api import prediction_service_pb2 | ||
from ie_serving.tensorflow_serving_api import predict_pb2 | ||
from ie_serving.server.predict import PredictionServiceServicer | ||
from ie_serving.models.model import Model | ||
from ie_serving.models.ir_engine import IrEngine | ||
from tensorflow.contrib.util import make_tensor_proto | ||
import grpc_testing | ||
import numpy as np | ||
import pytest | ||
|
||
PREDICT_SERVICE = prediction_service_pb2.\ | ||
DESCRIPTOR.services_by_name['PredictionService'] | ||
|
||
|
||
@pytest.fixture | ||
def get_fake_model(): | ||
model_xml = 'model1.xml' | ||
model_bin = 'model1.bin' | ||
exec_net = None | ||
input_key = 'input' | ||
inputs = {input_key: [1, 1]} | ||
outputs = ['test_output'] | ||
engine = IrEngine(model_bin=model_bin, model_xml=model_xml, | ||
exec_net=exec_net, inputs=inputs, outputs=outputs) | ||
new_engines = {1: engine, 2: engine, 3: engine} | ||
new_model = Model(model_name="test", model_directory='fake_path/model/', | ||
available_versions=[1, 2, 3], engines=new_engines) | ||
return new_model | ||
|
||
|
||
@pytest.fixture | ||
def get_grpc_service_for_predict(get_fake_model): | ||
_real_time = grpc_testing.strict_real_time() | ||
servicer = PredictionServiceServicer(models={'test': get_fake_model}) | ||
descriptors_to_servicers = { | ||
PREDICT_SERVICE: servicer | ||
} | ||
_real_time_server = grpc_testing.server_from_dictionary( | ||
descriptors_to_servicers, _real_time) | ||
|
||
return _real_time_server | ||
|
||
|
||
def get_fake_request(model_name, data_shape, input_blob, version=None): | ||
request = predict_pb2.PredictRequest() | ||
request.model_spec.name = model_name | ||
if version is not None: | ||
request.model_spec.version.value = version | ||
data = np.ones(shape=data_shape) | ||
request.inputs[input_blob].CopyFrom( | ||
make_tensor_proto(data, shape=data.shape)) | ||
return request |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from ie_serving.models.ir_engine import IrEngine | ||
|
||
|
||
def test_init_class(): | ||
model_xml = 'model1.xml' | ||
model_bin = 'model1.bin' | ||
exec_net = None | ||
input_key = 'input' | ||
inputs = {input_key: []} | ||
outputs = None | ||
engine = IrEngine(model_bin=model_bin, model_xml=model_xml, | ||
exec_net=exec_net, inputs=inputs, outputs=outputs) | ||
assert model_xml == engine.model_xml | ||
assert model_bin == engine.model_bin | ||
assert exec_net == engine.exec_net | ||
assert input_key == engine.input_blob | ||
assert inputs == engine.inputs | ||
assert outputs == engine.outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import pytest | ||
from ie_serving.models.model import Model | ||
|
||
|
||
def test_model_init(): | ||
new_model = Model(model_name="test", model_directory='fake_path', | ||
available_versions=[1, 2, 3], engines={}) | ||
assert new_model.default_version == 3 | ||
assert new_model.model_name == 'test' | ||
assert new_model.model_directory == 'fake_path' | ||
assert new_model.engines == {} | ||
|
||
|
||
@pytest.mark.parametrize("path, expected_value", [ | ||
('fake_path/model/1', 1), | ||
('fake_path/model/1/test', 0), | ||
('fake_path/model/56', 56) | ||
]) | ||
def test_get_version_number_of_model(path, expected_value): | ||
output = Model.get_model_version_number(version_path=path) | ||
assert output == expected_value | ||
|
||
|
||
@pytest.mark.parametrize("path, model_files, expected_value", [ | ||
('fake_path/model/1', [['model.bin'], ['model.xml']], | ||
('model.xml', 'model.bin')), | ||
('fake_path/model/1', [['model'], ['model.xml']], | ||
(None, None)), | ||
('fake_path/model/1', [['model.bin'], ['model.yml']], | ||
(None, None)) | ||
]) | ||
def test_get_absolute_path_to_model(mocker, path, model_files, | ||
expected_value): | ||
model_mocker = mocker.patch('glob.glob') | ||
model_mocker.side_effect = model_files | ||
output1, output2 = Model.get_absolute_path_to_model( | ||
specific_version_model_path=path) | ||
assert expected_value[0] == output1 | ||
assert expected_value[1] == output2 | ||
|
||
|
||
def test_get_all_available_versions(mocker): | ||
new_model = Model(model_name="test", model_directory='fake_path/model/', | ||
available_versions=[1, 2, 3], engines={}) | ||
model_mocker = mocker.patch('glob.glob') | ||
models_path = [new_model.model_directory + str(x) for x in range(5)] | ||
model_mocker.return_value = models_path | ||
absolute_path_model_mocker = mocker.patch('ie_serving.models.model.Model.' | ||
'get_absolute_path_to_model') | ||
absolute_path_model_mocker.side_effect = [(None, None), | ||
('modelv2.xml', 'modelv2.bin'), | ||
(None, None), | ||
('modelv4.xml', 'modelv4.bin')] | ||
output = new_model.get_all_available_versions(new_model.model_directory) | ||
expected_output = [{'xml_model_path': 'modelv2.xml', | ||
'bin_model_path': 'modelv2.bin', 'version': 2}, | ||
{'xml_model_path': 'modelv4.xml', 'bin_model_path': | ||
'modelv4.bin', 'version': 4}] | ||
|
||
assert 2 == len(output) | ||
assert expected_output == output | ||
|
||
|
||
def test_get_engines_for_model(mocker): | ||
engines_mocker = mocker.patch('ie_serving.models.ir_engine.IrEngine.' | ||
'build') | ||
engines_mocker.side_effect = ['modelv2', 'modelv4'] | ||
available_versions = [{'xml_model_path': 'modelv2.xml', | ||
'bin_model_path': 'modelv2.bin', 'version': 2}, | ||
{'xml_model_path': 'modelv4.xml', 'bin_model_path': | ||
'modelv4.bin', 'version': 4}] | ||
output = Model.get_engines_for_model(versions=available_versions) | ||
assert 2 == len(output) | ||
assert 'modelv2' == output[2] | ||
assert 'modelv4' == output[4] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import grpc | ||
import numpy as np | ||
from tensorflow.contrib.util import make_ndarray | ||
from conftest import get_fake_request, PREDICT_SERVICE | ||
|
||
|
||
def test_predict_successful(mocker, get_grpc_service_for_predict, | ||
get_fake_model): | ||
infer_mocker = mocker.patch('ie_serving.models.ir_engine.IrEngine.infer') | ||
expected_response = np.ones(shape=(2, 2)) | ||
infer_mocker.return_value = {'test_output': expected_response} | ||
|
||
request = get_fake_request(model_name='test', | ||
data_shape=(1, 1), input_blob='input') | ||
grpc_server = get_grpc_service_for_predict | ||
rpc = grpc_server.invoke_unary_unary( | ||
PREDICT_SERVICE.methods_by_name['Predict'], | ||
(), | ||
request, None) | ||
rpc.initial_metadata() | ||
response, trailing_metadata, code, details = rpc.termination() | ||
|
||
encoded_response = make_ndarray(response.outputs['test_output']) | ||
assert get_fake_model.default_version == response.model_spec.version.value | ||
assert grpc.StatusCode.OK == code | ||
assert expected_response.shape == encoded_response.shape | ||
|
||
|
||
def test_predict_successful_version(mocker, get_grpc_service_for_predict): | ||
infer_mocker = mocker.patch('ie_serving.models.ir_engine.IrEngine.infer') | ||
expected_response = np.ones(shape=(2, 2)) | ||
infer_mocker.return_value = {'test_output': expected_response} | ||
requested_version = 1 | ||
request = get_fake_request(model_name='test', data_shape=(1, 1), | ||
input_blob='input', version=requested_version) | ||
grpc_server = get_grpc_service_for_predict | ||
rpc = grpc_server.invoke_unary_unary( | ||
PREDICT_SERVICE.methods_by_name['Predict'], | ||
(), | ||
request, None) | ||
rpc.initial_metadata() | ||
response, trailing_metadata, code, details = rpc.termination() | ||
|
||
encoded_response = make_ndarray(response.outputs['test_output']) | ||
assert requested_version == response.model_spec.version.value | ||
assert grpc.StatusCode.OK == code | ||
assert expected_response.shape == encoded_response.shape | ||
|
||
|
||
def test_predict_wrong_model_name(get_grpc_service_for_predict): | ||
wrong_model_name = 'wrong_name' | ||
request = get_fake_request(model_name=wrong_model_name, data_shape=(1, 1), | ||
input_blob='input') | ||
grpc_server = get_grpc_service_for_predict | ||
rpc = grpc_server.invoke_unary_unary( | ||
PREDICT_SERVICE.methods_by_name['Predict'], | ||
(), | ||
request, None) | ||
rpc.initial_metadata() | ||
response, trailing_metadata, code, details = rpc.termination() | ||
assert grpc.StatusCode.NOT_FOUND == code | ||
|
||
|
||
def test_predict_wrong_model_version(get_grpc_service_for_predict): | ||
wrong_requested_version = 999 | ||
request = get_fake_request(model_name='test', data_shape=(1, 1), | ||
input_blob='input', | ||
version=wrong_requested_version) | ||
grpc_server = get_grpc_service_for_predict | ||
rpc = grpc_server.invoke_unary_unary( | ||
PREDICT_SERVICE.methods_by_name['Predict'], | ||
(), | ||
request, None) | ||
rpc.initial_metadata() | ||
response, trailing_metadata, code, details = rpc.termination() | ||
assert grpc.StatusCode.NOT_FOUND == code | ||
|
||
|
||
def test_predict_wrong_shape(get_grpc_service_for_predict): | ||
wrong_shape = (4, 4) | ||
request = get_fake_request(model_name='test', data_shape=wrong_shape, | ||
input_blob='input') | ||
grpc_server = get_grpc_service_for_predict | ||
rpc = grpc_server.invoke_unary_unary( | ||
PREDICT_SERVICE.methods_by_name['Predict'], | ||
(), | ||
request, None) | ||
rpc.initial_metadata() | ||
response, trailing_metadata, code, details = rpc.termination() | ||
assert grpc.StatusCode.INVALID_ARGUMENT == code | ||
|
||
|
||
def test_predict_wrong_input_blob(get_grpc_service_for_predict): | ||
wrong_input_blob = 'wrong_input_blob' | ||
request = get_fake_request(model_name='test', data_shape=(1, 1), | ||
input_blob=wrong_input_blob) | ||
grpc_server = get_grpc_service_for_predict | ||
rpc = grpc_server.invoke_unary_unary( | ||
PREDICT_SERVICE.methods_by_name['Predict'], | ||
(), | ||
request, None) | ||
rpc.initial_metadata() | ||
response, trailing_metadata, code, details = rpc.termination() | ||
assert grpc.StatusCode.INVALID_ARGUMENT == code |
Oops, something went wrong.