diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index c0d62206f..e07a214da 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -17,12 +17,14 @@ class DynamicInstantiate: - __default_session = Session() + __default_session = None def __init__(self, service): self._service = service def __getattr__(self, name): + if DynamicInstantiate.__default_session is None: + DynamicInstantiate.__default_session = Session() return getattr(getattr(DynamicInstantiate.__default_session, self._service), name) diff --git a/awswrangler/exceptions.py b/awswrangler/exceptions.py index c940e2cb5..d66dd8ee0 100644 --- a/awswrangler/exceptions.py +++ b/awswrangler/exceptions.py @@ -92,3 +92,7 @@ class InvalidTable(Exception): class InvalidParameters(Exception): pass + + +class AWSCredentialsNotFound(Exception): + pass diff --git a/awswrangler/sagemaker.py b/awswrangler/sagemaker.py index 6654687c0..bb32ff63b 100644 --- a/awswrangler/sagemaker.py +++ b/awswrangler/sagemaker.py @@ -1,7 +1,9 @@ import pickle import tarfile import logging + from typing import Any +from awswrangler.exceptions import InvalidParameters logger = logging.getLogger(__name__) @@ -10,6 +12,7 @@ class SageMaker: def __init__(self, session): self._session = session self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config) + self._client_sagemaker = session.boto3_session.client(service_name="sagemaker") @staticmethod def _parse_path(path): @@ -17,11 +20,21 @@ def _parse_path(path): parts = path2.partition("/") return parts[0], parts[2] - def get_job_outputs(self, path: str) -> Any: + def get_job_outputs(self, job_name: str = None, path: str = None) -> Any: + + if path and job_name: + raise InvalidParameters("Specify either path, job_arn or job_name") + + if job_name: + path = self._client_sagemaker.describe_training_job(TrainingJobName=job_name)["ModelArtifacts"]["S3ModelArtifacts"] + + if not self._session.s3.does_object_exists(path): + return None bucket, key = SageMaker._parse_path(path) if key.split("/")[-1] != "model.tar.gz": key = f"{key}/model.tar.gz" + body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read() body = tarfile.io.BytesIO(body) # type: ignore tar = tarfile.open(fileobj=body) diff --git a/awswrangler/session.py b/awswrangler/session.py index 359e5d48f..12dfef129 100644 --- a/awswrangler/session.py +++ b/awswrangler/session.py @@ -14,6 +14,7 @@ from awswrangler.redshift import Redshift from awswrangler.emr import EMR from awswrangler.sagemaker import SageMaker +from awswrangler.exceptions import AWSCredentialsNotFound PYSPARK_INSTALLED = False if importlib.util.find_spec("pyspark"): # type: ignore @@ -77,6 +78,7 @@ def __init__(self, :param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. :param redshift_temp_s3_path: redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) """ + self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name) self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key if boto3_session else aws_access_key_id) @@ -130,8 +132,11 @@ def _load_new_boto3_session(self): args["aws_secret_access_key"] = self.aws_secret_access_key self._boto3_session = boto3.Session(**args) self._profile_name = self._boto3_session.profile_name - self._aws_access_key_id = self._boto3_session.get_credentials().access_key - self._aws_secret_access_key = self._boto3_session.get_credentials().secret_key + credentials = self._boto3_session.get_credentials() + if credentials is None: + raise AWSCredentialsNotFound("Please run aws configure: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html") + self._aws_access_key_id = credentials.access_key + self._aws_secret_access_key = credentials.secret_key self._region_name = self._boto3_session.region_name def _load_new_primitives(self): diff --git a/testing/test_awswrangler/test_sagemaker.py b/testing/test_awswrangler/test_sagemaker.py index 2c3033c11..568210306 100644 --- a/testing/test_awswrangler/test_sagemaker.py +++ b/testing/test_awswrangler/test_sagemaker.py @@ -38,9 +38,9 @@ def bucket(session, cloudformation_outputs): session.s3.delete_objects(path=f"s3://{bucket}/") -def test_get_job_outputs(session, bucket): - model_path = "output" - s3 = boto3.resource("s3") +@pytest.fixture(scope="module") +def model(bucket): + model_path = "output/model.tar.gz" lr = LinearRegression() with open("model.pkl", "wb") as fp: @@ -49,10 +49,23 @@ def test_get_job_outputs(session, bucket): with tarfile.open("model.tar.gz", "w:gz") as tar: tar.add("model.pkl") - s3.Bucket(bucket).upload_file("model.tar.gz", f"{model_path}/model.tar.gz") - outputs = session.sagemaker.get_job_outputs(f"{bucket}/{model_path}") + s3 = boto3.resource("s3") + s3.Bucket(bucket).upload_file("model.tar.gz", model_path) + + yield f"s3://{bucket}/{model_path}" os.remove("model.pkl") os.remove("model.tar.gz") + +def test_get_job_outputs_by_path(session, model): + outputs = session.sagemaker.get_job_outputs(path=model) assert type(outputs[0]) == LinearRegression + + +def test_get_job_outputs_by_job_id(session, bucket): + pass + + +def test_get_job_outputs_empty(session, bucket): + pass