From c9435f1f09edde34181acf25967632d05b5f87c2 Mon Sep 17 00:00:00 2001 From: igorborgest Date: Sat, 11 Apr 2020 15:09:42 -0300 Subject: [PATCH 1/2] Add categories argument for all read_parquet related functions #160 --- awswrangler/athena.py | 16 ++++++-- awswrangler/db.py | 8 ++++ awswrangler/s3.py | 45 ++++++++++++---------- testing/test_awswrangler/_utils.py | 32 +++++++++++++++ testing/test_awswrangler/test_data_lake.py | 38 +++++++++++++++++- testing/test_awswrangler/test_db.py | 38 +++++++++++++++++- 6 files changed, 150 insertions(+), 27 deletions(-) diff --git a/awswrangler/athena.py b/awswrangler/athena.py index b2886426d..f173ecacb 100644 --- a/awswrangler/athena.py +++ b/awswrangler/athena.py @@ -259,7 +259,7 @@ def _extract_ctas_manifest_paths(path: str, boto3_session: Optional[boto3.Sessio def _get_query_metadata( - query_execution_id: str, boto3_session: Optional[boto3.Session] = None + query_execution_id: str, categories: List[str] = None, boto3_session: Optional[boto3.Session] = None ) -> Tuple[Dict[str, str], List[str], List[str], Dict[str, Any], List[str]]: """Get query metadata.""" cols_types: Dict[str, str] = get_query_columns_types( @@ -285,7 +285,9 @@ def _get_query_metadata( "Please use ctas_approach=True for Struct columns." ) pandas_type: str = _data_types.athena2pandas(dtype=col_type) - if pandas_type in ["datetime64", "date"]: + if (categories is not None) and (col_name in categories): + dtype[col_name] = "category" + elif pandas_type in ["datetime64", "date"]: parse_timestamps.append(col_name) if pandas_type == "date": parse_dates.append(col_name) @@ -326,6 +328,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals sql: str, database: str, ctas_approach: bool = True, + categories: List[str] = None, chunksize: Optional[int] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, @@ -377,6 +380,9 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals ctas_approach: bool Wraps the query using a CTAS, and read the resulted parquet data on S3. If false, read the regular CSV on S3. + categories: List[str], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. chunksize: int, optional If specified, return an generator where chunksize is the number of rows to include in each chunk. s3_output : str, optional @@ -457,10 +463,12 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals dfs = _utils.empty_generator() else: s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session) - dfs = s3.read_parquet(path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked) + dfs = s3.read_parquet( + path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked, categories=categories + ) return dfs dtype, parse_timestamps, parse_dates, converters, binaries = _get_query_metadata( - query_execution_id=query_id, boto3_session=session + query_execution_id=query_id, categories=categories, boto3_session=session ) path = f"{_s3_output}{query_id}.csv" s3.wait_objects_exist(paths=[path], use_threads=False, boto3_session=session) diff --git a/awswrangler/db.py b/awswrangler/db.py index c674df203..fb17c40cf 100644 --- a/awswrangler/db.py +++ b/awswrangler/db.py @@ -887,6 +887,7 @@ def unload_redshift( path: str, con: sqlalchemy.engine.Engine, iam_role: str, + categories: List[str] = None, chunked: bool = False, keep_files: bool = False, use_threads: bool = True, @@ -920,6 +921,9 @@ def unload_redshift( wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine() iam_role : str AWS IAM role with the related permissions. + categories: List[str], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. keep_files : bool Should keep the stage files? chunked : bool @@ -960,6 +964,7 @@ def unload_redshift( return pd.DataFrame() df: pd.DataFrame = s3.read_parquet( path=paths, + categories=categories, chunked=chunked, dataset=False, use_threads=use_threads, @@ -973,6 +978,7 @@ def unload_redshift( return _utils.empty_generator() return _read_parquet_iterator( paths=paths, + categories=categories, use_threads=use_threads, boto3_session=session, s3_additional_kwargs=s3_additional_kwargs, @@ -984,11 +990,13 @@ def _read_parquet_iterator( paths: List[str], keep_files: bool, use_threads: bool, + categories: List[str] = None, boto3_session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None, ) -> Iterator[pd.DataFrame]: dfs: Iterator[pd.DataFrame] = s3.read_parquet( path=paths, + categories=categories, chunked=True, dataset=False, use_threads=use_threads, diff --git a/awswrangler/s3.py b/awswrangler/s3.py index 35432a1a8..f24091520 100644 --- a/awswrangler/s3.py +++ b/awswrangler/s3.py @@ -833,7 +833,7 @@ def _to_parquet_file( fs: s3fs.S3FileSystem, dtype: Dict[str, str], ) -> str: - table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=False) + table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True) for col_name, col_type in dtype.items(): if col_name in table.column_names: col_index = table.column_names.index(col_name) @@ -1190,6 +1190,7 @@ def _read_text_full( def _read_parquet_init( path: Union[str, List[str]], filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None, + categories: List[str] = None, dataset: bool = False, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, @@ -1206,7 +1207,7 @@ def _read_parquet_init( fs: s3fs.S3FileSystem = _utils.get_fs(session=boto3_session, s3_additional_kwargs=s3_additional_kwargs) cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) data: pyarrow.parquet.ParquetDataset = pyarrow.parquet.ParquetDataset( - path_or_paths=path_or_paths, filesystem=fs, metadata_nthreads=cpus, filters=filters + path_or_paths=path_or_paths, filesystem=fs, metadata_nthreads=cpus, filters=filters, read_dictionary=categories ) return data @@ -1217,6 +1218,7 @@ def read_parquet( columns: Optional[List[str]] = None, chunked: bool = False, dataset: bool = False, + categories: List[str] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None, @@ -1243,6 +1245,9 @@ def read_parquet( Otherwise return a single DataFrame with the whole data. dataset: bool If True read a parquet dataset instead of simple file(s) loading all the related partitions as columns. + categories: List[str], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -1292,43 +1297,37 @@ def read_parquet( path=path, filters=filters, dataset=dataset, + categories=categories, use_threads=use_threads, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs, ) - common_metadata = data.common_metadata - common_metadata = None if common_metadata is None else common_metadata.metadata.get(b"pandas", None) if chunked is False: - return _read_parquet(data=data, columns=columns, use_threads=use_threads, common_metadata=common_metadata) - return _read_parquet_chunked(data=data, columns=columns, use_threads=use_threads, common_metadata=common_metadata) + return _read_parquet(data=data, columns=columns, categories=categories, use_threads=use_threads) + return _read_parquet_chunked(data=data, columns=columns, categories=categories, use_threads=use_threads) def _read_parquet( data: pyarrow.parquet.ParquetDataset, columns: Optional[List[str]] = None, + categories: List[str] = None, use_threads: bool = True, - common_metadata: Any = None, ) -> pd.DataFrame: - # Data tables: List[pa.Table] = [] for piece in data.pieces: table: pa.Table = piece.read( - columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=True + columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False ) tables.append(table) table = pa.lib.concat_tables(tables) - - # Metadata - current_metadata = table.schema.metadata or {} - if common_metadata and b"pandas" not in current_metadata: # pragma: no cover - table = table.replace_schema_metadata({b"pandas": common_metadata}) - return table.to_pandas( use_threads=use_threads, split_blocks=True, self_destruct=True, integer_object_nulls=False, date_as_object=True, + ignore_metadata=True, + categories=categories, types_mapper=_data_types.pyarrow2pandas_extension, ) @@ -1336,22 +1335,21 @@ def _read_parquet( def _read_parquet_chunked( data: pyarrow.parquet.ParquetDataset, columns: Optional[List[str]] = None, + categories: List[str] = None, use_threads: bool = True, - common_metadata: Any = None, ) -> Iterator[pd.DataFrame]: for piece in data.pieces: table: pa.Table = piece.read( - columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=True + columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False ) - current_metadata = table.schema.metadata or {} - if common_metadata and b"pandas" not in current_metadata: # pragma: no cover - table = table.replace_schema_metadata({b"pandas": common_metadata}) yield table.to_pandas( use_threads=use_threads, split_blocks=True, self_destruct=True, integer_object_nulls=False, date_as_object=True, + ignore_metadata=True, + categories=categories, types_mapper=_data_types.pyarrow2pandas_extension, ) @@ -1670,6 +1668,7 @@ def read_parquet_table( database: str, filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None, columns: Optional[List[str]] = None, + categories: List[str] = None, chunked: bool = False, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, @@ -1690,7 +1689,10 @@ def read_parquet_table( filters: Union[List[Tuple], List[List[Tuple]]], optional List of filters to apply, like ``[[('x', '=', 0), ...], ...]``. columns : List[str], optional - Names of columns to read from the file(s) + Names of columns to read from the file(s). + categories: List[str], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. chunked : bool If True will break the data in smaller DataFrames (Non deterministic number of lines). Otherwise return a single DataFrame with the whole data. @@ -1740,6 +1742,7 @@ def read_parquet_table( path=path, filters=filters, columns=columns, + categories=categories, chunked=chunked, dataset=True, use_threads=use_threads, diff --git a/testing/test_awswrangler/_utils.py b/testing/test_awswrangler/_utils.py index da17665a1..b55219d87 100644 --- a/testing/test_awswrangler/_utils.py +++ b/testing/test_awswrangler/_utils.py @@ -94,6 +94,25 @@ def get_df_cast(): return df +def get_df_category(): + df = pd.DataFrame( + { + "id": [1, 2, 3], + "string_object": ["foo", None, "boo"], + "string": ["foo", None, "boo"], + "binary": [b"1", None, b"2"], + "float": [1.0, None, 2.0], + "int": [1, None, 2], + "par0": [1, 1, 2], + "par1": ["a", "b", "b"], + } + ) + df["string"] = df["string"].astype("string") + df["int"] = df["int"].astype("Int64") + df["par1"] = df["par1"].astype("string") + return df + + def get_query_long(): return """ SELECT @@ -324,3 +343,16 @@ def ensure_data_types(df, has_list=False): if has_list is True: assert str(type(row["list"][0]).__name__) == "int64" assert str(type(row["list_list"][0][0]).__name__) == "int64" + + +def ensure_data_types_category(df): + assert len(df.columns) in (7, 8) + assert str(df["id"].dtype) in ("category", "Int64") + assert str(df["string_object"].dtype) == "category" + assert str(df["string"].dtype) == "category" + if "binary" in df.columns: + assert str(df["binary"].dtype) == "category" + assert str(df["float"].dtype) == "category" + assert str(df["int"].dtype) in ("category", "Int64") + assert str(df["par0"].dtype) in ("category", "Int64") + assert str(df["par1"].dtype) == "category" diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index 7cb94b229..86004df08 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -7,7 +7,8 @@ import awswrangler as wr -from ._utils import ensure_data_types, get_df, get_df_cast, get_df_list, get_query_long +from ._utils import (ensure_data_types, ensure_data_types_category, get_df, get_df_cast, get_df_category, get_df_list, + get_query_long) logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") logging.getLogger("awswrangler").setLevel(logging.DEBUG) @@ -614,3 +615,38 @@ def test_athena_time_zone(database): assert len(df.columns) == 2 assert df["type"][0] == "timestamp with time zone" assert df["value"][0].year == datetime.datetime.utcnow().year + + +def test_category(bucket, database): + df = get_df_category() + path = f"s3://{bucket}/test_category/" + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + database=database, + table="test_category", + mode="overwrite", + partition_cols=["par0", "par1"], + )["paths"] + wr.s3.wait_objects_exist(paths=paths, use_threads=False) + df2 = wr.s3.read_parquet(path=path, dataset=True, categories=[c for c in df.columns if c not in ["par0", "par1"]]) + ensure_data_types_category(df2) + df2 = wr.athena.read_sql_query("SELECT * FROM test_category", database=database, categories=list(df.columns)) + ensure_data_types_category(df2) + df2 = wr.athena.read_sql_query( + "SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=False + ) + ensure_data_types_category(df2) + dfs = wr.athena.read_sql_query( + "SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=False, chunksize=1 + ) + for df2 in dfs: + ensure_data_types_category(df2) + dfs = wr.athena.read_sql_query( + "SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=True, chunksize=1 + ) + for df2 in dfs: + ensure_data_types_category(df2) + wr.s3.delete_objects(path=paths) + assert wr.catalog.delete_table_if_exists(database=database, table="test_category") is True diff --git a/testing/test_awswrangler/test_db.py b/testing/test_awswrangler/test_db.py index 47dcdd543..d809ec977 100644 --- a/testing/test_awswrangler/test_db.py +++ b/testing/test_awswrangler/test_db.py @@ -9,7 +9,7 @@ import awswrangler as wr -from ._utils import ensure_data_types, get_df +from ._utils import ensure_data_types, ensure_data_types_category, get_df, get_df_category logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") logging.getLogger("awswrangler").setLevel(logging.DEBUG) @@ -348,3 +348,39 @@ def test_redshift_spectrum(bucket, glue_database, external_schema): assert len(rows) == len(df.index) for row in rows: assert len(row) == len(df.columns) + + +def test_redshift_category(bucket, parameters): + path = f"s3://{bucket}/test_redshift_category/" + df = get_df_category().drop(["binary"], axis=1, inplace=False) + engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-redshift") + wr.db.copy_to_redshift( + df=df, + path=path, + con=engine, + schema="public", + table="test_redshift_category", + mode="overwrite", + iam_role=parameters["redshift"]["role"], + ) + df2 = wr.db.unload_redshift( + sql="SELECT * FROM public.test_redshift_category", + con=engine, + iam_role=parameters["redshift"]["role"], + path=path, + keep_files=False, + categories=df.columns, + ) + ensure_data_types_category(df2) + dfs = wr.db.unload_redshift( + sql="SELECT * FROM public.test_redshift_category", + con=engine, + iam_role=parameters["redshift"]["role"], + path=path, + keep_files=False, + categories=df.columns, + chunked=True, + ) + for df2 in dfs: + ensure_data_types_category(df2) + wr.s3.delete_objects(path=path) From 6b24a5dcccae999064c71a712ee69d23c40f239c Mon Sep 17 00:00:00 2001 From: igorborgest Date: Sat, 11 Apr 2020 18:38:41 +0000 Subject: [PATCH 2/2] Add categories example to tutorial 6 #160 --- tutorials/06 - Amazon Athena.ipynb | 218 ++++++++++++++++++++++++++++- 1 file changed, 212 insertions(+), 6 deletions(-) diff --git a/tutorials/06 - Amazon Athena.ipynb b/tutorials/06 - Amazon Athena.ipynb index cca09157f..5501370fa 100644 --- a/tutorials/06 - Amazon Athena.ipynb +++ b/tutorials/06 - Amazon Athena.ipynb @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -251,7 +251,7 @@ "[29240017 rows x 8 columns]" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -269,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -285,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -388,7 +388,7 @@ "7 obs_time string False " ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -809,6 +809,212 @@ "wr.athena.read_sql_query(\"SELECT * FROM noaa\", database=\"awswrangler_test\", ctas_approach=False)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using categories to speed up and save memory!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.84 s, sys: 2.01 s, total: 5.85 s\n", + "Wall time: 30.2 s\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iddtelementvaluem_flagq_flags_flagobs_time
0SF0014658801890-01-02PRCP0NaNNaNINaN
1ASN000740681890-01-02PRCP0NaNNaNaNaN
2ASN000830291890-01-02PRCP25NaNNaNaNaN
3ASN000640211890-01-02PRCP0NaNNaNaNaN
4ASN000770221890-01-02PRCP0NaNNaNaNaN
...........................
29240012USC003954811899-12-31SNOW0NaNNaN6NaN
29240013ASN000630551899-12-31PRCP0NaNNaNaNaN
29240014USC003578141899-12-31TMAX78NaNNaN6NaN
29240015USC003578141899-12-31TMIN0NaNNaN6NaN
29240016USC003578141899-12-31PRCP102NaNNaN6NaN
\n", + "

