From 42eac323229133a5e38ef962490d4b3d55c47287 Mon Sep 17 00:00:00 2001 From: Bryan Date: Mon, 16 Nov 2020 22:22:00 +0800 Subject: [PATCH 1/4] add add/delete column --- awswrangler/catalog/__init__.py | 10 +++----- awswrangler/catalog/_add.py | 36 ++++++++++++++++++++++++++++- awswrangler/catalog/_definitions.py | 23 ++++++++++++++++++ awswrangler/catalog/_delete.py | 30 ++++++++++++++++++++++-- tests/test_athena.py | 18 +++++++++++++++ 5 files changed, 107 insertions(+), 10 deletions(-) diff --git a/awswrangler/catalog/__init__.py b/awswrangler/catalog/__init__.py index 11be6645d..adceef43f 100644 --- a/awswrangler/catalog/__init__.py +++ b/awswrangler/catalog/__init__.py @@ -1,6 +1,6 @@ """Amazon Glue Catalog Module.""" -from awswrangler.catalog._add import add_csv_partitions, add_parquet_partitions # noqa +from awswrangler.catalog._add import add_column, add_csv_partitions, add_parquet_partitions # noqa from awswrangler.catalog._create import ( # noqa _create_csv_table, _create_parquet_table, @@ -10,12 +10,8 @@ overwrite_table_parameters, upsert_table_parameters, ) -from awswrangler.catalog._delete import ( # noqa - delete_all_partitions, - delete_database, - delete_partitions, - delete_table_if_exists, -) +from awswrangler.catalog._delete import delete_table_if_exists # noqa +from awswrangler.catalog._delete import delete_all_partitions, delete_column, delete_database, delete_partitions from awswrangler.catalog._get import ( # noqa _get_table_input, databases, diff --git a/awswrangler/catalog/_add.py b/awswrangler/catalog/_add.py index 7990cc1a4..d6ada1804 100644 --- a/awswrangler/catalog/_add.py +++ b/awswrangler/catalog/_add.py @@ -7,7 +7,12 @@ from awswrangler import _utils, exceptions from awswrangler._config import apply_configs -from awswrangler.catalog._definitions import _csv_partition_definition, _parquet_partition_definition +from awswrangler.catalog._definitions import ( + _check_column_type, + _csv_partition_definition, + _parquet_partition_definition, + _update_table_definition, +) from awswrangler.catalog._utils import _catalog_id, sanitize_table_name _logger: logging.Logger = logging.getLogger(__name__) @@ -157,3 +162,32 @@ def add_parquet_partitions( _add_partitions( database=database, table=table, boto3_session=boto3_session, inputs=inputs, catalog_id=catalog_id ) + + +@apply_configs +def add_column( + database: str, + table: str, + column_name: str, + column_type: str = 'string', + column_comment: str = None, + boto3_session: Optional[boto3.Session] = None, + catalog_id: Optional[str] = None, +) -> None: + if _check_column_type(column_type): + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) + table_input: dict = _update_table_definition(res) + table_input['StorageDescriptor']['Columns'].append({ + 'Name': column_name, + 'Type': column_type, + 'Comment': column_comment + }) + res: Dict[str, Any] = client_glue.update_table( + **_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input) + ) + if ("Errors" in res) and res["Errors"]: + for error in res["Errors"]: + if "ErrorDetail" in error: + if "ErrorCode" in error["ErrorDetail"]: + raise exceptions.ServiceApiError(str(res["Errors"])) diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index 8b2076256..881eb3ae6 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -5,6 +5,9 @@ _logger: logging.Logger = logging.getLogger(__name__) +_LEGAL_COLUMN_TYPES = ["array", "bigint", "binary", "boolean", "char", "date", "decimal", "double", "float", "int", + "interval", "map", "set", "smallint", "string", "struct", "timestamp", "tinyint"] + def _parquet_table_definition( table: str, path: str, columns_types: Dict[str, str], partitions_types: Dict[str, str], compression: Optional[str] @@ -138,3 +141,23 @@ def _csv_partition_definition( {"Name": cname, "Type": dtype} for cname, dtype in columns_types.items() ] return definition + + +def _check_column_type( + column_type: str +) -> bool: + if column_type not in _LEGAL_COLUMN_TYPES: + raise ValueError(f"{column_type} is not a legal data type.") + else: + return True + + +def _update_table_definition(current_definition: dict): + definition: dict[Any, Any] = dict() + keep_keys = ["Name", "Description", "Owner", "LastAccessTime", "LastAnalyzedTime", "Retention", + "StorageDescriptor", "PartitionKeys", "ViewOriginalText", "ViewExpandedText", "TableType", + "Parameters", "TargetTable"] + for key in current_definition['Table']: + if key in keep_keys: + definition[key] = current_definition['Table'][key] + return definition diff --git a/awswrangler/catalog/_delete.py b/awswrangler/catalog/_delete.py index 1c1c32bbd..88603cd65 100644 --- a/awswrangler/catalog/_delete.py +++ b/awswrangler/catalog/_delete.py @@ -1,12 +1,13 @@ """AWS Glue Catalog Delete Module.""" import logging -from typing import List, Optional +from typing import Any, Dict, List, Optional import boto3 -from awswrangler import _utils +from awswrangler import _utils, exceptions from awswrangler._config import apply_configs +from awswrangler.catalog._definitions import _update_table_definition from awswrangler.catalog._get import _get_partitions from awswrangler.catalog._utils import _catalog_id @@ -181,3 +182,28 @@ def delete_all_partitions( boto3_session=boto3_session, ) return partitions_values + + +@apply_configs +def delete_column( + database: str, + table: str, + column_name: str, + boto3_session: Optional[boto3.Session] = None, + catalog_id: Optional[str] = None, +) -> None: + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) + table_input: dict = _update_table_definition(res) + print(table_input) + table_input['StorageDescriptor']['Columns'] = \ + [i for i in table_input['StorageDescriptor']['Columns'] if i['Name'] != column_name] + print(table_input) + res: Dict[str, Any] = client_glue.update_table( + **_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input) + ) + if ("Errors" in res) and res["Errors"]: + for error in res["Errors"]: + if "ErrorDetail" in error: + if "ErrorCode" in error["ErrorDetail"]: + raise exceptions.ServiceApiError(str(res["Errors"])) diff --git a/tests/test_athena.py b/tests/test_athena.py index 56dc6ce91..69e229d95 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -244,6 +244,24 @@ def test_catalog(path: str, glue_database: str, glue_table: str) -> None: assert len(tables) > 0 for tbl in tables: assert tbl["DatabaseName"] == glue_database + # add & delete column + wr.catalog.add_column( + database=glue_database, + table=glue_table, + column_name="col2", + column_type="int", + column_comment="comment" + ) + dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table) + assert len(dtypes) == 5 + assert dtypes["col2"] == "int" + wr.catalog.delete_column( + database=glue_database, + table=glue_table, + column_name="col2" + ) + dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table) + assert len(dtypes) == 4 # search tables = list(wr.catalog.search_tables(text="parquet", catalog_id=account_id)) assert len(tables) > 0 From 9c8acfc7c4b7c24fe74950f87fd1a8d41059454c Mon Sep 17 00:00:00 2001 From: Bryan Date: Mon, 16 Nov 2020 22:34:00 +0800 Subject: [PATCH 2/4] fix format --- awswrangler/catalog/_add.py | 6 +++--- awswrangler/catalog/_definitions.py | 4 ++-- awswrangler/catalog/_delete.py | 6 ++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/awswrangler/catalog/_add.py b/awswrangler/catalog/_add.py index d6ada1804..7471cc8f7 100644 --- a/awswrangler/catalog/_add.py +++ b/awswrangler/catalog/_add.py @@ -170,14 +170,14 @@ def add_column( table: str, column_name: str, column_type: str = 'string', - column_comment: str = None, + column_comment: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, catalog_id: Optional[str] = None, ) -> None: if _check_column_type(column_type): client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) - res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) - table_input: dict = _update_table_definition(res) + table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) + table_input: Dict[str, Any] = _update_table_definition(table_res) table_input['StorageDescriptor']['Columns'].append({ 'Name': column_name, 'Type': column_type, diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index 881eb3ae6..4c91bcf97 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -152,8 +152,8 @@ def _check_column_type( return True -def _update_table_definition(current_definition: dict): - definition: dict[Any, Any] = dict() +def _update_table_definition(current_definition: Dict[str, Any]) -> Dict[str, Any]: + definition: Dict[str, Any] = dict() keep_keys = ["Name", "Description", "Owner", "LastAccessTime", "LastAnalyzedTime", "Retention", "StorageDescriptor", "PartitionKeys", "ViewOriginalText", "ViewExpandedText", "TableType", "Parameters", "TargetTable"] diff --git a/awswrangler/catalog/_delete.py b/awswrangler/catalog/_delete.py index 88603cd65..3241e8c1c 100644 --- a/awswrangler/catalog/_delete.py +++ b/awswrangler/catalog/_delete.py @@ -193,12 +193,10 @@ def delete_column( catalog_id: Optional[str] = None, ) -> None: client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) - res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) - table_input: dict = _update_table_definition(res) - print(table_input) + table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) + table_input: Dict[str, Any] = _update_table_definition(table_res) table_input['StorageDescriptor']['Columns'] = \ [i for i in table_input['StorageDescriptor']['Columns'] if i['Name'] != column_name] - print(table_input) res: Dict[str, Any] = client_glue.update_table( **_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input) ) From a421e2a9643ee5d293eeb900d10a5b333ae4dc2f Mon Sep 17 00:00:00 2001 From: Bryan Date: Mon, 16 Nov 2020 22:41:25 +0800 Subject: [PATCH 3/4] fix flake8 --- awswrangler/catalog/__init__.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/awswrangler/catalog/__init__.py b/awswrangler/catalog/__init__.py index adceef43f..b83b7a7d3 100644 --- a/awswrangler/catalog/__init__.py +++ b/awswrangler/catalog/__init__.py @@ -10,8 +10,13 @@ overwrite_table_parameters, upsert_table_parameters, ) -from awswrangler.catalog._delete import delete_table_if_exists # noqa -from awswrangler.catalog._delete import delete_all_partitions, delete_column, delete_database, delete_partitions +from awswrangler.catalog._delete import ( # noqa + delete_all_partitions, + delete_column, + delete_database, + delete_partitions, + delete_table_if_exists, +) from awswrangler.catalog._get import ( # noqa _get_table_input, databases, From 85e2217d1e5ded3385dcef07a5987f21c5e98ded Mon Sep 17 00:00:00 2001 From: Bryan Date: Mon, 16 Nov 2020 23:03:03 +0800 Subject: [PATCH 4/4] add docs --- awswrangler/catalog/_add.py | 57 +++++++++++++++++++++++------ awswrangler/catalog/_definitions.py | 51 ++++++++++++++++++++------ awswrangler/catalog/_delete.py | 45 +++++++++++++++++++---- tests/test_athena.py | 12 +----- 4 files changed, 124 insertions(+), 41 deletions(-) diff --git a/awswrangler/catalog/_add.py b/awswrangler/catalog/_add.py index 7471cc8f7..9b3f6de39 100644 --- a/awswrangler/catalog/_add.py +++ b/awswrangler/catalog/_add.py @@ -166,23 +166,56 @@ def add_parquet_partitions( @apply_configs def add_column( - database: str, - table: str, - column_name: str, - column_type: str = 'string', - column_comment: Optional[str] = None, - boto3_session: Optional[boto3.Session] = None, - catalog_id: Optional[str] = None, + database: str, + table: str, + column_name: str, + column_type: str = "string", + column_comment: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + catalog_id: Optional[str] = None, ) -> None: + """Delete a column in a AWS Glue Catalog table. + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + column_name : str + Column name + column_type : str + Column type. + column_comment : str + Column Comment + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + + Returns + ------- + None + None + + Examples + -------- + >>> import awswrangler as wr + >>> wr.catalog.add_column( + ... database='my_db', + ... table='my_table', + ... column_name='my_col', + ... column_type='int' + ... ) + """ if _check_column_type(column_type): client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) table_input: Dict[str, Any] = _update_table_definition(table_res) - table_input['StorageDescriptor']['Columns'].append({ - 'Name': column_name, - 'Type': column_type, - 'Comment': column_comment - }) + table_input["StorageDescriptor"]["Columns"].append( + {"Name": column_name, "Type": column_type, "Comment": column_comment} + ) res: Dict[str, Any] = client_glue.update_table( **_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input) ) diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index 4c91bcf97..ceb513ec7 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -5,8 +5,26 @@ _logger: logging.Logger = logging.getLogger(__name__) -_LEGAL_COLUMN_TYPES = ["array", "bigint", "binary", "boolean", "char", "date", "decimal", "double", "float", "int", - "interval", "map", "set", "smallint", "string", "struct", "timestamp", "tinyint"] +_LEGAL_COLUMN_TYPES = [ + "array", + "bigint", + "binary", + "boolean", + "char", + "date", + "decimal", + "double", + "float", + "int", + "interval", + "map", + "set", + "smallint", + "string", + "struct", + "timestamp", + "tinyint", +] def _parquet_table_definition( @@ -143,21 +161,30 @@ def _csv_partition_definition( return definition -def _check_column_type( - column_type: str -) -> bool: +def _check_column_type(column_type: str) -> bool: if column_type not in _LEGAL_COLUMN_TYPES: raise ValueError(f"{column_type} is not a legal data type.") - else: - return True + return True def _update_table_definition(current_definition: Dict[str, Any]) -> Dict[str, Any]: definition: Dict[str, Any] = dict() - keep_keys = ["Name", "Description", "Owner", "LastAccessTime", "LastAnalyzedTime", "Retention", - "StorageDescriptor", "PartitionKeys", "ViewOriginalText", "ViewExpandedText", "TableType", - "Parameters", "TargetTable"] - for key in current_definition['Table']: + keep_keys = [ + "Name", + "Description", + "Owner", + "LastAccessTime", + "LastAnalyzedTime", + "Retention", + "StorageDescriptor", + "PartitionKeys", + "ViewOriginalText", + "ViewExpandedText", + "TableType", + "Parameters", + "TargetTable", + ] + for key in current_definition["Table"]: if key in keep_keys: - definition[key] = current_definition['Table'][key] + definition[key] = current_definition["Table"][key] return definition diff --git a/awswrangler/catalog/_delete.py b/awswrangler/catalog/_delete.py index 3241e8c1c..5436e8346 100644 --- a/awswrangler/catalog/_delete.py +++ b/awswrangler/catalog/_delete.py @@ -186,17 +186,48 @@ def delete_all_partitions( @apply_configs def delete_column( - database: str, - table: str, - column_name: str, - boto3_session: Optional[boto3.Session] = None, - catalog_id: Optional[str] = None, + database: str, + table: str, + column_name: str, + boto3_session: Optional[boto3.Session] = None, + catalog_id: Optional[str] = None, ) -> None: + """Delete a column in a AWS Glue Catalog table. + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + column_name : str + Column name + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + + Returns + ------- + None + None + + Examples + -------- + >>> import awswrangler as wr + >>> wr.catalog.delete_column( + ... database='my_db', + ... table='my_table', + ... column_name='my_col', + ... ) + """ client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) table_input: Dict[str, Any] = _update_table_definition(table_res) - table_input['StorageDescriptor']['Columns'] = \ - [i for i in table_input['StorageDescriptor']['Columns'] if i['Name'] != column_name] + table_input["StorageDescriptor"]["Columns"] = [ + i for i in table_input["StorageDescriptor"]["Columns"] if i["Name"] != column_name + ] res: Dict[str, Any] = client_glue.update_table( **_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input) ) diff --git a/tests/test_athena.py b/tests/test_athena.py index 69e229d95..3ceabdb65 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -246,20 +246,12 @@ def test_catalog(path: str, glue_database: str, glue_table: str) -> None: assert tbl["DatabaseName"] == glue_database # add & delete column wr.catalog.add_column( - database=glue_database, - table=glue_table, - column_name="col2", - column_type="int", - column_comment="comment" + database=glue_database, table=glue_table, column_name="col2", column_type="int", column_comment="comment" ) dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table) assert len(dtypes) == 5 assert dtypes["col2"] == "int" - wr.catalog.delete_column( - database=glue_database, - table=glue_table, - column_name="col2" - ) + wr.catalog.delete_column(database=glue_database, table=glue_table, column_name="col2") dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table) assert len(dtypes) == 4 # search