Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8fd4966
initial draft
luigift Apr 18, 2020
863ba26
adding Pytorch as a development dependency
igorborgest Apr 19, 2020
2864dc0
Cleaning up initial draft
igorborgest Apr 19, 2020
4fed4c7
Add first test
igorborgest Apr 19, 2020
72c739c
add audio and image dataset
luigift Apr 19, 2020
f72810e
Add label_col to torch.SQLDataset
igorborgest Apr 20, 2020
bf1be07
Updating catersian product of pytest parameters
igorborgest Apr 20, 2020
1a41d18
Pivoting SQLDataset parser strategy to avoid cast losses.
igorborgest Apr 20, 2020
36c15e4
tested lambda & image datasets
luigift Apr 20, 2020
d4dcfc5
add audio test
luigift Apr 20, 2020
30dc2fa
Add test for torch.AudioS3Dataset
igorborgest Apr 20, 2020
5a9a83f
s3 iterable dataset
luigift Apr 22, 2020
60232f4
add tutorial draft
luigift Apr 23, 2020
215fbd5
add torch extras_requirements to setuptools
luigift Apr 23, 2020
0ad9e4b
handle labels in S3IterableDataset
luigift Apr 23, 2020
5e72ddf
clear bucket in S3Iterable Dataset test
luigift Apr 23, 2020
5b399ac
update setuptools
luigift Apr 23, 2020
2db15b6
update pytorch tutorial
luigift Apr 23, 2020
5e647c6
Update tutorial
igorborgest Apr 23, 2020
b3d9fe2
parallel tests fix
luigift Apr 23, 2020
c091fa8
fix lint
luigift Apr 24, 2020
37b7f1e
update readme
luigift Apr 24, 2020
33d74c4
remove captalized requirement from docstring
luigift Apr 24, 2020
4b05b36
add torch requirements
luigift Apr 24, 2020
86cdb30
fix init and docs
luigift Apr 26, 2020
b3c8c81
update tutorial
luigift Apr 26, 2020
f6927a4
rollback pytorch==1.5.0, due to torchaudio requirement
luigift Apr 27, 2020
3103e34
Merge branch 'dev' into pytorch
igorborgest Apr 27, 2020
7fd449e
Adapting to validations
igorborgest Apr 27, 2020
fd115d8
Bumping dev dependencies
igorborgest Apr 27, 2020
8fad37c
Bumping PyTorch libs versions
igorborgest Apr 27, 2020
85bfade
Replacing all f-string on logging commands
igorborgest Apr 27, 2020
910e3b6
100% test coverage on wr.torch
igorborgest Apr 27, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .github/workflows/static-checking.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,12 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Setup Environment
run: ./setup-dev-env.sh
- name: CloudFormation Lint
run: cfn-lint -t testing/cloudformation.yaml
- name: Documentation Lint
run: pydocstyle awswrangler/ --add-ignore=D204
run: pydocstyle awswrangler/ --add-ignore=D204,D403
- name: mypy check
run: mypy awswrangler
- name: Flake8 Lint
Expand Down
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ disable=print-statement,
comprehension-escape,
C0330,
C0103,
W1202
W1202,
too-few-public-methods

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ df = wr.db.read_sql_query("SELECT * FROM external_schema.my_table", con=engine)
- [11 - CSV Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/11%20-%20CSV%20Datasets.ipynb)
- [12 - CSV Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/12%20-%20CSV%20Crawler.ipynb)
- [13 - Merging Datasets on S3](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/13%20-%20Merging%20Datasets%20on%20S3.ipynb)
- [14 - PyTorch](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/14%20-%20PyTorch.ipynb)
- [15 - EMR](https://github.com/awslabs/aws-data-wrangler/blob/dev/tutorials/15%20-%20EMR.ipynb)
- [16 - EMR & Docker](https://github.com/awslabs/aws-data-wrangler/blob/dev/tutorials/16%20-%20EMR%20%26%20Docker.ipynb)
- [**API Reference**](https://aws-data-wrangler.readthedocs.io/en/latest/api.html)
Expand Down
4 changes: 4 additions & 0 deletions awswrangler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
"""

import logging
from importlib.util import find_spec

from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, s3 # noqa
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
from awswrangler._utils import get_account_id # noqa

if find_spec("torch") and find_spec("torchvision") and find_spec("torchaudio") and find_spec("PIL"):
from awswrangler import torch # noqa

logging.getLogger("awswrangler").addHandler(logging.NullHandler())
14 changes: 7 additions & 7 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
return sqlalchemy.types.Date
if pa.types.is_binary(dtype):
if db_type == "redshift":
raise exceptions.UnsupportedType(f"Binary columns are not supported for Redshift.") # pragma: no cover
raise exceptions.UnsupportedType("Binary columns are not supported for Redshift.") # pragma: no cover
return sqlalchemy.types.Binary
if pa.types.is_decimal(dtype):
return sqlalchemy.types.Numeric(precision=dtype.precision, scale=dtype.scale)
Expand Down Expand Up @@ -257,7 +257,7 @@ def pyarrow_types_from_pandas(
# Filling schema
columns_types: Dict[str, pa.DataType]
columns_types = {n: cols_dtypes[n] for n in sorted_cols}
_logger.debug(f"columns_types: {columns_types}")
_logger.debug("columns_types: %s", columns_types)
return columns_types


Expand All @@ -275,7 +275,7 @@ def athena_types_from_pandas(
athena_columns_types[k] = casts[k]
else:
athena_columns_types[k] = pyarrow2athena(dtype=v)
_logger.debug(f"athena_columns_types: {athena_columns_types}")
_logger.debug("athena_columns_types: %s", athena_columns_types)
return athena_columns_types


Expand Down Expand Up @@ -315,7 +315,7 @@ def pyarrow_schema_from_pandas(
if (k in df.columns) and (k not in ignore):
columns_types[k] = athena2pyarrow(v)
columns_types = {k: v for k, v in columns_types.items() if v is not None}
_logger.debug(f"columns_types: {columns_types}")
_logger.debug("columns_types: %s", columns_types)
return pa.schema(fields=columns_types)


Expand All @@ -324,11 +324,11 @@ def athena_types_from_pyarrow_schema(
) -> Tuple[Dict[str, str], Optional[Dict[str, str]]]:
"""Extract the related Athena data types from any PyArrow Schema considering possible partitions."""
columns_types: Dict[str, str] = {str(f.name): pyarrow2athena(dtype=f.type) for f in schema}
_logger.debug(f"columns_types: {columns_types}")
_logger.debug("columns_types: %s", columns_types)
partitions_types: Optional[Dict[str, str]] = None
if partitions is not None:
partitions_types = {p.name: pyarrow2athena(p.dictionary.type) for p in partitions}
_logger.debug(f"partitions_types: {partitions_types}")
_logger.debug("partitions_types: %s", partitions_types)
return columns_types, partitions_types


Expand Down Expand Up @@ -382,5 +382,5 @@ def sqlalchemy_types_from_pandas(
sqlalchemy_columns_types[k] = casts[k]
else:
sqlalchemy_columns_types[k] = pyarrow2sqlalchemy(dtype=v, db_type=db_type)
_logger.debug(f"sqlalchemy_columns_types: {sqlalchemy_columns_types}")
_logger.debug("sqlalchemy_columns_types: %s", sqlalchemy_columns_types)
return sqlalchemy_columns_types
26 changes: 13 additions & 13 deletions awswrangler/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def wait_query(query_execution_id: str, boto3_session: Optional[boto3.Session] =
time.sleep(_QUERY_WAIT_POLLING_DELAY)
response = client_athena.get_query_execution(QueryExecutionId=query_execution_id)
state = response["QueryExecution"]["Status"]["State"]
_logger.debug(f"state: {state}")
_logger.debug(f"StateChangeReason: {response['QueryExecution']['Status'].get('StateChangeReason')}")
_logger.debug("state: %s", state)
_logger.debug("StateChangeReason: %s", response["QueryExecution"]["Status"].get("StateChangeReason"))
if state == "FAILED":
raise exceptions.QueryFailed(response["QueryExecution"]["Status"].get("StateChangeReason"))
if state == "CANCELLED":
Expand Down Expand Up @@ -265,7 +265,7 @@ def _get_query_metadata(
cols_types: Dict[str, str] = get_query_columns_types(
query_execution_id=query_execution_id, boto3_session=boto3_session
)
_logger.debug(f"cols_types: {cols_types}")
_logger.debug("cols_types: %s", cols_types)
dtype: Dict[str, str] = {}
parse_timestamps: List[str] = []
parse_dates: List[str] = []
Expand Down Expand Up @@ -298,11 +298,11 @@ def _get_query_metadata(
converters[col_name] = lambda x: Decimal(str(x)) if str(x) not in ("", "none", " ", "<NA>") else None
else:
dtype[col_name] = pandas_type
_logger.debug(f"dtype: {dtype}")
_logger.debug(f"parse_timestamps: {parse_timestamps}")
_logger.debug(f"parse_dates: {parse_dates}")
_logger.debug(f"converters: {converters}")
_logger.debug(f"binaries: {binaries}")
_logger.debug("dtype: %s", dtype)
_logger.debug("parse_timestamps: %s", parse_timestamps)
_logger.debug("parse_dates: %s", parse_dates)
_logger.debug("converters: %s", converters)
_logger.debug("binaries: %s", binaries)
return dtype, parse_timestamps, parse_dates, converters, binaries


Expand Down Expand Up @@ -446,7 +446,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
f") AS\n"
f"{sql}"
)
_logger.debug(f"sql: {sql}")
_logger.debug("sql: %s", sql)
query_id: str = start_query_execution(
sql=sql,
database=database,
Expand All @@ -456,7 +456,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
kms_key=kms_key,
boto3_session=session,
)
_logger.debug(f"query_id: {query_id}")
_logger.debug("query_id: %s", query_id)
query_response: Dict[str, Any] = wait_query(query_execution_id=query_id, boto3_session=session)
if query_response["QueryExecution"]["Status"]["State"] in ["FAILED", "CANCELLED"]: # pragma: no cover
reason: str = query_response["QueryExecution"]["Status"]["StateChangeReason"]
Expand All @@ -468,7 +468,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
manifest_path: str = f"{_s3_output}/tables/{query_id}-manifest.csv"
paths: List[str] = _extract_ctas_manifest_paths(path=manifest_path, boto3_session=session)
chunked: Union[bool, int] = False if chunksize is None else chunksize
_logger.debug(f"chunked: {chunked}")
_logger.debug("chunked: %s", chunked)
if not paths:
if chunked is False:
dfs = pd.DataFrame()
Expand All @@ -485,9 +485,9 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
)
path = f"{_s3_output}/{query_id}.csv"
s3.wait_objects_exist(paths=[path], use_threads=False, boto3_session=session)
_logger.debug(f"Start CSV reading from {path}")
_logger.debug("Start CSV reading from %s", path)
_chunksize: Optional[int] = chunksize if isinstance(chunksize, int) else None
_logger.debug(f"_chunksize: {_chunksize}")
_logger.debug("_chunksize: %s", _chunksize)
ret = s3.read_csv(
path=[path],
dtype=dtype,
Expand Down
10 changes: 5 additions & 5 deletions awswrangler/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def drop_duplicated_columns(df: pd.DataFrame) -> pd.DataFrame:
duplicated_cols = df.columns.duplicated()
duplicated_cols_names: List[str] = list(df.columns[duplicated_cols])
if len(duplicated_cols_names) > 0:
_logger.warning(f"Dropping repeated columns: {duplicated_cols_names}")
_logger.warning("Dropping repeated columns: %s", duplicated_cols_names)
return df.loc[:, ~duplicated_cols]


Expand Down Expand Up @@ -967,11 +967,11 @@ def _create_table(
if name in columns_comments:
par["Comment"] = columns_comments[name]
session: boto3.Session = _utils.ensure_session(session=boto3_session)

if mode == "overwrite":
exist: bool = does_table_exist(database=database, table=table, boto3_session=session)
if (mode == "overwrite") or (exist is False):
delete_table_if_exists(database=database, table=table, boto3_session=session)
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
client_glue.create_table(DatabaseName=database, TableInput=table_input)
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
client_glue.create_table(DatabaseName=database, TableInput=table_input)


def _csv_table_definition(
Expand Down
8 changes: 4 additions & 4 deletions awswrangler/cloudwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def start_query(
... )

"""
_logger.debug(f"log_group_names: {log_group_names}")
_logger.debug("log_group_names: %s", log_group_names)
start_timestamp: int = int(1000 * start_time.timestamp())
end_timestamp: int = int(1000 * end_time.timestamp())
_logger.debug(f"start_timestamp: {start_timestamp}")
_logger.debug(f"end_timestamp: {end_timestamp}")
_logger.debug("start_timestamp: %s", start_timestamp)
_logger.debug("end_timestamp: %s", end_timestamp)
args: Dict[str, Any] = {
"logGroupNames": log_group_names,
"startTime": start_timestamp,
Expand Down Expand Up @@ -109,7 +109,7 @@ def wait_query(query_id: str, boto3_session: Optional[boto3.Session] = None) ->
time.sleep(_QUERY_WAIT_POLLING_DELAY)
response = client_logs.get_query_results(queryId=query_id)
status = response["status"]
_logger.debug(f"status: {status}")
_logger.debug("status: %s", status)
if status == "Failed": # pragma: no cover
raise exceptions.QueryFailed(f"query ID: {query_id}")
if status == "Cancelled":
Expand Down
67 changes: 38 additions & 29 deletions awswrangler/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,29 +155,15 @@ def read_sql_query(
... )

"""
if not isinstance(con, sqlalchemy.engine.Engine): # pragma: no cover
raise exceptions.InvalidConnection(
"Invalid 'con' argument, please pass a "
"SQLAlchemy Engine. Use wr.db.get_engine(), "
"wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()"
)
_validate_engine(con=con)
with con.connect() as _con:
args = _convert_params(sql, params)
cursor = _con.execute(*args)
if chunksize is None:
return _records2df(records=cursor.fetchall(), cols_names=cursor.keys(), index=index_col, dtype=dtype)
return _iterate_cursor(cursor=cursor, chunksize=chunksize, index=index_col, dtype=dtype)


def _iterate_cursor(
cursor, chunksize: int, index: Optional[Union[str, List[str]]], dtype: Optional[Dict[str, pa.DataType]] = None
) -> Iterator[pd.DataFrame]:
while True:
records = cursor.fetchmany(chunksize)
if not records:
break
df: pd.DataFrame = _records2df(records=records, cols_names=cursor.keys(), index=index, dtype=dtype)
yield df
return _iterate_cursor(
cursor=cursor, chunksize=chunksize, cols_names=cursor.keys(), index=index_col, dtype=dtype
)


def _records2df(
Expand Down Expand Up @@ -207,6 +193,20 @@ def _records2df(
return df


def _iterate_cursor(
cursor: Any,
chunksize: int,
cols_names: List[str],
index: Optional[Union[str, List[str]]],
dtype: Optional[Dict[str, pa.DataType]] = None,
) -> Iterator[pd.DataFrame]:
while True:
records = cursor.fetchmany(chunksize)
if not records:
break
yield _records2df(records=records, cols_names=cols_names, index=index, dtype=dtype)


def _convert_params(sql: str, params: Optional[Union[List, Tuple, Dict]]) -> List[Any]:
args: List[Any] = [sql]
if params is not None:
Expand Down Expand Up @@ -646,7 +646,7 @@ def copy_files_to_redshift( # pylint: disable=too-many-locals,too-many-argument
athena_types, _ = s3.read_parquet_metadata(
path=paths, dataset=False, use_threads=use_threads, boto3_session=session
)
_logger.debug(f"athena_types: {athena_types}")
_logger.debug("athena_types: %s", athena_types)
redshift_types: Dict[str, str] = {}
for col_name, col_type in athena_types.items():
length: int = _varchar_lengths[col_name] if col_name in _varchar_lengths else varchar_lengths_default
Expand Down Expand Up @@ -680,7 +680,7 @@ def copy_files_to_redshift( # pylint: disable=too-many-locals,too-many-argument
def _rs_upsert(con: Any, table: str, temp_table: str, schema: str, primary_keys: Optional[List[str]] = None) -> None:
if not primary_keys:
primary_keys = _rs_get_primary_keys(con=con, schema=schema, table=table)
_logger.debug(f"primary_keys: {primary_keys}")
_logger.debug("primary_keys: %s", primary_keys)
if not primary_keys: # pragma: no cover
raise exceptions.InvalidRedshiftPrimaryKeys()
equals_clause: str = f"{table}.%s = {temp_table}.%s"
Expand Down Expand Up @@ -735,7 +735,7 @@ def _rs_create_table(
f"{distkey_str}"
f"{sortkey_str}"
)
_logger.debug(f"Create table query:\n{sql}")
_logger.debug("Create table query:\n%s", sql)
con.execute(sql)
return table, schema

Expand All @@ -746,7 +746,7 @@ def _rs_validate_parameters(
if diststyle not in _RS_DISTSTYLES:
raise exceptions.InvalidRedshiftDiststyle(f"diststyle must be in {_RS_DISTSTYLES}")
cols = list(redshift_types.keys())
_logger.debug(f"Redshift columns: {cols}")
_logger.debug("Redshift columns: %s", cols)
if (diststyle == "KEY") and (not distkey):
raise exceptions.InvalidRedshiftDistkey("You must pass a distkey if you intend to use KEY diststyle")
if distkey and distkey not in cols:
Expand Down Expand Up @@ -775,13 +775,13 @@ def _rs_copy(
sql: str = (
f"COPY {table_name} FROM '{manifest_path}'\n" f"IAM_ROLE '{iam_role}'\n" "MANIFEST\n" "FORMAT AS PARQUET"
)
_logger.debug(f"copy query:\n{sql}")
_logger.debug("copy query:\n%s", sql)
con.execute(sql)
sql = "SELECT pg_last_copy_id() AS query_id"
query_id: int = con.execute(sql).fetchall()[0][0]
sql = f"SELECT COUNT(DISTINCT filename) as num_files_loaded " f"FROM STL_LOAD_COMMITS WHERE query = {query_id}"
num_files_loaded: int = con.execute(sql).fetchall()[0][0]
_logger.debug(f"{num_files_loaded} files counted. {num_files} expected.")
_logger.debug("%s files counted. %s expected.", num_files_loaded, num_files)
if num_files_loaded != num_files: # pragma: no cover
raise exceptions.RedshiftLoadError(
f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected."
Expand Down Expand Up @@ -846,17 +846,17 @@ def write_redshift_copy_manifest(
payload: str = json.dumps(manifest)
bucket: str
bucket, key = _utils.parse_path(manifest_path)
_logger.debug(f"payload: {payload}")
_logger.debug("payload: %s", payload)
client_s3: boto3.client = _utils.client(service_name="s3", session=session)
_logger.debug(f"bucket: {bucket}")
_logger.debug(f"key: {key}")
_logger.debug("bucket: %s", bucket)
_logger.debug("key: %s", key)
client_s3.put_object(Body=payload, Bucket=bucket, Key=key)
return manifest


def _rs_drop_table(con: Any, schema: str, table: str) -> None:
sql = f"DROP TABLE IF EXISTS {schema}.{table}"
_logger.debug(f"Drop table query:\n{sql}")
_logger.debug("Drop table query:\n%s", sql)
con.execute(sql)


Expand Down Expand Up @@ -1104,5 +1104,14 @@ def unload_redshift_to_files(
query_id: int = _con.execute(sql).fetchall()[0][0]
sql = f"SELECT path FROM STL_UNLOAD_LOG WHERE query={query_id};"
paths = [x[0].replace(" ", "") for x in _con.execute(sql).fetchall()]
_logger.debug(f"paths: {paths}")
_logger.debug("paths: %s", paths)
return paths


def _validate_engine(con: sqlalchemy.engine.Engine) -> None: # pragma: no cover
if not isinstance(con, sqlalchemy.engine.Engine):
raise exceptions.InvalidConnection(
"Invalid 'con' argument, please pass a "
"SQLAlchemy Engine. Use wr.db.get_engine(), "
"wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()"
)
Loading