diff --git a/README.md b/README.md index c20ae18ec..5ffde81a0 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ * Get EMR step state * Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*) * Load and Unzip SageMaker jobs outputs +* Load and Unzip SageMaker models * Redshift -> Parquet (S3) * Aurora -> CSV (S3) (MySQL) (NEW :star:) @@ -417,6 +418,14 @@ for row in wr.athena.query(query="...", database="..."): ```py3 import awswrangler as wr +outputs = wr.sagemaker.get_model("JOB_NAME") +``` + +#### Load and unzip SageMaker job output + +```py3 +import awswrangler as wr + outputs = wr.sagemaker.get_job_outputs("JOB_NAME") ``` diff --git a/awswrangler/exceptions.py b/awswrangler/exceptions.py index a9bf91d4f..af679e71e 100644 --- a/awswrangler/exceptions.py +++ b/awswrangler/exceptions.py @@ -104,3 +104,7 @@ class AWSCredentialsNotFound(Exception): class InvalidEngine(Exception): pass + + +class InvalidSagemakerOutput(Exception): + pass diff --git a/awswrangler/sagemaker.py b/awswrangler/sagemaker.py index 9d8d4b07f..86c385113 100644 --- a/awswrangler/sagemaker.py +++ b/awswrangler/sagemaker.py @@ -1,9 +1,9 @@ -from typing import Any +from typing import Any, Dict import pickle import tarfile import logging -from awswrangler.exceptions import InvalidParameters +from awswrangler.exceptions import InvalidParameters, InvalidSagemakerOutput logger = logging.getLogger(__name__) @@ -22,34 +22,68 @@ def _parse_path(path): parts = path2.partition("/") return parts[0], parts[2] - def get_job_outputs(self, job_name: str = None, path: str = None) -> Any: + def get_job_outputs(self, job_name: str = None, path: str = None) -> Dict[str, Any]: + """ + Extract and deserialize all Sagemaker's outputs (everything inside model.tar.gz) + + :param job_name: Sagemaker's job name + :param path: S3 path (model.tar.gz path) + :return: A Dictionary with all filenames (key) and all objects (values) + """ if path and job_name: - raise InvalidParameters("Specify either path, job_arn or job_name") + raise InvalidParameters("Specify either path or job_name") if job_name: path = self._client_sagemaker.describe_training_job( TrainingJobName=job_name)["ModelArtifacts"]["S3ModelArtifacts"] - if not self._session.s3.does_object_exists(path): - return None + if path is not None: + if path.split("/")[-1] != "model.tar.gz": + path = f"{path}/model.tar.gz" - bucket, key = SageMaker._parse_path(path) - if key.split("/")[-1] != "model.tar.gz": - key = f"{key}/model.tar.gz" + if self._session.s3.does_object_exists(path) is False: + raise InvalidSagemakerOutput(f"Path does not exists ({path})") + bucket: str + key: str + bucket, key = SageMaker._parse_path(path) body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read() body = tarfile.io.BytesIO(body) # type: ignore tar = tarfile.open(fileobj=body) - results = [] - for member in tar.getmembers(): + members = tar.getmembers() + if len(members) < 1: + raise InvalidSagemakerOutput(f"No artifacts found in {path}") + + results: Dict[str, Any] = {} + for member in members: + logger.debug(f"member: {member.name}") f = tar.extractfile(member) - file_type = member.name.split(".")[-1] + file_type: str = member.name.split(".")[-1] if (file_type == "pkl") and (f is not None): f = pickle.load(f) - results.append(f) + results[member.name] = f return results + + def get_model(self, job_name: str = None, path: str = None, model_name: str = None) -> Any: + """ + Extract and deserialize a Sagemaker's output model (.tat.gz) + + :param job_name: Sagemaker's job name + :param path: S3 path (model.tar.gz path) + :param model_name: model name (e.g: ) + :return: + """ + outputs: Dict[str, Any] = self.get_job_outputs(job_name=job_name, path=path) + outputs_len: int = len(outputs) + if model_name in outputs: + return outputs[model_name] + elif outputs_len > 1: + raise InvalidSagemakerOutput( + f"Number of artifacts found: {outputs_len}. Please, specify a model_name or use the Sagemaker.get_job_outputs() method." + ) + return list(outputs.values())[0] diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 04812dcf4..290840c7a 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -370,6 +370,15 @@ Athena query to receive the result as python primitives (Iterable[Dict[str, Any] for row in wr.athena.query(query="...", database="..."): print(row) +Load and unzip SageMaker model +`````````````````````````````` + +.. code-block:: python + + import awswrangler as wr + + outputs = wr.sagemaker.get_model("JOB_NAME") + Load and unzip SageMaker job output ``````````````````````````````````` diff --git a/docs/source/index.rst b/docs/source/index.rst index 86b3af6f7..917d3117b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -54,6 +54,7 @@ General * Get EMR step state * Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*) * Load and Unzip SageMaker jobs outputs +* Load and Unzip SageMaker models * Redshift -> Parquet (S3) * Aurora -> CSV (S3) (MySQL) (NEW :star:) diff --git a/testing/test_awswrangler/test_sagemaker.py b/testing/test_awswrangler/test_sagemaker.py index 5a6bc210e..a1af2adcb 100644 --- a/testing/test_awswrangler/test_sagemaker.py +++ b/testing/test_awswrangler/test_sagemaker.py @@ -6,7 +6,9 @@ import boto3 import pytest +import awswrangler as wr from awswrangler import Session +from awswrangler.exceptions import InvalidSagemakerOutput from sklearn.linear_model import LinearRegression logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") @@ -54,18 +56,89 @@ def model(bucket): yield f"s3://{bucket}/{model_path}" - os.remove("model.pkl") - os.remove("model.tar.gz") + try: + os.remove("model.pkl") + except OSError: + pass + try: + os.remove("model.tar.gz") + except OSError: + pass + + +@pytest.fixture(scope="module") +def model_empty(bucket): + model_path = "output_empty/model.tar.gz" + + with tarfile.open("model.tar.gz", "w:gz") as tar: + pass + + s3 = boto3.resource("s3") + s3.Bucket(bucket).upload_file("model.tar.gz", model_path) + + yield f"s3://{bucket}/{model_path}" + + try: + os.remove("model.tar.gz") + except OSError: + pass + + +@pytest.fixture(scope="module") +def model_double(bucket): + model_path = "output_double/model.tar.gz" + + lr = LinearRegression() + with open("model.pkl", "wb") as fp: + pickle.dump(lr, fp, pickle.HIGHEST_PROTOCOL) + + with open("model2.pkl", "wb") as fp: + pickle.dump(lr, fp, pickle.HIGHEST_PROTOCOL) + + with tarfile.open("model.tar.gz", "w:gz") as tar: + tar.add("model.pkl") + tar.add("model2.pkl") + + s3 = boto3.resource("s3") + s3.Bucket(bucket).upload_file("model.tar.gz", model_path) + + yield f"s3://{bucket}/{model_path}" + + try: + os.remove("model.pkl") + except OSError: + pass + try: + os.remove("model2.pkl") + except OSError: + pass + try: + os.remove("model.tar.gz") + except OSError: + pass def test_get_job_outputs_by_path(session, model): outputs = session.sagemaker.get_job_outputs(path=model) - assert type(outputs[0]) == LinearRegression + assert type(list(outputs.values())[0]) == LinearRegression def test_get_job_outputs_by_job_id(session, bucket): pass -def test_get_job_outputs_empty(session, bucket): - pass +def test_get_model_empty(model_empty): + with pytest.raises(InvalidSagemakerOutput): + wr.sagemaker.get_model(path=model_empty) + + +def test_get_model_double(session, model_double): + with pytest.raises(InvalidSagemakerOutput): + wr.sagemaker.get_model(path=model_double) + model = session.sagemaker.get_model(path=model_double, model_name="model.pkl") + assert type(model) == LinearRegression + + +def test_get_model_by_path(session, model): + model = session.sagemaker.get_model(path=model) + assert type(model) == LinearRegression