diff --git a/awswrangler/catalog/_create.py b/awswrangler/catalog/_create.py index 47fcd4e35..adefbff1b 100644 --- a/awswrangler/catalog/_create.py +++ b/awswrangler/catalog/_create.py @@ -42,6 +42,7 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements projection_values: Optional[Dict[str, str]], projection_intervals: Optional[Dict[str, str]], projection_digits: Optional[Dict[str, str]], + projection_storage_location_template: Optional[str], catalog_id: Optional[str], ) -> None: # Description @@ -71,7 +72,7 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements projection_digits = {sanitize_column_name(k): v for k, v in projection_digits.items()} for k, v in projection_types.items(): dtype: Optional[str] = partitions_types.get(k) - if dtype is None: + if dtype is None and projection_storage_location_template is None: raise exceptions.InvalidArgumentCombination( f"Column {k} appears as projected column but not as partitioned column." ) @@ -95,6 +96,12 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements mode = _update_if_necessary( dic=table_input["Parameters"], key=f"projection.{k}.digits", value=str(v), mode=mode ) + mode = _update_if_necessary( + table_input["Parameters"], + key="storage.location.template", + value=projection_storage_location_template, + mode=mode, + ) else: table_input["Parameters"]["projection.enabled"] = "false" @@ -232,6 +239,7 @@ def _create_parquet_table( projection_values: Optional[Dict[str, str]], projection_intervals: Optional[Dict[str, str]], projection_digits: Optional[Dict[str, str]], + projection_storage_location_template: Optional[str], boto3_session: Optional[boto3.Session], catalog_table_input: Optional[Dict[str, Any]], ) -> None: @@ -280,6 +288,7 @@ def _create_parquet_table( projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=projection_storage_location_template, catalog_id=catalog_id, ) @@ -309,6 +318,7 @@ def _create_csv_table( # pylint: disable=too-many-arguments projection_values: Optional[Dict[str, str]], projection_intervals: Optional[Dict[str, str]], projection_digits: Optional[Dict[str, str]], + projection_storage_location_template: Optional[str], catalog_table_input: Optional[Dict[str, Any]], catalog_id: Optional[str], ) -> None: @@ -353,6 +363,7 @@ def _create_csv_table( # pylint: disable=too-many-arguments projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=projection_storage_location_template, catalog_id=catalog_id, ) @@ -380,6 +391,7 @@ def _create_json_table( # pylint: disable=too-many-arguments projection_values: Optional[Dict[str, str]], projection_intervals: Optional[Dict[str, str]], projection_digits: Optional[Dict[str, str]], + projection_storage_location_template: Optional[str], catalog_table_input: Optional[Dict[str, Any]], catalog_id: Optional[str], ) -> None: @@ -422,6 +434,7 @@ def _create_json_table( # pylint: disable=too-many-arguments projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=projection_storage_location_template, catalog_id=catalog_id, ) @@ -613,6 +626,7 @@ def create_parquet_table( projection_values: Optional[Dict[str, str]] = None, projection_intervals: Optional[Dict[str, str]] = None, projection_digits: Optional[Dict[str, str]] = None, + projection_storage_location_template: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, ) -> None: """Create a Parquet Table (Metadata Only) in the AWS Glue Catalog. @@ -673,6 +687,11 @@ def create_parquet_table( Dictionary of partitions names and Athena projections digits. https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html (e.g. {'col_name': '1', 'col2_name': '2'}) + projection_storage_location_template: Optional[str] + Value which is allows Athena to properly map partition values if the S3 file locations do not follow + a typical `.../column=value/...` pattern. + https://docs.aws.amazon.com/athena/latest/ug/partition-projection-setting-up.html + (e.g. s3://bucket/table_root/a=${a}/${b}/some_static_subdirectory/${c}/) boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. @@ -721,13 +740,14 @@ def create_parquet_table( projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=projection_storage_location_template, boto3_session=boto3_session, catalog_table_input=catalog_table_input, ) @apply_configs -def create_csv_table( +def create_csv_table( # pylint: disable=too-many-arguments database: str, table: str, path: str, @@ -752,6 +772,7 @@ def create_csv_table( projection_values: Optional[Dict[str, str]] = None, projection_intervals: Optional[Dict[str, str]] = None, projection_digits: Optional[Dict[str, str]] = None, + projection_storage_location_template: Optional[str] = None, catalog_id: Optional[str] = None, ) -> None: r"""Create a CSV Table (Metadata Only) in the AWS Glue Catalog. @@ -825,6 +846,11 @@ def create_csv_table( Dictionary of partitions names and Athena projections digits. https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html (e.g. {'col_name': '1', 'col2_name': '2'}) + projection_storage_location_template: Optional[str] + Value which is allows Athena to properly map partition values if the S3 file locations do not follow + a typical `.../column=value/...` pattern. + https://docs.aws.amazon.com/athena/latest/ug/partition-projection-setting-up.html + (e.g. s3://bucket/table_root/a=${a}/${b}/some_static_subdirectory/${c}/) boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. catalog_id : str, optional @@ -877,6 +903,7 @@ def create_csv_table( projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=projection_storage_location_template, boto3_session=boto3_session, catalog_table_input=catalog_table_input, sep=sep, @@ -910,6 +937,7 @@ def create_json_table( projection_values: Optional[Dict[str, str]] = None, projection_intervals: Optional[Dict[str, str]] = None, projection_digits: Optional[Dict[str, str]] = None, + projection_storage_location_template: Optional[str] = None, catalog_id: Optional[str] = None, ) -> None: r"""Create a JSON Table (Metadata Only) in the AWS Glue Catalog. @@ -979,6 +1007,11 @@ def create_json_table( Dictionary of partitions names and Athena projections digits. https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html (e.g. {'col_name': '1', 'col2_name': '2'}) + projection_storage_location_template: Optional[str] + Value which is allows Athena to properly map partition values if the S3 file locations do not follow + a typical `.../column=value/...` pattern. + https://docs.aws.amazon.com/athena/latest/ug/partition-projection-setting-up.html + (e.g. s3://bucket/table_root/a=${a}/${b}/some_static_subdirectory/${c}/) boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. catalog_id : str, optional @@ -1030,6 +1063,7 @@ def create_json_table( projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=projection_storage_location_template, boto3_session=boto3_session, catalog_table_input=catalog_table_input, serde_library=serde_library, diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index 13b272901..410db8490 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -606,6 +606,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=None, catalog_id=catalog_id, catalog_table_input=catalog_table_input, ) diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index 001cdbceb..0f7ce4330 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -538,6 +538,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=None, catalog_table_input=catalog_table_input, catalog_id=catalog_id, compression=pandas_kwargs.get("compression"), @@ -888,6 +889,7 @@ def to_json( # pylint: disable=too-many-arguments,too-many-locals,too-many-stat projection_values=projection_values, projection_intervals=projection_intervals, projection_digits=projection_digits, + projection_storage_location_template=None, catalog_table_input=catalog_table_input, catalog_id=catalog_id, compression=pandas_kwargs.get("compression"), diff --git a/tests/test_athena_projection.py b/tests/test_athena_projection.py index 7850daa75..abe200721 100644 --- a/tests/test_athena_projection.py +++ b/tests/test_athena_projection.py @@ -94,3 +94,28 @@ def test_to_parquet_projection_injected(glue_database, glue_table, path): df2 = wr.athena.read_sql_query(f"SELECT * FROM {glue_table} WHERE c1='foo' AND c2='0'", glue_database) assert df2.shape == (1, 3) assert df2.c0.iloc[0] == 0 + + +def test_to_parquet_storage_location(glue_database, glue_table, path): + df1 = pd.DataFrame({"c0": [0], "c1": ["foo"], "c2": ["0"]}) + df2 = pd.DataFrame({"c0": [1], "c1": ["foo"], "c2": ["1"]}) + df3 = pd.DataFrame({"c0": [2], "c1": ["boo"], "c2": ["2"]}) + df4 = pd.DataFrame({"c0": [3], "c1": ["boo"], "c2": ["3"]}) + + wr.s3.to_parquet(df=df1, path=f"{path}foo/0/file0.parquet") + wr.s3.to_parquet(df=df2, path=f"{path}foo/1/file1.parquet") + wr.s3.to_parquet(df=df3, path=f"{path}boo/2/file2.parquet") + wr.s3.to_parquet(df=df4, path=f"{path}boo/3/file3.parquet") + column_types, partitions_types = wr.catalog.extract_athena_types(df1) + wr.catalog.create_parquet_table( + database=glue_database, + table=glue_table, + path=path, + columns_types=column_types, + projection_enabled=True, + projection_types={"c1": "injected", "c2": "injected"}, + projection_storage_location_template=f"{path}${{c1}}/${{c2}}", + ) + + df5 = wr.athena.read_sql_query(f"SELECT * FROM {glue_table} WHERE c1='foo' AND c2='0'", glue_database) + pd.testing.assert_frame_equal(df1, df5, check_dtype=False) diff --git a/tests/test_config.py b/tests/test_config.py index 86ea3eccc..318f0f807 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -126,6 +126,7 @@ def test_basics(path, glue_database, glue_table, workgroup0, workgroup1): def test_athena_cache_configuration(): + wr.config.max_remote_cache_entries = 50 wr.config.max_local_cache_entries = 20 assert wr.config.max_remote_cache_entries == 20