From 3161a50723cc5db0150e4811fe1012bd3b936692 Mon Sep 17 00:00:00 2001 From: igorborgest Date: Sun, 16 Feb 2020 16:45:27 -0300 Subject: [PATCH] Now Pandas.read_parquet() will return Int64 for integers with null values and the Pandas.to_redshift() also will be able to cast it. --- awswrangler/data_types.py | 11 +++++-- awswrangler/pandas.py | 3 +- awswrangler/redshift.py | 4 ++- testing/test_awswrangler/test_pandas.py | 16 ++++++++++ testing/test_awswrangler/test_redshift.py | 36 +++++++++++++++++++++++ 5 files changed, 64 insertions(+), 6 deletions(-) diff --git a/awswrangler/data_types.py b/awswrangler/data_types.py index 22e8bb2d3..d1017e9c2 100644 --- a/awswrangler/data_types.py +++ b/awswrangler/data_types.py @@ -382,24 +382,29 @@ def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, s def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame, preserve_index: bool, - indexes_position: str = "right") -> List[Tuple[str, Any]]: + indexes_position: str = "right", + ignore_cols: Optional[List[str]] = None) -> List[Tuple[str, Any]]: """ Extract the related Pyarrow schema from any Pandas DataFrame. :param dataframe: Pandas Dataframe :param preserve_index: True or False :param indexes_position: "right" or "left" + :param ignore_cols: List of columns to be ignored :return: Pyarrow schema (e.g. [("col name": "bigint"), ("col2 name": "int")] """ + ignore_cols = [] if ignore_cols is None else ignore_cols cols: List[str] = [] - cols_dtypes: Dict[str, str] = {} + cols_dtypes: Dict[str, Optional[str]] = {} if indexes_position not in ("right", "left"): raise ValueError(f"indexes_position must be \"right\" or \"left\"") # Handle exception data types (e.g. Int64, string) for name, dtype in dataframe.dtypes.to_dict().items(): dtype = str(dtype) - if dtype == "Int64": + if name in ignore_cols: + cols_dtypes[name] = None + elif dtype == "Int64": cols_dtypes[name] = "int64" elif dtype == "string": cols_dtypes[name] = "string" diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 159c665b2..c6304b520 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -843,7 +843,6 @@ def _data_to_s3_dataset_writer(dataframe: pd.DataFrame, isolated_dataframe=isolated_dataframe) objects_paths.append(object_path) else: - dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns) for keys, subgroup in dataframe.groupby(by=partition_cols, observed=True): subgroup = subgroup.drop(partition_cols, axis="columns") if not isinstance(keys, tuple): @@ -1390,7 +1389,7 @@ def _read_parquet_path(session_primitives: "SessionPrimitives", if str(field.type).startswith("int") and field.name != "__index_level_0__" ] logger.debug(f"Converting to Pandas: {path}") - df = table.to_pandas(use_threads=use_threads, integer_object_nulls=True) + df = table.to_pandas(use_threads=use_threads, integer_object_nulls=False) logger.debug(f"Casting Int64 columns: {path}") for c in integers: if not str(df[c].dtype).startswith("int"): diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 3d45b90df..4dae6fb2e 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -431,9 +431,11 @@ def _get_redshift_schema(dataframe, varchar_lengths = {} if varchar_lengths is None else varchar_lengths schema_built: List[Tuple[str, str]] = [] if dataframe_type.lower() == "pandas": + ignore_cols = list(cast_columns.keys()) if cast_columns is not None else None pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe, preserve_index=preserve_index, - indexes_position="right") + indexes_position="right", + ignore_cols=ignore_cols) for name, dtype in pyarrow_schema: if (cast_columns is not None) and (name in cast_columns.keys()): schema_built.append((name, cast_columns[name])) diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index 24c63d3b6..142380712 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -2536,3 +2536,19 @@ def test_sequential_overwrite(bucket): df3 = wr.pandas.read_parquet(path=path) assert len(df3.index) == 1 assert df3.col[0] == 2 + + +def test_read_parquet_int_na(bucket): + path = f"s3://{bucket}/test_read_parquet_int_na/" + df = pd.DataFrame({"col": [1] + [pd.NA for _ in range(10_000)]}, dtype="Int64") + wr.pandas.to_parquet( + dataframe=df, + path=path, + preserve_index=False, + mode="overwrite", + procs_cpu_bound=4 + ) + df2 = wr.pandas.read_parquet(path=path) + assert len(df2.index) == 10_001 + assert len(df2.columns) == 1 + assert df2.dtypes["col"] == "Int64" diff --git a/testing/test_awswrangler/test_redshift.py b/testing/test_awswrangler/test_redshift.py index a409a8364..a8836ea34 100644 --- a/testing/test_awswrangler/test_redshift.py +++ b/testing/test_awswrangler/test_redshift.py @@ -911,3 +911,39 @@ def test_to_redshift_spark_varchar(session, bucket, redshift_parameters): for row in rows: assert len(row) == len(pdf.columns) conn.close() + + +def test_to_redshift_int_na(bucket, redshift_parameters): + df = pd.DataFrame({ + "id": [1, 2, 3, 4, 5], + "col1": [1, pd.NA, 2, pd.NA, pd.NA], + "col2": [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA], + "col3": [None, None, None, None, None], + "col4": [1, pd.NA, 2, pd.NA, pd.NA] + }) + df["col1"] = df["col1"].astype("Int64") + df["col2"] = df["col2"].astype("Int64") + df["col3"] = df["col3"].astype("Int64") + path = f"s3://{bucket}/test_to_redshift_int_na" + wr.pandas.to_redshift(dataframe=df, + path=path, + schema="public", + table="test_to_redshift_int_na", + connection="aws-data-wrangler-redshift", + iam_role=redshift_parameters.get("RedshiftRole"), + mode="overwrite", + preserve_index=False, + cast_columns={ + "col4": "INT8" + }) + conn = wr.glue.get_connection("aws-data-wrangler-redshift") + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM public.test_to_redshift_int_na") + rows = cursor.fetchall() + assert len(rows) == len(df.index) + for row in rows: + assert len(row) == len(df.columns) + cursor.execute("SELECT SUM(col1) FROM public.test_to_redshift_int_na") + rows = cursor.fetchall() + assert rows[0][0] == 3 + conn.close()