Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
) -> Optional[VisitableType]:
"""Pyarrow to Athena data types conversion."""
if pa.types.is_int8(dtype):
return sqlalchemy.types.SMALLINT
return sqlalchemy.types.SmallInteger
if pa.types.is_int16(dtype):
return sqlalchemy.types.SMALLINT
return sqlalchemy.types.SmallInteger
if pa.types.is_int32(dtype):
return sqlalchemy.types.INTEGER
return sqlalchemy.types.Integer
if pa.types.is_int64(dtype):
return sqlalchemy.types.BIGINT
return sqlalchemy.types.BigInteger
if pa.types.is_float32(dtype):
return sqlalchemy.types.FLOAT
return sqlalchemy.types.Float
if pa.types.is_float64(dtype):
if db_type == "mysql":
return sqlalchemy.dialects.mysql.DOUBLE
Expand All @@ -195,25 +195,25 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
f"{db_type} is a invalid database type, please choose between postgresql, mysql and redshift."
) # pragma: no cover
if pa.types.is_boolean(dtype):
return sqlalchemy.types.BOOLEAN
return sqlalchemy.types.Boolean
if pa.types.is_string(dtype):
if db_type == "mysql":
return sqlalchemy.types.TEXT
return sqlalchemy.types.Text
if db_type == "postgresql":
return sqlalchemy.types.TEXT
return sqlalchemy.types.Text
if db_type == "redshift":
return sqlalchemy.types.VARCHAR(length=256)
raise exceptions.InvalidDatabaseType(
f"{db_type} is a invalid database type. " f"Please choose between postgresql, mysql and redshift."
) # pragma: no cover
if pa.types.is_timestamp(dtype):
return sqlalchemy.types.DATETIME
return sqlalchemy.types.DateTime
if pa.types.is_date(dtype):
return sqlalchemy.types.DATE
return sqlalchemy.types.Date
if pa.types.is_binary(dtype):
if db_type == "redshift":
raise exceptions.UnsupportedType("Binary columns are not supported for Redshift.") # pragma: no cover
return sqlalchemy.types.BINARY
return sqlalchemy.types.Binary
if pa.types.is_decimal(dtype):
return sqlalchemy.types.Numeric(precision=dtype.precision, scale=dtype.scale)
if pa.types.is_dictionary(dtype):
Expand Down Expand Up @@ -396,7 +396,7 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd
df[col] = (
df[col]
.astype("string")
.apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", " ", "<NA>") else None)
.apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
)
elif pandas_type == "string":
curr_type: str = str(df[col].dtypes)
Expand All @@ -405,7 +405,16 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd
else:
df[col] = df[col].astype("string")
else:
df[col] = df[col].astype(pandas_type)
try:
df[col] = df[col].astype(pandas_type)
except TypeError as ex:
if "object cannot be converted to an IntegerDtype" not in str(ex):
raise ex # pragma: no cover
df[col] = (
df[col]
.apply(lambda x: int(x) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
.astype(pandas_type)
)
return df


Expand Down
183 changes: 158 additions & 25 deletions awswrangler/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,33 @@ def create_parquet_table(
"""
table = sanitize_table_name(table=table)
partitions_types = {} if partitions_types is None else partitions_types
table_input: Dict[str, Any] = _parquet_table_definition(
table=table, path=path, columns_types=columns_types, partitions_types=partitions_types, compression=compression
)
session: boto3.Session = _utils.ensure_session(session=boto3_session)
cat_table_input: Optional[Dict[str, Any]] = _get_table_input(database=database, table=table, boto3_session=session)
table_input: Dict[str, Any]
if (cat_table_input is not None) and (mode in ("append", "overwrite_partitions")):
table_input = cat_table_input
updated: bool = False
cat_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
for c, t in columns_types.items():
if c not in cat_cols:
_logger.debug("New column %s with type %s.", c, t)
table_input["StorageDescriptor"]["Columns"].append({"Name": c, "Type": t})
updated = True
elif t != cat_cols[c]: # Data type change detected!
raise exceptions.InvalidArgumentValue(
f"Data type change detected on column {c}. Old type: {cat_cols[c]}. New type {t}."
)
if updated is True:
mode = "update"
else:
table_input = _parquet_table_definition(
table=table,
path=path,
columns_types=columns_types,
partitions_types=partitions_types,
compression=compression,
)
table_exist: bool = cat_table_input is not None
_create_table(
database=database,
table=table,
Expand All @@ -161,8 +185,9 @@ def create_parquet_table(
columns_comments=columns_comments,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
boto3_session=session,
table_input=table_input,
table_exist=table_exist,
)


Expand Down Expand Up @@ -266,7 +291,9 @@ def _parquet_partition_definition(location: str, values: List[str], compression:
}


def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, str]:
def get_table_types(
database: str, table: str, boto3_session: Optional[boto3.Session] = None
) -> Optional[Dict[str, str]]:
"""Get all columns and types from a table.

Parameters
Expand All @@ -280,8 +307,8 @@ def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Ses

Returns
-------
Dict[str, str]
A dictionary as {'col name': 'col data type'}.
Optional[Dict[str, str]]
If table exists, a dictionary like {'col name': 'col data type'}. Otherwise None.

Examples
--------
Expand All @@ -291,7 +318,10 @@ def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Ses

"""
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
response: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)
try:
response: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)
except client_glue.exceptions.EntityNotFoundException:
return None
dtypes: Dict[str, str] = {}
for col in response["Table"]["StorageDescriptor"]["Columns"]:
dtypes[col["Name"]] = col["Type"]
Expand Down Expand Up @@ -938,6 +968,7 @@ def create_csv_table(
compression=compression,
sep=sep,
)
session: boto3.Session = _utils.ensure_session(session=boto3_session)
_create_table(
database=database,
table=table,
Expand All @@ -946,8 +977,9 @@ def create_csv_table(
columns_comments=columns_comments,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
boto3_session=session,
table_input=table_input,
table_exist=does_table_exist(database=database, table=table, boto3_session=session),
)


Expand All @@ -961,6 +993,7 @@ def _create_table(
catalog_versioning: bool,
boto3_session: Optional[boto3.Session],
table_input: Dict[str, Any],
table_exist: bool,
):
if description is not None:
table_input["Description"] = description
Expand All @@ -978,23 +1011,25 @@ def _create_table(
par["Comment"] = columns_comments[name]
session: boto3.Session = _utils.ensure_session(session=boto3_session)
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
exist: bool = does_table_exist(database=database, table=table, boto3_session=session)
if mode not in ("overwrite", "append", "overwrite_partitions"): # pragma: no cover
skip_archive: bool = not catalog_versioning
if mode not in ("overwrite", "append", "overwrite_partitions", "update"): # pragma: no cover
raise exceptions.InvalidArgument(
f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'."
)
if (exist is True) and (mode == "overwrite"):
skip_archive: bool = not catalog_versioning
if (table_exist is True) and (mode == "overwrite"):
partitions_values: List[List[str]] = list(
_get_partitions(database=database, table=table, boto3_session=session).values()
)
client_glue.batch_delete_partition(
DatabaseName=database, TableName=table, PartitionsToDelete=[{"Values": v} for v in partitions_values]
)
client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive)
elif (exist is True) and (mode in ("append", "overwrite_partitions")) and (parameters is not None):
upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session)
elif exist is False:
elif (table_exist is True) and (mode in ("append", "overwrite_partitions", "update")):
if parameters is not None:
upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session)
if mode == "update":
client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive)
elif table_exist is False:
client_glue.create_table(DatabaseName=database, TableInput=table_input)


Expand Down Expand Up @@ -1379,6 +1414,88 @@ def get_table_parameters(
return parameters


def get_table_description(
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
) -> str:
"""Get table description.

Parameters
----------
database : str
Database name.
table : str
Table name.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.

Returns
-------
str
Description.

Examples
--------
>>> import awswrangler as wr
>>> desc = wr.catalog.get_table_description(database="...", table="...")

"""
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
args: Dict[str, str] = {}
if catalog_id is not None:
args["CatalogId"] = catalog_id # pragma: no cover
args["DatabaseName"] = database
args["Name"] = table
response: Dict[str, Any] = client_glue.get_table(**args)
desc: str = response["Table"]["Description"]
return desc


def get_columns_comments(
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
) -> Dict[str, str]:
"""Get all columns comments.

Parameters
----------
database : str
Database name.
table : str
Table name.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.

Returns
-------
Dict[str, str]
Columns comments. e.g. {"col1": "foo boo bar"}.

Examples
--------
>>> import awswrangler as wr
>>> pars = wr.catalog.get_table_parameters(database="...", table="...")

"""
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
args: Dict[str, str] = {}
if catalog_id is not None:
args["CatalogId"] = catalog_id # pragma: no cover
args["DatabaseName"] = database
args["Name"] = table
response: Dict[str, Any] = client_glue.get_table(**args)
comments: Dict[str, str] = {}
for c in response["Table"]["StorageDescriptor"]["Columns"]:
comments[c["Name"]] = c["Comment"]
for p in response["Table"]["PartitionKeys"]:
comments[p["Name"]] = p["Comment"]
return comments


def upsert_table_parameters(
parameters: Dict[str, str],
database: str,
Expand Down Expand Up @@ -1465,14 +1582,36 @@ def overwrite_table_parameters(
... table="...")

"""
session: boto3.Session = _utils.ensure_session(session=boto3_session)
table_input: Optional[Dict[str, Any]] = _get_table_input(
database=database, table=table, catalog_id=catalog_id, boto3_session=session
)
if table_input is None:
raise exceptions.InvalidTable(f"Table {table} does not exist on database {database}.")
table_input["Parameters"] = parameters
args2: Dict[str, Union[str, Dict[str, Any]]] = {}
if catalog_id is not None:
args2["CatalogId"] = catalog_id # pragma: no cover
args2["DatabaseName"] = database
args2["TableInput"] = table_input
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
client_glue.update_table(**args2)
return parameters


def _get_table_input(
database: str, table: str, boto3_session: Optional[boto3.Session], catalog_id: Optional[str] = None
) -> Optional[Dict[str, str]]:
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
args: Dict[str, str] = {}
if catalog_id is not None:
args["CatalogId"] = catalog_id # pragma: no cover
args["DatabaseName"] = database
args["Name"] = table
response: Dict[str, Any] = client_glue.get_table(**args)
response["Table"]["Parameters"] = parameters
try:
response: Dict[str, Any] = client_glue.get_table(**args)
except client_glue.exceptions.EntityNotFoundException:
return None
if "DatabaseName" in response["Table"]:
del response["Table"]["DatabaseName"]
if "CreateTime" in response["Table"]:
Expand All @@ -1483,10 +1622,4 @@ def overwrite_table_parameters(
del response["Table"]["CreatedBy"]
if "IsRegisteredWithLakeFormation" in response["Table"]:
del response["Table"]["IsRegisteredWithLakeFormation"]
args2: Dict[str, Union[str, Dict[str, Any]]] = {}
if catalog_id is not None:
args2["CatalogId"] = catalog_id # pragma: no cover
args2["DatabaseName"] = database
args2["TableInput"] = response["Table"]
client_glue.update_table(**args2)
return parameters
return response["Table"]
Loading