diff --git a/awswrangler/s3/_read_parquet.py b/awswrangler/s3/_read_parquet.py index 6facbc06c..75286d3e3 100644 --- a/awswrangler/s3/_read_parquet.py +++ b/awswrangler/s3/_read_parquet.py @@ -238,6 +238,7 @@ def _arrowtable2df( table: pa.Table, categories: Optional[List[str]], safe: bool, + map_types: bool, use_threads: bool, dataset: bool, path: str, @@ -257,7 +258,7 @@ def _arrowtable2df( strings_to_categorical=False, safe=safe, categories=categories, - types_mapper=_data_types.pyarrow2pandas_extension, + types_mapper=_data_types.pyarrow2pandas_extension if map_types else None, ), dataset=dataset, path=path, @@ -279,6 +280,7 @@ def _read_parquet_chunked( columns: Optional[List[str]], categories: Optional[List[str]], safe: bool, + map_types: bool, boto3_session: boto3.Session, dataset: bool, path_root: Optional[str], @@ -325,6 +327,7 @@ def _read_parquet_chunked( ), categories=categories, safe=safe, + map_types=map_types, use_threads=use_threads, dataset=dataset, path=path, @@ -404,6 +407,7 @@ def _read_parquet( columns: Optional[List[str]], categories: Optional[List[str]], safe: bool, + map_types: bool, boto3_session: boto3.Session, dataset: bool, path_root: Optional[str], @@ -421,6 +425,7 @@ def _read_parquet( ), categories=categories, safe=safe, + map_types=map_types, use_threads=use_threads, dataset=dataset, path=path, @@ -441,6 +446,7 @@ def read_parquet( dataset: bool = False, categories: Optional[List[str]] = None, safe: bool = True, + map_types: bool = True, use_threads: bool = True, last_modified_begin: Optional[datetime.datetime] = None, last_modified_end: Optional[datetime.datetime] = None, @@ -524,6 +530,10 @@ def read_parquet( data in a pandas DataFrame or Series (e.g. timestamps are always stored as nanoseconds in pandas). This option controls whether it is a safe cast or not. + map_types : bool, default True + True to convert pyarrow DataTypes to pandas ExtensionDtypes. It is + used to override the default pandas type for conversion of built-in + pyarrow types or in absence of pandas_metadata in the Table schema. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -597,6 +607,7 @@ def read_parquet( "columns": columns, "categories": categories, "safe": safe, + "map_types": map_types, "boto3_session": session, "dataset": dataset, "path_root": path_root, @@ -633,6 +644,7 @@ def read_parquet_table( validate_schema: bool = True, categories: Optional[List[str]] = None, safe: bool = True, + map_types: bool = True, chunked: Union[bool, int] = False, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, @@ -699,6 +711,10 @@ def read_parquet_table( data in a pandas DataFrame or Series (e.g. timestamps are always stored as nanoseconds in pandas). This option controls whether it is a safe cast or not. + map_types : bool, default True + True to convert pyarrow DataTypes to pandas ExtensionDtypes. It is + used to override the default pandas type for conversion of built-in + pyarrow types or in absence of pandas_metadata in the Table schema. chunked : bool If True will break the data in smaller DataFrames (Non deterministic number of lines). Otherwise return a single DataFrame with the whole data. @@ -767,6 +783,7 @@ def read_parquet_table( validate_schema=validate_schema, categories=categories, safe=safe, + map_types=map_types, chunked=chunked, dataset=True, use_threads=use_threads, diff --git a/tests/test_s3_parquet.py b/tests/test_s3_parquet.py index 5411b9373..a7eeda567 100644 --- a/tests/test_s3_parquet.py +++ b/tests/test_s3_parquet.py @@ -192,6 +192,16 @@ def test_to_parquet_file_dtype(path, use_threads): assert str(df2.c1.dtype) == "string" +def test_read_parquet_map_types(path): + df = pd.DataFrame({"c0": [0, 1, 1, 2]}, dtype=np.int8) + file_path = f"{path}0.parquet" + wr.s3.to_parquet(df, file_path) + df2 = wr.s3.read_parquet(file_path) + assert str(df2.c0.dtype) == "Int8" + df3 = wr.s3.read_parquet(file_path, map_types=False) + assert str(df3.c0.dtype) == "int8" + + @pytest.mark.parametrize("use_threads", [True, False]) @pytest.mark.parametrize("max_rows_by_file", [None, 0, 40, 250, 1000]) def test_parquet_with_size(path, use_threads, max_rows_by_file):