Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ testing/*parameters-*.properties
testing/*requirements*.txt
testing/coverage/*
building/*requirements*.txt
building/arrow
building/lambda/arrow
/docs/coverage/
/docs/build/
/docs/source/_build/
Expand Down
67 changes: 57 additions & 10 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Internal (private) Data Types Module."""

import logging
import re
from decimal import Decimal
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Match, Optional, Sequence, Tuple

import pandas as pd # type: ignore
import pyarrow as pa # type: ignore
Expand Down Expand Up @@ -139,8 +140,10 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc
return f"decimal({dtype.precision},{dtype.scale})"
if pa.types.is_list(dtype):
return f"array<{pyarrow2athena(dtype=dtype.value_type)}>"
if pa.types.is_struct(dtype): # pragma: no cover
return f"struct<{', '.join([f'{f.name}: {pyarrow2athena(dtype=f.type)}' for f in dtype])}>"
if pa.types.is_struct(dtype):
return f"struct<{', '.join([f'{f.name}:{pyarrow2athena(dtype=f.type)}' for f in dtype])}>"
if pa.types.is_map(dtype): # pragma: no cover
return f"map<{pyarrow2athena(dtype=dtype.key_type)},{pyarrow2athena(dtype=dtype.item_type)}>"
if dtype == pa.null():
raise exceptions.UndetectedType("We can not infer the data type from an entire null object column")
raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover
Expand All @@ -167,7 +170,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu

def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-statements
dtype: pa.DataType, db_type: str
) -> VisitableType:
) -> Optional[VisitableType]:
"""Pyarrow to Athena data types conversion."""
if pa.types.is_int8(dtype):
return sqlalchemy.types.SmallInteger
Expand Down Expand Up @@ -214,7 +217,7 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
if pa.types.is_dictionary(dtype):
return pyarrow2sqlalchemy(dtype=dtype.value_type, db_type=db_type)
if dtype == pa.null(): # pragma: no cover
raise exceptions.UndetectedType("We can not infer the data type from an entire null object column")
return None
raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover


Expand Down Expand Up @@ -243,12 +246,23 @@ def pyarrow_types_from_pandas(
else:
cols.append(name)

# Filling cols_dtypes and indexes
# Filling cols_dtypes
for col in cols:
_logger.debug("Inferring PyArrow type from column: %s", col)
try:
schema: pa.Schema = pa.Schema.from_pandas(df=df[[col]], preserve_index=False)
except pa.ArrowInvalid as ex: # pragma: no cover
cols_dtypes[col] = process_not_inferred_dtype(ex)
else:
cols_dtypes[col] = schema.field(col).type

# Filling indexes
indexes: List[str] = []
for field in pa.Schema.from_pandas(df=df[cols], preserve_index=index):
name = str(field.name)
cols_dtypes[name] = field.type
if (name not in df.columns) and (index is True):
if index is True:
for field in pa.Schema.from_pandas(df=df[[]], preserve_index=True):
name = str(field.name)
_logger.debug("Inferring PyArrow type from index: %s", name)
cols_dtypes[name] = field.type
indexes.append(name)

# Merging Index
Expand All @@ -261,6 +275,39 @@ def pyarrow_types_from_pandas(
return columns_types


def process_not_inferred_dtype(ex: pa.ArrowInvalid) -> pa.DataType:
"""Infer data type from PyArrow inference exception."""
ex_str = str(ex)
_logger.debug("PyArrow was not able to infer data type:\n%s", ex_str)
match: Optional[Match] = re.search(
pattern="Could not convert (.*) with type (.*): did not recognize "
"Python value type when inferring an Arrow data type",
string=ex_str,
)
if match is None:
raise ex # pragma: no cover
groups: Optional[Sequence[str]] = match.groups()
if groups is None:
raise ex # pragma: no cover
if len(groups) != 2:
raise ex # pragma: no cover
_logger.debug("groups: %s", groups)
type_str: str = groups[1]
if type_str == "UUID":
return pa.string()
raise ex # pragma: no cover


def process_not_inferred_array(ex: pa.ArrowInvalid, values: Any) -> pa.Array:
"""Infer `pyarrow.array` from PyArrow inference exception."""
dtype = process_not_inferred_dtype(ex=ex)
if dtype == pa.string():
array: pa.Array = pa.array(obj=[str(x) for x in values], type=dtype, safe=True)
else:
raise ex # pragma: no cover
return array


def athena_types_from_pandas(
df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None, index_left: bool = False
) -> Dict[str, str]:
Expand Down
5 changes: 4 additions & 1 deletion awswrangler/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def _records2df(
arrays: List[pa.Array] = []
for col_values, col_name in zip(tuple(zip(*records)), cols_names): # Transposing
if (dtype is None) or (col_name not in dtype):
array: pa.Array = pa.array(obj=col_values, safe=True) # Creating Arrow array
try:
array: pa.Array = pa.array(obj=col_values, safe=True) # Creating Arrow array
except pa.ArrowInvalid as ex:
array = _data_types.process_not_inferred_array(ex, values=col_values) # Creating Arrow array
else:
array = pa.array(obj=col_values, type=dtype[col_name], safe=True) # Creating Arrow array with dtype
arrays.append(array)
Expand Down
22 changes: 22 additions & 0 deletions testing/test_awswrangler/test_data_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,3 +1203,25 @@ def test_athena_encryption(
assert len(df2.columns) == 2
wr.catalog.delete_table_if_exists(database=database, table=table)
wr.s3.delete_objects(path=paths)


def test_athena_nested(bucket, database):
table = "test_athena_nested"
path = f"s3://{bucket}/{table}/"
df = pd.DataFrame(
{
"c0": [[1, 2, 3], [4, 5, 6]],
"c1": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
"c2": [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]],
"c3": [[], [[[[[[[[1]]]]]]]]],
"c4": [{"a": 1}, {"a": 1}],
"c5": [{"a": {"b": {"c": [1, 2]}}}, {"a": {"b": {"c": [3, 4]}}}],
}
)
paths = wr.s3.to_parquet(
df=df, path=path, index=False, use_threads=True, dataset=True, mode="overwrite", database=database, table=table
)["paths"]
wr.s3.wait_objects_exist(paths=paths)
df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {table}", database=database)
assert len(df2.index) == 2
assert len(df2.columns) == 4