diff --git a/awswrangler/distributed/ray/_register.py b/awswrangler/distributed/ray/_register.py index 066e9bf8a..1ded3fde8 100644 --- a/awswrangler/distributed/ray/_register.py +++ b/awswrangler/distributed/ray/_register.py @@ -6,7 +6,6 @@ from awswrangler.distributed.ray import ray_remote from awswrangler.lakeformation._read import _get_work_unit_results from awswrangler.s3._delete import _delete_objects -from awswrangler.s3._describe import _describe_object from awswrangler.s3._read_parquet import _read_parquet, _read_parquet_metadata_file from awswrangler.s3._read_text import _read_text from awswrangler.s3._select import _select_object_content, _select_query @@ -21,7 +20,6 @@ def register_ray() -> None: """Register dispatched Ray and Modin (on Ray) methods.""" for func in [ _get_work_unit_results, - _describe_object, _delete_objects, _read_parquet_metadata_file, _select_query, diff --git a/awswrangler/s3/_describe.py b/awswrangler/s3/_describe.py index 666596b90..3728d5245 100644 --- a/awswrangler/s3/_describe.py +++ b/awswrangler/s3/_describe.py @@ -1,5 +1,6 @@ """Amazon S3 Describe Module (INTERNAL).""" +import concurrent.futures import datetime import itertools import logging @@ -8,19 +9,15 @@ import boto3 from awswrangler import _utils -from awswrangler._distributed import engine -from awswrangler._threading import _get_executor -from awswrangler.distributed.ray import ray_get from awswrangler.s3 import _fs from awswrangler.s3._list import _path2list _logger: logging.Logger = logging.getLogger(__name__) -@engine.dispatch_on_engine def _describe_object( - boto3_session: boto3.Session, path: str, + boto3_session: boto3.Session, s3_additional_kwargs: Optional[Dict[str, Any]], version_id: Optional[str] = None, ) -> Tuple[str, Dict[str, Any]]: @@ -43,6 +40,18 @@ def _describe_object( return path, desc +def _describe_object_concurrent( + path: str, + boto3_primitives: _utils.Boto3PrimitivesType, + s3_additional_kwargs: Optional[Dict[str, Any]], + version_id: Optional[str] = None, +) -> Tuple[str, Dict[str, Any]]: + boto3_session = _utils.boto3_from_primitives(primitives=boto3_primitives) + return _describe_object( + path=path, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs, version_id=version_id + ) + + def describe_objects( path: Union[str, List[str]], version_id: Optional[Union[str, Dict[str, str]]] = None, @@ -118,22 +127,41 @@ def describe_objects( last_modified_end=last_modified_end, s3_additional_kwargs=s3_additional_kwargs, ) - if len(paths) < 1: return {} resp_list: List[Tuple[str, Dict[str, Any]]] - - executor = _get_executor(use_threads=use_threads) - resp_list = ray_get( - executor.map( - _describe_object, - boto3_session, - paths, - itertools.repeat(s3_additional_kwargs), - [version_id.get(p) if isinstance(version_id, dict) else version_id for p in paths], - ) - ) - + if len(paths) == 1: + resp_list = [ + _describe_object( + path=paths[0], + version_id=version_id.get(paths[0]) if isinstance(version_id, dict) else version_id, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + ) + ] + elif use_threads is False: + resp_list = [ + _describe_object( + path=p, + version_id=version_id.get(p) if isinstance(version_id, dict) else version_id, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + ) + for p in paths + ] + else: + cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) + versions = [version_id.get(p) if isinstance(version_id, dict) else version_id for p in paths] + with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: + resp_list = list( + executor.map( + _describe_object_concurrent, + paths, + versions, + itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), + itertools.repeat(s3_additional_kwargs), + ) + ) desc_dict: Dict[str, Dict[str, Any]] = dict(resp_list) return desc_dict diff --git a/awswrangler/s3/_read.py b/awswrangler/s3/_read.py index 2147d1fe6..c4e9b3cc1 100644 --- a/awswrangler/s3/_read.py +++ b/awswrangler/s3/_read.py @@ -1,6 +1,8 @@ """Amazon S3 Read Module (PRIVATE).""" +import concurrent.futures import logging +from functools import partial from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast import numpy as np @@ -9,6 +11,7 @@ from awswrangler import exceptions from awswrangler._arrow import _extract_partitions_from_path +from awswrangler._utils import boto3_to_primitives, ensure_cpu_count from awswrangler.s3._list import _prefix_cleanup _logger: logging.Logger = logging.getLogger(__name__) @@ -107,3 +110,28 @@ def _union(dfs: List[pd.DataFrame], ignore_index: Optional[bool]) -> pd.DataFram for df in dfs: df[col] = pd.Categorical(df[col].values, categories=cat.categories) return pd.concat(objs=dfs, sort=False, copy=False, ignore_index=ignore_index) + + +def _read_dfs_from_multiple_paths( + read_func: Callable[..., pd.DataFrame], + paths: List[str], + version_ids: Optional[Dict[str, str]], + use_threads: Union[bool, int], + kwargs: Dict[str, Any], +) -> List[pd.DataFrame]: + cpus = ensure_cpu_count(use_threads) + if cpus < 2: + return [ + read_func( + path, + version_id=version_ids.get(path) if version_ids else None, + **kwargs, + ) + for path in paths + ] + + with concurrent.futures.ThreadPoolExecutor(max_workers=ensure_cpu_count(use_threads)) as executor: + kwargs["boto3_session"] = boto3_to_primitives(kwargs["boto3_session"]) + partial_read_func = partial(read_func, **kwargs) + versions = [version_ids.get(p) if isinstance(version_ids, dict) else None for p in paths] + return list(df for df in executor.map(partial_read_func, paths, versions))