Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion awswrangler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

class DynamicInstantiate:

__default_session = Session()
__default_session = None

def __init__(self, service):
self._service = service

def __getattr__(self, name):
if DynamicInstantiate.__default_session is None:
DynamicInstantiate.__default_session = Session()
return getattr(getattr(DynamicInstantiate.__default_session, self._service), name)


Expand Down
4 changes: 4 additions & 0 deletions awswrangler/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,7 @@ class InvalidTable(Exception):

class InvalidParameters(Exception):
pass


class AWSCredentialsNotFound(Exception):
pass
15 changes: 14 additions & 1 deletion awswrangler/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pickle
import tarfile
import logging

from typing import Any
from awswrangler.exceptions import InvalidParameters

logger = logging.getLogger(__name__)

Expand All @@ -10,18 +12,29 @@ class SageMaker:
def __init__(self, session):
self._session = session
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
self._client_sagemaker = session.boto3_session.client(service_name="sagemaker")

@staticmethod
def _parse_path(path):
path2 = path.replace("s3://", "")
parts = path2.partition("/")
return parts[0], parts[2]

def get_job_outputs(self, path: str) -> Any:
def get_job_outputs(self, job_name: str = None, path: str = None) -> Any:

if path and job_name:
raise InvalidParameters("Specify either path, job_arn or job_name")

if job_name:
path = self._client_sagemaker.describe_training_job(TrainingJobName=job_name)["ModelArtifacts"]["S3ModelArtifacts"]

if not self._session.s3.does_object_exists(path):
return None

bucket, key = SageMaker._parse_path(path)
if key.split("/")[-1] != "model.tar.gz":
key = f"{key}/model.tar.gz"

body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read()
body = tarfile.io.BytesIO(body) # type: ignore
tar = tarfile.open(fileobj=body)
Expand Down
9 changes: 7 additions & 2 deletions awswrangler/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from awswrangler.redshift import Redshift
from awswrangler.emr import EMR
from awswrangler.sagemaker import SageMaker
from awswrangler.exceptions import AWSCredentialsNotFound

PYSPARK_INSTALLED = False
if importlib.util.find_spec("pyspark"): # type: ignore
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(self,
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
:param redshift_temp_s3_path: redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...)
"""

self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name)
self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key
if boto3_session else aws_access_key_id)
Expand Down Expand Up @@ -130,8 +132,11 @@ def _load_new_boto3_session(self):
args["aws_secret_access_key"] = self.aws_secret_access_key
self._boto3_session = boto3.Session(**args)
self._profile_name = self._boto3_session.profile_name
self._aws_access_key_id = self._boto3_session.get_credentials().access_key
self._aws_secret_access_key = self._boto3_session.get_credentials().secret_key
credentials = self._boto3_session.get_credentials()
if credentials is None:
raise AWSCredentialsNotFound("Please run aws configure: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html")
self._aws_access_key_id = credentials.access_key
self._aws_secret_access_key = credentials.secret_key
self._region_name = self._boto3_session.region_name

def _load_new_primitives(self):
Expand Down
23 changes: 18 additions & 5 deletions testing/test_awswrangler/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def bucket(session, cloudformation_outputs):
session.s3.delete_objects(path=f"s3://{bucket}/")


def test_get_job_outputs(session, bucket):
model_path = "output"
s3 = boto3.resource("s3")
@pytest.fixture(scope="module")
def model(bucket):
model_path = "output/model.tar.gz"

lr = LinearRegression()
with open("model.pkl", "wb") as fp:
Expand All @@ -49,10 +49,23 @@ def test_get_job_outputs(session, bucket):
with tarfile.open("model.tar.gz", "w:gz") as tar:
tar.add("model.pkl")

s3.Bucket(bucket).upload_file("model.tar.gz", f"{model_path}/model.tar.gz")
outputs = session.sagemaker.get_job_outputs(f"{bucket}/{model_path}")
s3 = boto3.resource("s3")
s3.Bucket(bucket).upload_file("model.tar.gz", model_path)

yield f"s3://{bucket}/{model_path}"

os.remove("model.pkl")
os.remove("model.tar.gz")


def test_get_job_outputs_by_path(session, model):
outputs = session.sagemaker.get_job_outputs(path=model)
assert type(outputs[0]) == LinearRegression


def test_get_job_outputs_by_job_id(session, bucket):
pass


def test_get_job_outputs_empty(session, bucket):
pass