diff --git a/awswrangler/catalog.py b/awswrangler/catalog.py index 3b408785e..86b59906c 100644 --- a/awswrangler/catalog.py +++ b/awswrangler/catalog.py @@ -5,7 +5,7 @@ import logging import re import unicodedata -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from urllib.parse import quote_plus import boto3 # type: ignore @@ -989,6 +989,8 @@ def _create_table( DatabaseName=database, TableName=table, PartitionsToDelete=[{"Values": v} for v in partitions_values] ) client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive) + elif (exist is True) and (mode == "append") and (parameters is not None): + upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session) elif exist is False: client_glue.create_table(DatabaseName=database, TableInput=table_input) @@ -1333,3 +1335,155 @@ def extract_athena_types( return _data_types.athena_types_from_pandas_partitioned( df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=index_left ) + + +def get_table_parameters( + database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None +) -> Dict[str, str]: + """Get all parameters. + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + Dict[str, str] + Dictionary of parameters. + + Examples + -------- + >>> import awswrangler as wr + >>> pars = wr.catalog.get_table_parameters(database="...", table="...") + + """ + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + args: Dict[str, str] = {} + if catalog_id is not None: + args["CatalogId"] = catalog_id # pragma: no cover + args["DatabaseName"] = database + args["Name"] = table + response: Dict[str, Any] = client_glue.get_table(**args) + parameters: Dict[str, str] = response["Table"]["Parameters"] + return parameters + + +def upsert_table_parameters( + parameters: Dict[str, str], + database: str, + table: str, + catalog_id: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, +) -> Dict[str, str]: + """Insert or Update the received parameters. + + Parameters + ---------- + parameters : Dict[str, str] + e.g. {"source": "mysql", "destination": "datalake"} + database : str + Database name. + table : str + Table name. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + Dict[str, str] + All parameters after the upsert. + + Examples + -------- + >>> import awswrangler as wr + >>> pars = wr.catalog.upsert_table_parameters( + ... parameters={"source": "mysql", "destination": "datalake"}, + ... database="...", + ... table="...") + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + pars: Dict[str, str] = get_table_parameters( + database=database, table=table, catalog_id=catalog_id, boto3_session=session + ) + for k, v in parameters.items(): + pars[k] = v + overwrite_table_parameters( + parameters=pars, database=database, table=table, catalog_id=catalog_id, boto3_session=session + ) + return pars + + +def overwrite_table_parameters( + parameters: Dict[str, str], + database: str, + table: str, + catalog_id: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, +) -> Dict[str, str]: + """Overwrite all existing parameters. + + Parameters + ---------- + parameters : Dict[str, str] + e.g. {"source": "mysql", "destination": "datalake"} + database : str + Database name. + table : str + Table name. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + Dict[str, str] + All parameters after the overwrite (The same received). + + Examples + -------- + >>> import awswrangler as wr + >>> pars = wr.catalog.overwrite_table_parameters( + ... parameters={"source": "mysql", "destination": "datalake"}, + ... database="...", + ... table="...") + + """ + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + args: Dict[str, str] = {} + if catalog_id is not None: + args["CatalogId"] = catalog_id # pragma: no cover + args["DatabaseName"] = database + args["Name"] = table + response: Dict[str, Any] = client_glue.get_table(**args) + response["Table"]["Parameters"] = parameters + if "DatabaseName" in response["Table"]: + del response["Table"]["DatabaseName"] + if "CreateTime" in response["Table"]: + del response["Table"]["CreateTime"] + if "UpdateTime" in response["Table"]: + del response["Table"]["UpdateTime"] + if "CreatedBy" in response["Table"]: + del response["Table"]["CreatedBy"] + if "IsRegisteredWithLakeFormation" in response["Table"]: + del response["Table"]["IsRegisteredWithLakeFormation"] + args2: Dict[str, Union[str, Dict[str, Any]]] = {} + if catalog_id is not None: + args2["CatalogId"] = catalog_id # pragma: no cover + args2["DatabaseName"] = database + args2["TableInput"] = response["Table"] + client_glue.update_table(**args2) + return parameters diff --git a/awswrangler/s3.py b/awswrangler/s3.py index 3f9960600..9ba8a0db6 100644 --- a/awswrangler/s3.py +++ b/awswrangler/s3.py @@ -453,6 +453,10 @@ def to_csv( # pylint: disable=too-many-arguments The table name and all column names will be automatically sanitize using `wr.catalog.sanitize_table_name` and `wr.catalog.sanitize_column_name`. + Note + ---- + On `append` mode, the `parameters` will be upsert on an existing table. + Note ---- In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count(). @@ -640,7 +644,6 @@ def to_csv( # pylint: disable=too-many-arguments paths = [path] else: mode = "append" if mode is None else mode - exist: bool = False if columns: df = df[columns] if (database is not None) and (table is not None): # Normalize table to respect Athena's standards @@ -648,7 +651,7 @@ def to_csv( # pylint: disable=too-many-arguments partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols] dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()} columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()} - exist = catalog.does_table_exist(database=database, table=table, boto3_session=session) + exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session) if (exist is True) and (mode in ("append", "overwrite_partitions")): for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items(): dtype[k] = v @@ -669,21 +672,20 @@ def to_csv( # pylint: disable=too-many-arguments columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True ) - if (exist is False) or (mode == "overwrite"): - catalog.create_csv_table( - database=database, - table=table, - path=path, - columns_types=columns_types, - partitions_types=partitions_types, - description=description, - parameters=parameters, - columns_comments=columns_comments, - boto3_session=session, - mode="overwrite", - catalog_versioning=catalog_versioning, - sep=sep, - ) + catalog.create_csv_table( + database=database, + table=table, + path=path, + columns_types=columns_types, + partitions_types=partitions_types, + description=description, + parameters=parameters, + columns_comments=columns_comments, + boto3_session=session, + mode=mode, + catalog_versioning=catalog_versioning, + sep=sep, + ) if partitions_values: _logger.debug("partitions_values:\n%s", partitions_values) catalog.add_csv_partitions( @@ -869,6 +871,10 @@ def to_parquet( # pylint: disable=too-many-arguments The table name and all column names will be automatically sanitize using `wr.catalog.sanitize_table_name` and `wr.catalog.sanitize_column_name`. + Note + ---- + On `append` mode, the `parameters` will be upsert on an existing table. + Note ---- In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count(). @@ -1058,13 +1064,12 @@ def to_parquet( # pylint: disable=too-many-arguments ] else: mode = "append" if mode is None else mode - exist: bool = False if (database is not None) and (table is not None): # Normalize table to respect Athena's standards df = catalog.sanitize_dataframe_columns_names(df=df) partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols] dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()} columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()} - exist = catalog.does_table_exist(database=database, table=table, boto3_session=session) + exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session) if (exist is True) and (mode in ("append", "overwrite_partitions")): for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items(): dtype[k] = v @@ -1087,21 +1092,20 @@ def to_parquet( # pylint: disable=too-many-arguments columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( df=df, index=index, partition_cols=partition_cols, dtype=dtype ) - if (exist is False) or (mode == "overwrite"): - catalog.create_parquet_table( - database=database, - table=table, - path=path, - columns_types=columns_types, - partitions_types=partitions_types, - compression=compression, - description=description, - parameters=parameters, - columns_comments=columns_comments, - boto3_session=session, - mode="overwrite", - catalog_versioning=catalog_versioning, - ) + catalog.create_parquet_table( + database=database, + table=table, + path=path, + columns_types=columns_types, + partitions_types=partitions_types, + compression=compression, + description=description, + parameters=parameters, + columns_comments=columns_comments, + boto3_session=session, + mode=mode, + catalog_versioning=catalog_versioning, + ) if partitions_values: _logger.debug("partitions_values:\n%s", partitions_values) catalog.add_parquet_partitions( @@ -1865,6 +1869,10 @@ def store_parquet_metadata( The concept of Dataset goes beyond the simple idea of files and enable more complex features like partitioning and catalog integration (AWS Glue Catalog). + Note + ---- + On `append` mode, the `parameters` will be upsert on an existing table. + Note ---- In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count(). diff --git a/docs/source/api.rst b/docs/source/api.rst index 6b841705e..6444f382c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -63,6 +63,9 @@ AWS Glue Catalog drop_duplicated_columns get_engine extract_athena_types + get_table_parameters + upsert_table_parameters + upsert_table_parameters Amazon Athena ------------- diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index 21a4f4bf0..497fa344f 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -1478,3 +1478,63 @@ def test_parquet_overwrite_partition_cols(bucket, database, external_schema): wr.s3.delete_objects(path=path) wr.catalog.delete_table_if_exists(database=database, table=table) + + +def test_catalog_parameters(bucket, database): + table = "test_catalog_parameters" + path = f"s3://{bucket}/{table}/" + wr.s3.delete_objects(path=path) + wr.catalog.delete_table_if_exists(database=database, table=table) + + wr.s3.to_parquet( + df=pd.DataFrame({"c0": [1, 2]}), + path=path, + dataset=True, + database=database, + table=table, + mode="overwrite", + parameters={"a": "1", "b": "2"}, + ) + pars = wr.catalog.get_table_parameters(database=database, table=table) + assert pars["a"] == "1" + assert pars["b"] == "2" + pars["a"] = "0" + pars["c"] = "3" + wr.catalog.upsert_table_parameters(parameters=pars, database=database, table=table) + pars = wr.catalog.get_table_parameters(database=database, table=table) + assert pars["a"] == "0" + assert pars["b"] == "2" + assert pars["c"] == "3" + wr.catalog.overwrite_table_parameters(parameters={"d": "4"}, database=database, table=table) + pars = wr.catalog.get_table_parameters(database=database, table=table) + assert pars.get("a") is None + assert pars.get("b") is None + assert pars.get("c") is None + assert pars["d"] == "4" + df = wr.athena.read_sql_table(table=table, database=database) + assert len(df.index) == 2 + assert len(df.columns) == 1 + assert df.c0.sum() == 3 + + wr.s3.to_parquet( + df=pd.DataFrame({"c0": [3, 4]}), + path=path, + dataset=True, + database=database, + table=table, + mode="append", + parameters={"e": "5"}, + ) + pars = wr.catalog.get_table_parameters(database=database, table=table) + assert pars.get("a") is None + assert pars.get("b") is None + assert pars.get("c") is None + assert pars["d"] == "4" + assert pars["e"] == "5" + df = wr.athena.read_sql_table(table=table, database=database) + assert len(df.index) == 4 + assert len(df.columns) == 1 + assert df.c0.sum() == 10 + + wr.s3.delete_objects(path=path) + wr.catalog.delete_table_if_exists(database=database, table=table)