diff --git a/awswrangler/distributed/ray/_register.py b/awswrangler/distributed/ray/_register.py index a9a7cdbf0..3120b78f6 100644 --- a/awswrangler/distributed/ray/_register.py +++ b/awswrangler/distributed/ray/_register.py @@ -6,6 +6,7 @@ from awswrangler.distributed.ray import ray_remote from awswrangler.lakeformation._read import _get_work_unit_results from awswrangler.s3._delete import _delete_objects +from awswrangler.s3._describe import _describe_object from awswrangler.s3._read_parquet import _read_parquet, _read_parquet_metadata_file from awswrangler.s3._read_text import _read_text from awswrangler.s3._select import _select_object_content, _select_query @@ -20,6 +21,7 @@ def register_ray() -> None: """Register dispatched Ray and Modin (on Ray) methods.""" for func in [ _get_work_unit_results, + _describe_object, _delete_objects, _read_parquet_metadata_file, _select_query, diff --git a/awswrangler/s3/_describe.py b/awswrangler/s3/_describe.py index 3728d5245..666596b90 100644 --- a/awswrangler/s3/_describe.py +++ b/awswrangler/s3/_describe.py @@ -1,6 +1,5 @@ """Amazon S3 Describe Module (INTERNAL).""" -import concurrent.futures import datetime import itertools import logging @@ -9,15 +8,19 @@ import boto3 from awswrangler import _utils +from awswrangler._distributed import engine +from awswrangler._threading import _get_executor +from awswrangler.distributed.ray import ray_get from awswrangler.s3 import _fs from awswrangler.s3._list import _path2list _logger: logging.Logger = logging.getLogger(__name__) +@engine.dispatch_on_engine def _describe_object( - path: str, boto3_session: boto3.Session, + path: str, s3_additional_kwargs: Optional[Dict[str, Any]], version_id: Optional[str] = None, ) -> Tuple[str, Dict[str, Any]]: @@ -40,18 +43,6 @@ def _describe_object( return path, desc -def _describe_object_concurrent( - path: str, - boto3_primitives: _utils.Boto3PrimitivesType, - s3_additional_kwargs: Optional[Dict[str, Any]], - version_id: Optional[str] = None, -) -> Tuple[str, Dict[str, Any]]: - boto3_session = _utils.boto3_from_primitives(primitives=boto3_primitives) - return _describe_object( - path=path, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs, version_id=version_id - ) - - def describe_objects( path: Union[str, List[str]], version_id: Optional[Union[str, Dict[str, str]]] = None, @@ -127,41 +118,22 @@ def describe_objects( last_modified_end=last_modified_end, s3_additional_kwargs=s3_additional_kwargs, ) + if len(paths) < 1: return {} resp_list: List[Tuple[str, Dict[str, Any]]] - if len(paths) == 1: - resp_list = [ - _describe_object( - path=paths[0], - version_id=version_id.get(paths[0]) if isinstance(version_id, dict) else version_id, - boto3_session=boto3_session, - s3_additional_kwargs=s3_additional_kwargs, - ) - ] - elif use_threads is False: - resp_list = [ - _describe_object( - path=p, - version_id=version_id.get(p) if isinstance(version_id, dict) else version_id, - boto3_session=boto3_session, - s3_additional_kwargs=s3_additional_kwargs, - ) - for p in paths - ] - else: - cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) - versions = [version_id.get(p) if isinstance(version_id, dict) else version_id for p in paths] - with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: - resp_list = list( - executor.map( - _describe_object_concurrent, - paths, - versions, - itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), - itertools.repeat(s3_additional_kwargs), - ) - ) + + executor = _get_executor(use_threads=use_threads) + resp_list = ray_get( + executor.map( + _describe_object, + boto3_session, + paths, + itertools.repeat(s3_additional_kwargs), + [version_id.get(p) if isinstance(version_id, dict) else version_id for p in paths], + ) + ) + desc_dict: Dict[str, Dict[str, Any]] = dict(resp_list) return desc_dict diff --git a/awswrangler/s3/_read.py b/awswrangler/s3/_read.py index c4e9b3cc1..2147d1fe6 100644 --- a/awswrangler/s3/_read.py +++ b/awswrangler/s3/_read.py @@ -1,8 +1,6 @@ """Amazon S3 Read Module (PRIVATE).""" -import concurrent.futures import logging -from functools import partial from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast import numpy as np @@ -11,7 +9,6 @@ from awswrangler import exceptions from awswrangler._arrow import _extract_partitions_from_path -from awswrangler._utils import boto3_to_primitives, ensure_cpu_count from awswrangler.s3._list import _prefix_cleanup _logger: logging.Logger = logging.getLogger(__name__) @@ -110,28 +107,3 @@ def _union(dfs: List[pd.DataFrame], ignore_index: Optional[bool]) -> pd.DataFram for df in dfs: df[col] = pd.Categorical(df[col].values, categories=cat.categories) return pd.concat(objs=dfs, sort=False, copy=False, ignore_index=ignore_index) - - -def _read_dfs_from_multiple_paths( - read_func: Callable[..., pd.DataFrame], - paths: List[str], - version_ids: Optional[Dict[str, str]], - use_threads: Union[bool, int], - kwargs: Dict[str, Any], -) -> List[pd.DataFrame]: - cpus = ensure_cpu_count(use_threads) - if cpus < 2: - return [ - read_func( - path, - version_id=version_ids.get(path) if version_ids else None, - **kwargs, - ) - for path in paths - ] - - with concurrent.futures.ThreadPoolExecutor(max_workers=ensure_cpu_count(use_threads)) as executor: - kwargs["boto3_session"] = boto3_to_primitives(kwargs["boto3_session"]) - partial_read_func = partial(read_func, **kwargs) - versions = [version_ids.get(p) if isinstance(version_ids, dict) else None for p in paths] - return list(df for df in executor.map(partial_read_func, paths, versions))