Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 244 additions & 4 deletions awswrangler/glue.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Dict, Optional
from typing import Dict, Optional, Any, Iterator, List
from math import ceil
from itertools import islice
import re
import logging

from pandas import DataFrame # type: ignore

from awswrangler import data_types
from awswrangler.athena import Athena
from awswrangler.exceptions import UnsupportedFileFormat, InvalidSerDe, ApiError, UnsupportedType, UndetectedType, InvalidTable, InvalidArguments
Expand Down Expand Up @@ -55,7 +58,29 @@ def metadata_to_glue(self,
mode="append",
compression=None,
cast_columns=None,
extra_args=None):
extra_args=None,
description: Optional[str] = None,
parameters: Optional[Dict[str, str]] = None,
columns_comments: Optional[Dict[str, str]] = None) -> None:
"""

:param dataframe: Pandas Dataframe
:param objects_paths: Files paths on S3
:param preserve_index: Should preserve index on S3?
:param partition_cols: partitions names
:param mode: "append", "overwrite", "overwrite_partitions"
:param cast_columns: Dictionary of columns names and Athena/Glue types to be casted. (E.g. {"col name": "bigint", "col2 name": "int"}) (Only for "parquet" file_format)
:param database: AWS Glue Database name
:param table: AWS Glue table name
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
:param file_format: "csv" or "parquet"
:param compression: None, gzip, snappy, etc
:param extra_args: Extra arguments specific for each file formats (E.g. "sep" for CSV)
:param description: Table description
:param parameters: Key/value pairs to tag the table (Optional[Dict[str, str]])
:param columns_comments: Columns names and the related comments (Optional[Dict[str, str]])
:return: None
"""
indexes_position = "left" if file_format == "csv" else "right"
schema, partition_cols_schema = Glue._build_schema(dataframe=dataframe,
partition_cols=partition_cols,
Expand All @@ -75,7 +100,10 @@ def metadata_to_glue(self,
path=path,
file_format=file_format,
compression=compression,
extra_args=extra_args)
extra_args=extra_args,
description=description,
parameters=parameters,
columns_comments=columns_comments)
if partition_cols:
partitions_tuples = Glue._parse_partitions_tuples(objects_paths=objects_paths,
partition_cols=partition_cols)
Expand Down Expand Up @@ -111,7 +139,26 @@ def create_table(self,
file_format,
compression,
partition_cols_schema=None,
extra_args=None):
extra_args=None,
description: Optional[str] = None,
parameters: Optional[Dict[str, str]] = None,
columns_comments: Optional[Dict[str, str]] = None) -> None:
"""
Create Glue table (Catalog)

:param database: AWS Glue Database name
:param table: AWS Glue table name
:param schema: Table schema
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
:param file_format: "csv" or "parquet"
:param compression: None, gzip, snappy, etc
:param partition_cols_schema: Partitions schema
:param extra_args: Extra arguments specific for each file formats (E.g. "sep" for CSV)
:param description: Table description
:param parameters: Key/value pairs to tag the table (Optional[Dict[str, str]])
:param columns_comments: Columns names and the related comments (Optional[Dict[str, str]])
:return: None
"""
if file_format == "parquet":
table_input = Glue.parquet_table_definition(table, partition_cols_schema, schema, path, compression)
elif file_format == "csv":
Expand All @@ -123,6 +170,20 @@ def create_table(self,
extra_args=extra_args)
else:
raise UnsupportedFileFormat(file_format)
if description is not None:
table_input["Description"] = description
if parameters is not None:
for k, v in parameters.items():
table_input["Parameters"][k] = v
if columns_comments is not None:
for col in table_input["StorageDescriptor"]["Columns"]:
name = col["Name"]
if name in columns_comments:
col["Comment"] = columns_comments[name]
for par in table_input["PartitionKeys"]:
name = par["Name"]
if name in columns_comments:
par["Comment"] = columns_comments[name]
self._client_glue.create_table(DatabaseName=database, TableInput=table_input)

def add_partitions(self, database, table, partition_paths, file_format, compression, extra_args=None):
Expand Down Expand Up @@ -390,3 +451,182 @@ def get_table_location(self, database: str, table: str):
return res["Table"]["StorageDescriptor"]["Location"]
except KeyError:
raise InvalidTable(f"{database}.{table}")

def get_databases(self, catalog_id: Optional[str] = None) -> Iterator[Dict[str, Any]]:
"""
Get an iterator of databases

:param catalog_id: The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default.
:return: Iterator[Dict[str, Any]] of Databases
"""
paginator = self._client_glue.get_paginator("get_databases")
if catalog_id is None:
response_iterator = paginator.paginate()
else:
response_iterator = paginator.paginate(CatalogId=catalog_id)
for page in response_iterator:
for db in page["DatabaseList"]:
yield db