29240017 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " id dt element value m_flag q_flag s_flag obs_time\n", + "0 SF001465880 1890-01-02 PRCP 0 NaN NaN I NaN\n", + "1 ASN00074068 1890-01-02 PRCP 0 NaN NaN a NaN\n", + "2 ASN00083029 1890-01-02 PRCP 25 NaN NaN a NaN\n", + "3 ASN00064021 1890-01-02 PRCP 0 NaN NaN a NaN\n", + "4 ASN00077022 1890-01-02 PRCP 0 NaN NaN a NaN\n", + "... ... ... ... ... ... ... ... ...\n", + "29240012 USC00395481 1899-12-31 SNOW 0 NaN NaN 6 NaN\n", + "29240013 ASN00063055 1899-12-31 PRCP 0 NaN NaN a NaN\n", + "29240014 USC00357814 1899-12-31 TMAX 78 NaN NaN 6 NaN\n", + "29240015 USC00357814 1899-12-31 TMIN 0 NaN NaN 6 NaN\n", + "29240016 USC00357814 1899-12-31 PRCP 102 NaN NaN 6 NaN\n", + "\n", + "[29240017 rows x 8 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "\n", + "wr.athena.read_sql_query(\"SELECT * FROM noaa\", database=\"awswrangler_test\", categories=[\"id\", \"dt\", \"element\", \"value\", \"m_flag\", \"q_flag\", \"s_flag\", \"obs_time\"])" + ] + }, { "cell_type": "markdown", "metadata": {},