Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion awswrangler/catalog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Amazon Glue Catalog Module."""

from awswrangler.catalog._add import add_csv_partitions, add_parquet_partitions # noqa
from awswrangler.catalog._add import add_column, add_csv_partitions, add_parquet_partitions # noqa
from awswrangler.catalog._create import ( # noqa
_create_csv_table,
_create_parquet_table,
Expand All @@ -12,6 +12,7 @@
)
from awswrangler.catalog._delete import ( # noqa
delete_all_partitions,
delete_column,
delete_database,
delete_partitions,
delete_table_if_exists,
Expand Down
69 changes: 68 additions & 1 deletion awswrangler/catalog/_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

from awswrangler import _utils, exceptions
from awswrangler._config import apply_configs
from awswrangler.catalog._definitions import _csv_partition_definition, _parquet_partition_definition
from awswrangler.catalog._definitions import (
_check_column_type,
_csv_partition_definition,
_parquet_partition_definition,
_update_table_definition,
)
from awswrangler.catalog._utils import _catalog_id, sanitize_table_name

_logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -157,3 +162,65 @@ def add_parquet_partitions(
_add_partitions(
database=database, table=table, boto3_session=boto3_session, inputs=inputs, catalog_id=catalog_id
)


@apply_configs
def add_column(
database: str,
table: str,
column_name: str,
column_type: str = "string",
column_comment: Optional[str] = None,
boto3_session: Optional[boto3.Session] = None,
catalog_id: Optional[str] = None,
) -> None:
"""Delete a column in a AWS Glue Catalog table.

Parameters
----------
database : str
Database name.
table : str
Table name.
column_name : str
Column name
column_type : str
Column type.
column_comment : str
Column Comment
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.

Returns
-------
None
None

Examples
--------
>>> import awswrangler as wr
>>> wr.catalog.add_column(
... database='my_db',
... table='my_table',
... column_name='my_col',
... column_type='int'
... )
"""
if _check_column_type(column_type):
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)
table_input: Dict[str, Any] = _update_table_definition(table_res)
table_input["StorageDescriptor"]["Columns"].append(
{"Name": column_name, "Type": column_type, "Comment": column_comment}
)
res: Dict[str, Any] = client_glue.update_table(
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input)
)
if ("Errors" in res) and res["Errors"]:
for error in res["Errors"]:
if "ErrorDetail" in error:
if "ErrorCode" in error["ErrorDetail"]:
raise exceptions.ServiceApiError(str(res["Errors"]))
50 changes: 50 additions & 0 deletions awswrangler/catalog/_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,27 @@

_logger: logging.Logger = logging.getLogger(__name__)

_LEGAL_COLUMN_TYPES = [
"array",
"bigint",
"binary",
"boolean",
"char",
"date",
"decimal",
"double",
"float",
"int",
"interval",
"map",
"set",
"smallint",
"string",
"struct",
"timestamp",
"tinyint",
]


def _parquet_table_definition(
table: str, path: str, columns_types: Dict[str, str], partitions_types: Dict[str, str], compression: Optional[str]
Expand Down Expand Up @@ -138,3 +159,32 @@ def _csv_partition_definition(
{"Name": cname, "Type": dtype} for cname, dtype in columns_types.items()
]
return definition


def _check_column_type(column_type: str) -> bool:
if column_type not in _LEGAL_COLUMN_TYPES:
raise ValueError(f"{column_type} is not a legal data type.")
return True


def _update_table_definition(current_definition: Dict[str, Any]) -> Dict[str, Any]:
definition: Dict[str, Any] = dict()
keep_keys = [
"Name",
"Description",
"Owner",
"LastAccessTime",
"LastAnalyzedTime",
"Retention",
"StorageDescriptor",
"PartitionKeys",
"ViewOriginalText",
"ViewExpandedText",
"TableType",
"Parameters",
"TargetTable",
]
for key in current_definition["Table"]:
if key in keep_keys:
definition[key] = current_definition["Table"][key]
return definition
59 changes: 57 additions & 2 deletions awswrangler/catalog/_delete.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""AWS Glue Catalog Delete Module."""

import logging
from typing import List, Optional
from typing import Any, Dict, List, Optional

import boto3

from awswrangler import _utils
from awswrangler import _utils, exceptions
from awswrangler._config import apply_configs
from awswrangler.catalog._definitions import _update_table_definition
from awswrangler.catalog._get import _get_partitions
from awswrangler.catalog._utils import _catalog_id

Expand Down Expand Up @@ -181,3 +182,57 @@ def delete_all_partitions(
boto3_session=boto3_session,
)
return partitions_values


@apply_configs
def delete_column(
database: str,
table: str,
column_name: str,
boto3_session: Optional[boto3.Session] = None,
catalog_id: Optional[str] = None,
) -> None:
"""Delete a column in a AWS Glue Catalog table.

Parameters
----------
database : str
Database name.
table : str
Table name.
column_name : str
Column name
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.

Returns
-------
None
None

Examples
--------
>>> import awswrangler as wr
>>> wr.catalog.delete_column(
... database='my_db',
... table='my_table',
... column_name='my_col',
... )
"""
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)
table_input: Dict[str, Any] = _update_table_definition(table_res)
table_input["StorageDescriptor"]["Columns"] = [
i for i in table_input["StorageDescriptor"]["Columns"] if i["Name"] != column_name
]
res: Dict[str, Any] = client_glue.update_table(
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input)
)
if ("Errors" in res) and res["Errors"]:
for error in res["Errors"]:
if "ErrorDetail" in error:
if "ErrorCode" in error["ErrorDetail"]:
raise exceptions.ServiceApiError(str(res["Errors"]))
10 changes: 10 additions & 0 deletions tests/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ def test_catalog(path: str, glue_database: str, glue_table: str) -> None:
assert len(tables) > 0
for tbl in tables:
assert tbl["DatabaseName"] == glue_database
# add & delete column
wr.catalog.add_column(
database=glue_database, table=glue_table, column_name="col2", column_type="int", column_comment="comment"
)
dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table)
assert len(dtypes) == 5
assert dtypes["col2"] == "int"
wr.catalog.delete_column(database=glue_database, table=glue_table, column_name="col2")
dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table)
assert len(dtypes) == 4
# search
tables = list(wr.catalog.search_tables(text="parquet", catalog_id=account_id))
assert len(tables) > 0
Expand Down