diff --git a/awswrangler/s3/_read_parquet.py b/awswrangler/s3/_read_parquet.py index e46a5a5a7..00d037d7c 100644 --- a/awswrangler/s3/_read_parquet.py +++ b/awswrangler/s3/_read_parquet.py @@ -2,6 +2,7 @@ import concurrent.futures import datetime +import functools import itertools import json import logging @@ -339,8 +340,8 @@ def _read_parquet_chunked( if next_slice is not None: df = _union(dfs=[next_slice, df], ignore_index=ignore_index) while len(df.index) >= chunked: - yield df.iloc[:chunked] - df = df.iloc[chunked:] + yield df.iloc[:chunked, :].copy() + df = df.iloc[chunked:, :] if df.empty: next_slice = None else: @@ -773,26 +774,32 @@ def read_parquet_table( path: str = res["Table"]["StorageDescriptor"]["Location"] except KeyError as ex: raise exceptions.InvalidTable(f"Missing s3 location for {database}.{table}.") from ex - return _data_types.cast_pandas_with_athena_types( - df=read_parquet( - path=path, - path_suffix=filename_suffix, - path_ignore_suffix=filename_ignore_suffix, - partition_filter=partition_filter, - columns=columns, - validate_schema=validate_schema, - categories=categories, - safe=safe, - map_types=map_types, - chunked=chunked, - dataset=True, - use_threads=use_threads, - boto3_session=boto3_session, - s3_additional_kwargs=s3_additional_kwargs, - ), - dtype=_extract_partitions_dtypes_from_table_details(response=res), + df = read_parquet( + path=path, + path_suffix=filename_suffix, + path_ignore_suffix=filename_ignore_suffix, + partition_filter=partition_filter, + columns=columns, + validate_schema=validate_schema, + categories=categories, + safe=safe, + map_types=map_types, + chunked=chunked, + dataset=True, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + ) + partial_cast_function = functools.partial( + _data_types.cast_pandas_with_athena_types, dtype=_extract_partitions_dtypes_from_table_details(response=res) ) + if isinstance(df, pd.DataFrame): + return partial_cast_function(df) + + # df is a generator, so map is needed for casting dtypes + return map(partial_cast_function, df) + @apply_configs def read_parquet_metadata(