From a6b7e92fbf5438a52fa2b8b89f26eae19aba9eeb Mon Sep 17 00:00:00 2001 From: igorborgest Date: Sat, 11 Apr 2020 10:03:26 -0300 Subject: [PATCH] Fixing athena queries for workgroups without encryption #159 --- awswrangler/athena.py | 7 +-- testing/test_awswrangler/test_data_lake.py | 62 ++++++++++++---------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/awswrangler/athena.py b/awswrangler/athena.py index b2886426d..72151f718 100644 --- a/awswrangler/athena.py +++ b/awswrangler/athena.py @@ -539,12 +539,13 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None def _ensure_workgroup( session: boto3.Session, workgroup: Optional[str] = None ) -> Tuple[Optional[str], Optional[str], Optional[str]]: - if workgroup: + if workgroup is not None: res: Dict[str, Any] = get_work_group(workgroup=workgroup, boto3_session=session) config: Dict[str, Any] = res["WorkGroup"]["Configuration"]["ResultConfiguration"] wg_s3_output: Optional[str] = config.get("OutputLocation") - wg_encryption: Optional[str] = config["EncryptionConfiguration"].get("EncryptionOption") - wg_kms_key: Optional[str] = config["EncryptionConfiguration"].get("KmsKey") + encrypt_config: Optional[Dict[str, str]] = config.get("EncryptionConfiguration") + wg_encryption: Optional[str] = None if encrypt_config is None else encrypt_config.get("EncryptionOption") + wg_kms_key: Optional[str] = None if encrypt_config is None else encrypt_config.get("KmsKey") else: wg_s3_output, wg_encryption, wg_kms_key = None, None, None return wg_s3_output, wg_encryption, wg_kms_key diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index 7cb94b229..879fed5f8 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -25,43 +25,27 @@ def cloudformation_outputs(): @pytest.fixture(scope="module") def region(cloudformation_outputs): - if "Region" in cloudformation_outputs: - region = cloudformation_outputs["Region"] - else: - raise Exception("You must deploy/update the test infrastructure (CloudFormation)!") - yield region + yield cloudformation_outputs["Region"] @pytest.fixture(scope="module") def bucket(cloudformation_outputs): - if "BucketName" in cloudformation_outputs: - bucket = cloudformation_outputs["BucketName"] - else: - raise Exception("You must deploy/update the test infrastructure (CloudFormation)") - yield bucket + yield cloudformation_outputs["BucketName"] @pytest.fixture(scope="module") def database(cloudformation_outputs): - if "GlueDatabaseName" in cloudformation_outputs: - database = cloudformation_outputs["GlueDatabaseName"] - else: - raise Exception("You must deploy the test infrastructure using Cloudformation!") - yield database + yield cloudformation_outputs["GlueDatabaseName"] @pytest.fixture(scope="module") def kms_key(cloudformation_outputs): - if "KmsKeyArn" in cloudformation_outputs: - key = cloudformation_outputs["KmsKeyArn"] - else: - raise Exception("You must deploy the test infrastructure using Cloudformation!") - yield key + yield cloudformation_outputs["KmsKeyArn"] @pytest.fixture(scope="module") -def workgroup_secondary(bucket): - wkg_name = "awswrangler_test" +def workgroup0(bucket): + wkg_name = "awswrangler_test_0" client = boto3.client("athena") wkgs = client.list_work_groups() wkgs = [x["Name"] for x in wkgs["WorkGroups"]] @@ -70,7 +54,7 @@ def workgroup_secondary(bucket): Name=wkg_name, Configuration={ "ResultConfiguration": { - "OutputLocation": f"s3://{bucket}/athena_workgroup_secondary/", + "OutputLocation": f"s3://{bucket}/athena_workgroup0/", "EncryptionConfiguration": {"EncryptionOption": "SSE_S3"}, }, "EnforceWorkGroupConfiguration": True, @@ -78,7 +62,28 @@ def workgroup_secondary(bucket): "BytesScannedCutoffPerQuery": 100_000_000, "RequesterPaysEnabled": False, }, - Description="AWS Data Wrangler Test WorkGroup", + Description="AWS Data Wrangler Test WorkGroup Number 0", + ) + yield wkg_name + + +@pytest.fixture(scope="module") +def workgroup1(bucket): + wkg_name = "awswrangler_test_1" + client = boto3.client("athena") + wkgs = client.list_work_groups() + wkgs = [x["Name"] for x in wkgs["WorkGroups"]] + if wkg_name not in wkgs: + client.create_work_group( + Name=wkg_name, + Configuration={ + "ResultConfiguration": {"OutputLocation": f"s3://{bucket}/athena_workgroup1/"}, + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "BytesScannedCutoffPerQuery": 100_000_000, + "RequesterPaysEnabled": False, + }, + Description="AWS Data Wrangler Test WorkGroup Number 1", ) yield wkg_name @@ -120,7 +125,7 @@ def test_athena_ctas(bucket, database, kms_key): wr.s3.delete_objects(path=f"s3://{bucket}/test_athena_ctas_result/") -def test_athena(bucket, database, kms_key, workgroup_secondary): +def test_athena(bucket, database, kms_key, workgroup0, workgroup1): wr.s3.delete_objects(path=f"s3://{bucket}/test_athena/") paths = wr.s3.to_parquet( df=get_df(), @@ -141,13 +146,13 @@ def test_athena(bucket, database, kms_key, workgroup_secondary): chunksize=1, encryption="SSE_KMS", kms_key=kms_key, - workgroup=workgroup_secondary, + workgroup=workgroup0, ) for df2 in dfs: print(df2) ensure_data_types(df=df2) df = wr.athena.read_sql_query( - sql="SELECT * FROM __test_athena", database=database, ctas_approach=False, workgroup=workgroup_secondary + sql="SELECT * FROM __test_athena", database=database, ctas_approach=False, workgroup=workgroup1 ) assert len(df.index) == 3 ensure_data_types(df=df) @@ -155,7 +160,8 @@ def test_athena(bucket, database, kms_key, workgroup_secondary): wr.catalog.delete_table_if_exists(database=database, table="__test_athena") wr.s3.delete_objects(path=paths) wr.s3.wait_objects_not_exist(paths=paths) - wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup_secondary/") + wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup0/") + wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup1/") def test_csv(bucket):