Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions awswrangler/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _extract_ctas_manifest_paths(path: str, boto3_session: Optional[boto3.Sessio


def _get_query_metadata(
query_execution_id: str, boto3_session: Optional[boto3.Session] = None
query_execution_id: str, categories: List[str] = None, boto3_session: Optional[boto3.Session] = None
) -> Tuple[Dict[str, str], List[str], List[str], Dict[str, Any], List[str]]:
"""Get query metadata."""
cols_types: Dict[str, str] = get_query_columns_types(
Expand All @@ -285,7 +285,9 @@ def _get_query_metadata(
"Please use ctas_approach=True for Struct columns."
)
pandas_type: str = _data_types.athena2pandas(dtype=col_type)
if pandas_type in ["datetime64", "date"]:
if (categories is not None) and (col_name in categories):
dtype[col_name] = "category"
elif pandas_type in ["datetime64", "date"]:
parse_timestamps.append(col_name)
if pandas_type == "date":
parse_dates.append(col_name)
Expand Down Expand Up @@ -326,6 +328,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
sql: str,
database: str,
ctas_approach: bool = True,
categories: List[str] = None,
chunksize: Optional[int] = None,
s3_output: Optional[str] = None,
workgroup: Optional[str] = None,
Expand Down Expand Up @@ -377,6 +380,9 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
ctas_approach: bool
Wraps the query using a CTAS, and read the resulted parquet data on S3.
If false, read the regular CSV on S3.
categories: List[str], optional
List of columns names that should be returned as pandas.Categorical.
Recommended for memory restricted environments.
chunksize: int, optional
If specified, return an generator where chunksize is the number of rows to include in each chunk.
s3_output : str, optional
Expand Down Expand Up @@ -457,10 +463,12 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
dfs = _utils.empty_generator()
else:
s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session)
dfs = s3.read_parquet(path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked)
dfs = s3.read_parquet(
path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked, categories=categories
)
return dfs
dtype, parse_timestamps, parse_dates, converters, binaries = _get_query_metadata(
query_execution_id=query_id, boto3_session=session
query_execution_id=query_id, categories=categories, boto3_session=session
)
path = f"{_s3_output}{query_id}.csv"
s3.wait_objects_exist(paths=[path], use_threads=False, boto3_session=session)
Expand Down
8 changes: 8 additions & 0 deletions awswrangler/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ def unload_redshift(
path: str,
con: sqlalchemy.engine.Engine,
iam_role: str,
categories: List[str] = None,
chunked: bool = False,
keep_files: bool = False,
use_threads: bool = True,
Expand Down Expand Up @@ -920,6 +921,9 @@ def unload_redshift(
wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
iam_role : str
AWS IAM role with the related permissions.
categories: List[str], optional
List of columns names that should be returned as pandas.Categorical.
Recommended for memory restricted environments.
keep_files : bool
Should keep the stage files?
chunked : bool
Expand Down Expand Up @@ -960,6 +964,7 @@ def unload_redshift(
return pd.DataFrame()
df: pd.DataFrame = s3.read_parquet(
path=paths,
categories=categories,
chunked=chunked,
dataset=False,
use_threads=use_threads,
Expand All @@ -973,6 +978,7 @@ def unload_redshift(
return _utils.empty_generator()
return _read_parquet_iterator(
paths=paths,
categories=categories,
use_threads=use_threads,
boto3_session=session,
s3_additional_kwargs=s3_additional_kwargs,
Expand All @@ -984,11 +990,13 @@ def _read_parquet_iterator(
paths: List[str],
keep_files: bool,
use_threads: bool,
categories: List[str] = None,
boto3_session: Optional[boto3.Session] = None,
s3_additional_kwargs: Optional[Dict[str, str]] = None,
) -> Iterator[pd.DataFrame]:
dfs: Iterator[pd.DataFrame] = s3.read_parquet(
path=paths,
categories=categories,
chunked=True,
dataset=False,
use_threads=use_threads,
Expand Down
45 changes: 24 additions & 21 deletions awswrangler/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def _to_parquet_file(
fs: s3fs.S3FileSystem,
dtype: Dict[str, str],
) -> str:
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=False)
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True)
for col_name, col_type in dtype.items():
if col_name in table.column_names:
col_index = table.column_names.index(col_name)
Expand Down Expand Up @@ -1190,6 +1190,7 @@ def _read_text_full(
def _read_parquet_init(
path: Union[str, List[str]],
filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None,
categories: List[str] = None,
dataset: bool = False,
use_threads: bool = True,
boto3_session: Optional[boto3.Session] = None,
Expand All @@ -1206,7 +1207,7 @@ def _read_parquet_init(
fs: s3fs.S3FileSystem = _utils.get_fs(session=boto3_session, s3_additional_kwargs=s3_additional_kwargs)
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
data: pyarrow.parquet.ParquetDataset = pyarrow.parquet.ParquetDataset(
path_or_paths=path_or_paths, filesystem=fs, metadata_nthreads=cpus, filters=filters
path_or_paths=path_or_paths, filesystem=fs, metadata_nthreads=cpus, filters=filters, read_dictionary=categories
)
return data

Expand All @@ -1217,6 +1218,7 @@ def read_parquet(
columns: Optional[List[str]] = None,
chunked: bool = False,
dataset: bool = False,
categories: List[str] = None,
use_threads: bool = True,
boto3_session: Optional[boto3.Session] = None,
s3_additional_kwargs: Optional[Dict[str, str]] = None,
Expand All @@ -1243,6 +1245,9 @@ def read_parquet(
Otherwise return a single DataFrame with the whole data.
dataset: bool
If True read a parquet dataset instead of simple file(s) loading all the related partitions as columns.
categories: List[str], optional
List of columns names that should be returned as pandas.Categorical.
Recommended for memory restricted environments.
use_threads : bool
True to enable concurrent requests, False to disable multiple threads.
If enabled os.cpu_count() will be used as the max number of threads.
Expand Down Expand Up @@ -1292,66 +1297,59 @@ def read_parquet(
path=path,
filters=filters,
dataset=dataset,
categories=categories,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
common_metadata = data.common_metadata
common_metadata = None if common_metadata is None else common_metadata.metadata.get(b"pandas", None)
if chunked is False:
return _read_parquet(data=data, columns=columns, use_threads=use_threads, common_metadata=common_metadata)
return _read_parquet_chunked(data=data, columns=columns, use_threads=use_threads, common_metadata=common_metadata)
return _read_parquet(data=data, columns=columns, categories=categories, use_threads=use_threads)
return _read_parquet_chunked(data=data, columns=columns, categories=categories, use_threads=use_threads)


def _read_parquet(
data: pyarrow.parquet.ParquetDataset,
columns: Optional[List[str]] = None,
categories: List[str] = None,
use_threads: bool = True,
common_metadata: Any = None,
) -> pd.DataFrame:
# Data
tables: List[pa.Table] = []
for piece in data.pieces:
table: pa.Table = piece.read(
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=True
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
)
tables.append(table)
table = pa.lib.concat_tables(tables)

# Metadata
current_metadata = table.schema.metadata or {}
if common_metadata and b"pandas" not in current_metadata: # pragma: no cover
table = table.replace_schema_metadata({b"pandas": common_metadata})

return table.to_pandas(
use_threads=use_threads,
split_blocks=True,
self_destruct=True,
integer_object_nulls=False,
date_as_object=True,
ignore_metadata=True,
categories=categories,
types_mapper=_data_types.pyarrow2pandas_extension,
)


def _read_parquet_chunked(
data: pyarrow.parquet.ParquetDataset,
columns: Optional[List[str]] = None,
categories: List[str] = None,
use_threads: bool = True,
common_metadata: Any = None,
) -> Iterator[pd.DataFrame]:
for piece in data.pieces:
table: pa.Table = piece.read(
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=True
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
)
current_metadata = table.schema.metadata or {}
if common_metadata and b"pandas" not in current_metadata: # pragma: no cover
table = table.replace_schema_metadata({b"pandas": common_metadata})
yield table.to_pandas(
use_threads=use_threads,
split_blocks=True,
self_destruct=True,
integer_object_nulls=False,
date_as_object=True,
ignore_metadata=True,
categories=categories,
types_mapper=_data_types.pyarrow2pandas_extension,
)

Expand Down Expand Up @@ -1670,6 +1668,7 @@ def read_parquet_table(
database: str,
filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None,
columns: Optional[List[str]] = None,
categories: List[str] = None,
chunked: bool = False,
use_threads: bool = True,
boto3_session: Optional[boto3.Session] = None,
Expand All @@ -1690,7 +1689,10 @@ def read_parquet_table(
filters: Union[List[Tuple], List[List[Tuple]]], optional
List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
columns : List[str], optional
Names of columns to read from the file(s)
Names of columns to read from the file(s).
categories: List[str], optional
List of columns names that should be returned as pandas.Categorical.
Recommended for memory restricted environments.
chunked : bool
If True will break the data in smaller DataFrames (Non deterministic number of lines).
Otherwise return a single DataFrame with the whole data.
Expand Down Expand Up @@ -1740,6 +1742,7 @@ def read_parquet_table(
path=path,
filters=filters,
columns=columns,
categories=categories,
chunked=chunked,
dataset=True,
use_threads=use_threads,
Expand Down
32 changes: 32 additions & 0 deletions testing/test_awswrangler/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ def get_df_cast():
return df


def get_df_category():
df = pd.DataFrame(
{
"id": [1, 2, 3],
"string_object": ["foo", None, "boo"],
"string": ["foo", None, "boo"],
"binary": [b"1", None, b"2"],
"float": [1.0, None, 2.0],
"int": [1, None, 2],
"par0": [1, 1, 2],
"par1": ["a", "b", "b"],
}
)
df["string"] = df["string"].astype("string")
df["int"] = df["int"].astype("Int64")
df["par1"] = df["par1"].astype("string")
return df


def get_query_long():
return """
SELECT
Expand Down Expand Up @@ -324,3 +343,16 @@ def ensure_data_types(df, has_list=False):
if has_list is True:
assert str(type(row["list"][0]).__name__) == "int64"
assert str(type(row["list_list"][0][0]).__name__) == "int64"


def ensure_data_types_category(df):
assert len(df.columns) in (7, 8)
assert str(df["id"].dtype) in ("category", "Int64")
assert str(df["string_object"].dtype) == "category"
assert str(df["string"].dtype) == "category"
if "binary" in df.columns:
assert str(df["binary"].dtype) == "category"
assert str(df["float"].dtype) == "category"
assert str(df["int"].dtype) in ("category", "Int64")
assert str(df["par0"].dtype) in ("category", "Int64")
assert str(df["par1"].dtype) == "category"
38 changes: 37 additions & 1 deletion testing/test_awswrangler/test_data_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import awswrangler as wr

from ._utils import ensure_data_types, get_df, get_df_cast, get_df_list, get_query_long
from ._utils import (ensure_data_types, ensure_data_types_category, get_df, get_df_cast, get_df_category, get_df_list,
get_query_long)

logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
Expand Down Expand Up @@ -614,3 +615,38 @@ def test_athena_time_zone(database):
assert len(df.columns) == 2
assert df["type"][0] == "timestamp with time zone"
assert df["value"][0].year == datetime.datetime.utcnow().year


def test_category(bucket, database):
df = get_df_category()
path = f"s3://{bucket}/test_category/"
paths = wr.s3.to_parquet(
df=df,
path=path,
dataset=True,
database=database,
table="test_category",
mode="overwrite",
partition_cols=["par0", "par1"],
)["paths"]
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
df2 = wr.s3.read_parquet(path=path, dataset=True, categories=[c for c in df.columns if c not in ["par0", "par1"]])
ensure_data_types_category(df2)
df2 = wr.athena.read_sql_query("SELECT * FROM test_category", database=database, categories=list(df.columns))
ensure_data_types_category(df2)
df2 = wr.athena.read_sql_query(
"SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=False
)
ensure_data_types_category(df2)
dfs = wr.athena.read_sql_query(
"SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=False, chunksize=1
)
for df2 in dfs:
ensure_data_types_category(df2)
dfs = wr.athena.read_sql_query(
"SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=True, chunksize=1
)
for df2 in dfs:
ensure_data_types_category(df2)
wr.s3.delete_objects(path=paths)
assert wr.catalog.delete_table_if_exists(database=database, table="test_category") is True
38 changes: 37 additions & 1 deletion testing/test_awswrangler/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import awswrangler as wr

from ._utils import ensure_data_types, get_df
from ._utils import ensure_data_types, ensure_data_types_category, get_df, get_df_category

logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
Expand Down Expand Up @@ -348,3 +348,39 @@ def test_redshift_spectrum(bucket, glue_database, external_schema):
assert len(rows) == len(df.index)
for row in rows:
assert len(row) == len(df.columns)


def test_redshift_category(bucket, parameters):
path = f"s3://{bucket}/test_redshift_category/"
df = get_df_category().drop(["binary"], axis=1, inplace=False)
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-redshift")
wr.db.copy_to_redshift(
df=df,
path=path,
con=engine,
schema="public",
table="test_redshift_category",
mode="overwrite",
iam_role=parameters["redshift"]["role"],
)
df2 = wr.db.unload_redshift(
sql="SELECT * FROM public.test_redshift_category",
con=engine,
iam_role=parameters["redshift"]["role"],
path=path,
keep_files=False,
categories=df.columns,
)
ensure_data_types_category(df2)
dfs = wr.db.unload_redshift(
sql="SELECT * FROM public.test_redshift_category",
con=engine,
iam_role=parameters["redshift"]["role"],
path=path,
keep_files=False,
categories=df.columns,
chunked=True,
)
for df2 in dfs:
ensure_data_types_category(df2)
wr.s3.delete_objects(path=path)
Loading