Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -555,10 +555,10 @@ max-attributes=7
max-bool-expr=5

# Maximum number of branch for function / method body.
max-branches=12
max-branches=15

# Maximum number of locals for function / method body.
max-locals=25
max-locals=30

# Maximum number of parents for a class (see R0901).
max-parents=7
Expand Down
31 changes: 22 additions & 9 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta


def pyarrow_types_from_pandas(
df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None
df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None, index_left: bool = False
) -> Dict[str, pa.DataType]:
"""Extract the related Pyarrow data types from any Pandas DataFrame."""
# Handle exception data types (e.g. Int64, Int32, string)
Expand Down Expand Up @@ -251,18 +251,23 @@ def pyarrow_types_from_pandas(
if (name not in df.columns) and (index is True):
indexes.append(name)

# Merging Index
sorted_cols: List[str] = indexes + list(df.columns) if index_left is True else list(df.columns) + indexes

# Filling schema
columns_types: Dict[str, pa.DataType]
columns_types = {n: cols_dtypes[n] for n in list(df.columns) + indexes} # add cols + indexes
columns_types = {n: cols_dtypes[n] for n in sorted_cols}
_logger.debug(f"columns_types: {columns_types}")
return columns_types


def athena_types_from_pandas(df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None) -> Dict[str, str]:
def athena_types_from_pandas(
df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None, index_left: bool = False
) -> Dict[str, str]:
"""Extract the related Athena data types from any Pandas DataFrame."""
casts: Dict[str, str] = dtype if dtype else {}
pa_columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(
df=df, index=index, ignore_cols=list(casts.keys())
df=df, index=index, ignore_cols=list(casts.keys()), index_left=index_left
)
athena_columns_types: Dict[str, str] = {}
for k, v in pa_columns_types.items():
Expand All @@ -275,11 +280,17 @@ def athena_types_from_pandas(df: pd.DataFrame, index: bool, dtype: Optional[Dict


def athena_types_from_pandas_partitioned(
df: pd.DataFrame, index: bool, partition_cols: Optional[List[str]] = None, dtype: Optional[Dict[str, str]] = None
df: pd.DataFrame,
index: bool,
partition_cols: Optional[List[str]] = None,
dtype: Optional[Dict[str, str]] = None,
index_left: bool = False,
) -> Tuple[Dict[str, str], Dict[str, str]]:
"""Extract the related Athena data types from any Pandas DataFrame considering possible partitions."""
partitions: List[str] = partition_cols if partition_cols else []
athena_columns_types: Dict[str, str] = athena_types_from_pandas(df=df, index=index, dtype=dtype)
athena_columns_types: Dict[str, str] = athena_types_from_pandas(
df=df, index=index, dtype=dtype, index_left=index_left
)
columns_types: Dict[str, str] = {}
partitions_types: Dict[str, str] = {}
for k, v in athena_columns_types.items():
Expand All @@ -296,10 +307,12 @@ def pyarrow_schema_from_pandas(
"""Extract the related Pyarrow Schema from any Pandas DataFrame."""
casts: Dict[str, str] = {} if dtype is None else dtype
ignore: List[str] = [] if ignore_cols is None else ignore_cols
ignore = ignore + list(casts.keys())
columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(df=df, index=index, ignore_cols=ignore)
ignore_plus = ignore + list(casts.keys())
columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(
df=df, index=index, ignore_cols=ignore_plus
)
for k, v in casts.items():
if k in df.columns:
if (k in df.columns) and (k not in ignore):
columns_types[k] = athena2pyarrow(v)
columns_types = {k: v for k, v in columns_types.items() if v is not None}
_logger.debug(f"columns_types: {columns_types}")
Expand Down
Loading