Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
* Get EMR step state
* Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*)
* Load and Unzip SageMaker jobs outputs
* Load and Unzip SageMaker models
* Redshift -> Parquet (S3)
* Aurora -> CSV (S3) (MySQL) (NEW :star:)

Expand Down Expand Up @@ -417,6 +418,14 @@ for row in wr.athena.query(query="...", database="..."):
```py3
import awswrangler as wr

outputs = wr.sagemaker.get_model("JOB_NAME")
```

#### Load and unzip SageMaker job output

```py3
import awswrangler as wr

outputs = wr.sagemaker.get_job_outputs("JOB_NAME")
```

Expand Down
4 changes: 4 additions & 0 deletions awswrangler/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,7 @@ class AWSCredentialsNotFound(Exception):

class InvalidEngine(Exception):
pass


class InvalidSagemakerOutput(Exception):
pass
60 changes: 47 additions & 13 deletions awswrangler/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any
from typing import Any, Dict
import pickle
import tarfile
import logging

from awswrangler.exceptions import InvalidParameters
from awswrangler.exceptions import InvalidParameters, InvalidSagemakerOutput

logger = logging.getLogger(__name__)

Expand All @@ -22,34 +22,68 @@ def _parse_path(path):
parts = path2.partition("/")
return parts[0], parts[2]

def get_job_outputs(self, job_name: str = None, path: str = None) -> Any:
def get_job_outputs(self, job_name: str = None, path: str = None) -> Dict[str, Any]:
"""
Extract and deserialize all Sagemaker's outputs (everything inside model.tar.gz)

:param job_name: Sagemaker's job name
:param path: S3 path (model.tar.gz path)
:return: A Dictionary with all filenames (key) and all objects (values)
"""

if path and job_name:
raise InvalidParameters("Specify either path, job_arn or job_name")
raise InvalidParameters("Specify either path or job_name")

if job_name:
path = self._client_sagemaker.describe_training_job(
TrainingJobName=job_name)["ModelArtifacts"]["S3ModelArtifacts"]

if not self._session.s3.does_object_exists(path):
return None
if path is not None:
if path.split("/")[-1] != "model.tar.gz":
path = f"{path}/model.tar.gz"

bucket, key = SageMaker._parse_path(path)
if key.split("/")[-1] != "model.tar.gz":
key = f"{key}/model.tar.gz"
if self._session.s3.does_object_exists(path) is False:
raise InvalidSagemakerOutput(f"Path does not exists ({path})")

bucket: str
key: str
bucket, key = SageMaker._parse_path(path)
body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read()
body = tarfile.io.BytesIO(body) # type: ignore
tar = tarfile.open(fileobj=body)

results = []
for member in tar.getmembers():
members = tar.getmembers()
if len(members) < 1:
raise InvalidSagemakerOutput(f"No artifacts found in {path}")

results: Dict[str, Any] = {}
for member in members:
logger.debug(f"member: {member.name}")
f = tar.extractfile(member)
file_type = member.name.split(".")[-1]
file_type: str = member.name.split(".")[-1]

if (file_type == "pkl") and (f is not None):
f = pickle.load(f)

results.append(f)
results[member.name] = f

return results

def get_model(self, job_name: str = None, path: str = None, model_name: str = None) -> Any:
"""
Extract and deserialize a Sagemaker's output model (.tat.gz)

:param job_name: Sagemaker's job name
:param path: S3 path (model.tar.gz path)
:param model_name: model name (e.g: )
:return:
"""
outputs: Dict[str, Any] = self.get_job_outputs(job_name=job_name, path=path)
outputs_len: int = len(outputs)
if model_name in outputs:
return outputs[model_name]
elif outputs_len > 1:
raise InvalidSagemakerOutput(
f"Number of artifacts found: {outputs_len}. Please, specify a model_name or use the Sagemaker.get_job_outputs() method."
)
return list(outputs.values())[0]
9 changes: 9 additions & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ Athena query to receive the result as python primitives (Iterable[Dict[str, Any]
for row in wr.athena.query(query="...", database="..."):
print(row)

Load and unzip SageMaker model
``````````````````````````````

.. code-block:: python

import awswrangler as wr

outputs = wr.sagemaker.get_model("JOB_NAME")

Load and unzip SageMaker job output
```````````````````````````````````

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ General
* Get EMR step state
* Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*)
* Load and Unzip SageMaker jobs outputs
* Load and Unzip SageMaker models
* Redshift -> Parquet (S3)
* Aurora -> CSV (S3) (MySQL) (NEW :star:)

Expand Down
83 changes: 78 additions & 5 deletions testing/test_awswrangler/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import boto3
import pytest

import awswrangler as wr
from awswrangler import Session
from awswrangler.exceptions import InvalidSagemakerOutput
from sklearn.linear_model import LinearRegression

logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
Expand Down Expand Up @@ -54,18 +56,89 @@ def model(bucket):

yield f"s3://{bucket}/{model_path}"

os.remove("model.pkl")
os.remove("model.tar.gz")
try:
os.remove("model.pkl")
except OSError:
pass
try:
os.remove("model.tar.gz")
except OSError:
pass


@pytest.fixture(scope="module")
def model_empty(bucket):
model_path = "output_empty/model.tar.gz"

with tarfile.open("model.tar.gz", "w:gz") as tar:
pass

s3 = boto3.resource("s3")
s3.Bucket(bucket).upload_file("model.tar.gz", model_path)

yield f"s3://{bucket}/{model_path}"

try:
os.remove("model.tar.gz")
except OSError:
pass


@pytest.fixture(scope="module")
def model_double(bucket):
model_path = "output_double/model.tar.gz"

lr = LinearRegression()
with open("model.pkl", "wb") as fp:
pickle.dump(lr, fp, pickle.HIGHEST_PROTOCOL)

with open("model2.pkl", "wb") as fp:
pickle.dump(lr, fp, pickle.HIGHEST_PROTOCOL)

with tarfile.open("model.tar.gz", "w:gz") as tar:
tar.add("model.pkl")
tar.add("model2.pkl")

s3 = boto3.resource("s3")
s3.Bucket(bucket).upload_file("model.tar.gz", model_path)

yield f"s3://{bucket}/{model_path}"

try:
os.remove("model.pkl")
except OSError:
pass
try:
os.remove("model2.pkl")
except OSError:
pass
try:
os.remove("model.tar.gz")
except OSError:
pass


def test_get_job_outputs_by_path(session, model):
outputs = session.sagemaker.get_job_outputs(path=model)
assert type(outputs[0]) == LinearRegression
assert type(list(outputs.values())[0]) == LinearRegression


def test_get_job_outputs_by_job_id(session, bucket):
pass


def test_get_job_outputs_empty(session, bucket):
pass
def test_get_model_empty(model_empty):
with pytest.raises(InvalidSagemakerOutput):
wr.sagemaker.get_model(path=model_empty)


def test_get_model_double(session, model_double):
with pytest.raises(InvalidSagemakerOutput):
wr.sagemaker.get_model(path=model_double)
model = session.sagemaker.get_model(path=model_double, model_name="model.pkl")
assert type(model) == LinearRegression


def test_get_model_by_path(session, model):
model = session.sagemaker.get_model(path=model)
assert type(model) == LinearRegression