def get_tables(self,
catalog_id: Optional[str] = None,
database: Optional[str] = None,
name_contains: Optional[str] = None,
name_prefix: Optional[str] = None,
name_suffix: Optional[str] = None) -> Iterator[Dict[str, Any]]:
"""
Get an iterator of tables

:param catalog_id: The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default.
:param database: Filter a specific database
:param name_contains: Select by a specific string on table name
:param name_prefix: Select by a specific prefix on table name
:param name_suffix: Select by a specific suffix on table name
:return: Iterator[Dict[str, Any]] of Tables
"""
paginator = self._client_glue.get_paginator("get_tables")
args: Dict[str, str] = {}
if catalog_id is not None:
args["CatalogId"] = catalog_id
if (name_prefix is not None) and (name_suffix is not None) and (name_contains is not None):
args["Expression"] = f"{name_prefix}.*{name_contains}.*{name_suffix}"
elif (name_prefix is not None) and (name_suffix is not None):
args["Expression"] = f"{name_prefix}.*{name_suffix}"
elif name_contains is not None:
args["Expression"] = f".*{name_contains}.*"
elif name_prefix is not None:
args["Expression"] = f"{name_prefix}.*"
elif name_suffix is not None:
args["Expression"] = f".*{name_suffix}"
if database is not None:
databases = [database]
else:
databases = [x["Name"] for x in self.get_databases(catalog_id=catalog_id)]
for db in databases:
args["DatabaseName"] = db
response_iterator = paginator.paginate(**args)
for page in response_iterator:
for tbl in page["TableList"]:
yield tbl

def tables(self,
limit: int = 100,
catalog_id: Optional[str] = None,
database: Optional[str] = None,
search_text: Optional[str] = None,
name_contains: Optional[str] = None,
name_prefix: Optional[str] = None,
name_suffix: Optional[str] = None) -> DataFrame:
"""
Get a Dataframe with tables filtered by a search term, prefix, suffix.

:param limit: Max number of tables
:param catalog_id: The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default.
:param database: Glue database name
:param search_text: Select only tables with the given string in table's properties
:param name_contains: Select by a specific string on table name
:param name_prefix: Select only tables with the given string in the name prefix
:param name_suffix: Select only tables with the given string in the name suffix
:return: Pandas Dataframe filled by formatted infos
"""
if search_text is None:
table_iter = self.get_tables(catalog_id=catalog_id,
database=database,
name_contains=name_contains,
name_prefix=name_prefix,
name_suffix=name_suffix)
tables: List[Dict[str, Any]] = list(islice(table_iter, limit))
else:
tables = list(self.search_tables(text=search_text, catalog_id=catalog_id))
if database is not None:
tables = [x for x in tables if x["DatabaseName"] == database]
if name_contains is not None:
tables = [x for x in tables if name_contains in x["Name"]]
if name_prefix is not None:
tables = [x for x in tables if x["Name"].startswith(name_prefix)]
if name_suffix is not None:
tables = [x for x in tables if x["Name"].endswith(name_suffix)]
tables = tables[:limit]

df_dict: Dict[str, List] = {"Database": [], "Table": [], "Description": [], "Columns": [], "Partitions": []}
for table in tables:
df_dict["Database"].append(table["DatabaseName"])
df_dict["Table"].append(table["Name"])
if "Description" in table:
df_dict["Description"].append(table["Description"])
else:
df_dict["Description"].append("")
df_dict["Columns"].append(", ".join([x["Name"] for x in table["StorageDescriptor"]["Columns"]]))
df_dict["Partitions"].append(", ".join([x["Name"] for x in table["PartitionKeys"]]))
return DataFrame(data=df_dict)

def search_tables(self, text: str, catalog_id: Optional[str] = None):
"""
Get iterator of tables filtered by a search string.

:param text: Select only tables with the given string in table's properties.
:param catalog_id: The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default.
:return: Iterator of tables
"""
args: Dict[str, Any] = {"SearchText": text}
if catalog_id is not None:
args["CatalogId"] = catalog_id
response = self._client_glue.search_tables(**args)
for tbl in response["TableList"]:
yield tbl
while "NextToken" in response:
args["NextToken"] = response["NextToken"]
response = self._client_glue.search_tables(**args)
for tbl in response["TableList"]:
yield tbl

def databases(self, limit: int = 100, catalog_id: Optional[str] = None) -> DataFrame:
"""
Get iterator of databases.

:param limit: Max number of tables
:param catalog_id: The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default.
:return: Pandas Dataframe filled by formatted infos
"""
database_iter = self.get_databases(catalog_id=catalog_id)
dbs = islice(database_iter, limit)
df_dict: Dict[str, List] = {"Database": [], "Description": []}
for db in dbs:
df_dict["Database"].append(db["Name"])
if "Description" in db:
df_dict["Description"].append(db["Description"])
else:
df_dict["Description"].append("")
return DataFrame(data=df_dict)

def table(self, database: str, name: str, catalog_id: Optional[str] = None) -> DataFrame:
"""
Get table details as Pandas Dataframe

:param database: Glue database name
:param name: Table name
:param catalog_id: The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default.
:return: Pandas Dataframe filled by formatted infos
"""
if catalog_id is None:
table: Dict[str, Any] = self._client_glue.get_table(DatabaseName=database, Name=name)["Table"]
else:
table = self._client_glue.get_table(CatalogId=catalog_id, DatabaseName=database, Name=name)["Table"]
df_dict: Dict[str, List] = {"Column Name": [], "Type": [], "Partition": [], "Comment": []}
for col in table["StorageDescriptor"]["Columns"]:
df_dict["Column Name"].append(col["Name"])
df_dict["Type"].append(col["Type"])
df_dict["Partition"].append(False)
if "Comment" in table:
df_dict["Comment"].append(table["Comment"])
else:
df_dict["Comment"].append("")
for col in table["PartitionKeys"]:
df_dict["Column Name"].append(col["Name"])
df_dict["Type"].append(col["Type"])
df_dict["Partition"].append(True)
if "Comment" in table:
df_dict["Comment"].append(table["Comment"])
else:
df_dict["Comment"].append("")
return DataFrame(data=df_dict)
Loading