Skip to content

Commit

Permalink
feat: add columns parameters support (#2814)
Browse files Browse the repository at this point in the history
* feat: add columns parameters support

* tests: expand test routine

* docs: fix minor docs edits

---------

Co-authored-by: Leon Luttenberger <LeonLuttenberger@users.noreply.github.com>
  • Loading branch information
jaidisido and LeonLuttenberger committed May 14, 2024
1 parent 1686539 commit 44ae3fb
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 22 deletions.
2 changes: 2 additions & 0 deletions awswrangler/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_get_table_input,
databases,
get_columns_comments,
get_columns_parameters,
get_connection,
get_csv_partitions,
get_databases,
Expand Down Expand Up @@ -83,6 +84,7 @@
"_get_table_input",
"databases",
"get_columns_comments",
"get_columns_parameters",
"get_connection",
"get_csv_partitions",
"get_databases",
Expand Down
42 changes: 41 additions & 1 deletion awswrangler/catalog/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
_logger: logging.Logger = logging.getLogger(__name__)


def _update_if_necessary(dic: dict[str, str], key: str, value: str | None, mode: str) -> str:
def _update_if_necessary(
dic: dict[str, str | dict[str, str]], key: str, value: str | dict[str, str] | None, mode: str
) -> str:
if value is not None:
if key not in dic or dic[key] != value:
dic[key] = value
Expand All @@ -46,6 +48,7 @@ def _create_table( # noqa: PLR0912,PLR0915
table_exist: bool,
partitions_types: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
catalog_id: str | None,
) -> None:
Expand Down Expand Up @@ -130,6 +133,19 @@ def _create_table( # noqa: PLR0912,PLR0915
if name in columns_comments:
mode = _update_if_necessary(dic=par, key="Comment", value=columns_comments[name], mode=mode)

# Column parameters
columns_parameters = columns_parameters if columns_parameters else {}
columns_parameters = {sanitize_column_name(k): v for k, v in columns_parameters.items()}
if columns_parameters:
for col in table_input["StorageDescriptor"]["Columns"]:
name: str = col["Name"] # type: ignore[no-redef]
if name in columns_parameters:
mode = _update_if_necessary(dic=col, key="Parameters", value=columns_parameters[name], mode=mode)
for par in table_input["PartitionKeys"]:
name = par["Name"]
if name in columns_parameters:
mode = _update_if_necessary(dic=par, key="Parameters", value=columns_parameters[name], mode=mode)

_logger.debug("table_input: %s", table_input)

client_glue = _utils.client(service_name="glue", session=boto3_session)
Expand Down Expand Up @@ -275,6 +291,7 @@ def _create_parquet_table(
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
Expand Down Expand Up @@ -311,6 +328,7 @@ def _create_parquet_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
Expand All @@ -335,6 +353,7 @@ def _create_orc_table(
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
Expand Down Expand Up @@ -371,6 +390,7 @@ def _create_orc_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
Expand All @@ -394,6 +414,7 @@ def _create_csv_table(
compression: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
schema_evolution: bool,
Expand Down Expand Up @@ -444,6 +465,7 @@ def _create_csv_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
Expand All @@ -467,6 +489,7 @@ def _create_json_table(
compression: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
schema_evolution: bool,
Expand Down Expand Up @@ -512,6 +535,7 @@ def _create_json_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
Expand Down Expand Up @@ -713,6 +737,7 @@ def create_parquet_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: Literal["overwrite", "append"] = "overwrite",
catalog_versioning: bool = False,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None = None,
Expand Down Expand Up @@ -751,6 +776,8 @@ def create_parquet_table(
Key/value pairs to tag the table.
columns_comments: Dict[str, str], optional
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
columns_parameters: Dict[str, Dict[str, str]], optional
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
mode: str
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
catalog_versioning : bool
Expand Down Expand Up @@ -848,6 +875,7 @@ def create_parquet_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
athena_partition_projection_settings=athena_partition_projection_settings,
Expand All @@ -870,6 +898,7 @@ def create_orc_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: Literal["overwrite", "append"] = "overwrite",
catalog_versioning: bool = False,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None = None,
Expand Down Expand Up @@ -908,6 +937,8 @@ def create_orc_table(
Key/value pairs to tag the table.
columns_comments: Dict[str, str], optional
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
columns_parameters: Dict[str, Dict[str, str]], optional
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
mode: str
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
catalog_versioning : bool
Expand Down Expand Up @@ -1005,6 +1036,7 @@ def create_orc_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
athena_partition_projection_settings=athena_partition_projection_settings,
Expand All @@ -1026,6 +1058,7 @@ def create_csv_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: Literal["overwrite", "append"] = "overwrite",
catalog_versioning: bool = False,
schema_evolution: bool = False,
Expand Down Expand Up @@ -1072,6 +1105,8 @@ def create_csv_table(
Key/value pairs to tag the table.
columns_comments: Dict[str, str], optional
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
columns_parameters: Dict[str, Dict[str, str]], optional
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
mode : str
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
catalog_versioning : bool
Expand Down Expand Up @@ -1188,6 +1223,7 @@ def create_csv_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
schema_evolution=schema_evolution,
Expand All @@ -1214,6 +1250,7 @@ def create_json_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: Literal["overwrite", "append"] = "overwrite",
catalog_versioning: bool = False,
schema_evolution: bool = False,
Expand Down Expand Up @@ -1253,6 +1290,8 @@ def create_json_table(
Key/value pairs to tag the table.
columns_comments: Dict[str, str], optional
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
columns_parameters: Dict[str, Dict[str, str]], optional
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
mode : str
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
catalog_versioning : bool
Expand Down Expand Up @@ -1361,6 +1400,7 @@ def create_json_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
schema_evolution=schema_evolution,
Expand Down
45 changes: 44 additions & 1 deletion awswrangler/catalog/_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import base64
import itertools
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterator, cast
from typing import TYPE_CHECKING, Any, Dict, Iterator, Mapping, cast

import boto3
import botocore.exceptions
Expand Down Expand Up @@ -887,6 +887,49 @@ def get_columns_comments(
return comments


@apply_configs
def get_columns_parameters(
database: str,
table: str,
catalog_id: str | None = None,
boto3_session: boto3.Session | None = None,
) -> dict[str, Mapping[str, str] | None]:
"""Get all columns parameters.
Parameters
----------
database : str
Database name.
table : str
Table name.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
Returns
-------
Dict[str, Optional[Dict[str, str]]]
Columns parameters.
Examples
--------
>>> import awswrangler as wr
>>> pars = wr.catalog.get_columns_parameters(database="...", table="...")
"""
client_glue = _utils.client("glue", session=boto3_session)
response = client_glue.get_table(**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table))
parameters = {}
for c in response["Table"]["StorageDescriptor"]["Columns"]:
parameters[c["Name"]] = c.get("Parameters")
if "PartitionKeys" in response["Table"]:
for p in response["Table"]["PartitionKeys"]:
parameters[p["Name"]] = p.get("Parameters")
return parameters


@apply_configs
def get_table_versions(
database: str, table: str, catalog_id: str | None = None, boto3_session: boto3.Session | None = None
Expand Down
8 changes: 6 additions & 2 deletions awswrangler/s3/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _validate_args(
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
execution_engine: Enum,
) -> None:
if df.empty is True:
Expand All @@ -87,11 +88,11 @@ def _validate_args(
raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use bucketing_info.")
if mode is not None:
raise exceptions.InvalidArgumentCombination("Please pass dataset=True to be able to use mode.")
if any(arg is not None for arg in (table, description, parameters, columns_comments)):
if any(arg is not None for arg in (table, description, parameters, columns_comments, columns_parameters)):
raise exceptions.InvalidArgumentCombination(
"Please pass dataset=True to be able to use any one of these "
"arguments: database, table, description, parameters, "
"columns_comments."
"columns_comments, columns_parameters."
)
elif (database is None) != (table is None):
raise exceptions.InvalidArgumentCombination(
Expand Down Expand Up @@ -214,6 +215,7 @@ def _create_glue_table(
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
Expand Down Expand Up @@ -262,6 +264,7 @@ def write( # noqa: PLR0912,PLR0913
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
regular_partitions: bool,
table_type: str | None,
dtype: dict[str, str] | None,
Expand Down Expand Up @@ -361,6 +364,7 @@ def write( # noqa: PLR0912,PLR0913
"description": description,
"parameters": parameters,
"columns_comments": columns_comments,
"columns_parameters": columns_parameters,
"boto3_session": boto3_session,
"mode": mode,
"catalog_versioning": catalog_versioning,
Expand Down
5 changes: 5 additions & 0 deletions awswrangler/s3/_write_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _create_glue_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: str = "overwrite",
catalog_versioning: bool = False,
athena_partition_projection_settings: AthenaPartitionProjectionSettings | None = None,
Expand All @@ -272,6 +273,7 @@ def _create_glue_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
athena_partition_projection_settings=athena_partition_projection_settings,
Expand Down Expand Up @@ -629,6 +631,7 @@ def to_orc(
description = glue_table_settings.get("description")
parameters = glue_table_settings.get("parameters")
columns_comments = glue_table_settings.get("columns_comments")
columns_parameters = glue_table_settings.get("columns_parameters")
regular_partitions = glue_table_settings.get("regular_partitions", True)

_validate_args(
Expand All @@ -643,6 +646,7 @@ def to_orc(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
execution_engine=engine.get(),
)

Expand Down Expand Up @@ -682,6 +686,7 @@ def to_orc(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
table_type=table_type,
regular_partitions=regular_partitions,
dtype=dtype,
Expand Down

0 comments on commit 44ae3fb

Please sign in to comment.