Skip to content

Commit

Permalink
Add get_table_parameters, upsert_table_parameters, upsert_table_param…
Browse files Browse the repository at this point in the history
…eters. #224
  • Loading branch information
igorborgest committed May 8, 2020
1 parent 014228f commit a87867a
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 35 deletions.
156 changes: 155 additions & 1 deletion awswrangler/catalog.py
Expand Up @@ -5,7 +5,7 @@
import logging
import re
import unicodedata
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from urllib.parse import quote_plus

import boto3 # type: ignore
Expand Down Expand Up @@ -989,6 +989,8 @@ def _create_table(
DatabaseName=database, TableName=table, PartitionsToDelete=[{"Values": v} for v in partitions_values]
)
client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive)
elif (exist is True) and (mode == "append") and (parameters is not None):
upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session)
elif exist is False:
client_glue.create_table(DatabaseName=database, TableInput=table_input)

Expand Down Expand Up @@ -1333,3 +1335,155 @@ def extract_athena_types(
return _data_types.athena_types_from_pandas_partitioned(
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=index_left
)


def get_table_parameters(
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
) -> Dict[str, str]:
"""Get all parameters.
Parameters
----------
database : str
Database name.
table : str
Table name.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
Returns
-------
Dict[str, str]
Dictionary of parameters.
Examples
--------
>>> import awswrangler as wr
>>> pars = wr.catalog.get_table_parameters(database="...", table="...")
"""
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
args: Dict[str, str] = {}
if catalog_id is not None:
args["CatalogId"] = catalog_id # pragma: no cover
args["DatabaseName"] = database
args["Name"] = table
response: Dict[str, Any] = client_glue.get_table(**args)
parameters: Dict[str, str] = response["Table"]["Parameters"]
return parameters


def upsert_table_parameters(
parameters: Dict[str, str],
database: str,
table: str,
catalog_id: Optional[str] = None,
boto3_session: Optional[boto3.Session] = None,
) -> Dict[str, str]:
"""Insert or Update the received parameters.
Parameters
----------
parameters : Dict[str, str]
e.g. {"source": "mysql", "destination": "datalake"}
database : str
Database name.
table : str
Table name.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
Returns
-------
Dict[str, str]
All parameters after the upsert.
Examples
--------
>>> import awswrangler as wr
>>> pars = wr.catalog.upsert_table_parameters(
... parameters={"source": "mysql", "destination": "datalake"},
... database="...",
... table="...")
"""
session: boto3.Session = _utils.ensure_session(session=boto3_session)
pars: Dict[str, str] = get_table_parameters(
database=database, table=table, catalog_id=catalog_id, boto3_session=session
)
for k, v in parameters.items():
pars[k] = v
overwrite_table_parameters(
parameters=pars, database=database, table=table, catalog_id=catalog_id, boto3_session=session
)
return pars


def overwrite_table_parameters(
parameters: Dict[str, str],
database: str,
table: str,
catalog_id: Optional[str] = None,
boto3_session: Optional[boto3.Session] = None,
) -> Dict[str, str]:
"""Overwrite all existing parameters.
Parameters
----------
parameters : Dict[str, str]
e.g. {"source": "mysql", "destination": "datalake"}
database : str
Database name.
table : str
Table name.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
Returns
-------
Dict[str, str]
All parameters after the overwrite (The same received).
Examples
--------
>>> import awswrangler as wr
>>> pars = wr.catalog.overwrite_table_parameters(
... parameters={"source": "mysql", "destination": "datalake"},
... database="...",
... table="...")
"""
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
args: Dict[str, str] = {}
if catalog_id is not None:
args["CatalogId"] = catalog_id # pragma: no cover
args["DatabaseName"] = database
args["Name"] = table
response: Dict[str, Any] = client_glue.get_table(**args)
response["Table"]["Parameters"] = parameters
if "DatabaseName" in response["Table"]:
del response["Table"]["DatabaseName"]
if "CreateTime" in response["Table"]:
del response["Table"]["CreateTime"]
if "UpdateTime" in response["Table"]:
del response["Table"]["UpdateTime"]
if "CreatedBy" in response["Table"]:
del response["Table"]["CreatedBy"]
if "IsRegisteredWithLakeFormation" in response["Table"]:
del response["Table"]["IsRegisteredWithLakeFormation"]
args2: Dict[str, Union[str, Dict[str, Any]]] = {}
if catalog_id is not None:
args2["CatalogId"] = catalog_id # pragma: no cover
args2["DatabaseName"] = database
args2["TableInput"] = response["Table"]
client_glue.update_table(**args2)
return parameters
76 changes: 42 additions & 34 deletions awswrangler/s3.py
Expand Up @@ -453,6 +453,10 @@ def to_csv( # pylint: disable=too-many-arguments
The table name and all column names will be automatically sanitize using
`wr.catalog.sanitize_table_name` and `wr.catalog.sanitize_column_name`.
Note
----
On `append` mode, the `parameters` will be upsert on an existing table.
Note
----
In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count().
Expand Down Expand Up @@ -640,15 +644,14 @@ def to_csv( # pylint: disable=too-many-arguments
paths = [path]
else:
mode = "append" if mode is None else mode
exist: bool = False
if columns:
df = df[columns]
if (database is not None) and (table is not None): # Normalize table to respect Athena's standards
df = catalog.sanitize_dataframe_columns_names(df=df)
partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols]
dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()}
columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()}
exist = catalog.does_table_exist(database=database, table=table, boto3_session=session)
exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session)
if (exist is True) and (mode in ("append", "overwrite_partitions")):
for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items():
dtype[k] = v
Expand All @@ -669,21 +672,20 @@ def to_csv( # pylint: disable=too-many-arguments
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True
)
if (exist is False) or (mode == "overwrite"):
catalog.create_csv_table(
database=database,
table=table,
path=path,
columns_types=columns_types,
partitions_types=partitions_types,
description=description,
parameters=parameters,
columns_comments=columns_comments,
boto3_session=session,
mode="overwrite",
catalog_versioning=catalog_versioning,
sep=sep,
)
catalog.create_csv_table(
database=database,
table=table,
path=path,
columns_types=columns_types,
partitions_types=partitions_types,
description=description,
parameters=parameters,
columns_comments=columns_comments,
boto3_session=session,
mode=mode,
catalog_versioning=catalog_versioning,
sep=sep,
)
if partitions_values:
_logger.debug("partitions_values:\n%s", partitions_values)
catalog.add_csv_partitions(
Expand Down Expand Up @@ -869,6 +871,10 @@ def to_parquet( # pylint: disable=too-many-arguments
The table name and all column names will be automatically sanitize using
`wr.catalog.sanitize_table_name` and `wr.catalog.sanitize_column_name`.
Note
----
On `append` mode, the `parameters` will be upsert on an existing table.
Note
----
In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count().
Expand Down Expand Up @@ -1058,13 +1064,12 @@ def to_parquet( # pylint: disable=too-many-arguments
]
else:
mode = "append" if mode is None else mode
exist: bool = False
if (database is not None) and (table is not None): # Normalize table to respect Athena's standards
df = catalog.sanitize_dataframe_columns_names(df=df)
partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols]
dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()}
columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()}
exist = catalog.does_table_exist(database=database, table=table, boto3_session=session)
exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session)
if (exist is True) and (mode in ("append", "overwrite_partitions")):
for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items():
dtype[k] = v
Expand All @@ -1087,21 +1092,20 @@ def to_parquet( # pylint: disable=too-many-arguments
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
df=df, index=index, partition_cols=partition_cols, dtype=dtype
)
if (exist is False) or (mode == "overwrite"):
catalog.create_parquet_table(
database=database,
table=table,
path=path,
columns_types=columns_types,
partitions_types=partitions_types,
compression=compression,
description=description,
parameters=parameters,
columns_comments=columns_comments,
boto3_session=session,
mode="overwrite",
catalog_versioning=catalog_versioning,
)
catalog.create_parquet_table(
database=database,
table=table,
path=path,
columns_types=columns_types,
partitions_types=partitions_types,
compression=compression,
description=description,
parameters=parameters,
columns_comments=columns_comments,
boto3_session=session,
mode=mode,
catalog_versioning=catalog_versioning,
)
if partitions_values:
_logger.debug("partitions_values:\n%s", partitions_values)
catalog.add_parquet_partitions(
Expand Down Expand Up @@ -1865,6 +1869,10 @@ def store_parquet_metadata(
The concept of Dataset goes beyond the simple idea of files and enable more
complex features like partitioning and catalog integration (AWS Glue Catalog).
Note
----
On `append` mode, the `parameters` will be upsert on an existing table.
Note
----
In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count().
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api.rst
Expand Up @@ -63,6 +63,9 @@ AWS Glue Catalog
drop_duplicated_columns
get_engine
extract_athena_types
get_table_parameters
upsert_table_parameters
upsert_table_parameters

Amazon Athena
-------------
Expand Down
60 changes: 60 additions & 0 deletions testing/test_awswrangler/test_data_lake.py
Expand Up @@ -1478,3 +1478,63 @@ def test_parquet_overwrite_partition_cols(bucket, database, external_schema):

wr.s3.delete_objects(path=path)
wr.catalog.delete_table_if_exists(database=database, table=table)


def test_catalog_parameters(bucket, database):
table = "test_catalog_parameters"
path = f"s3://{bucket}/{table}/"
wr.s3.delete_objects(path=path)
wr.catalog.delete_table_if_exists(database=database, table=table)

wr.s3.to_parquet(
df=pd.DataFrame({"c0": [1, 2]}),
path=path,
dataset=True,
database=database,
table=table,
mode="overwrite",
parameters={"a": "1", "b": "2"},
)
pars = wr.catalog.get_table_parameters(database=database, table=table)
assert pars["a"] == "1"
assert pars["b"] == "2"
pars["a"] = "0"
pars["c"] = "3"
wr.catalog.upsert_table_parameters(parameters=pars, database=database, table=table)
pars = wr.catalog.get_table_parameters(database=database, table=table)
assert pars["a"] == "0"
assert pars["b"] == "2"
assert pars["c"] == "3"
wr.catalog.overwrite_table_parameters(parameters={"d": "4"}, database=database, table=table)
pars = wr.catalog.get_table_parameters(database=database, table=table)
assert pars.get("a") is None
assert pars.get("b") is None
assert pars.get("c") is None
assert pars["d"] == "4"
df = wr.athena.read_sql_table(table=table, database=database)
assert len(df.index) == 2
assert len(df.columns) == 1
assert df.c0.sum() == 3

wr.s3.to_parquet(
df=pd.DataFrame({"c0": [3, 4]}),
path=path,
dataset=True,
database=database,
table=table,
mode="append",
parameters={"e": "5"},
)
pars = wr.catalog.get_table_parameters(database=database, table=table)
assert pars.get("a") is None
assert pars.get("b") is None
assert pars.get("c") is None
assert pars["d"] == "4"
assert pars["e"] == "5"
df = wr.athena.read_sql_table(table=table, database=database)
assert len(df.index) == 4
assert len(df.columns) == 1
assert df.c0.sum() == 10

wr.s3.delete_objects(path=path)
wr.catalog.delete_table_if_exists(database=database, table=table)

0 comments on commit a87867a

Please sign in to comment.