From b07f502e3c8c272d67103199d053a4e35826d369 Mon Sep 17 00:00:00 2001 From: Luigi Tedesco Date: Thu, 19 Dec 2019 14:34:21 -0600 Subject: [PATCH 1/2] sagemaker module --- awswrangler/sagemaker.py | 39 +++++++++++++++ awswrangler/session.py | 9 ++++ requirements-dev.txt | 4 +- testing/test_awswrangler/test_sagemaker.py | 58 ++++++++++++++++++++++ 4 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 awswrangler/sagemaker.py create mode 100644 testing/test_awswrangler/test_sagemaker.py diff --git a/awswrangler/sagemaker.py b/awswrangler/sagemaker.py new file mode 100644 index 000000000..a231a6017 --- /dev/null +++ b/awswrangler/sagemaker.py @@ -0,0 +1,39 @@ +import pickle +import tarfile +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class SageMaker: + def __init__(self, session): + self._session = session + self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config) + + @staticmethod + def _parse_path(path): + path2 = path.replace("s3://", "") + parts = path2.partition("/") + return parts[0], parts[2] + + def get_job_outputs(self, path: str) -> Any: + + bucket, key = SageMaker._parse_path(path) + if key.split("/")[-1] != "model.tar.gz": + key = f"{key}/model.tar.gz" + body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read() + body = tarfile.io.BytesIO(body) + tar = tarfile.open(fileobj=body) + + results = [] + for member in tar.getmembers(): + f = tar.extractfile(member) + file_type = member.name.split(".")[-1] + + if file_type == "pkl": + f = pickle.load(f) + + results.append(f) + + return results diff --git a/awswrangler/session.py b/awswrangler/session.py index 16d1536f3..4fac5954c 100644 --- a/awswrangler/session.py +++ b/awswrangler/session.py @@ -13,6 +13,8 @@ from awswrangler.glue import Glue from awswrangler.redshift import Redshift from awswrangler.emr import EMR +from awswrangler.sagemaker import SageMaker + PYSPARK_INSTALLED = False if importlib.util.find_spec("pyspark"): # type: ignore @@ -112,6 +114,7 @@ def __init__(self, self._glue = None self._redshift = None self._spark = None + self._sagemaker = None def _load_new_boto3_session(self): """ @@ -281,6 +284,12 @@ def redshift(self): self._redshift = Redshift(session=self) return self._redshift + @property + def sagemaker(self): + if not self._sagemaker: + self._sagemaker = SageMaker(session=self) + return self._sagemaker + @property def spark(self): if not PYSPARK_INSTALLED: diff --git a/requirements-dev.txt b/requirements-dev.txt index fe387bbf9..5157fac0b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,8 +3,10 @@ mypy~=0.750 flake8~=3.7.9 pytest-cov~=2.8.1 cfn-lint~=0.26.0 +scikit-learn==0.22 +sklearn==0.0 twine~=3.1.1 wheel~=0.33.6 sphinx~=2.2.2 pyspark~=2.4.4 -pyspark-stubs~=2.4.0.post6 \ No newline at end of file +pyspark-stubs~=2.4.0.post6 diff --git a/testing/test_awswrangler/test_sagemaker.py b/testing/test_awswrangler/test_sagemaker.py new file mode 100644 index 000000000..2c3033c11 --- /dev/null +++ b/testing/test_awswrangler/test_sagemaker.py @@ -0,0 +1,58 @@ +import os +import pickle +import logging +import tarfile + +import boto3 +import pytest + +from awswrangler import Session +from sklearn.linear_model import LinearRegression + +logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +@pytest.fixture(scope="module") +def session(): + yield Session() + + +@pytest.fixture(scope="module") +def cloudformation_outputs(): + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + outputs = {} + for output in response.get("Stacks")[0].get("Outputs"): + outputs[output.get("OutputKey")] = output.get("OutputValue") + yield outputs + + +@pytest.fixture(scope="module") +def bucket(session, cloudformation_outputs): + if "BucketName" in cloudformation_outputs: + bucket = cloudformation_outputs["BucketName"] + session.s3.delete_objects(path=f"s3://{bucket}/") + else: + raise Exception("You must deploy the test infrastructure using Cloudformation!") + yield bucket + session.s3.delete_objects(path=f"s3://{bucket}/") + + +def test_get_job_outputs(session, bucket): + model_path = "output" + s3 = boto3.resource("s3") + + lr = LinearRegression() + with open("model.pkl", "wb") as fp: + pickle.dump(lr, fp, pickle.HIGHEST_PROTOCOL) + + with tarfile.open("model.tar.gz", "w:gz") as tar: + tar.add("model.pkl") + + s3.Bucket(bucket).upload_file("model.tar.gz", f"{model_path}/model.tar.gz") + outputs = session.sagemaker.get_job_outputs(f"{bucket}/{model_path}") + + os.remove("model.pkl") + os.remove("model.tar.gz") + + assert type(outputs[0]) == LinearRegression From 84763fbe4c888714265b827485fbc76fd6177586 Mon Sep 17 00:00:00 2001 From: Luigi Tedesco Date: Thu, 19 Dec 2019 14:37:46 -0600 Subject: [PATCH 2/2] default session --- README.md | 126 ++++++++++++++++++++++++---------------- awswrangler/__init__.py | 29 +++++++++ 2 files changed, 104 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 0ba09638a..44544fa90 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ ## Use Cases ### Pandas + * Pandas -> Parquet (S3) (Parallel) * Pandas -> CSV (S3) (Parallel) * Pandas -> Glue Catalog Table @@ -38,11 +39,13 @@ * Encrypt Pandas Dataframes on S3 with KMS keys ### PySpark + * PySpark -> Redshift (Parallel) * Register Glue table from Dataframe stored on S3 * Flatten nested DataFrames ### General + * List S3 objects (Parallel) * Delete S3 objects (Parallel) * Delete listed S3 objects (Parallel) @@ -78,8 +81,9 @@ Runs anywhere (AWS Lambda, AWS Glue Python Shell, EMR, EC2, on-premises, local, #### Writing Pandas Dataframe to S3 + Glue Catalog ```py3 -wrangler = awswrangler.Session() -wrangler.pandas.to_parquet( +import awswrangler as wr + +wr.pandas.to_parquet( dataframe=dataframe, database="database", path="s3://...", @@ -92,12 +96,14 @@ If a Glue Database name is passed, all the metadata will be created in the Glue #### Writing Pandas Dataframe to S3 as Parquet encrypting with a KMS key ```py3 +import awswrangler as wr + extra_args = { "ServerSideEncryption": "aws:kms", "SSEKMSKeyId": "YOUR_KMY_KEY_ARN" } -wrangler = awswrangler.Session(s3_additional_kwargs=extra_args) -wrangler.pandas.to_parquet( +sess = wr.Session(s3_additional_kwargs=extra_args) +sess.pandas.to_parquet( path="s3://..." ) ``` @@ -105,8 +111,9 @@ wrangler.pandas.to_parquet( #### Reading from AWS Athena to Pandas ```py3 -wrangler = awswrangler.Session() -dataframe = wrangler.pandas.read_sql_athena( +import awswrangler as wr + +dataframe = wr.pandas.read_sql_athena( sql="select * from table", database="database" ) @@ -115,21 +122,25 @@ dataframe = wrangler.pandas.read_sql_athena( #### Reading from AWS Athena to Pandas in chunks (For memory restrictions) ```py3 -wrangler = awswrangler.Session() -dataframe_iter = wrangler.pandas.read_sql_athena( +import awswrangler as wr + +df_iter = wr.pandas.read_sql_athena( sql="select * from table", database="database", max_result_size=512_000_000 # 512 MB ) -for dataframe in dataframe_iter: - print(dataframe) # Do whatever you want + +for df in df_iter: + print(df) # Do whatever you want ``` #### Reading from AWS Athena to Pandas with the blazing fast CTAS approach ```py3 -wrangler = awswrangler.Session(athena_ctas_approach=True) -dataframe = wrangler.pandas.read_sql_athena( +import awswrangler as wr + +sess = wr.Session(athena_ctas_approach=True) +dataframe = sess.pandas.read_sql_athena( sql="select * from table", database="database" ) @@ -138,27 +149,31 @@ dataframe = wrangler.pandas.read_sql_athena( #### Reading from S3 (CSV) to Pandas ```py3 -wrangler = awswrangler.Session() -dataframe = wrangler.pandas.read_csv(path="s3://...") +import awswrangler as wr + +dataframe = wr.pandas.read_csv(path="s3://...") ``` #### Reading from S3 (CSV) to Pandas in chunks (For memory restrictions) ```py3 -wrangler = awswrangler.Session() -dataframe_iter = wrangler.pandas.read_csv( +import awswrangler as wr + +df_iter = wr.pandas.read_csv( path="s3://...", max_result_size=512_000_000 # 512 MB ) -for dataframe in dataframe_iter: - print(dataframe) # Do whatever you want + +for df in df_iter: + print(df) # Do whatever you want ``` #### Reading from CloudWatch Logs Insights to Pandas ```py3 -wrangler = awswrangler.Session() -dataframe = wrangler.pandas.read_log_query( +import awswrangler as wr + +dataframe = wr.pandas.read_log_query( log_group_names=[LOG_GROUP_NAME], query="fields @timestamp, @message | sort @timestamp desc | limit 5", ) @@ -168,14 +183,13 @@ dataframe = wrangler.pandas.read_log_query( ```py3 import pandas -import awswrangler +import awswrangler as wr df = pandas.read_... # Read from anywhere # Typical Pandas, Numpy or Pyarrow transformation HERE! -wrangler = awswrangler.Session() -wrangler.pandas.to_parquet( # Storing the data and metadata to Data Lake +wr.pandas.to_parquet( # Storing the data and metadata to Data Lake dataframe=dataframe, database="database", path="s3://...", @@ -186,8 +200,9 @@ wrangler.pandas.to_parquet( # Storing the data and metadata to Data Lake #### Loading Pandas Dataframe to Redshift ```py3 -wrangler = awswrangler.Session() -wrangler.pandas.to_redshift( +import awswrangler as wr + +wr.pandas.to_redshift( dataframe=dataframe, path="s3://temp_path", schema="...", @@ -202,8 +217,9 @@ wrangler.pandas.to_redshift( #### Extract Redshift query to Pandas DataFrame ```py3 -wrangler = awswrangler.Session() -dataframe = session.pandas.read_sql_redshift( +import awswrangler as wr + +dataframe = wr.pandas.read_sql_redshift( sql="SELECT ...", iam_role="YOUR_ROLE_ARN", connection=con, @@ -215,8 +231,9 @@ dataframe = session.pandas.read_sql_redshift( #### Loading PySpark Dataframe to Redshift ```py3 -wrangler = awswrangler.Session(spark_session=spark) -wrangler.spark.to_redshift( +import awswrangler as wr + +wr.spark.to_redshift( dataframe=df, path="s3://...", connection=conn, @@ -230,13 +247,15 @@ wrangler.spark.to_redshift( #### Register Glue table from Dataframe stored on S3 ```py3 +import awswrangler as wr + dataframe.write \ .mode("overwrite") \ .format("parquet") \ .partitionBy(["year", "month"]) \ .save(compression="gzip", path="s3://...") -wrangler = awswrangler.Session(spark_session=spark) -wrangler.spark.create_glue_table( +sess = wr.Session(spark_session=spark) +sess.spark.create_glue_table( dataframe=dataframe, file_format="parquet", partition_by=["year", "month"], @@ -248,8 +267,9 @@ wrangler.spark.create_glue_table( #### Flatten nested PySpark DataFrame ```py3 -wrangler = awswrangler.Session(spark_session=spark) -dfs = wrangler.spark.flatten(dataframe=df_nested) +import awswrangler as wr +sess = awswrangler.Session(spark_session=spark) +dfs = sess.spark.flatten(dataframe=df_nested) for name, df_flat in dfs.items(): print(name) df_flat.show() @@ -260,15 +280,17 @@ for name, df_flat in dfs.items(): #### Deleting a bunch of S3 objects (parallel) ```py3 -wrangler = awswrangler.Session() -wrangler.s3.delete_objects(path="s3://...") +import awswrangler as wr + +wr.s3.delete_objects(path="s3://...") ``` #### Get CloudWatch Logs Insights query results ```py3 -wrangler = awswrangler.Session() -results = wrangler.cloudwatchlogs.query( +import awswrangler as wr + +results = wr.cloudwatchlogs.query( log_group_names=[LOG_GROUP_NAME], query="fields @timestamp, @message | sort @timestamp desc | limit 5", ) @@ -277,15 +299,17 @@ results = wrangler.cloudwatchlogs.query( #### Load partitions on Athena/Glue table (repair table) ```py3 -wrangler = awswrangler.Session() -wrangler.athena.repair_table(database="db_name", table="tbl_name") +import awswrangler as wr + +wr.athena.repair_table(database="db_name", table="tbl_name") ``` #### Create EMR cluster ```py3 -wrangler = awswrangler.Session() -cluster_id = wrangler.emr.create_cluster( +import awswrangler as wr + +cluster_id = wr.emr.create_cluster( cluster_name="wrangler_cluster", logging_s3_path=f"s3://BUCKET_NAME/emr-logs/", emr_release="emr-5.27.0", @@ -337,28 +361,28 @@ print(cluster_id) #### Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*) ```py3 -wrangler = awswrangler.Session() -for row in wrangler.athena.query(query="...", database="..."): +import awswrangler as wr + +for row in wr.athena.query(query="...", database="..."): print(row) ``` ## Diving Deep - ### Parallelism, Non-picklable objects and GeoPandas AWS Data Wrangler tries to parallelize everything that is possible (I/O and CPU bound task). You can control the parallelism level using the parameters: -- **procs_cpu_bound**: number of processes that can be used in single node applications for CPU bound case (Default: os.cpu_count()) -- **procs_io_bound**: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR) +* **procs_cpu_bound**: number of processes that can be used in single node applications for CPU bound case (Default: os.cpu_count()) +* **procs_io_bound**: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR) Both can be defined on Session level or directly in the functions. Some special cases will not work with parallelism: -- GeoPandas -- Columns with non-picklable objects +* GeoPandas +* Columns with non-picklable objects To handle that use `procs_cpu_bound=1` and avoid the distribution of the dataframe. @@ -370,7 +394,7 @@ We can handle this object column fine inferring the types of theses objects insi To work with null object columns you can explicitly set the expected Athena data type for the target table doing: ```py3 -import awswrangler +import awswrangler as wr import pandas as pd dataframe = pd.DataFrame({ @@ -378,8 +402,8 @@ dataframe = pd.DataFrame({ "col_string_null": [None, None], "col_date_null": [None, None], }) -session = awswrangler.Session() -session.pandas.to_parquet( + +wr.pandas.to_parquet( dataframe=dataframe, database="DATABASE", path=f"s3://...", diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index 246c56cc1..eb4a82f00 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -10,10 +10,39 @@ from awswrangler.glue import Glue # noqa from awswrangler.redshift import Redshift # noqa from awswrangler.emr import EMR # noqa +from awswrangler.sagemaker import SageMaker # noqa import awswrangler.utils # noqa import awswrangler.data_types # noqa + +class DynamicInstantiate: + + __default_session = Session() + + def __init__(self, service): + self._service = service + + def __getattr__(self, name): + return getattr( + getattr( + DynamicInstantiate.__default_session, + self._service + ), + name + ) + + if importlib.util.find_spec("pyspark"): # type: ignore from awswrangler.spark import Spark # noqa +s3 = DynamicInstantiate("s3") +emr = DynamicInstantiate("emr") +glue = DynamicInstantiate("glue") +spark = DynamicInstantiate("spark") +pandas = DynamicInstantiate("pandas") +athena = DynamicInstantiate("athena") +redshift = DynamicInstantiate("redshift") +sagemaker = DynamicInstantiate("sagemaker") +cloudwatchlogs = DynamicInstantiate("cloudwatchlogs") + logging.getLogger("awswrangler").addHandler(logging.NullHandler())