From 4a5ac7a66227fbcf6545de0113baaa09bfe59dae Mon Sep 17 00:00:00 2001 From: igorborgest Date: Tue, 14 Apr 2020 12:37:57 -0300 Subject: [PATCH] Add Dataset feature to s3.to_csv #141 #170 --- .pylintrc | 4 +- awswrangler/_data_types.py | 31 +- awswrangler/catalog.py | 583 +++++++++++++++++---- awswrangler/s3.py | 339 ++++++++++-- docs/source/api.rst | 4 + testing/run-tests.sh | 2 +- testing/test_awswrangler/_utils.py | 43 ++ testing/test_awswrangler/test_data_lake.py | 212 +++++++- testing/test_awswrangler/test_db.py | 2 + 9 files changed, 1060 insertions(+), 160 deletions(-) diff --git a/.pylintrc b/.pylintrc index 0ecba0bcc..132ce213a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -555,10 +555,10 @@ max-attributes=7 max-bool-expr=5 # Maximum number of branch for function / method body. -max-branches=12 +max-branches=15 # Maximum number of locals for function / method body. -max-locals=25 +max-locals=30 # Maximum number of parents for a class (see R0901). max-parents=7 diff --git a/awswrangler/_data_types.py b/awswrangler/_data_types.py index 31b648c47..2289d572a 100644 --- a/awswrangler/_data_types.py +++ b/awswrangler/_data_types.py @@ -219,7 +219,7 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta def pyarrow_types_from_pandas( - df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None + df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None, index_left: bool = False ) -> Dict[str, pa.DataType]: """Extract the related Pyarrow data types from any Pandas DataFrame.""" # Handle exception data types (e.g. Int64, Int32, string) @@ -251,18 +251,23 @@ def pyarrow_types_from_pandas( if (name not in df.columns) and (index is True): indexes.append(name) + # Merging Index + sorted_cols: List[str] = indexes + list(df.columns) if index_left is True else list(df.columns) + indexes + # Filling schema columns_types: Dict[str, pa.DataType] - columns_types = {n: cols_dtypes[n] for n in list(df.columns) + indexes} # add cols + indexes + columns_types = {n: cols_dtypes[n] for n in sorted_cols} _logger.debug(f"columns_types: {columns_types}") return columns_types -def athena_types_from_pandas(df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None) -> Dict[str, str]: +def athena_types_from_pandas( + df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None, index_left: bool = False +) -> Dict[str, str]: """Extract the related Athena data types from any Pandas DataFrame.""" casts: Dict[str, str] = dtype if dtype else {} pa_columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas( - df=df, index=index, ignore_cols=list(casts.keys()) + df=df, index=index, ignore_cols=list(casts.keys()), index_left=index_left ) athena_columns_types: Dict[str, str] = {} for k, v in pa_columns_types.items(): @@ -275,11 +280,17 @@ def athena_types_from_pandas(df: pd.DataFrame, index: bool, dtype: Optional[Dict def athena_types_from_pandas_partitioned( - df: pd.DataFrame, index: bool, partition_cols: Optional[List[str]] = None, dtype: Optional[Dict[str, str]] = None + df: pd.DataFrame, + index: bool, + partition_cols: Optional[List[str]] = None, + dtype: Optional[Dict[str, str]] = None, + index_left: bool = False, ) -> Tuple[Dict[str, str], Dict[str, str]]: """Extract the related Athena data types from any Pandas DataFrame considering possible partitions.""" partitions: List[str] = partition_cols if partition_cols else [] - athena_columns_types: Dict[str, str] = athena_types_from_pandas(df=df, index=index, dtype=dtype) + athena_columns_types: Dict[str, str] = athena_types_from_pandas( + df=df, index=index, dtype=dtype, index_left=index_left + ) columns_types: Dict[str, str] = {} partitions_types: Dict[str, str] = {} for k, v in athena_columns_types.items(): @@ -296,10 +307,12 @@ def pyarrow_schema_from_pandas( """Extract the related Pyarrow Schema from any Pandas DataFrame.""" casts: Dict[str, str] = {} if dtype is None else dtype ignore: List[str] = [] if ignore_cols is None else ignore_cols - ignore = ignore + list(casts.keys()) - columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(df=df, index=index, ignore_cols=ignore) + ignore_plus = ignore + list(casts.keys()) + columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas( + df=df, index=index, ignore_cols=ignore_plus + ) for k, v in casts.items(): - if k in df.columns: + if (k in df.columns) and (k not in ignore): columns_types[k] = athena2pyarrow(v) columns_types = {k: v for k, v in columns_types.items() if v is not None} _logger.debug(f"columns_types: {columns_types}") diff --git a/awswrangler/catalog.py b/awswrangler/catalog.py index f7a6ead16..6e1bcd374 100644 --- a/awswrangler/catalog.py +++ b/awswrangler/catalog.py @@ -5,14 +5,14 @@ import logging import re import unicodedata -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional, Tuple from urllib.parse import quote_plus import boto3 # type: ignore import pandas as pd # type: ignore import sqlalchemy # type: ignore -from awswrangler import _utils, exceptions +from awswrangler import _data_types, _utils, exceptions _logger: logging.Logger = logging.getLogger(__name__) @@ -150,24 +150,16 @@ def create_parquet_table( table_input: Dict[str, Any] = _parquet_table_definition( table=table, path=path, columns_types=columns_types, partitions_types=partitions_types, compression=compression ) - if description is not None: - table_input["Description"] = description - if parameters is not None: - for k, v in parameters.items(): - table_input["Parameters"][k] = v - if columns_comments is not None: - for col in table_input["StorageDescriptor"]["Columns"]: - name: str = col["Name"] - if name in columns_comments: - col["Comment"] = columns_comments[name] - for par in table_input["PartitionKeys"]: - name = par["Name"] - if name in columns_comments: - par["Comment"] = columns_comments[name] - if mode == "overwrite": - delete_table_if_exists(database=database, table=table, boto3_session=boto3_session) - client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) - client_glue.create_table(DatabaseName=database, TableInput=table_input) + _create_table( + database=database, + table=table, + description=description, + parameters=parameters, + columns_comments=columns_comments, + mode=mode, + boto3_session=boto3_session, + table_input=table_input, + ) def _parquet_table_definition( @@ -248,95 +240,7 @@ def add_parquet_partitions( _parquet_partition_definition(location=k, values=v, compression=compression) for k, v in partitions_values.items() ] - chunks: List[List[Dict[str, Any]]] = _utils.chunkify(lst=inputs, max_length=100) - client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) - for chunk in chunks: - res: Dict[str, Any] = client_glue.batch_create_partition( - DatabaseName=database, TableName=table, PartitionInputList=chunk - ) - if ("Errors" in res) and res["Errors"]: # pragma: no cover - raise exceptions.ServiceApiError(str(res["Errors"])) - - -def get_parquet_partitions( - database: str, - table: str, - expression: Optional[str] = None, - catalog_id: Optional[str] = None, - boto3_session: Optional[boto3.Session] = None, -) -> Dict[str, List[str]]: - """Get all partitions from a Table in the AWS Glue Catalog. - - Expression argument instructions: - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue.html#Glue.Client.get_partitions - - Parameters - ---------- - database : str - Database name. - table : str - Table name. - expression : str, optional - An expression that filters the partitions to be returned. - catalog_id : str, optional - The ID of the Data Catalog from which to retrieve Databases. - If none is provided, the AWS account ID is used by default. - boto3_session : boto3.Session(), optional - Boto3 Session. The default boto3 session will be used if boto3_session receive None. - - Returns - ------- - Dict[str, List[str]] - partitions_values: Dictionary with keys as S3 path locations and values as a - list of partitions values as str (e.g. {'s3://bucket/prefix/y=2020/m=10/': ['2020', '10']}). - - Examples - -------- - Fetch all partitions - - >>> import awswrangler as wr - >>> wr.catalog.get_parquet_partitions( - ... database='default', - ... table='my_table', - ... ) - { - 's3://bucket/prefix/y=2020/m=10/': ['2020', '10'], - 's3://bucket/prefix/y=2020/m=11/': ['2020', '11'], - 's3://bucket/prefix/y=2020/m=12/': ['2020', '12'] - } - - Filtering partitions - - >>> import awswrangler as wr - >>> wr.catalog.get_parquet_partitions( - ... database='default', - ... table='my_table', - ... expression='m=10' - ... ) - { - 's3://bucket/prefix/y=2020/m=10/': ['2020', '10'] - } - - """ - client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) - paginator = client_glue.get_paginator("get_partitions") - args: Dict[str, Any] = {} - if expression is not None: - args["Expression"] = expression - if catalog_id is not None: - args["CatalogId"] = catalog_id - response_iterator = paginator.paginate( - DatabaseName=database, TableName=table, PaginationConfig={"PageSize": 1000}, **args - ) - partitions_values: Dict[str, List[str]] = {} - for page in response_iterator: - if (page is not None) and ("Partitions" in page): - for partition in page["Partitions"]: - location: Optional[str] = partition["StorageDescriptor"].get("Location") - if location is not None: - values: List[str] = partition["Values"] - partitions_values[location] = values - return partitions_values + _add_partitions(database=database, table=table, boto3_session=boto3_session, inputs=inputs) def _parquet_partition_definition(location: str, values: List[str], compression: Optional[str]) -> Dict[str, Any]: @@ -943,3 +847,464 @@ def get_engine( raise exceptions.InvalidDatabaseType( # pragma: no cover f"{db_type} is not a valid Database type." f" Only Redshift, PostgreSQL and MySQL are supported." ) + + +def create_csv_table( + database: str, + table: str, + path: str, + columns_types: Dict[str, str], + partitions_types: Optional[Dict[str, str]], + compression: Optional[str] = None, + description: Optional[str] = None, + parameters: Optional[Dict[str, str]] = None, + columns_comments: Optional[Dict[str, str]] = None, + mode: str = "overwrite", + sep: str = ",", + boto3_session: Optional[boto3.Session] = None, +) -> None: + """Create a CSV Table (Metadata Only) in the AWS Glue Catalog. + + 'https://docs.aws.amazon.com/athena/latest/ug/data-types.html' + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + path : str + Amazon S3 path (e.g. s3://bucket/prefix/). + columns_types: Dict[str, str] + Dictionary with keys as column names and vales as data types (e.g. {'col0': 'bigint', 'col1': 'double'}). + partitions_types: Dict[str, str], optional + Dictionary with keys as partition names and values as data types (e.g. {'col2': 'date'}). + compression: str, optional + Compression style (``None``, ``gzip``, etc). + description: str, optional + Table description + parameters: Dict[str, str], optional + Key/value pairs to tag the table. + columns_comments: Dict[str, str], optional + Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}). + mode: str + Only 'overwrite' available by now. + sep : str + String of length 1. Field delimiter for the output file. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + None + None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.catalog.create_csv_table( + ... database='default', + ... table='my_table', + ... path='s3://bucket/prefix/', + ... columns_types={'col0': 'bigint', 'col1': 'double'}, + ... partitions_types={'col2': 'date'}, + ... compression='gzip', + ... description='My own table!', + ... parameters={'source': 'postgresql'}, + ... columns_comments={'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'} + ... ) + + """ + table = sanitize_table_name(table=table) + partitions_types = {} if partitions_types is None else partitions_types + table_input: Dict[str, Any] = _csv_table_definition( + table=table, + path=path, + columns_types=columns_types, + partitions_types=partitions_types, + compression=compression, + sep=sep, + ) + _create_table( + database=database, + table=table, + description=description, + parameters=parameters, + columns_comments=columns_comments, + mode=mode, + boto3_session=boto3_session, + table_input=table_input, + ) + + +def _create_table( + database: str, + table: str, + description: Optional[str], + parameters: Optional[Dict[str, str]], + columns_comments: Optional[Dict[str, str]], + mode: str, + boto3_session: Optional[boto3.Session], + table_input: Dict[str, Any], +): + if description is not None: + table_input["Description"] = description + if parameters is not None: + for k, v in parameters.items(): + table_input["Parameters"][k] = v + if columns_comments is not None: + for col in table_input["StorageDescriptor"]["Columns"]: + name: str = col["Name"] + if name in columns_comments: + col["Comment"] = columns_comments[name] + for par in table_input["PartitionKeys"]: + name = par["Name"] + if name in columns_comments: + par["Comment"] = columns_comments[name] + session: boto3.Session = _utils.ensure_session(session=boto3_session) + if mode == "overwrite": + delete_table_if_exists(database=database, table=table, boto3_session=session) + client_glue: boto3.client = _utils.client(service_name="glue", session=session) + client_glue.create_table(DatabaseName=database, TableInput=table_input) + + +def _csv_table_definition( + table: str, + path: str, + columns_types: Dict[str, str], + partitions_types: Dict[str, str], + compression: Optional[str], + sep: str, +) -> Dict[str, Any]: + compressed: bool = compression is not None + return { + "Name": table, + "PartitionKeys": [{"Name": cname, "Type": dtype} for cname, dtype in partitions_types.items()], + "TableType": "EXTERNAL_TABLE", + "Parameters": { + "classification": "csv", + "compressionType": str(compression).lower(), + "typeOfData": "file", + "delimiter": sep, + "columnsOrdered": "true", + "areColumnsQuoted": "false", + }, + "StorageDescriptor": { + "Columns": [{"Name": cname, "Type": dtype} for cname, dtype in columns_types.items()], + "Location": path, + "InputFormat": "org.apache.hadoop.mapred.TextInputFormat", + "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat", + "Compressed": compressed, + "NumberOfBuckets": -1, + "SerdeInfo": { + "Parameters": {"field.delim": sep, "escape.delim": "\\"}, + "SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", + }, + "StoredAsSubDirectories": False, + "SortColumns": [], + "Parameters": { + "classification": "csv", + "compressionType": str(compression).lower(), + "typeOfData": "file", + "delimiter": sep, + "columnsOrdered": "true", + "areColumnsQuoted": "false", + }, + }, + } + + +def add_csv_partitions( + database: str, + table: str, + partitions_values: Dict[str, List[str]], + compression: Optional[str] = None, + sep: str = ",", + boto3_session: Optional[boto3.Session] = None, +) -> None: + """Add partitions (metadata) to a CSV Table in the AWS Glue Catalog. + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + partitions_values: Dict[str, List[str]] + Dictionary with keys as S3 path locations and values as a list of partitions values as str + (e.g. {'s3://bucket/prefix/y=2020/m=10/': ['2020', '10']}). + compression: str, optional + Compression style (``None``, ``gzip``, etc). + sep : str + String of length 1. Field delimiter for the output file. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + None + None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.catalog.add_csv_partitions( + ... database='default', + ... table='my_table', + ... partitions_values={ + ... 's3://bucket/prefix/y=2020/m=10/': ['2020', '10'], + ... 's3://bucket/prefix/y=2020/m=11/': ['2020', '11'], + ... 's3://bucket/prefix/y=2020/m=12/': ['2020', '12'] + ... } + ... ) + + """ + inputs: List[Dict[str, Any]] = [ + _csv_partition_definition(location=k, values=v, compression=compression, sep=sep) + for k, v in partitions_values.items() + ] + _add_partitions(database=database, table=table, boto3_session=boto3_session, inputs=inputs) + + +def _add_partitions(database: str, table: str, boto3_session: Optional[boto3.Session], inputs: List[Dict[str, Any]]): + chunks: List[List[Dict[str, Any]]] = _utils.chunkify(lst=inputs, max_length=100) + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + for chunk in chunks: # pylint: disable=too-many-nested-blocks + res: Dict[str, Any] = client_glue.batch_create_partition( + DatabaseName=database, TableName=table, PartitionInputList=chunk + ) + if ("Errors" in res) and res["Errors"]: + for error in res["Errors"]: + if "ErrorDetail" in error: + if "ErrorCode" in error["ErrorDetail"]: + if error["ErrorDetail"]["ErrorCode"] != "AlreadyExistsException": # pragma: no cover + raise exceptions.ServiceApiError(str(res["Errors"])) + + +def _csv_partition_definition(location: str, values: List[str], compression: Optional[str], sep: str) -> Dict[str, Any]: + compressed: bool = compression is not None + return { + "StorageDescriptor": { + "InputFormat": "org.apache.hadoop.mapred.TextInputFormat", + "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat", + "Location": location, + "Compressed": compressed, + "SerdeInfo": { + "Parameters": {"field.delim": sep, "escape.delim": "\\"}, + "SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", + }, + "StoredAsSubDirectories": False, + }, + "Values": values, + } + + +def get_parquet_partitions( + database: str, + table: str, + expression: Optional[str] = None, + catalog_id: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, +) -> Dict[str, List[str]]: + """Get all partitions from a Table in the AWS Glue Catalog. + + Expression argument instructions: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue.html#Glue.Client.get_partitions + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + expression : str, optional + An expression that filters the partitions to be returned. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + Dict[str, List[str]] + partitions_values: Dictionary with keys as S3 path locations and values as a + list of partitions values as str (e.g. {'s3://bucket/prefix/y=2020/m=10/': ['2020', '10']}). + + Examples + -------- + Fetch all partitions + + >>> import awswrangler as wr + >>> wr.catalog.get_parquet_partitions( + ... database='default', + ... table='my_table', + ... ) + { + 's3://bucket/prefix/y=2020/m=10/': ['2020', '10'], + 's3://bucket/prefix/y=2020/m=11/': ['2020', '11'], + 's3://bucket/prefix/y=2020/m=12/': ['2020', '12'] + } + + Filtering partitions + + >>> import awswrangler as wr + >>> wr.catalog.get_parquet_partitions( + ... database='default', + ... table='my_table', + ... expression='m=10' + ... ) + { + 's3://bucket/prefix/y=2020/m=10/': ['2020', '10'] + } + + """ + return _get_partitions( + database=database, table=table, expression=expression, catalog_id=catalog_id, boto3_session=boto3_session + ) + + +def get_csv_partitions( + database: str, + table: str, + expression: Optional[str] = None, + catalog_id: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, +) -> Dict[str, List[str]]: + """Get all partitions from a Table in the AWS Glue Catalog. + + Expression argument instructions: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue.html#Glue.Client.get_partitions + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + expression : str, optional + An expression that filters the partitions to be returned. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + Dict[str, List[str]] + partitions_values: Dictionary with keys as S3 path locations and values as a + list of partitions values as str (e.g. {'s3://bucket/prefix/y=2020/m=10/': ['2020', '10']}). + + Examples + -------- + Fetch all partitions + + >>> import awswrangler as wr + >>> wr.catalog.get_csv_partitions( + ... database='default', + ... table='my_table', + ... ) + { + 's3://bucket/prefix/y=2020/m=10/': ['2020', '10'], + 's3://bucket/prefix/y=2020/m=11/': ['2020', '11'], + 's3://bucket/prefix/y=2020/m=12/': ['2020', '12'] + } + + Filtering partitions + + >>> import awswrangler as wr + >>> wr.catalog.get_csv_partitions( + ... database='default', + ... table='my_table', + ... expression='m=10' + ... ) + { + 's3://bucket/prefix/y=2020/m=10/': ['2020', '10'] + } + + """ + return _get_partitions( + database=database, table=table, expression=expression, catalog_id=catalog_id, boto3_session=boto3_session + ) + + +def _get_partitions( + database: str, + table: str, + expression: Optional[str] = None, + catalog_id: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, +) -> Dict[str, List[str]]: + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + paginator = client_glue.get_paginator("get_partitions") + args: Dict[str, Any] = {} + if expression is not None: + args["Expression"] = expression + if catalog_id is not None: + args["CatalogId"] = catalog_id + response_iterator = paginator.paginate( + DatabaseName=database, TableName=table, PaginationConfig={"PageSize": 1000}, **args + ) + partitions_values: Dict[str, List[str]] = {} + for page in response_iterator: + if (page is not None) and ("Partitions" in page): + for partition in page["Partitions"]: + location: Optional[str] = partition["StorageDescriptor"].get("Location") + if location is not None: + values: List[str] = partition["Values"] + partitions_values[location] = values + return partitions_values + + +def extract_athena_types( + df: pd.DataFrame, + index: bool = False, + partition_cols: Optional[List[str]] = None, + dtype: Optional[Dict[str, str]] = None, + file_format: str = "parquet", +) -> Tuple[Dict[str, str], Dict[str, str]]: + """Extract columns and partitions types (Amazon Athena) from Pandas DataFrame. + + https://docs.aws.amazon.com/athena/latest/ug/data-types.html + + Parameters + ---------- + df : pandas.DataFrame + Pandas DataFrame. + index : bool + Should consider the DataFrame index as a column?. + partition_cols : List[str], optional + List of partitions names. + dtype: Dict[str, str], optional + Dictionary of columns names and Athena/Glue types to be casted. + Useful when you have columns with undetermined or mixed data types. + (e.g. {'col name': 'bigint', 'col2 name': 'int'}) + file_format : str, optional + File format to be consided to place the index column: "parquet" | "csv". + + Returns + ------- + Tuple[Dict[str, str], Optional[Dict[str, str]]] + columns_types: Dictionary with keys as column names and vales as + data types (e.g. {'col0': 'bigint', 'col1': 'double'}). / + partitions_types: Dictionary with keys as partition names + and values as data types (e.g. {'col2': 'date'}). + + Examples + -------- + >>> import awswrangler as wr + >>> columns_types, partitions_types = wr.catalog.extract_athena_types( + ... df=df, index=False, partition_cols=["par0", "par1"], file_format="csv" + ... ) + + """ + if file_format == "parquet": + index_left: bool = False + elif file_format == "csv": + index_left = True + else: + raise exceptions.InvalidArgumentValue("file_format argument must be parquet or csv") + return _data_types.athena_types_from_pandas_partitioned( + df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=index_left + ) diff --git a/awswrangler/s3.py b/awswrangler/s3.py index 5a37c7b3a..8c9aaf269 100644 --- a/awswrangler/s3.py +++ b/awswrangler/s3.py @@ -1,6 +1,7 @@ """Amazon S3 Module.""" import concurrent.futures +import csv import logging import time import uuid @@ -361,14 +362,39 @@ def size_objects( return size_list -def to_csv( +def to_csv( # pylint: disable=too-many-arguments df: pd.DataFrame, path: str, + sep: str = ",", + index: bool = True, + columns: Optional[List[str]] = None, + use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None, + dataset: bool = False, + partition_cols: Optional[List[str]] = None, + mode: Optional[str] = None, + database: Optional[str] = None, + table: Optional[str] = None, + dtype: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + parameters: Optional[Dict[str, str]] = None, + columns_comments: Optional[Dict[str, str]] = None, **pandas_kwargs, -) -> None: - """Write CSV file on Amazon S3. +) -> Dict[str, Union[List[str], Dict[str, List[str]]]]: + """Write CSV file or dataset on Amazon S3. + + The concept of Dataset goes beyond the simple idea of files and enable more + complex features like partitioning, casting and catalog integration (Amazon Athena/AWS Glue Catalog). + + Note + ---- + The table name and all column names will be automatically sanitize using + `wr.catalog.sanitize_table_name` and `wr.catalog.sanitize_column_name`. + + Note + ---- + In case of `use_threads=True` the number of process that will be spawned will be get from os.cpu_count(). Parameters ---------- @@ -376,11 +402,44 @@ def to_csv( Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html path : str Amazon S3 path (e.g. s3://bucket/filename.csv). + sep : str + String of length 1. Field delimiter for the output file. + index : bool + Write row names (index). + columns : List[str], optional + Columns to write. + use_threads : bool + True to enable concurrent requests, False to disable multiple threads. + If enabled os.cpu_count() will be used as the max number of threads. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 Session will be used if boto3_session receive None. s3_additional_kwargs: Forward to s3fs, useful for server side encryption https://s3fs.readthedocs.io/en/latest/#serverside-encryption + dataset: bool + If True store a parquet dataset instead of a single file. + If True, enable all follow arguments: + partition_cols, mode, database, table, description, parameters, columns_comments, . + partition_cols: List[str], optional + List of column names that will be used to create partitions. Only takes effect if dataset=True. + mode: str, optional + ``append`` (Default), ``overwrite``, ``overwrite_partitions``. Only takes effect if dataset=True. + database : str + Glue/Athena catalog: Database name. + table : str + Glue/Athena catalog: Table name. + dtype: Dict[str, str], optional + Dictionary of columns names and Athena/Glue types to be casted. + Useful when you have columns with undetermined or mixed data types. + Only takes effect if dataset=True. + (e.g. {'col name': 'bigint', 'col2 name': 'int'}) + description: str, optional + Glue/Athena catalog: Table description + parameters: Dict[str, str], optional + Glue/Athena catalog: Key/value pairs to tag the table. + columns_comments: Dict[str, str], optional + Glue/Athena catalog: + Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}). pandas_kwargs: keyword arguments forwarded to pandas.DataFrame.to_csv() https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_csv.html @@ -392,37 +451,245 @@ def to_csv( Examples -------- - Writing CSV file + Writing single file >>> import awswrangler as wr >>> import pandas as pd >>> wr.s3.to_csv( ... df=pd.DataFrame({'col': [1, 2, 3]}), - ... path='s3://bucket/filename.csv', + ... path='s3://bucket/prefix/my_file.csv', ... ) + { + 'paths': ['s3://bucket/prefix/my_file.csv'], + 'partitions_values': {} + } - Writing CSV file encrypted with a KMS key + Writing single file encrypted with a KMS key >>> import awswrangler as wr >>> import pandas as pd >>> wr.s3.to_csv( ... df=pd.DataFrame({'col': [1, 2, 3]}), - ... path='s3://bucket/filename.csv', + ... path='s3://bucket/prefix/my_file.csv', ... s3_additional_kwargs={ ... 'ServerSideEncryption': 'aws:kms', ... 'SSEKMSKeyId': 'YOUR_KMY_KEY_ARN' ... } ... ) + { + 'paths': ['s3://bucket/prefix/my_file.csv'], + 'partitions_values': {} + } + + Writing partitioned dataset + + >>> import awswrangler as wr + >>> import pandas as pd + >>> wr.s3.to_csv( + ... df=pd.DataFrame({ + ... 'col': [1, 2, 3], + ... 'col2': ['A', 'A', 'B'] + ... }), + ... path='s3://bucket/prefix', + ... dataset=True, + ... partition_cols=['col2'] + ... ) + { + 'paths': ['s3://.../col2=A/x.csv', 's3://.../col2=B/y.csv'], + 'partitions_values: { + 's3://.../col2=A/': ['A'], + 's3://.../col2=B/': ['B'] + } + } + + Writing dataset to S3 with metadata on Athena/Glue Catalog. + + >>> import awswrangler as wr + >>> import pandas as pd + >>> wr.s3.to_csv( + ... df=pd.DataFrame({ + ... 'col': [1, 2, 3], + ... 'col2': ['A', 'A', 'B'] + ... }), + ... path='s3://bucket/prefix', + ... dataset=True, + ... partition_cols=['col2'], + ... database='default', # Athena/Glue database + ... table='my_table' # Athena/Glue table + ... ) + { + 'paths': ['s3://.../col2=A/x.csv', 's3://.../col2=B/y.csv'], + 'partitions_values: { + 's3://.../col2=A/': ['A'], + 's3://.../col2=B/': ['B'] + } + } + + Writing dataset casting empty column data type + + >>> import awswrangler as wr + >>> import pandas as pd + >>> wr.s3.to_csv( + ... df=pd.DataFrame({ + ... 'col': [1, 2, 3], + ... 'col2': ['A', 'A', 'B'], + ... 'col3': [None, None, None] + ... }), + ... path='s3://bucket/prefix', + ... dataset=True, + ... database='default', # Athena/Glue database + ... table='my_table' # Athena/Glue table + ... dtype={'col3': 'date'} + ... ) + { + 'paths': ['s3://.../x.csv'], + 'partitions_values: {} + } """ - return _to_text( - file_format="csv", - df=df, - path=path, - boto3_session=boto3_session, - s3_additional_kwargs=s3_additional_kwargs, - **pandas_kwargs, - ) + if (database is None) ^ (table is None): + raise exceptions.InvalidArgumentCombination( + "Please pass database and table arguments to be able to store the metadata into the Athena/Glue Catalog." + ) + if df.empty is True: + raise exceptions.EmptyDataFrame() + session: boto3.Session = _utils.ensure_session(session=boto3_session) + partition_cols = partition_cols if partition_cols else [] + dtype = dtype if dtype else {} + columns_comments = columns_comments if columns_comments else {} + partitions_values: Dict[str, List[str]] = {} + fs: s3fs.S3FileSystem = _utils.get_fs(session=session, s3_additional_kwargs=s3_additional_kwargs) + if dataset is False: + if partition_cols: + raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use partition_cols.") + if mode is not None: + raise exceptions.InvalidArgumentCombination("Please pass dataset=True to be able to use mode.") + if any(arg is not None for arg in (database, table, description, parameters)): + raise exceptions.InvalidArgumentCombination( + "Please pass dataset=True to be able to use any one of these " + "arguments: database, table, description, parameters, " + "columns_comments." + ) + pandas_kwargs["sep"] = sep + pandas_kwargs["index"] = index + pandas_kwargs["columns"] = columns + _to_text(file_format="csv", df=df, path=path, fs=fs, **pandas_kwargs) + paths = [path] + else: + mode = "append" if mode is None else mode + exist: bool = False + if columns: + df = df[columns] + if (database is not None) and (table is not None): # Normalize table to respect Athena's standards + df = catalog.sanitize_dataframe_columns_names(df=df) + partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols] + dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()} + columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()} + exist = catalog.does_table_exist(database=database, table=table, boto3_session=session) + if (exist is True) and (mode in ("append", "overwrite_partitions")): + for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items(): + dtype[k] = v + df = catalog.drop_duplicated_columns(df=df) + paths, partitions_values = _to_csv_dataset( + df=df, + path=path, + index=index, + sep=sep, + fs=fs, + use_threads=use_threads, + partition_cols=partition_cols, + dtype=dtype, + mode=mode, + boto3_session=session, + ) + if (database is not None) and (table is not None): + columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( + df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True + ) + if (exist is False) or (mode == "overwrite"): + catalog.create_csv_table( + database=database, + table=table, + path=path, + columns_types=columns_types, + partitions_types=partitions_types, + description=description, + parameters=parameters, + columns_comments=columns_comments, + boto3_session=session, + mode="overwrite", + sep=sep, + ) + if partitions_values: + _logger.debug(f"partitions_values:\n{partitions_values}") + catalog.add_csv_partitions( + database=database, table=table, partitions_values=partitions_values, boto3_session=session, sep=sep + ) + return {"paths": paths, "partitions_values": partitions_values} + + +def _to_csv_dataset( + df: pd.DataFrame, + path: str, + index: bool, + sep: str, + fs: s3fs.S3FileSystem, + use_threads: bool, + mode: str, + dtype: Dict[str, str], + partition_cols: Optional[List[str]] = None, + boto3_session: Optional[boto3.Session] = None, +) -> Tuple[List[str], Dict[str, List[str]]]: + paths: List[str] = [] + partitions_values: Dict[str, List[str]] = {} + path = path if path[-1] == "/" else f"{path}/" + if mode not in ["append", "overwrite", "overwrite_partitions"]: + raise exceptions.InvalidArgumentValue( + f"{mode} is a invalid mode, please use append, overwrite or overwrite_partitions." + ) + if (mode == "overwrite") or ((mode == "overwrite_partitions") and (not partition_cols)): + delete_objects(path=path, use_threads=use_threads, boto3_session=boto3_session) + df = _data_types.cast_pandas_with_athena_types(df=df, dtype=dtype) + _logger.debug(f"dtypes: {df.dtypes}") + if not partition_cols: + file_path: str = f"{path}{uuid.uuid4().hex}.csv" + _to_text( + file_format="csv", + df=df, + path=file_path, + fs=fs, + quoting=csv.QUOTE_NONE, + escapechar="\\", + header=False, + date_format="%Y-%m-%d %H:%M:%S.%f", + index=index, + sep=sep, + ) + paths.append(file_path) + else: + for keys, subgroup in df.groupby(by=partition_cols, observed=True): + subgroup = subgroup.drop(partition_cols, axis="columns") + keys = (keys,) if not isinstance(keys, tuple) else keys + subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)]) + prefix: str = f"{path}{subdir}/" + if mode == "overwrite_partitions": + delete_objects(path=prefix, use_threads=use_threads, boto3_session=boto3_session) + file_path = f"{prefix}{uuid.uuid4().hex}.csv" + _to_text( + file_format="csv", + df=subgroup, + path=file_path, + fs=fs, + quoting=csv.QUOTE_NONE, + escapechar="\\", + header=False, + date_format="%Y-%m-%d %H:%M:%S.%f", + index=index, + sep=sep, + ) + paths.append(file_path) + partitions_values[prefix] = [str(k) for k in keys] + return paths, partitions_values def to_json( @@ -493,13 +760,15 @@ def _to_text( file_format: str, df: pd.DataFrame, path: str, + fs: Optional[s3fs.S3FileSystem] = None, boto3_session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None, **pandas_kwargs, ) -> None: - if df.empty is True: + if df.empty is True: # pragma: no cover raise exceptions.EmptyDataFrame() - fs: s3fs.S3FileSystem = _utils.get_fs(session=boto3_session, s3_additional_kwargs=s3_additional_kwargs) + if fs is None: + fs = _utils.get_fs(session=boto3_session, s3_additional_kwargs=s3_additional_kwargs) with fs.open(path, "w") as f: if file_format == "csv": df.to_csv(f, **pandas_kwargs) @@ -690,7 +959,7 @@ def to_parquet( # pylint: disable=too-many-arguments """ if (database is None) ^ (table is None): raise exceptions.InvalidArgumentCombination( - "Please pass database and table arguments to be able to " "store the metadata into the Athena/Glue Catalog." + "Please pass database and table arguments to be able to store the metadata into the Athena/Glue Catalog." ) if df.empty is True: raise exceptions.EmptyDataFrame() @@ -721,13 +990,18 @@ def to_parquet( # pylint: disable=too-many-arguments ) ] else: + mode = "append" if mode is None else mode + exist: bool = False if (database is not None) and (table is not None): # Normalize table to respect Athena's standards df = catalog.sanitize_dataframe_columns_names(df=df) partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols] dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()} columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()} + exist = catalog.does_table_exist(database=database, table=table, boto3_session=session) + if (exist is True) and (mode in ("append", "overwrite_partitions")): + for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items(): + dtype[k] = v df = catalog.drop_duplicated_columns(df=df) - mode = "append" if mode is None else mode paths, partitions_values = _to_parquet_dataset( df=df, path=path, @@ -746,19 +1020,20 @@ def to_parquet( # pylint: disable=too-many-arguments columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( df=df, index=index, partition_cols=partition_cols, dtype=dtype ) - catalog.create_parquet_table( - database=database, - table=table, - path=path, - columns_types=columns_types, - partitions_types=partitions_types, - compression=compression, - description=description, - parameters=parameters, - columns_comments=columns_comments, - boto3_session=session, - mode="overwrite", - ) + if (exist is False) or (mode == "overwrite"): + catalog.create_parquet_table( + database=database, + table=table, + path=path, + columns_types=columns_types, + partitions_types=partitions_types, + compression=compression, + description=description, + parameters=parameters, + columns_comments=columns_comments, + boto3_session=session, + mode="overwrite", + ) if partitions_values: _logger.debug(f"partitions_values:\n{partitions_values}") catalog.add_parquet_partitions( diff --git a/docs/source/api.rst b/docs/source/api.rst index 2f71f41f4..abb7cb16f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -40,11 +40,14 @@ AWS Glue Catalog add_parquet_partitions create_parquet_table + add_csv_partitions + create_csv_table databases delete_table_if_exists does_table_exist get_databases get_parquet_partitions + get_csv_partitions get_table_location get_table_types get_tables @@ -56,6 +59,7 @@ AWS Glue Catalog sanitize_table_name drop_duplicated_columns get_engine + extract_athena_types Amazon Athena ------------- diff --git a/testing/run-tests.sh b/testing/run-tests.sh index 07a4d0f88..3fb87ea78 100755 --- a/testing/run-tests.sh +++ b/testing/run-tests.sh @@ -9,7 +9,7 @@ START=$(microtime) ./run-validations.sh pushd .. -tox --recreate --develop -e ALL +tox --recreate --develop -e py36 coverage html --directory testing/coverage rm -rf .coverage* testing/Running Running diff --git a/testing/test_awswrangler/_utils.py b/testing/test_awswrangler/_utils.py index b55219d87..b4c210cd1 100644 --- a/testing/test_awswrangler/_utils.py +++ b/testing/test_awswrangler/_utils.py @@ -94,6 +94,27 @@ def get_df_cast(): return df +def get_df_csv(): + df = pd.DataFrame( + { + "id": [1, 2, 3], + "string_object": ["foo", None, "boo"], + "string": ["foo", None, "boo"], + "float": [1.0, None, 2.0], + "int": [1, None, 2], + "date": [dt("2020-01-01"), None, dt("2020-01-02")], + "timestamp": [ts("2020-01-01 00:00:00.0"), None, ts("2020-01-02 00:00:01.0")], + "bool": [True, None, False], + "par0": [1, 1, 2], + "par1": ["a", "b", "b"], + } + ) + df["string"] = df["string"].astype("string") + df["int"] = df["int"].astype("Int64") + df["par1"] = df["par1"].astype("string") + return df + + def get_df_category(): df = pd.DataFrame( { @@ -356,3 +377,25 @@ def ensure_data_types_category(df): assert str(df["int"].dtype) in ("category", "Int64") assert str(df["par0"].dtype) in ("category", "Int64") assert str(df["par1"].dtype) == "category" + + +def ensure_data_types_csv(df): + if "__index_level_0__" in df: + assert str(df["__index_level_0__"].dtype).startswith("Int") + assert str(df["id"].dtype).startswith("Int") + if "string_object" in df: + assert str(df["string_object"].dtype) == "string" + if "string" in df: + assert str(df["string"].dtype) == "string" + if "float" in df: + assert str(df["float"].dtype).startswith("float") + if "int" in df: + assert str(df["int"].dtype).startswith("Int") + assert str(df["date"].dtype) == "object" + assert str(df["timestamp"].dtype).startswith("datetime") + if "bool" in df: + assert str(df["bool"].dtype) == "boolean" + if "par0" in df: + assert str(df["par0"].dtype).startswith("Int") + if "par1" in df: + assert str(df["par1"].dtype) == "string" diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index 6f21647ae..483bc140c 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -7,8 +7,8 @@ import awswrangler as wr -from ._utils import (ensure_data_types, ensure_data_types_category, get_df, get_df_cast, get_df_category, get_df_list, - get_query_long) +from ._utils import (ensure_data_types, ensure_data_types_category, ensure_data_types_csv, get_df, get_df_cast, + get_df_category, get_df_csv, get_df_list, get_query_long) logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") logging.getLogger("awswrangler").setLevel(logging.DEBUG) @@ -90,6 +90,12 @@ def workgroup1(bucket): def test_athena_ctas(bucket, database, kms_key): + df = get_df_list() + columns_types, partitions_types = wr.catalog.extract_athena_types(df=df, partition_cols=["par0", "par1"]) + assert len(columns_types) == 16 + assert len(partitions_types) == 2 + with pytest.raises(wr.exceptions.InvalidArgumentValue): + wr.catalog.extract_athena_types(df=df, file_format="avro") paths = wr.s3.to_parquet( df=get_df_list(), path=f"s3://{bucket}/test_athena_ctas", @@ -669,25 +675,217 @@ def test_category(bucket, database): def test_parquet_validate_schema(bucket, database): path = f"s3://{bucket}/test_parquet_file_validate/" wr.s3.delete_objects(path=path) - df = pd.DataFrame({"id": [1, 2, 3]}) path_file = f"s3://{bucket}/test_parquet_file_validate/0.parquet" wr.s3.to_parquet(df=df, path=path_file) wr.s3.wait_objects_exist(paths=[path_file]) - df2 = pd.DataFrame({"id2": [1, 2, 3], "val": ["foo", "boo", "bar"]}) path_file2 = f"s3://{bucket}/test_parquet_file_validate/1.parquet" wr.s3.to_parquet(df=df2, path=path_file2) wr.s3.wait_objects_exist(paths=[path_file2]) - df3 = wr.s3.read_parquet(path=path, validate_schema=False) assert len(df3.index) == 6 assert len(df3.columns) == 3 - with pytest.raises(ValueError): wr.s3.read_parquet(path=path, validate_schema=True) - with pytest.raises(ValueError): wr.s3.store_parquet_metadata(path=path, database=database, table="test_parquet_validate_schema", dataset=True) + wr.s3.delete_objects(path=path) + + +def test_csv_dataset(bucket, database): + path = f"s3://{bucket}/test_csv_dataset/" + with pytest.raises(wr.exceptions.UndetectedType): + wr.s3.to_csv(pd.DataFrame({"A": [None]}), path, dataset=True, database=database, table="test_csv_dataset") + df = get_df_csv() + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.s3.to_csv(df, path, dataset=False, mode="overwrite", database=database, table="test_csv_dataset") + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.s3.to_csv(df, path, dataset=False, table="test_csv_dataset") + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.s3.to_csv(df, path, dataset=True, mode="overwrite", database=database) + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.s3.to_csv(df=df, path=path, mode="append") + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.s3.to_csv(df=df, path=path, partition_cols=["col2"]) + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.s3.to_csv(df=df, path=path, description="foo") + with pytest.raises(wr.exceptions.InvalidArgumentValue): + wr.s3.to_csv(df=df, path=path, partition_cols=["col2"], dataset=True, mode="WRONG") + paths = wr.s3.to_csv( + df=df, + path=path, + sep="|", + index=False, + use_threads=True, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="overwrite", + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.s3.read_csv(path=paths, sep="|", header=None) + assert len(df2.index) == 3 + assert len(df2.columns) == 8 + assert df2[0].sum() == 6 + wr.s3.delete_objects(path=paths) + + +def test_csv_catalog(bucket, database): + path = f"s3://{bucket}/test_csv_catalog/" + df = get_df_csv() + paths = wr.s3.to_csv( + df=df, + path=path, + sep="\t", + index=True, + use_threads=True, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="overwrite", + table="test_csv_catalog", + database=database, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table("test_csv_catalog", database) + assert len(df2.index) == 3 + assert len(df2.columns) == 11 + assert df2["id"].sum() == 6 + ensure_data_types_csv(df2) + wr.s3.delete_objects(path=paths) + assert wr.catalog.delete_table_if_exists(database=database, table="test_csv_catalog") is True + + +def test_csv_catalog_columns(bucket, database): + path = f"s3://{bucket}/test_csv_catalog_columns /" + paths = wr.s3.to_csv( + df=get_df_csv(), + path=path, + sep="|", + columns=["id", "date", "timestamp", "par0", "par1"], + index=False, + use_threads=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="overwrite", + table="test_csv_catalog_columns", + database=database, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table("test_csv_catalog_columns", database) + assert len(df2.index) == 3 + assert len(df2.columns) == 5 + assert df2["id"].sum() == 6 + ensure_data_types_csv(df2) + + paths = wr.s3.to_csv( + df=pd.DataFrame({"id": [4], "date": [None], "timestamp": [None], "par0": [1], "par1": ["a"]}), + path=path, + sep="|", + index=False, + use_threads=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="overwrite_partitions", + table="test_csv_catalog_columns", + database=database, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table("test_csv_catalog_columns", database) + assert len(df2.index) == 3 + assert len(df2.columns) == 5 + assert df2["id"].sum() == 9 + ensure_data_types_csv(df2) + + wr.s3.delete_objects(path=path) + assert wr.catalog.delete_table_if_exists(database=database, table="test_csv_catalog_columns") is True + + +def test_athena_types(bucket, database): + path = f"s3://{bucket}/test_athena_types/" + df = get_df_csv() + paths = wr.s3.to_csv( + df=df, + path=path, + sep=",", + index=False, + use_threads=True, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="overwrite", + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + columns_types, partitions_types = wr.catalog.extract_athena_types( + df=df, index=False, partition_cols=["par0", "par1"], file_format="csv" + ) + wr.catalog.create_csv_table( + table="test_athena_types", + database=database, + path=path, + partitions_types=partitions_types, + columns_types=columns_types, + ) + wr.athena.repair_table("test_athena_types", database) + assert len(wr.catalog.get_csv_partitions(database, "test_athena_types")) == 3 + df2 = wr.athena.read_sql_table("test_athena_types", database) + assert len(df2.index) == 3 + assert len(df2.columns) == 10 + assert df2["id"].sum() == 6 + ensure_data_types_csv(df2) + wr.s3.delete_objects(path=paths) + assert wr.catalog.delete_table_if_exists(database=database, table="test_athena_types") is True + + +def test_parquet_catalog_columns(bucket, database): + path = f"s3://{bucket}/test_parquet_catalog_columns /" + paths = wr.s3.to_parquet( + df=get_df_csv()[["id", "date", "timestamp", "par0", "par1"]], + path=path, + index=False, + use_threads=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="overwrite", + table="test_parquet_catalog_columns", + database=database, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table("test_parquet_catalog_columns", database) + assert len(df2.index) == 3 + assert len(df2.columns) == 5 + assert df2["id"].sum() == 6 + ensure_data_types_csv(df2) + + paths = wr.s3.to_parquet( + df=pd.DataFrame({"id": [4], "date": [None], "timestamp": [None], "par0": [1], "par1": ["a"]}), + path=path, + index=False, + use_threads=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="overwrite_partitions", + table="test_parquet_catalog_columns", + database=database, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table("test_parquet_catalog_columns", database) + assert len(df2.index) == 3 + assert len(df2.columns) == 5 + assert df2["id"].sum() == 9 + ensure_data_types_csv(df2) wr.s3.delete_objects(path=path) + assert wr.catalog.delete_table_if_exists(database=database, table="test_parquet_catalog_columns") is True diff --git a/testing/test_awswrangler/test_db.py b/testing/test_awswrangler/test_db.py index d809ec977..adcacf4a4 100644 --- a/testing/test_awswrangler/test_db.py +++ b/testing/test_awswrangler/test_db.py @@ -348,6 +348,8 @@ def test_redshift_spectrum(bucket, glue_database, external_schema): assert len(rows) == len(df.index) for row in rows: assert len(row) == len(df.columns) + wr.s3.delete_objects(path=path) + assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_redshift_spectrum") is True def test_redshift_category(bucket, parameters):