diff --git a/awswrangler/catalog/__init__.py b/awswrangler/catalog/__init__.py index 11be6645d..b83b7a7d3 100644 --- a/awswrangler/catalog/__init__.py +++ b/awswrangler/catalog/__init__.py @@ -1,6 +1,6 @@ """Amazon Glue Catalog Module.""" -from awswrangler.catalog._add import add_csv_partitions, add_parquet_partitions # noqa +from awswrangler.catalog._add import add_column, add_csv_partitions, add_parquet_partitions # noqa from awswrangler.catalog._create import ( # noqa _create_csv_table, _create_parquet_table, @@ -12,6 +12,7 @@ ) from awswrangler.catalog._delete import ( # noqa delete_all_partitions, + delete_column, delete_database, delete_partitions, delete_table_if_exists, diff --git a/awswrangler/catalog/_add.py b/awswrangler/catalog/_add.py index 7990cc1a4..9b3f6de39 100644 --- a/awswrangler/catalog/_add.py +++ b/awswrangler/catalog/_add.py @@ -7,7 +7,12 @@ from awswrangler import _utils, exceptions from awswrangler._config import apply_configs -from awswrangler.catalog._definitions import _csv_partition_definition, _parquet_partition_definition +from awswrangler.catalog._definitions import ( + _check_column_type, + _csv_partition_definition, + _parquet_partition_definition, + _update_table_definition, +) from awswrangler.catalog._utils import _catalog_id, sanitize_table_name _logger: logging.Logger = logging.getLogger(__name__) @@ -157,3 +162,65 @@ def add_parquet_partitions( _add_partitions( database=database, table=table, boto3_session=boto3_session, inputs=inputs, catalog_id=catalog_id ) + + +@apply_configs +def add_column( + database: str, + table: str, + column_name: str, + column_type: str = "string", + column_comment: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + catalog_id: Optional[str] = None, +) -> None: + """Delete a column in a AWS Glue Catalog table. + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + column_name : str + Column name + column_type : str + Column type. + column_comment : str + Column Comment + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + + Returns + ------- + None + None + + Examples + -------- + >>> import awswrangler as wr + >>> wr.catalog.add_column( + ... database='my_db', + ... table='my_table', + ... column_name='my_col', + ... column_type='int' + ... ) + """ + if _check_column_type(column_type): + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) + table_input: Dict[str, Any] = _update_table_definition(table_res) + table_input["StorageDescriptor"]["Columns"].append( + {"Name": column_name, "Type": column_type, "Comment": column_comment} + ) + res: Dict[str, Any] = client_glue.update_table( + **_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input) + ) + if ("Errors" in res) and res["Errors"]: + for error in res["Errors"]: + if "ErrorDetail" in error: + if "ErrorCode" in error["ErrorDetail"]: + raise exceptions.ServiceApiError(str(res["Errors"])) diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index 8b2076256..ceb513ec7 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -5,6 +5,27 @@ _logger: logging.Logger = logging.getLogger(__name__) +_LEGAL_COLUMN_TYPES = [ + "array", + "bigint", + "binary", + "boolean", + "char", + "date", + "decimal", + "double", + "float", + "int", + "interval", + "map", + "set", + "smallint", + "string", + "struct", + "timestamp", + "tinyint", +] + def _parquet_table_definition( table: str, path: str, columns_types: Dict[str, str], partitions_types: Dict[str, str], compression: Optional[str] @@ -138,3 +159,32 @@ def _csv_partition_definition( {"Name": cname, "Type": dtype} for cname, dtype in columns_types.items() ] return definition + + +def _check_column_type(column_type: str) -> bool: + if column_type not in _LEGAL_COLUMN_TYPES: + raise ValueError(f"{column_type} is not a legal data type.") + return True + + +def _update_table_definition(current_definition: Dict[str, Any]) -> Dict[str, Any]: + definition: Dict[str, Any] = dict() + keep_keys = [ + "Name", + "Description", + "Owner", + "LastAccessTime", + "LastAnalyzedTime", + "Retention", + "StorageDescriptor", + "PartitionKeys", + "ViewOriginalText", + "ViewExpandedText", + "TableType", + "Parameters", + "TargetTable", + ] + for key in current_definition["Table"]: + if key in keep_keys: + definition[key] = current_definition["Table"][key] + return definition diff --git a/awswrangler/catalog/_delete.py b/awswrangler/catalog/_delete.py index 1c1c32bbd..5436e8346 100644 --- a/awswrangler/catalog/_delete.py +++ b/awswrangler/catalog/_delete.py @@ -1,12 +1,13 @@ """AWS Glue Catalog Delete Module.""" import logging -from typing import List, Optional +from typing import Any, Dict, List, Optional import boto3 -from awswrangler import _utils +from awswrangler import _utils, exceptions from awswrangler._config import apply_configs +from awswrangler.catalog._definitions import _update_table_definition from awswrangler.catalog._get import _get_partitions from awswrangler.catalog._utils import _catalog_id @@ -181,3 +182,57 @@ def delete_all_partitions( boto3_session=boto3_session, ) return partitions_values + + +@apply_configs +def delete_column( + database: str, + table: str, + column_name: str, + boto3_session: Optional[boto3.Session] = None, + catalog_id: Optional[str] = None, +) -> None: + """Delete a column in a AWS Glue Catalog table. + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + column_name : str + Column name + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + + Returns + ------- + None + None + + Examples + -------- + >>> import awswrangler as wr + >>> wr.catalog.delete_column( + ... database='my_db', + ... table='my_table', + ... column_name='my_col', + ... ) + """ + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) + table_input: Dict[str, Any] = _update_table_definition(table_res) + table_input["StorageDescriptor"]["Columns"] = [ + i for i in table_input["StorageDescriptor"]["Columns"] if i["Name"] != column_name + ] + res: Dict[str, Any] = client_glue.update_table( + **_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input) + ) + if ("Errors" in res) and res["Errors"]: + for error in res["Errors"]: + if "ErrorDetail" in error: + if "ErrorCode" in error["ErrorDetail"]: + raise exceptions.ServiceApiError(str(res["Errors"])) diff --git a/tests/test_athena.py b/tests/test_athena.py index 56dc6ce91..3ceabdb65 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -244,6 +244,16 @@ def test_catalog(path: str, glue_database: str, glue_table: str) -> None: assert len(tables) > 0 for tbl in tables: assert tbl["DatabaseName"] == glue_database + # add & delete column + wr.catalog.add_column( + database=glue_database, table=glue_table, column_name="col2", column_type="int", column_comment="comment" + ) + dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table) + assert len(dtypes) == 5 + assert dtypes["col2"] == "int" + wr.catalog.delete_column(database=glue_database, table=glue_table, column_name="col2") + dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table) + assert len(dtypes) == 4 # search tables = list(wr.catalog.search_tables(text="parquet", catalog_id=account_id)) assert len(tables) > 0