diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 8d8f8083ef..8eaead9d7d 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -26,25 +26,8 @@ from sagemaker.tensorflow.serving import Model, Predictor -@pytest.fixture( - scope="session", - params=[ - "ml.c5.xlarge", - pytest.param( - "ml.p3.2xlarge", - marks=pytest.mark.skipif( - tests.integ.test_region() in tests.integ.HOSTING_NO_P3_REGIONS, - reason="no ml.p3 instances in this region", - ), - ), - ], -) -def instance_type(request): - return request.param - - @pytest.fixture(scope="module") -def tfs_predictor(instance_type, sagemaker_session, tf_full_version): +def tfs_predictor(sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") model_data = sagemaker_session.upload_data( path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"), @@ -57,7 +40,7 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version): framework_version=tf_full_version, sagemaker_session=sagemaker_session, ) - predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) + predictor = model.deploy(1, "ml.c5.xlarge", endpoint_name=endpoint_name) yield predictor @@ -130,8 +113,6 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies( @pytest.fixture(scope="module") def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") - instance_type = "ml.c4.large" - accelerator_type = "ml.eia1.medium" model_data = sagemaker_session.upload_data( path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"), key_prefix="tensorflow-serving/models", @@ -144,13 +125,13 @@ def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version): sagemaker_session=sagemaker_session, ) predictor = model.deploy( - 1, instance_type, endpoint_name=endpoint_name, accelerator_type=accelerator_type + 1, "ml.c4.large", endpoint_name=endpoint_name, accelerator_type="ml.eia1.medium" ) yield predictor @pytest.mark.canary_quick -def test_predict(tfs_predictor, instance_type): # pylint: disable=W0613 +def test_predict(tfs_predictor): # pylint: disable=W0613 input_data = {"instances": [1.0, 2.0, 5.0]} expected_result = {"predictions": [3.5, 4.0, 5.5]}