From 84f7de2027fa864afe668773b9f51198aa73daee Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Wed, 19 May 2021 09:08:11 +0100 Subject: [PATCH 1/4] WIP - Stashing --- awswrangler/s3/__init__.py | 2 + awswrangler/s3/_select.py | 127 +++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+) create mode 100644 awswrangler/s3/_select.py diff --git a/awswrangler/s3/__init__.py b/awswrangler/s3/__init__.py index e95810ece..b2fbadba1 100644 --- a/awswrangler/s3/__init__.py +++ b/awswrangler/s3/__init__.py @@ -9,6 +9,7 @@ from awswrangler.s3._read_excel import read_excel # noqa from awswrangler.s3._read_parquet import read_parquet, read_parquet_metadata, read_parquet_table # noqa from awswrangler.s3._read_text import read_csv, read_fwf, read_json # noqa +from awswrangler.s3._select import select_query from awswrangler.s3._upload import upload # noqa from awswrangler.s3._wait import wait_objects_exist, wait_objects_not_exist # noqa from awswrangler.s3._write_excel import to_excel # noqa @@ -33,6 +34,7 @@ "read_json", "wait_objects_exist", "wait_objects_not_exist", + "select_query", "store_parquet_metadata", "to_parquet", "to_csv", diff --git a/awswrangler/s3/_select.py b/awswrangler/s3/_select.py new file mode 100644 index 000000000..c788bf180 --- /dev/null +++ b/awswrangler/s3/_select.py @@ -0,0 +1,127 @@ +"""Amazon S3 Select Module (PRIVATE).""" + +import concurrent.futures +import itertools +from io import StringIO +import logging +import pprint +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import boto3 +import pandas as pd + +from awswrangler import _utils, exceptions +from awswrangler.s3._describe import size_objects + +_logger: logging.Logger = logging.getLogger(__name__) + +_RANGE_CHUNK_SIZE: int = 5_242_880 # 5 MB (5 * 2**20) + + +def _select_object_content( + args: Dict[str, Any], scan_range: Optional[Tuple[int, int]], client_s3: Optional[boto3.Session] +) -> pd.DataFrame: + if scan_range: + args.update({"ScanRange": {"Start": scan_range[0], "End": scan_range[1]}}) + response = client_s3.select_object_content(**args) + l: pd.DataFrame = [] + print(type(response["Payload"])) + for event in response["Payload"]: + print(type(event)) + if "Records" in event: + l.append(pd.read_csv(StringIO(event["Records"]["Payload"].decode("utf-8")))) + return pd.concat(l) + +def _paginate_stream( + args: Dict[str, Any], path: str, use_threads: Union[bool, int], boto3_session: Optional[boto3.Session] +) -> pd.DataFrame: + obj_size: int = size_objects( # type: ignore + path=[path], + use_threads=False, + boto3_session=boto3_session, + ).get(path) + if obj_size is None: + raise exceptions.InvalidArgumentValue(f"S3 object w/o defined size: {path}") + + scan_ranges: List[Tuple[int, int]] = [] + for i in range(0, obj_size, _RANGE_CHUNK_SIZE): + scan_ranges.append((i, i + min(_RANGE_CHUNK_SIZE, obj_size - i))) + + dfs_iterator: List[pd.Dataframe] = [] + client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) + if use_threads is False: + dfs_iterator = list( + _select_object_content( + args=args, + scan_range=scan_range, + client_s3=client_s3, + ) + for scan_range in scan_ranges + ) + else: + cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: + dfs_iterator = list( + executor.map( + _select_object_content, + itertools.repeat(args), + scan_ranges, + itertools.repeat(client_s3), + ) + ) + return pd.concat([df for df in dfs_iterator]) + + +# TODO: clarify when to use @config (e.g. read_parquet vs read_parquet_table) +def select_query( # Read sql query or S3 select? Here or in a separate file? + sql: str, + path: str, + input_serialization: str, + output_serialization: str, + input_serialization_params: Dict[str, Union[bool, str]] = {}, + output_serialization_params: Dict[str, str] = {}, + compression: Optional[str] = None, + use_threads: Union[bool, int] = False, + boto3_session: Optional[boto3.Session] = None, + params: Optional[Dict[str, Any]] = None, + s3_additional_kwargs: Optional[Dict[str, Any]] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + + if path.endswith("/"): + raise exceptions.InvalidArgumentValue(" argument should be an S3 key, not a prefix.") + if input_serialization not in ["CSV", "JSON", "Parquet"]: + raise exceptions.InvalidArgumentValue(" argument must be 'CSV', 'JSON' or 'Parquet'") + if compression not in [None, "gzip", "bzip2"]: + raise exceptions.InvalidCompression(f"Invalid {compression} compression, please use None, 'gzip' or 'bzip2'.") + else: + if compression and (input_serialization not in ["CSV", "JSON"]): + raise exceptions.InvalidArgumentCombination( + "'gzip' or 'bzip2' are only valid for input 'CSV' or 'JSON' objects." + ) + if output_serialization not in [None, "CSV", "JSON"]: + raise exceptions.InvalidArgumentValue(" argument must be None, 'csv' or 'json'") + if params is None: + params = {} + for key, value in params.items(): + sql = sql.replace(f":{key};", str(value)) + bucket, key = _utils.parse_path(path) + + args: Dict[str, Any] = { + "Bucket": bucket, + "Key": key, + "Expression": sql, + "ExpressionType": "SQL", + "RequestProgress": {"Enabled": False}, + "InputSerialization": { + input_serialization: input_serialization_params, + "CompressionType": compression.upper() if compression else "NONE", + }, + "OutputSerialization": { + output_serialization: output_serialization_params, + }, + } + if s3_additional_kwargs: + args.update(s3_additional_kwargs) + _logger.debug("args:\n%s", pprint.pformat(args)) + + return _paginate_stream(args=args, path=path, use_threads=use_threads, boto3_session=boto3_session) From ee17882bb1b0a9775ba16eb119bbf318af95ea35 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 27 May 2021 16:14:19 +0100 Subject: [PATCH 2/4] Major - Tested --- awswrangler/athena/_read.py | 3 +- awswrangler/chime.py | 2 +- awswrangler/s3/_fs.py | 2 +- awswrangler/s3/_select.py | 159 +++++++++++++++++++++------ awswrangler/s3/_write_concurrent.py | 2 +- tests/test_s3_select.py | 161 ++++++++++++++++++++++++++++ 6 files changed, 289 insertions(+), 40 deletions(-) create mode 100644 tests/test_s3_select.py diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 1eba8c637..2f6753a4d 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -799,8 +799,7 @@ def read_sql_query( for key, value in params.items(): sql = sql.replace(f":{key};", str(value)) - if max_remote_cache_entries > max_local_cache_entries: - max_remote_cache_entries = max_local_cache_entries + max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries) _cache_manager.max_cache_size = max_local_cache_entries cache_info: _CacheInfo = _check_for_cached_results( diff --git a/awswrangler/chime.py b/awswrangler/chime.py index ddc0015a1..df5091732 100644 --- a/awswrangler/chime.py +++ b/awswrangler/chime.py @@ -29,7 +29,7 @@ def post_message(webhook: str, message: str) -> Optional[Any]: chime_message = {"Content": "Message: %s" % (message)} req = Request(webhook, json.dumps(chime_message).encode("utf-8")) try: - response = urlopen(req) + response = urlopen(req) # pylint: disable=R1732 _logger.info("Message posted on Chime. Got respone as %s", response.read()) except HTTPError as e: _logger.exception("Request failed: %d %s", e.code, e.reason) diff --git a/awswrangler/s3/_fs.py b/awswrangler/s3/_fs.py index dff6f84fe..fc099ac58 100644 --- a/awswrangler/s3/_fs.py +++ b/awswrangler/s3/_fs.py @@ -131,7 +131,7 @@ def __init__(self, use_threads: Union[bool, int]): self._results: List[Dict[str, Union[str, int]]] = [] self._cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) if self._cpus > 1: - self._exec = concurrent.futures.ThreadPoolExecutor(max_workers=self._cpus) + self._exec = concurrent.futures.ThreadPoolExecutor(max_workers=self._cpus) # pylint: disable=R1732 self._futures: List[Any] = [] else: self._exec = None diff --git a/awswrangler/s3/_select.py b/awswrangler/s3/_select.py index c788bf180..9d0bb028f 100644 --- a/awswrangler/s3/_select.py +++ b/awswrangler/s3/_select.py @@ -2,10 +2,10 @@ import concurrent.futures import itertools -from io import StringIO +import json import logging import pprint -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import boto3 import pandas as pd @@ -15,22 +15,36 @@ _logger: logging.Logger = logging.getLogger(__name__) -_RANGE_CHUNK_SIZE: int = 5_242_880 # 5 MB (5 * 2**20) +_RANGE_CHUNK_SIZE: int = int(1024 * 1024) def _select_object_content( - args: Dict[str, Any], scan_range: Optional[Tuple[int, int]], client_s3: Optional[boto3.Session] + args: Dict[str, Any], + client_s3: boto3.Session, + scan_range: Optional[Tuple[int, int]] = None, ) -> pd.DataFrame: if scan_range: args.update({"ScanRange": {"Start": scan_range[0], "End": scan_range[1]}}) response = client_s3.select_object_content(**args) - l: pd.DataFrame = [] - print(type(response["Payload"])) + + dfs: List[pd.DataFrame] = [] + partial_record: str = "" for event in response["Payload"]: - print(type(event)) if "Records" in event: - l.append(pd.read_csv(StringIO(event["Records"]["Payload"].decode("utf-8")))) - return pd.concat(l) + records = partial_record.join(event["Records"]["Payload"].decode(encoding="utf-8", errors="ignore")).split( + "\n" + ) + # Record end can either be a partial record or a return char + partial_record = records[-1] + dfs.append( + pd.DataFrame( + [json.loads(record) for record in records[:-1]], + ) + ) + if not dfs: + return pd.DataFrame() + return pd.concat(dfs, ignore_index=True) + def _paginate_stream( args: Dict[str, Any], path: str, use_threads: Union[bool, int], boto3_session: Optional[boto3.Session] @@ -43,67 +57,130 @@ def _paginate_stream( if obj_size is None: raise exceptions.InvalidArgumentValue(f"S3 object w/o defined size: {path}") + dfs: List[pd.Dataframe] = [] + client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) + scan_ranges: List[Tuple[int, int]] = [] for i in range(0, obj_size, _RANGE_CHUNK_SIZE): scan_ranges.append((i, i + min(_RANGE_CHUNK_SIZE, obj_size - i))) - dfs_iterator: List[pd.Dataframe] = [] - client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) if use_threads is False: - dfs_iterator = list( + dfs = list( _select_object_content( args=args, - scan_range=scan_range, client_s3=client_s3, + scan_range=scan_range, ) for scan_range in scan_ranges ) else: cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: - dfs_iterator = list( + dfs = list( executor.map( _select_object_content, itertools.repeat(args), - scan_ranges, itertools.repeat(client_s3), + scan_ranges, ) ) - return pd.concat([df for df in dfs_iterator]) + return pd.concat(dfs, ignore_index=True) -# TODO: clarify when to use @config (e.g. read_parquet vs read_parquet_table) -def select_query( # Read sql query or S3 select? Here or in a separate file? +def select_query( sql: str, path: str, input_serialization: str, - output_serialization: str, - input_serialization_params: Dict[str, Union[bool, str]] = {}, - output_serialization_params: Dict[str, str] = {}, + input_serialization_params: Dict[str, Union[bool, str]], compression: Optional[str] = None, use_threads: Union[bool, int] = False, boto3_session: Optional[boto3.Session] = None, - params: Optional[Dict[str, Any]] = None, s3_additional_kwargs: Optional[Dict[str, Any]] = None, -) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: +) -> pd.DataFrame: + r"""Filter contents of an Amazon S3 object based on SQL statement. + Note: Scan ranges are only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) + and JSON objects (in LINES mode only). It means scanning cannot be split across threads if the latter + conditions are not met, leading to lower performance. + + Parameters + ---------- + sql: str + SQL statement used to query the object. + path: str + S3 path to the object (e.g. s3://bucket/key). + input_serialization: str, + Format of the S3 object queried. + Valid values: "CSV", "JSON", or "Parquet". Case sensitive. + input_serialization_params: Dict[str, Union[bool, str]] + Dictionary describing the serialization of the S3 object. + compression: Optional[str] + Compression type of the S3 object. + Valid values: None, "gzip", or "bzip2". gzip and bzip2 are only valid for CSV and JSON objects. + use_threads : Union[bool, int] + True to enable concurrent requests, False to disable multiple threads. + If enabled os.cpu_count() is used as the max number of threads. + If integer is provided, specified number is used. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session is used if none is provided. + s3_additional_kwargs : Optional[Dict[str, Any]] + Forwarded to botocore requests. + Valid values: "SSECustomerAlgorithm", "SSECustomerKey", "ExpectedBucketOwner". + e.g. s3_additional_kwargs={'SSECustomerAlgorithm': 'md5'} + + Returns + ------- + pandas.DataFrame + Pandas DataFrame with results from query. + + Examples + -------- + Reading a gzip compressed JSON document + + >>> import awswrangler as wr + >>> df = wr.s3.select_query( + ... sql='SELECT * FROM s3object[*][*]', + ... path='s3://bucket/key.json.gzip', + ... input_serialization='JSON', + ... input_serialization_params={ + ... 'Type': 'Document', + ... }, + ... compression="gzip", + ... ) + + Reading an entire CSV object using threads + + >>> import awswrangler as wr + >>> df = wr.s3.select_query( + ... sql='SELECT * FROM s3object', + ... path='s3://bucket/key.csv', + ... input_serialization='CSV', + ... input_serialization_params={ + ... 'FileHeaderInfo': 'Use', + ... 'RecordDelimiter': '\r\n' + ... }, + ... use_threads=True, + ... ) + + Reading a single column from Parquet object with pushdown filter + + >>> import awswrangler as wr + >>> df = wr.s3.select_query( + ... sql='SELECT s.\"id\" FROM s3object s where s.\"id\" = 1.0', + ... path='s3://bucket/key.snappy.parquet', + ... input_serialization='Parquet', + ... ) + """ if path.endswith("/"): raise exceptions.InvalidArgumentValue(" argument should be an S3 key, not a prefix.") if input_serialization not in ["CSV", "JSON", "Parquet"]: raise exceptions.InvalidArgumentValue(" argument must be 'CSV', 'JSON' or 'Parquet'") if compression not in [None, "gzip", "bzip2"]: raise exceptions.InvalidCompression(f"Invalid {compression} compression, please use None, 'gzip' or 'bzip2'.") - else: - if compression and (input_serialization not in ["CSV", "JSON"]): - raise exceptions.InvalidArgumentCombination( - "'gzip' or 'bzip2' are only valid for input 'CSV' or 'JSON' objects." - ) - if output_serialization not in [None, "CSV", "JSON"]: - raise exceptions.InvalidArgumentValue(" argument must be None, 'csv' or 'json'") - if params is None: - params = {} - for key, value in params.items(): - sql = sql.replace(f":{key};", str(value)) + if compression and (input_serialization not in ["CSV", "JSON"]): + raise exceptions.InvalidArgumentCombination( + "'gzip' or 'bzip2' are only valid for input 'CSV' or 'JSON' objects." + ) bucket, key = _utils.parse_path(path) args: Dict[str, Any] = { @@ -117,11 +194,23 @@ def select_query( # Read sql query or S3 select? Here or in a separate file? "CompressionType": compression.upper() if compression else "NONE", }, "OutputSerialization": { - output_serialization: output_serialization_params, + "JSON": {}, }, } if s3_additional_kwargs: args.update(s3_additional_kwargs) _logger.debug("args:\n%s", pprint.pformat(args)) + if any( + [ + compression, + input_serialization_params.get("AllowQuotedRecordDelimiter"), + input_serialization_params.get("Type") == "Document", + ] + ): # Scan range is only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) + # and JSON objects (in LINES mode only) + _logger.debug("Scan ranges are not supported given provided input.") + client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) + return _select_object_content(args=args, client_s3=client_s3) + return _paginate_stream(args=args, path=path, use_threads=use_threads, boto3_session=boto3_session) diff --git a/awswrangler/s3/_write_concurrent.py b/awswrangler/s3/_write_concurrent.py index a2fc7e8fc..ab5061227 100644 --- a/awswrangler/s3/_write_concurrent.py +++ b/awswrangler/s3/_write_concurrent.py @@ -18,7 +18,7 @@ def __init__(self, use_threads: bool): self._results: List[str] = [] self._cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) if self._cpus > 1: - self._exec = concurrent.futures.ThreadPoolExecutor(max_workers=self._cpus) + self._exec = concurrent.futures.ThreadPoolExecutor(max_workers=self._cpus) # pylint: disable=R1732 self._futures: List[Any] = [] else: self._exec = None diff --git a/tests/test_s3_select.py b/tests/test_s3_select.py new file mode 100644 index 000000000..46218182c --- /dev/null +++ b/tests/test_s3_select.py @@ -0,0 +1,161 @@ +import logging + +import pandas as pd +import pytest + +import awswrangler as wr + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +@pytest.mark.parametrize("use_threads", [True, False, 2]) +def test_full_table(path, use_threads): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]}) + + # Parquet + file_path = f"{path}test_parquet_file.snappy.parquet" + wr.s3.to_parquet(df, file_path, compression="snappy") + df2 = wr.s3.select_query( + sql="select * from s3object", + path=file_path, + input_serialization="Parquet", + input_serialization_params={}, + use_threads=use_threads, + ) + assert df.equals(df2) + + # CSV + file_path = f"{path}test_csv_file.csv" + wr.s3.to_csv(df, file_path, index=False) + df3 = wr.s3.select_query( + sql="select * from s3object", + path=file_path, + input_serialization="CSV", + input_serialization_params={"FileHeaderInfo": "Use", "RecordDelimiter": "\n"}, + use_threads=use_threads, + ) + assert len(df.index) == len(df3.index) + assert list(df.columns) == list(df3.columns) + assert df.shape == df3.shape + + # JSON + file_path = f"{path}test_json_file.json" + wr.s3.to_json(df, file_path, orient="records") + df4 = wr.s3.select_query( + sql="select * from s3object[*][*]", + path=file_path, + input_serialization="JSON", + input_serialization_params={"Type": "Document"}, + use_threads=use_threads, + ) + assert df.equals(df4) + + +@pytest.mark.parametrize("use_threads", [True, False, 2]) +def test_push_down(path, use_threads): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]}) + + file_path = f"{path}test_parquet_file.snappy.parquet" + wr.s3.to_parquet(df, file_path, compression="snappy") + df2 = wr.s3.select_query( + sql='select * from s3object s where s."c0" = 1', + path=file_path, + input_serialization="Parquet", + input_serialization_params={}, + use_threads=use_threads, + ) + assert df2.shape == (1, 3) + assert df2.c0.sum() == 1 + + file_path = f"{path}test_parquet_file.gzip.parquet" + wr.s3.to_parquet(df, file_path, compression="gzip") + df2 = wr.s3.select_query( + sql='select * from s3object s where s."c0" = 99', + path=file_path, + input_serialization="Parquet", + input_serialization_params={}, + use_threads=use_threads, + ) + assert df2.shape == (0, 0) + + file_path = f"{path}test_csv_file.csv" + wr.s3.to_csv(df, file_path, header=False, index=False) + df3 = wr.s3.select_query( + sql='select s."_1" from s3object s limit 2', + path=file_path, + input_serialization="CSV", + input_serialization_params={"FileHeaderInfo": "None", "RecordDelimiter": "\n"}, + use_threads=use_threads, + ) + assert df3.shape == (2, 1) + + file_path = f"{path}test_json_file.json" + wr.s3.to_json(df, file_path, orient="records") + df4 = wr.s3.select_query( + sql="select count(*) from s3object[*][*]", + path=file_path, + input_serialization="JSON", + input_serialization_params={"Type": "Document"}, + use_threads=use_threads, + ) + assert df4.shape == (1, 1) + assert df4._1.sum() == 3 + + +@pytest.mark.parametrize("compression", ["gzip", "bz2"]) +def test_compression(path, compression): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]}) + + # CSV + file_path = f"{path}test_csv_file.csv" + wr.s3.to_csv(df, file_path, index=False, compression=compression) + df2 = wr.s3.select_query( + sql="select * from s3object", + path=file_path, + input_serialization="CSV", + input_serialization_params={"FileHeaderInfo": "Use", "RecordDelimiter": "\n"}, + compression="bzip2" if compression == "bz2" else compression, + use_threads=False, + ) + assert len(df.index) == len(df2.index) + assert list(df.columns) == list(df2.columns) + assert df.shape == df2.shape + + # JSON + file_path = f"{path}test_json_file.json" + wr.s3.to_json(df, file_path, orient="records", compression=compression) + df3 = wr.s3.select_query( + sql="select * from s3object[*][*]", + path=file_path, + input_serialization="JSON", + input_serialization_params={"Type": "Document"}, + compression="bzip2" if compression == "bz2" else compression, + use_threads=False, + ) + assert df.equals(df3) + + +@pytest.mark.parametrize( + "s3_additional_kwargs", + [None, {"ServerSideEncryption": "AES256"}, {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": None}], +) +def test_encryption(path, kms_key_id, s3_additional_kwargs): + if s3_additional_kwargs is not None and "SSEKMSKeyId" in s3_additional_kwargs: + s3_additional_kwargs["SSEKMSKeyId"] = kms_key_id + + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]}) + file_path = f"{path}test_parquet_file.snappy.parquet" + wr.s3.to_parquet( + df, + file_path, + compression="snappy", + s3_additional_kwargs=s3_additional_kwargs, + ) + df2 = wr.s3.select_query( + sql="select * from s3object", + path=file_path, + input_serialization="Parquet", + input_serialization_params={}, + use_threads=False, + ) + assert df.equals(df2) From a1e6b4de7e571db43dfe72b3979e749c1917d594 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 27 May 2021 23:45:01 +0100 Subject: [PATCH 3/4] Minor - Fixing delimiter split --- awswrangler/s3/_select.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/awswrangler/s3/_select.py b/awswrangler/s3/_select.py index 9d0bb028f..c18f040e5 100644 --- a/awswrangler/s3/_select.py +++ b/awswrangler/s3/_select.py @@ -31,14 +31,13 @@ def _select_object_content( partial_record: str = "" for event in response["Payload"]: if "Records" in event: - records = partial_record.join(event["Records"]["Payload"].decode(encoding="utf-8", errors="ignore")).split( - "\n" - ) + records = partial_record + event["Records"]["Payload"].decode(encoding="utf-8", errors="ignore") + split_records = records.split("\n") # Record end can either be a partial record or a return char - partial_record = records[-1] + partial_record = split_records[-1] dfs.append( pd.DataFrame( - [json.loads(record) for record in records[:-1]], + [json.loads(record) for record in split_records[:-1]], ) ) if not dfs: From b5cfeca7868fbae149156b0dcbb060e41a7593d7 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Fri, 28 May 2021 17:57:28 +0100 Subject: [PATCH 4/4] [skip ci] - Minor - Addressing comments --- awswrangler/s3/_select.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/awswrangler/s3/_select.py b/awswrangler/s3/_select.py index c18f040e5..399d75278 100644 --- a/awswrangler/s3/_select.py +++ b/awswrangler/s3/_select.py @@ -5,7 +5,7 @@ import json import logging import pprint -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import boto3 import pandas as pd @@ -18,26 +18,32 @@ _RANGE_CHUNK_SIZE: int = int(1024 * 1024) +def _gen_scan_range(obj_size: int) -> Iterator[Tuple[int, int]]: + for i in range(0, obj_size, _RANGE_CHUNK_SIZE): + yield (i, i + min(_RANGE_CHUNK_SIZE, obj_size - i)) + + def _select_object_content( args: Dict[str, Any], client_s3: boto3.Session, scan_range: Optional[Tuple[int, int]] = None, ) -> pd.DataFrame: if scan_range: - args.update({"ScanRange": {"Start": scan_range[0], "End": scan_range[1]}}) - response = client_s3.select_object_content(**args) + response = client_s3.select_object_content(**args, ScanRange={"Start": scan_range[0], "End": scan_range[1]}) + else: + response = client_s3.select_object_content(**args) dfs: List[pd.DataFrame] = [] partial_record: str = "" for event in response["Payload"]: if "Records" in event: - records = partial_record + event["Records"]["Payload"].decode(encoding="utf-8", errors="ignore") - split_records = records.split("\n") + records = event["Records"]["Payload"].decode(encoding="utf-8", errors="ignore").split("\n") + records[0] = partial_record + records[0] # Record end can either be a partial record or a return char - partial_record = split_records[-1] + partial_record = records.pop() dfs.append( pd.DataFrame( - [json.loads(record) for record in split_records[:-1]], + [json.loads(record) for record in records], ) ) if not dfs: @@ -59,10 +65,6 @@ def _paginate_stream( dfs: List[pd.Dataframe] = [] client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) - scan_ranges: List[Tuple[int, int]] = [] - for i in range(0, obj_size, _RANGE_CHUNK_SIZE): - scan_ranges.append((i, i + min(_RANGE_CHUNK_SIZE, obj_size - i))) - if use_threads is False: dfs = list( _select_object_content( @@ -70,7 +72,7 @@ def _paginate_stream( client_s3=client_s3, scan_range=scan_range, ) - for scan_range in scan_ranges + for scan_range in _gen_scan_range(obj_size=obj_size) ) else: cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) @@ -80,7 +82,7 @@ def _paginate_stream( _select_object_content, itertools.repeat(args), itertools.repeat(client_s3), - scan_ranges, + _gen_scan_range(obj_size=obj_size), ) ) return pd.concat(dfs, ignore_index=True)