From 08cf244090583e253d4295d7d1d2aa4d0bbb867e Mon Sep 17 00:00:00 2001 From: igorborgest Date: Sun, 3 May 2020 21:56:43 -0300 Subject: [PATCH] Add support to write nested types (array and struct). --- .gitignore | 2 + awswrangler/_data_types.py | 67 ++++++++++++++++++---- awswrangler/db.py | 5 +- testing/test_awswrangler/test_data_lake.py | 22 +++++++ 4 files changed, 85 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 1947b2e87..8a3474a30 100644 --- a/.gitignore +++ b/.gitignore @@ -138,6 +138,8 @@ testing/*parameters-*.properties testing/*requirements*.txt testing/coverage/* building/*requirements*.txt +building/arrow +building/lambda/arrow /docs/coverage/ /docs/build/ /docs/source/_build/ diff --git a/awswrangler/_data_types.py b/awswrangler/_data_types.py index 01237ea49..fac82a37b 100644 --- a/awswrangler/_data_types.py +++ b/awswrangler/_data_types.py @@ -1,8 +1,9 @@ """Internal (private) Data Types Module.""" import logging +import re from decimal import Decimal -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Match, Optional, Sequence, Tuple import pandas as pd # type: ignore import pyarrow as pa # type: ignore @@ -139,8 +140,10 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc return f"decimal({dtype.precision},{dtype.scale})" if pa.types.is_list(dtype): return f"array<{pyarrow2athena(dtype=dtype.value_type)}>" - if pa.types.is_struct(dtype): # pragma: no cover - return f"struct<{', '.join([f'{f.name}: {pyarrow2athena(dtype=f.type)}' for f in dtype])}>" + if pa.types.is_struct(dtype): + return f"struct<{', '.join([f'{f.name}:{pyarrow2athena(dtype=f.type)}' for f in dtype])}>" + if pa.types.is_map(dtype): # pragma: no cover + return f"map<{pyarrow2athena(dtype=dtype.key_type)},{pyarrow2athena(dtype=dtype.item_type)}>" if dtype == pa.null(): raise exceptions.UndetectedType("We can not infer the data type from an entire null object column") raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover @@ -167,7 +170,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-statements dtype: pa.DataType, db_type: str -) -> VisitableType: +) -> Optional[VisitableType]: """Pyarrow to Athena data types conversion.""" if pa.types.is_int8(dtype): return sqlalchemy.types.SmallInteger @@ -214,7 +217,7 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta if pa.types.is_dictionary(dtype): return pyarrow2sqlalchemy(dtype=dtype.value_type, db_type=db_type) if dtype == pa.null(): # pragma: no cover - raise exceptions.UndetectedType("We can not infer the data type from an entire null object column") + return None raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover @@ -243,12 +246,23 @@ def pyarrow_types_from_pandas( else: cols.append(name) - # Filling cols_dtypes and indexes + # Filling cols_dtypes + for col in cols: + _logger.debug("Inferring PyArrow type from column: %s", col) + try: + schema: pa.Schema = pa.Schema.from_pandas(df=df[[col]], preserve_index=False) + except pa.ArrowInvalid as ex: # pragma: no cover + cols_dtypes[col] = process_not_inferred_dtype(ex) + else: + cols_dtypes[col] = schema.field(col).type + + # Filling indexes indexes: List[str] = [] - for field in pa.Schema.from_pandas(df=df[cols], preserve_index=index): - name = str(field.name) - cols_dtypes[name] = field.type - if (name not in df.columns) and (index is True): + if index is True: + for field in pa.Schema.from_pandas(df=df[[]], preserve_index=True): + name = str(field.name) + _logger.debug("Inferring PyArrow type from index: %s", name) + cols_dtypes[name] = field.type indexes.append(name) # Merging Index @@ -261,6 +275,39 @@ def pyarrow_types_from_pandas( return columns_types +def process_not_inferred_dtype(ex: pa.ArrowInvalid) -> pa.DataType: + """Infer data type from PyArrow inference exception.""" + ex_str = str(ex) + _logger.debug("PyArrow was not able to infer data type:\n%s", ex_str) + match: Optional[Match] = re.search( + pattern="Could not convert (.*) with type (.*): did not recognize " + "Python value type when inferring an Arrow data type", + string=ex_str, + ) + if match is None: + raise ex # pragma: no cover + groups: Optional[Sequence[str]] = match.groups() + if groups is None: + raise ex # pragma: no cover + if len(groups) != 2: + raise ex # pragma: no cover + _logger.debug("groups: %s", groups) + type_str: str = groups[1] + if type_str == "UUID": + return pa.string() + raise ex # pragma: no cover + + +def process_not_inferred_array(ex: pa.ArrowInvalid, values: Any) -> pa.Array: + """Infer `pyarrow.array` from PyArrow inference exception.""" + dtype = process_not_inferred_dtype(ex=ex) + if dtype == pa.string(): + array: pa.Array = pa.array(obj=[str(x) for x in values], type=dtype, safe=True) + else: + raise ex # pragma: no cover + return array + + def athena_types_from_pandas( df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None, index_left: bool = False ) -> Dict[str, str]: diff --git a/awswrangler/db.py b/awswrangler/db.py index b695bdd17..f5e90c78e 100644 --- a/awswrangler/db.py +++ b/awswrangler/db.py @@ -185,7 +185,10 @@ def _records2df( arrays: List[pa.Array] = [] for col_values, col_name in zip(tuple(zip(*records)), cols_names): # Transposing if (dtype is None) or (col_name not in dtype): - array: pa.Array = pa.array(obj=col_values, safe=True) # Creating Arrow array + try: + array: pa.Array = pa.array(obj=col_values, safe=True) # Creating Arrow array + except pa.ArrowInvalid as ex: + array = _data_types.process_not_inferred_array(ex, values=col_values) # Creating Arrow array else: array = pa.array(obj=col_values, type=dtype[col_name], safe=True) # Creating Arrow array with dtype arrays.append(array) diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index b05fb0881..99c1df1c6 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -1203,3 +1203,25 @@ def test_athena_encryption( assert len(df2.columns) == 2 wr.catalog.delete_table_if_exists(database=database, table=table) wr.s3.delete_objects(path=paths) + + +def test_athena_nested(bucket, database): + table = "test_athena_nested" + path = f"s3://{bucket}/{table}/" + df = pd.DataFrame( + { + "c0": [[1, 2, 3], [4, 5, 6]], + "c1": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + "c2": [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]], + "c3": [[], [[[[[[[[1]]]]]]]]], + "c4": [{"a": 1}, {"a": 1}], + "c5": [{"a": {"b": {"c": [1, 2]}}}, {"a": {"b": {"c": [3, 4]}}}], + } + ) + paths = wr.s3.to_parquet( + df=df, path=path, index=False, use_threads=True, dataset=True, mode="overwrite", database=database, table=table + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {table}", database=database) + assert len(df2.index) == 2 + assert len(df2.columns) == 4