Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions awswrangler/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,24 +382,29 @@ def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, s

def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
preserve_index: bool,
indexes_position: str = "right") -> List[Tuple[str, Any]]:
indexes_position: str = "right",
ignore_cols: Optional[List[str]] = None) -> List[Tuple[str, Any]]:
"""
Extract the related Pyarrow schema from any Pandas DataFrame.

:param dataframe: Pandas Dataframe
:param preserve_index: True or False
:param indexes_position: "right" or "left"
:param ignore_cols: List of columns to be ignored
:return: Pyarrow schema (e.g. [("col name": "bigint"), ("col2 name": "int")]
"""
ignore_cols = [] if ignore_cols is None else ignore_cols
cols: List[str] = []
cols_dtypes: Dict[str, str] = {}
cols_dtypes: Dict[str, Optional[str]] = {}
if indexes_position not in ("right", "left"):
raise ValueError(f"indexes_position must be \"right\" or \"left\"")

# Handle exception data types (e.g. Int64, string)
for name, dtype in dataframe.dtypes.to_dict().items():
dtype = str(dtype)
if dtype == "Int64":
if name in ignore_cols:
cols_dtypes[name] = None
elif dtype == "Int64":
cols_dtypes[name] = "int64"
elif dtype == "string":
cols_dtypes[name] = "string"
Expand Down
3 changes: 1 addition & 2 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,6 @@ def _data_to_s3_dataset_writer(dataframe: pd.DataFrame,
isolated_dataframe=isolated_dataframe)
objects_paths.append(object_path)
else:
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
for keys, subgroup in dataframe.groupby(by=partition_cols, observed=True):
subgroup = subgroup.drop(partition_cols, axis="columns")
if not isinstance(keys, tuple):
Expand Down Expand Up @@ -1390,7 +1389,7 @@ def _read_parquet_path(session_primitives: "SessionPrimitives",
if str(field.type).startswith("int") and field.name != "__index_level_0__"
]
logger.debug(f"Converting to Pandas: {path}")
df = table.to_pandas(use_threads=use_threads, integer_object_nulls=True)
df = table.to_pandas(use_threads=use_threads, integer_object_nulls=False)
logger.debug(f"Casting Int64 columns: {path}")
for c in integers:
if not str(df[c].dtype).startswith("int"):
Expand Down
4 changes: 3 additions & 1 deletion awswrangler/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,11 @@ def _get_redshift_schema(dataframe,
varchar_lengths = {} if varchar_lengths is None else varchar_lengths
schema_built: List[Tuple[str, str]] = []
if dataframe_type.lower() == "pandas":
ignore_cols = list(cast_columns.keys()) if cast_columns is not None else None
pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe,
preserve_index=preserve_index,
indexes_position="right")
indexes_position="right",
ignore_cols=ignore_cols)
for name, dtype in pyarrow_schema:
if (cast_columns is not None) and (name in cast_columns.keys()):
schema_built.append((name, cast_columns[name]))
Expand Down
16 changes: 16 additions & 0 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2536,3 +2536,19 @@ def test_sequential_overwrite(bucket):
df3 = wr.pandas.read_parquet(path=path)
assert len(df3.index) == 1
assert df3.col[0] == 2


def test_read_parquet_int_na(bucket):
path = f"s3://{bucket}/test_read_parquet_int_na/"
df = pd.DataFrame({"col": [1] + [pd.NA for _ in range(10_000)]}, dtype="Int64")
wr.pandas.to_parquet(
dataframe=df,
path=path,
preserve_index=False,
mode="overwrite",
procs_cpu_bound=4
)
df2 = wr.pandas.read_parquet(path=path)
assert len(df2.index) == 10_001
assert len(df2.columns) == 1
assert df2.dtypes["col"] == "Int64"
36 changes: 36 additions & 0 deletions testing/test_awswrangler/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,3 +911,39 @@ def test_to_redshift_spark_varchar(session, bucket, redshift_parameters):
for row in rows:
assert len(row) == len(pdf.columns)
conn.close()


def test_to_redshift_int_na(bucket, redshift_parameters):
df = pd.DataFrame({
"id": [1, 2, 3, 4, 5],
"col1": [1, pd.NA, 2, pd.NA, pd.NA],
"col2": [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA],
"col3": [None, None, None, None, None],
"col4": [1, pd.NA, 2, pd.NA, pd.NA]
})
df["col1"] = df["col1"].astype("Int64")
df["col2"] = df["col2"].astype("Int64")
df["col3"] = df["col3"].astype("Int64")
path = f"s3://{bucket}/test_to_redshift_int_na"
wr.pandas.to_redshift(dataframe=df,
path=path,
schema="public",
table="test_to_redshift_int_na",
connection="aws-data-wrangler-redshift",
iam_role=redshift_parameters.get("RedshiftRole"),
mode="overwrite",
preserve_index=False,
cast_columns={
"col4": "INT8"
})
conn = wr.glue.get_connection("aws-data-wrangler-redshift")
with conn.cursor() as cursor:
cursor.execute("SELECT * FROM public.test_to_redshift_int_na")
rows = cursor.fetchall()
assert len(rows) == len(df.index)
for row in rows:
assert len(row) == len(df.columns)
cursor.execute("SELECT SUM(col1) FROM public.test_to_redshift_int_na")
rows = cursor.fetchall()
assert rows[0][0] == 3
conn.close()