Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8fd4966
initial draft
luigift Apr 18, 2020
863ba26
adding Pytorch as a development dependency
igorborgest Apr 19, 2020
2864dc0
Cleaning up initial draft
igorborgest Apr 19, 2020
4fed4c7
Add first test
igorborgest Apr 19, 2020
72c739c
add audio and image dataset
luigift Apr 19, 2020
f72810e
Add label_col to torch.SQLDataset
igorborgest Apr 20, 2020
bf1be07
Updating catersian product of pytest parameters
igorborgest Apr 20, 2020
1a41d18
Pivoting SQLDataset parser strategy to avoid cast losses.
igorborgest Apr 20, 2020
36c15e4
tested lambda & image datasets
luigift Apr 20, 2020
d4dcfc5
add audio test
luigift Apr 20, 2020
30dc2fa
Add test for torch.AudioS3Dataset
igorborgest Apr 20, 2020
5a9a83f
s3 iterable dataset
luigift Apr 22, 2020
60232f4
add tutorial draft
luigift Apr 23, 2020
215fbd5
add torch extras_requirements to setuptools
luigift Apr 23, 2020
0ad9e4b
handle labels in S3IterableDataset
luigift Apr 23, 2020
5e72ddf
clear bucket in S3Iterable Dataset test
luigift Apr 23, 2020
5b399ac
update setuptools
luigift Apr 23, 2020
2db15b6
update pytorch tutorial
luigift Apr 23, 2020
5e647c6
Update tutorial
igorborgest Apr 23, 2020
b3d9fe2
parallel tests fix
luigift Apr 23, 2020
c091fa8
fix lint
luigift Apr 24, 2020
37b7f1e
update readme
luigift Apr 24, 2020
33d74c4
remove captalized requirement from docstring
luigift Apr 24, 2020
4b05b36
add torch requirements
luigift Apr 24, 2020
86cdb30
fix init and docs
luigift Apr 26, 2020
b3c8c81
update tutorial
luigift Apr 26, 2020
f6927a4
rollback pytorch==1.5.0, due to torchaudio requirement
luigift Apr 27, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/static-checking.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
pip install -r requirements-torch.txt
- name: CloudFormation Lint
run: cfn-lint -t testing/cloudformation.yaml
- name: Documentation Lint
run: pydocstyle awswrangler/ --add-ignore=D204
run: pydocstyle awswrangler/ --add-ignore=D204,D403
- name: mypy check
run: mypy awswrangler
- name: Flake8 Lint
Expand Down
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ disable=print-statement,
comprehension-escape,
C0330,
C0103,
W1202
W1202,
too-few-public-methods

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ df = wr.db.read_sql_query("SELECT * FROM external_schema.my_table", con=engine)
- [11 - CSV Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/11%20-%20CSV%20Datasets.ipynb)
- [12 - CSV Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/12%20-%20CSV%20Crawler.ipynb)
- [13 - Merging Datasets on S3](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/13%20-%20Merging%20Datasets%20on%20S3.ipynb)
- [14 - PyTorch](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/14%20-%20PyTorch.ipynb)
- [**API Reference**](https://aws-data-wrangler.readthedocs.io/en/latest/api.html)
- [Amazon S3](https://aws-data-wrangler.readthedocs.io/en/latest/api.html#amazon-s3)
- [AWS Glue Catalog](https://aws-data-wrangler.readthedocs.io/en/latest/api.html#aws-glue-catalog)
Expand Down
9 changes: 9 additions & 0 deletions awswrangler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@

"""

import importlib
import logging

from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, s3 # noqa
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa

if (
importlib.util.find_spec("torch")
and importlib.util.find_spec("torchvision")
and importlib.util.find_spec("torchaudio")
and importlib.util.find_spec("PIL")
): # type: ignore
from awswrangler import torch # noqa

logging.getLogger("awswrangler").addHandler(logging.NullHandler())
45 changes: 27 additions & 18 deletions awswrangler/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,29 +155,15 @@ def read_sql_query(
... )

"""
if not isinstance(con, sqlalchemy.engine.Engine): # pragma: no cover
raise exceptions.InvalidConnection(
"Invalid 'con' argument, please pass a "
"SQLAlchemy Engine. Use wr.db.get_engine(), "
"wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()"
)
_validate_engine(con=con)
with con.connect() as _con:
args = _convert_params(sql, params)
cursor = _con.execute(*args)
if chunksize is None:
return _records2df(records=cursor.fetchall(), cols_names=cursor.keys(), index=index_col, dtype=dtype)
return _iterate_cursor(cursor=cursor, chunksize=chunksize, index=index_col, dtype=dtype)


def _iterate_cursor(
cursor, chunksize: int, index: Optional[Union[str, List[str]]], dtype: Optional[Dict[str, pa.DataType]] = None
) -> Iterator[pd.DataFrame]:
while True:
records = cursor.fetchmany(chunksize)
if not records:
break
df: pd.DataFrame = _records2df(records=records, cols_names=cursor.keys(), index=index, dtype=dtype)
yield df
return _iterate_cursor(
cursor=cursor, chunksize=chunksize, cols_names=cursor.keys(), index=index_col, dtype=dtype
)


def _records2df(
Expand Down Expand Up @@ -207,6 +193,20 @@ def _records2df(
return df


def _iterate_cursor(
cursor: Any,
chunksize: int,
cols_names: List[str],
index: Optional[Union[str, List[str]]],
dtype: Optional[Dict[str, pa.DataType]] = None,
) -> Iterator[pd.DataFrame]:
while True:
records = cursor.fetchmany(chunksize)
if not records:
break
yield _records2df(records=records, cols_names=cols_names, index=index, dtype=dtype)


def _convert_params(sql: str, params: Optional[Union[List, Tuple, Dict]]) -> List[Any]:
args: List[Any] = [sql]
if params is not None:
Expand Down Expand Up @@ -1087,3 +1087,12 @@ def unload_redshift_to_files(
paths = [x[0].replace(" ", "") for x in _con.execute(sql).fetchall()]
_logger.debug(f"paths: {paths}")
return paths


def _validate_engine(con: sqlalchemy.engine.Engine) -> None: # pragma: no cover
if not isinstance(con, sqlalchemy.engine.Engine):
raise exceptions.InvalidConnection(
"Invalid 'con' argument, please pass a "
"SQLAlchemy Engine. Use wr.db.get_engine(), "
"wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()"
)
11 changes: 7 additions & 4 deletions awswrangler/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def does_object_exist(path: str, boto3_session: Optional[boto3.Session] = None)
raise ex # pragma: no cover


def list_objects(path: str, boto3_session: Optional[boto3.Session] = None) -> List[str]:
def list_objects(path: str, suffix: Optional[str] = None, boto3_session: Optional[boto3.Session] = None) -> List[str]:
"""List Amazon S3 objects from a prefix.

Parameters
Expand All @@ -120,6 +120,8 @@ def list_objects(path: str, boto3_session: Optional[boto3.Session] = None) -> Li
S3 path (e.g. s3://bucket/prefix).
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
suffix: str, optional
Suffix for filtering S3 keys.

Returns
-------
Expand Down Expand Up @@ -155,15 +157,16 @@ def list_objects(path: str, boto3_session: Optional[boto3.Session] = None) -> Li
for content in contents:
if (content is not None) and ("Key" in content):
key: str = content["Key"]
paths.append(f"s3://{bucket}/{key}")
if (suffix is None) or key.endswith(suffix):
paths.append(f"s3://{bucket}/{key}")
return paths


def _path2list(path: Union[str, List[str]], boto3_session: Optional[boto3.Session]) -> List[str]:
def _path2list(path: object, boto3_session: boto3.Session, suffix: str = None) -> List[str]:
if isinstance(path, str): # prefix
paths: List[str] = list_objects(path=path, boto3_session=boto3_session)
elif isinstance(path, list):
paths = path
paths = path if suffix is None else [x for x in path if x.endswith(suffix)]
else:
raise exceptions.InvalidArgumentType(f"{type(path)} is not a valid path type. Please, use str or List[str].")
return paths
Expand Down
Loading