diff --git a/README.rst b/README.rst index ac73a7acfa..5ec1fb350e 100644 --- a/README.rst +++ b/README.rst @@ -201,7 +201,7 @@ In order to host a SparkML model in SageMaker, it should be serialized with ``ML For more information on MLeap, see https://github.com/combust/mleap . -Supported major version of Spark: 2.2 (MLeap version - 0.9.6) +Supported major version of Spark: 2.4 (MLeap version - 0.9.6) Here is an example on how to create an instance of ``SparkMLModel`` class and use ``deploy()`` method to create an endpoint which can be used to perform prediction against your trained SparkML Model. diff --git a/src/sagemaker/image_uri_config/sparkml-serving.json b/src/sagemaker/image_uri_config/sparkml-serving.json index de9d0e0d4c..fd14df3bf9 100644 --- a/src/sagemaker/image_uri_config/sparkml-serving.json +++ b/src/sagemaker/image_uri_config/sparkml-serving.json @@ -30,5 +30,34 @@ }, "repository": "sagemaker-sparkml-serving" } - } + }, + "2.4": { + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-south-1": "720646828776", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ca-central-1": "341280168497", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-north-1": "662702820516", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "eu-south-1": "978288397137", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "repository": "sagemaker-sparkml-serving" + } } diff --git a/src/sagemaker/sparkml/model.py b/src/sagemaker/sparkml/model.py index a2b91ff6a8..32f49da16e 100644 --- a/src/sagemaker/sparkml/model.py +++ b/src/sagemaker/sparkml/model.py @@ -59,7 +59,7 @@ class SparkMLModel(Model): model . """ - def __init__(self, model_data, role=None, spark_version=2.2, sagemaker_session=None, **kwargs): + def __init__(self, model_data, role=None, spark_version=2.4, sagemaker_session=None, **kwargs): """Initialize a SparkMLModel. Args: @@ -73,7 +73,7 @@ def __init__(self, model_data, role=None, spark_version=2.2, sagemaker_session=N artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. spark_version (str): Spark version you want to use for executing the - inference (default: '2.2'). + inference (default: '2.4'). sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one diff --git a/tests/integ/test_sparkml_serving.py b/tests/integ/test_sparkml_serving.py index 66bda1beb5..094281918b 100644 --- a/tests/integ/test_sparkml_serving.py +++ b/tests/integ/test_sparkml_serving.py @@ -17,6 +17,8 @@ import pytest +from botocore.errorfactory import ClientError + from sagemaker.sparkml.model import SparkMLModel from sagemaker.utils import sagemaker_timestamp from tests.integ import DATA_DIR @@ -24,12 +26,9 @@ @pytest.mark.canary_quick -@pytest.mark.skip( - reason="This test has always failed, but the failure was masked by a bug. " - "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968" -) def test_sparkml_model_deploy(sagemaker_session, cpu_instance_type): - # Uploads an MLeap serialized MLeap model to S3 and use that to deploy a SparkML model to perform inference + # Uploads an MLeap serialized MLeap model to S3 and use that to deploy + # a SparkML model to perform inference data_path = os.path.join(DATA_DIR, "sparkml_model") endpoint_name = "test-sparkml-deploy-{}".format(sagemaker_timestamp()) model_data = sagemaker_session.upload_data( @@ -59,7 +58,8 @@ def test_sparkml_model_deploy(sagemaker_session, cpu_instance_type): predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) valid_data = "1.0,C,38.0,71.5,1.0,female" - assert predictor.predict(valid_data) == "1.0,0.0,38.0,1.0,71.5,0.0,1.0" + assert predictor.predict(valid_data) == b"1.0,0.0,38.0,1.0,71.5,0.0,1.0" invalid_data = "1.0,28.0,C,38.0,71.5,1.0" - assert predictor.predict(invalid_data) is None + with pytest.raises(ClientError): + predictor.predict(invalid_data) diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index eb91022b72..4888639537 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -49,7 +49,7 @@ def sagemaker_session(): def test_sparkml_model(sagemaker_session): sparkml = SparkMLModel(sagemaker_session=sagemaker_session, model_data=MODEL_DATA, role=ROLE) - assert sparkml.image_uri == image_uris.retrieve("sparkml-serving", REGION, version="2.2") + assert sparkml.image_uri == image_uris.retrieve("sparkml-serving", REGION, version="2.4") def test_predictor_type(sagemaker_session):