Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def sqlalchemy_types_from_pandas(
df: pd.DataFrame, db_type: str, dtype: Optional[Dict[str, VisitableType]] = None
) -> Dict[str, VisitableType]:
"""Extract the related SQLAlchemy data types from any Pandas DataFrame."""
casts: Dict[str, VisitableType] = dtype if dtype else {}
casts: Dict[str, VisitableType] = dtype if dtype is not None else {}
pa_columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(
df=df, index=False, ignore_cols=list(casts.keys())
)
Expand Down
67 changes: 62 additions & 5 deletions awswrangler/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from urllib.parse import quote_plus

Expand Down Expand Up @@ -91,7 +92,16 @@ def to_sql(df: pd.DataFrame, con: sqlalchemy.engine.Engine, **pandas_kwargs) ->
)
pandas_kwargs["dtype"] = dtypes
pandas_kwargs["con"] = con
df.to_sql(**pandas_kwargs)
max_attempts: int = 3
for attempt in range(max_attempts):
try:
df.to_sql(**pandas_kwargs)
except sqlalchemy.exc.InternalError as ex: # pragma: no cover
if attempt == (max_attempts - 1):
raise ex
time.sleep(1)
else:
break


def read_sql_query(
Expand Down Expand Up @@ -887,6 +897,9 @@ def unload_redshift(
path: str,
con: sqlalchemy.engine.Engine,
iam_role: str,
region: Optional[str] = None,
max_file_size: Optional[float] = None,
kms_key_id: Optional[str] = None,
categories: List[str] = None,
chunked: Union[bool, int] = False,
keep_files: bool = False,
Expand Down Expand Up @@ -937,6 +950,19 @@ def unload_redshift(
wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
iam_role : str
AWS IAM role with the related permissions.
region : str, optional
Specifies the AWS Region where the target Amazon S3 bucket is located.
REGION is required for UNLOAD to an Amazon S3 bucket that isn't in the
same AWS Region as the Amazon Redshift cluster. By default, UNLOAD
assumes that the target Amazon S3 bucket is located in the same AWS
Region as the Amazon Redshift cluster.
max_file_size : float, optional
Specifies the maximum size (MB) of files that UNLOAD creates in Amazon S3.
Specify a decimal value between 5.0 MB and 6200.0 MB. If None, the default
maximum file size is 6200.0 MB.
kms_key_id : str, optional
Specifies the key ID for an AWS Key Management Service (AWS KMS) key to be
used to encrypt data files on Amazon S3.
categories: List[str], optional
List of columns names that should be returned as pandas.Categorical.
Recommended for memory restricted environments.
Expand Down Expand Up @@ -973,7 +999,15 @@ def unload_redshift(
"""
session: boto3.Session = _utils.ensure_session(session=boto3_session)
paths: List[str] = unload_redshift_to_files(
sql=sql, path=path, con=con, iam_role=iam_role, use_threads=use_threads, boto3_session=session
sql=sql,
path=path,
con=con,
iam_role=iam_role,
region=region,
max_file_size=max_file_size,
kms_key_id=kms_key_id,
use_threads=use_threads,
boto3_session=session,
)
s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session)
if chunked is False:
Expand Down Expand Up @@ -1032,6 +1066,9 @@ def unload_redshift_to_files(
path: str,
con: sqlalchemy.engine.Engine,
iam_role: str,
region: Optional[str] = None,
max_file_size: Optional[float] = None,
kms_key_id: Optional[str] = None,
use_threads: bool = True,
manifest: bool = False,
partition_cols: Optional[List] = None,
Expand All @@ -1056,6 +1093,19 @@ def unload_redshift_to_files(
wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
iam_role : str
AWS IAM role with the related permissions.
region : str, optional
Specifies the AWS Region where the target Amazon S3 bucket is located.
REGION is required for UNLOAD to an Amazon S3 bucket that isn't in the
same AWS Region as the Amazon Redshift cluster. By default, UNLOAD
assumes that the target Amazon S3 bucket is located in the same AWS
Region as the Amazon Redshift cluster.
max_file_size : float, optional
Specifies the maximum size (MB) of files that UNLOAD creates in Amazon S3.
Specify a decimal value between 5.0 MB and 6200.0 MB. If None, the default
maximum file size is 6200.0 MB.
kms_key_id : str, optional
Specifies the key ID for an AWS Key Management Service (AWS KMS) key to be
used to encrypt data files on Amazon S3.
use_threads : bool
True to enable concurrent requests, False to disable multiple threads.
If enabled os.cpu_count() will be used as the max number of threads.
Expand Down Expand Up @@ -1086,19 +1136,26 @@ def unload_redshift_to_files(
session: boto3.Session = _utils.ensure_session(session=boto3_session)
s3.delete_objects(path=path, use_threads=use_threads, boto3_session=session)
with con.connect() as _con:
partition_str: str = f"PARTITION BY ({','.join(partition_cols)})\n" if partition_cols else ""
partition_str: str = f"\nPARTITION BY ({','.join(partition_cols)})" if partition_cols else ""
manifest_str: str = "\nmanifest" if manifest is True else ""
region_str: str = f"\nREGION AS '{region}'" if region is not None else ""
max_file_size_str: str = f"\nMAXFILESIZE AS {max_file_size} MB" if max_file_size is not None else ""
kms_key_id_str: str = f"\nKMS_KEY_ID '{kms_key_id}'" if kms_key_id is not None else ""
sql = (
f"UNLOAD ('{sql}')\n"
f"TO '{path}'\n"
f"IAM_ROLE '{iam_role}'\n"
"ALLOWOVERWRITE\n"
"PARALLEL ON\n"
"ENCRYPTED\n"
"FORMAT PARQUET\n"
"ENCRYPTED"
f"{kms_key_id_str}"
f"{partition_str}"
"FORMAT PARQUET"
f"{region_str}"
f"{max_file_size_str}"
f"{manifest_str};"
)
_logger.debug("sql: \n%s", sql)
_con.execute(sql)
sql = "SELECT pg_last_query_id() AS query_id"
query_id: int = _con.execute(sql).fetchall()[0][0]
Expand Down
44 changes: 21 additions & 23 deletions awswrangler/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def _to_parquet_dataset(
schema: pa.Schema = _data_types.pyarrow_schema_from_pandas(
df=df, index=index, ignore_cols=partition_cols, dtype=dtype
)
_logger.debug("schema: %s", schema)
_logger.debug("schema: \n%s", schema)
if not partition_cols:
file_path: str = f"{path}{uuid.uuid4().hex}{compression_ext}.parquet"
_to_parquet_file(
Expand Down Expand Up @@ -1688,12 +1688,7 @@ def read_parquet(
data=data, columns=columns, categories=categories, use_threads=use_threads, validate_schema=validate_schema
)
return _read_parquet_chunked(
data=data,
columns=columns,
categories=categories,
chunked=chunked,
use_threads=use_threads,
validate_schema=validate_schema,
data=data, columns=columns, categories=categories, chunked=chunked, use_threads=use_threads
)


Expand Down Expand Up @@ -1728,29 +1723,32 @@ def _read_parquet_chunked(
data: pyarrow.parquet.ParquetDataset,
columns: Optional[List[str]] = None,
categories: List[str] = None,
validate_schema: bool = True,
chunked: Union[bool, int] = True,
use_threads: bool = True,
) -> Iterator[pd.DataFrame]:
promote: bool = not validate_schema
next_slice: Optional[pa.Table] = None
next_slice: Optional[pd.DataFrame] = None
for piece in data.pieces:
table: pa.Table = piece.read(
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
df: pd.DataFrame = _table2df(
table=piece.read(
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
),
categories=categories,
use_threads=use_threads,
)
if chunked is True:
yield _table2df(table=table, categories=categories, use_threads=use_threads)
yield df
else:
if next_slice:
table = pa.lib.concat_tables([next_slice, table], promote=promote)
while len(table) >= chunked:
yield _table2df(
table=table.slice(offset=0, length=chunked), categories=categories, use_threads=use_threads
)
table = table.slice(offset=chunked, length=None)
next_slice = table
if next_slice:
yield _table2df(table=next_slice, categories=categories, use_threads=use_threads)
if next_slice is not None:
df = pd.concat(objs=[next_slice, df], ignore_index=True, sort=False)
while len(df.index) >= chunked:
yield df.iloc[:chunked]
df = df.iloc[chunked:]
if df.empty:
next_slice = None
else:
next_slice = df
if next_slice is not None:
yield next_slice


def _table2df(table: pa.Table, categories: List[str] = None, use_threads: bool = True) -> pd.DataFrame:
Expand Down
9 changes: 9 additions & 0 deletions testing/cloudformation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ Resources:
PolicyDocument:
Version: 2012-10-17
Statement:
- Effect: Allow
Action:
- kms:Encrypt
- kms:Decrypt
- kms:GenerateDataKey
Resource:
- Fn::GetAtt:
- KmsKey
- Arn
- Effect: Allow
Action:
- s3:Get*
Expand Down
74 changes: 74 additions & 0 deletions testing/test_awswrangler/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def external_schema(cloudformation_outputs, parameters, glue_database):
yield "aws_data_wrangler_external"


@pytest.fixture(scope="module")
def kms_key_id(cloudformation_outputs):
yield cloudformation_outputs["KmsKeyArn"].split("/", 1)[1]


@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
def test_sql(parameters, db_type):
df = get_df()
Expand Down Expand Up @@ -386,3 +391,72 @@ def test_redshift_category(bucket, parameters):
for df2 in dfs:
ensure_data_types_category(df2)
wr.s3.delete_objects(path=path)


def test_redshift_unload_extras(bucket, parameters, kms_key_id):
table = "test_redshift_unload_extras"
schema = parameters["redshift"]["schema"]
path = f"s3://{bucket}/{table}/"
wr.s3.delete_objects(path=path)
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-redshift")
df = pd.DataFrame({"id": [1, 2], "name": ["foo", "boo"]})
wr.db.to_sql(df=df, con=engine, name=table, schema=schema, if_exists="replace", index=False)
paths = wr.db.unload_redshift_to_files(
sql=f"SELECT * FROM {schema}.{table}",
path=path,
con=engine,
iam_role=parameters["redshift"]["role"],
region=wr.s3.get_bucket_region(bucket),
max_file_size=5.0,
kms_key_id=kms_key_id,
partition_cols=["name"],
)
wr.s3.wait_objects_exist(paths=paths)
df = wr.s3.read_parquet(path=path, dataset=True)
assert len(df.index) == 2
assert len(df.columns) == 2
wr.s3.delete_objects(path=path)
df = wr.db.unload_redshift(
sql=f"SELECT * FROM {schema}.{table}",
con=engine,
iam_role=parameters["redshift"]["role"],
path=path,
keep_files=False,
region=wr.s3.get_bucket_region(bucket),
max_file_size=5.0,
kms_key_id=kms_key_id,
)
assert len(df.index) == 2
assert len(df.columns) == 2
wr.s3.delete_objects(path=path)


@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
def test_to_sql_cast(parameters, db_type):
table = "test_to_sql_cast"
schema = parameters[db_type]["schema"]
df = pd.DataFrame(
{
"col": [
"".join([str(i)[-1] for i in range(1_024)]),
"".join([str(i)[-1] for i in range(1_024)]),
"".join([str(i)[-1] for i in range(1_024)]),
]
},
dtype="string",
)
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}")
wr.db.to_sql(
df=df,
con=engine,
name=table,
schema=schema,
if_exists="replace",
index=False,
index_label=None,
chunksize=None,
method=None,
dtype={"col": sqlalchemy.types.VARCHAR(length=1_024)},
)
df2 = wr.db.read_sql_query(sql=f"SELECT * FROM {schema}.{table}", con=engine)
assert df.equals(df2)
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ deps =
moto
-rrequirements-torch.txt
commands =
pytest -n 16 testing/test_awswrangler
pytest -n 8 testing/test_awswrangler

[testenv:py36]
deps =
{[testenv]deps}
pytest-cov
commands =
pytest --cov=awswrangler -n 16 testing/test_awswrangler
pytest --cov=awswrangler -n 8 testing/test_awswrangler