diff --git a/awswrangler/s3/__init__.py b/awswrangler/s3/__init__.py index e95810ece..b2fbadba1 100644 --- a/awswrangler/s3/__init__.py +++ b/awswrangler/s3/__init__.py @@ -9,6 +9,7 @@ from awswrangler.s3._read_excel import read_excel # noqa from awswrangler.s3._read_parquet import read_parquet, read_parquet_metadata, read_parquet_table # noqa from awswrangler.s3._read_text import read_csv, read_fwf, read_json # noqa +from awswrangler.s3._select import select_query from awswrangler.s3._upload import upload # noqa from awswrangler.s3._wait import wait_objects_exist, wait_objects_not_exist # noqa from awswrangler.s3._write_excel import to_excel # noqa @@ -33,6 +34,7 @@ "read_json", "wait_objects_exist", "wait_objects_not_exist", + "select_query", "store_parquet_metadata", "to_parquet", "to_csv", diff --git a/awswrangler/s3/_select.py b/awswrangler/s3/_select.py new file mode 100644 index 000000000..399d75278 --- /dev/null +++ b/awswrangler/s3/_select.py @@ -0,0 +1,217 @@ +"""Amazon S3 Select Module (PRIVATE).""" + +import concurrent.futures +import itertools +import json +import logging +import pprint +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import boto3 +import pandas as pd + +from awswrangler import _utils, exceptions +from awswrangler.s3._describe import size_objects + +_logger: logging.Logger = logging.getLogger(__name__) + +_RANGE_CHUNK_SIZE: int = int(1024 * 1024) + + +def _gen_scan_range(obj_size: int) -> Iterator[Tuple[int, int]]: + for i in range(0, obj_size, _RANGE_CHUNK_SIZE): + yield (i, i + min(_RANGE_CHUNK_SIZE, obj_size - i)) + + +def _select_object_content( + args: Dict[str, Any], + client_s3: boto3.Session, + scan_range: Optional[Tuple[int, int]] = None, +) -> pd.DataFrame: + if scan_range: + response = client_s3.select_object_content(**args, ScanRange={"Start": scan_range[0], "End": scan_range[1]}) + else: + response = client_s3.select_object_content(**args) + + dfs: List[pd.DataFrame] = [] + partial_record: str = "" + for event in response["Payload"]: + if "Records" in event: + records = event["Records"]["Payload"].decode(encoding="utf-8", errors="ignore").split("\n") + records[0] = partial_record + records[0] + # Record end can either be a partial record or a return char + partial_record = records.pop() + dfs.append( + pd.DataFrame( + [json.loads(record) for record in records], + ) + ) + if not dfs: + return pd.DataFrame() + return pd.concat(dfs, ignore_index=True) + + +def _paginate_stream( + args: Dict[str, Any], path: str, use_threads: Union[bool, int], boto3_session: Optional[boto3.Session] +) -> pd.DataFrame: + obj_size: int = size_objects( # type: ignore + path=[path], + use_threads=False, + boto3_session=boto3_session, + ).get(path) + if obj_size is None: + raise exceptions.InvalidArgumentValue(f"S3 object w/o defined size: {path}") + + dfs: List[pd.Dataframe] = [] + client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) + + if use_threads is False: + dfs = list( + _select_object_content( + args=args, + client_s3=client_s3, + scan_range=scan_range, + ) + for scan_range in _gen_scan_range(obj_size=obj_size) + ) + else: + cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: + dfs = list( + executor.map( + _select_object_content, + itertools.repeat(args), + itertools.repeat(client_s3), + _gen_scan_range(obj_size=obj_size), + ) + ) + return pd.concat(dfs, ignore_index=True) + + +def select_query( + sql: str, + path: str, + input_serialization: str, + input_serialization_params: Dict[str, Union[bool, str]], + compression: Optional[str] = None, + use_threads: Union[bool, int] = False, + boto3_session: Optional[boto3.Session] = None, + s3_additional_kwargs: Optional[Dict[str, Any]] = None, +) -> pd.DataFrame: + r"""Filter contents of an Amazon S3 object based on SQL statement. + + Note: Scan ranges are only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) + and JSON objects (in LINES mode only). It means scanning cannot be split across threads if the latter + conditions are not met, leading to lower performance. + + Parameters + ---------- + sql: str + SQL statement used to query the object. + path: str + S3 path to the object (e.g. s3://bucket/key). + input_serialization: str, + Format of the S3 object queried. + Valid values: "CSV", "JSON", or "Parquet". Case sensitive. + input_serialization_params: Dict[str, Union[bool, str]] + Dictionary describing the serialization of the S3 object. + compression: Optional[str] + Compression type of the S3 object. + Valid values: None, "gzip", or "bzip2". gzip and bzip2 are only valid for CSV and JSON objects. + use_threads : Union[bool, int] + True to enable concurrent requests, False to disable multiple threads. + If enabled os.cpu_count() is used as the max number of threads. + If integer is provided, specified number is used. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session is used if none is provided. + s3_additional_kwargs : Optional[Dict[str, Any]] + Forwarded to botocore requests. + Valid values: "SSECustomerAlgorithm", "SSECustomerKey", "ExpectedBucketOwner". + e.g. s3_additional_kwargs={'SSECustomerAlgorithm': 'md5'} + + Returns + ------- + pandas.DataFrame + Pandas DataFrame with results from query. + + Examples + -------- + Reading a gzip compressed JSON document + + >>> import awswrangler as wr + >>> df = wr.s3.select_query( + ... sql='SELECT * FROM s3object[*][*]', + ... path='s3://bucket/key.json.gzip', + ... input_serialization='JSON', + ... input_serialization_params={ + ... 'Type': 'Document', + ... }, + ... compression="gzip", + ... ) + + Reading an entire CSV object using threads + + >>> import awswrangler as wr + >>> df = wr.s3.select_query( + ... sql='SELECT * FROM s3object', + ... path='s3://bucket/key.csv', + ... input_serialization='CSV', + ... input_serialization_params={ + ... 'FileHeaderInfo': 'Use', + ... 'RecordDelimiter': '\r\n' + ... }, + ... use_threads=True, + ... ) + + Reading a single column from Parquet object with pushdown filter + + >>> import awswrangler as wr + >>> df = wr.s3.select_query( + ... sql='SELECT s.\"id\" FROM s3object s where s.\"id\" = 1.0', + ... path='s3://bucket/key.snappy.parquet', + ... input_serialization='Parquet', + ... ) + """ + if path.endswith("/"): + raise exceptions.InvalidArgumentValue(" argument should be an S3 key, not a prefix.") + if input_serialization not in ["CSV", "JSON", "Parquet"]: + raise exceptions.InvalidArgumentValue(" argument must be 'CSV', 'JSON' or 'Parquet'") + if compression not in [None, "gzip", "bzip2"]: + raise exceptions.InvalidCompression(f"Invalid {compression} compression, please use None, 'gzip' or 'bzip2'.") + if compression and (input_serialization not in ["CSV", "JSON"]): + raise exceptions.InvalidArgumentCombination( + "'gzip' or 'bzip2' are only valid for input 'CSV' or 'JSON' objects." + ) + bucket, key = _utils.parse_path(path) + + args: Dict[str, Any] = { + "Bucket": bucket, + "Key": key, + "Expression": sql, + "ExpressionType": "SQL", + "RequestProgress": {"Enabled": False}, + "InputSerialization": { + input_serialization: input_serialization_params, + "CompressionType": compression.upper() if compression else "NONE", + }, + "OutputSerialization": { + "JSON": {}, + }, + } + if s3_additional_kwargs: + args.update(s3_additional_kwargs) + _logger.debug("args:\n%s", pprint.pformat(args)) + + if any( + [ + compression, + input_serialization_params.get("AllowQuotedRecordDelimiter"), + input_serialization_params.get("Type") == "Document", + ] + ): # Scan range is only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) + # and JSON objects (in LINES mode only) + _logger.debug("Scan ranges are not supported given provided input.") + client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) + return _select_object_content(args=args, client_s3=client_s3) + + return _paginate_stream(args=args, path=path, use_threads=use_threads, boto3_session=boto3_session) diff --git a/tests/test_s3_select.py b/tests/test_s3_select.py new file mode 100644 index 000000000..46218182c --- /dev/null +++ b/tests/test_s3_select.py @@ -0,0 +1,161 @@ +import logging + +import pandas as pd +import pytest + +import awswrangler as wr + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +@pytest.mark.parametrize("use_threads", [True, False, 2]) +def test_full_table(path, use_threads): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]}) + + # Parquet + file_path = f"{path}test_parquet_file.snappy.parquet" + wr.s3.to_parquet(df, file_path, compression="snappy") + df2 = wr.s3.select_query( + sql="select * from s3object", + path=file_path, + input_serialization="Parquet", + input_serialization_params={}, + use_threads=use_threads, + ) + assert df.equals(df2) + + # CSV + file_path = f"{path}test_csv_file.csv" + wr.s3.to_csv(df, file_path, index=False) + df3 = wr.s3.select_query( + sql="select * from s3object", + path=file_path, + input_serialization="CSV", + input_serialization_params={"FileHeaderInfo": "Use", "RecordDelimiter": "\n"}, + use_threads=use_threads, + ) + assert len(df.index) == len(df3.index) + assert list(df.columns) == list(df3.columns) + assert df.shape == df3.shape + + # JSON + file_path = f"{path}test_json_file.json" + wr.s3.to_json(df, file_path, orient="records") + df4 = wr.s3.select_query( + sql="select * from s3object[*][*]", + path=file_path, + input_serialization="JSON", + input_serialization_params={"Type": "Document"}, + use_threads=use_threads, + ) + assert df.equals(df4) + + +@pytest.mark.parametrize("use_threads", [True, False, 2]) +def test_push_down(path, use_threads): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]}) + + file_path = f"{path}test_parquet_file.snappy.parquet" + wr.s3.to_parquet(df, file_path, compression="snappy") + df2 = wr.s3.select_query( + sql='select * from s3object s where s."c0" = 1', + path=file_path, + input_serialization="Parquet", + input_serialization_params={}, + use_threads=use_threads, + ) + assert df2.shape == (1, 3) + assert df2.c0.sum() == 1 + + file_path = f"{path}test_parquet_file.gzip.parquet" + wr.s3.to_parquet(df, file_path, compression="gzip") + df2 = wr.s3.select_query( + sql='select * from s3object s where s."c0" = 99', + path=file_path, + input_serialization="Parquet", + input_serialization_params={}, + use_threads=use_threads, + ) + assert df2.shape == (0, 0) + + file_path = f"{path}test_csv_file.csv" + wr.s3.to_csv(df, file_path, header=False, index=False) + df3 = wr.s3.select_query( + sql='select s."_1" from s3object s limit 2', + path=file_path, + input_serialization="CSV", + input_serialization_params={"FileHeaderInfo": "None", "RecordDelimiter": "\n"}, + use_threads=use_threads, + ) + assert df3.shape == (2, 1) + + file_path = f"{path}test_json_file.json" + wr.s3.to_json(df, file_path, orient="records") + df4 = wr.s3.select_query( + sql="select count(*) from s3object[*][*]", + path=file_path, + input_serialization="JSON", + input_serialization_params={"Type": "Document"}, + use_threads=use_threads, + ) + assert df4.shape == (1, 1) + assert df4._1.sum() == 3 + + +@pytest.mark.parametrize("compression", ["gzip", "bz2"]) +def test_compression(path, compression): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]}) + + # CSV + file_path = f"{path}test_csv_file.csv" + wr.s3.to_csv(df, file_path, index=False, compression=compression) + df2 = wr.s3.select_query( + sql="select * from s3object", + path=file_path, + input_serialization="CSV", + input_serialization_params={"FileHeaderInfo": "Use", "RecordDelimiter": "\n"}, + compression="bzip2" if compression == "bz2" else compression, + use_threads=False, + ) + assert len(df.index) == len(df2.index) + assert list(df.columns) == list(df2.columns) + assert df.shape == df2.shape + + # JSON + file_path = f"{path}test_json_file.json" + wr.s3.to_json(df, file_path, orient="records", compression=compression) + df3 = wr.s3.select_query( + sql="select * from s3object[*][*]", + path=file_path, + input_serialization="JSON", + input_serialization_params={"Type": "Document"}, + compression="bzip2" if compression == "bz2" else compression, + use_threads=False, + ) + assert df.equals(df3) + + +@pytest.mark.parametrize( + "s3_additional_kwargs", + [None, {"ServerSideEncryption": "AES256"}, {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": None}], +) +def test_encryption(path, kms_key_id, s3_additional_kwargs): + if s3_additional_kwargs is not None and "SSEKMSKeyId" in s3_additional_kwargs: + s3_additional_kwargs["SSEKMSKeyId"] = kms_key_id + + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]}) + file_path = f"{path}test_parquet_file.snappy.parquet" + wr.s3.to_parquet( + df, + file_path, + compression="snappy", + s3_additional_kwargs=s3_additional_kwargs, + ) + df2 = wr.s3.select_query( + sql="select * from s3object", + path=file_path, + input_serialization="Parquet", + input_serialization_params={}, + use_threads=False, + ) + assert df.equals(df2)