Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,12 @@ def read_sql_athena(self,
:param workgroup: The name of the workgroup in which the query is being started. (By default uses de Session() workgroup)
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
:param ctas_approach: Wraps the query with a CTAS (Session's deafult is False)
:param ctas_approach: Wraps the query with a CTAS (Session's default is False)
:param procs_cpu_bound: Number of cores used for CPU bound tasks
:param max_result_size: Max number of bytes on each request to S3 (VALID ONLY FOR ctas_approach=False)
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size was passed
"""
ctas_approach = ctas_approach if ctas_approach is not None else self._session.ctas_approach if self._session.ctas_approach is not None else False
ctas_approach = ctas_approach if ctas_approach is not None else self._session.athena_ctas_approach if self._session.athena_ctas_approach is not None else False
if ctas_approach is True and max_result_size is not None:
raise InvalidParameters("ctas_approach can't use max_result_size!")
if s3_output is None:
Expand Down Expand Up @@ -1376,3 +1376,45 @@ def read_table(self,
"""
path: str = self._session.glue.get_table_location(database=database, table=table)
return self.read_parquet(path=path, columns=columns, filters=filters, procs_cpu_bound=procs_cpu_bound)

def read_sql_redshift(self,
sql: str,
iam_role: str,
connection: Any,
temp_s3_path: Optional[str] = None,
procs_cpu_bound: Optional[int] = None) -> pd.DataFrame:
"""
Convert a query result in a Pandas Dataframe.

:param sql: SQL Query
:param iam_role: AWS IAM role with the related permissions
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
:param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket)
:param procs_cpu_bound: Number of cores used for CPU bound tasks
"""
guid: str = pa.compat.guid()
name: str = f"temp_redshift_{guid}"
if temp_s3_path is None:
if self._session.athena_s3_output is not None:
temp_s3_path = self._session.redshift_temp_s3_path
else:
temp_s3_path = self._session.athena.create_athena_bucket()
temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path
temp_s3_path = f"{temp_s3_path}/{name}"
logger.debug(f"temp_s3_path: {temp_s3_path}")
paths: Optional[List[str]] = None
try:
paths = self._session.redshift.to_parquet(sql=sql,
path=temp_s3_path,
iam_role=iam_role,
connection=connection)
logger.debug(f"paths: {paths}")
df: pd.DataFrame = self.read_parquet(path=paths, procs_cpu_bound=procs_cpu_bound) # type: ignore
self._session.s3.delete_listed_objects(objects_paths=paths)
return df
except Exception as e:
if paths is not None:
self._session.s3.delete_listed_objects(objects_paths=paths)
else:
self._session.s3.delete_objects(path=temp_s3_path)
raise e
8 changes: 4 additions & 4 deletions awswrangler/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,20 +351,20 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c
def to_parquet(sql: str,
path: str,
iam_role: str,
redshift_conn: Any,
connection: Any,
partition_cols: Optional[List] = None) -> List[str]:
"""
Write a query result as parquet files on S3

:param sql: SQL Query
:param path: AWS S3 path to write the data (e.g. s3://...)
:param iam_role: AWS IAM role with the related permissions
:param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
:param partition_cols: Specifies the partition keys for the unload operation.
"""
sql = sql.replace("'", "\'").replace(";", "") # escaping single quote
path = path if path[-1] == "/" else path + "/"
cursor: Any = redshift_conn.cursor()
cursor: Any = connection.cursor()
partition_str: str = ""
if partition_cols is not None:
partition_str = f"PARTITION BY ({','.join([x for x in partition_cols])})\n"
Expand All @@ -389,6 +389,6 @@ def to_parquet(sql: str,
cursor.execute(query)
paths: List[str] = [row[0].replace(" ", "") for row in cursor.fetchall()]
logger.debug(f"paths: {paths}")
redshift_conn.commit()
connection.commit()
cursor.close()
return paths
24 changes: 20 additions & 4 deletions awswrangler/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self,
athena_encryption: Optional[str] = "SSE_S3",
athena_kms_key: Optional[str] = None,
athena_database: str = "default",
athena_ctas_approach: bool = False):
athena_ctas_approach: bool = False,
redshift_temp_s3_path: Optional[str] = None):
"""
Most parameters inherit from Boto3 or Pyspark.
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
Expand All @@ -73,6 +74,7 @@ def __init__(self,
:param athena_s3_output: AWS S3 path
:param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
:param redshift_temp_s3_path: redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...)
"""
self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name)
self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key
Expand All @@ -95,6 +97,7 @@ def __init__(self,
self._athena_kms_key: Optional[str] = athena_kms_key
self._athena_database: str = athena_database
self._athena_ctas_approach: bool = athena_ctas_approach
self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path
self._primitives = None
self._load_new_primitives()
if boto3_session:
Expand Down Expand Up @@ -149,7 +152,8 @@ def _load_new_primitives(self):
athena_encryption=self._athena_encryption,
athena_kms_key=self._athena_kms_key,
athena_database=self._athena_database,
athena_ctas_approach=self._athena_ctas_approach)
athena_ctas_approach=self._athena_ctas_approach,
redshift_temp_s3_path=self._redshift_temp_s3_path)

@property
def profile_name(self):
Expand Down Expand Up @@ -223,6 +227,10 @@ def athena_database(self) -> str:
def athena_ctas_approach(self) -> bool:
return self._athena_ctas_approach

@property
def redshift_temp_s3_path(self) -> Optional[str]:
return self._redshift_temp_s3_path

@property
def boto3_session(self):
return self._boto3_session
Expand Down Expand Up @@ -304,7 +312,8 @@ def __init__(self,
athena_encryption: Optional[str] = None,
athena_kms_key: Optional[str] = None,
athena_database: Optional[str] = None,
athena_ctas_approach: bool = False):
athena_ctas_approach: bool = False,
redshift_temp_s3_path: Optional[str] = None):
"""
Most parameters inherit from Boto3.
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
Expand All @@ -325,6 +334,7 @@ def __init__(self,
:param athena_s3_output: AWS S3 path
:param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
:param redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...)
"""
self._profile_name: Optional[str] = profile_name
self._aws_access_key_id: Optional[str] = aws_access_key_id
Expand All @@ -342,6 +352,7 @@ def __init__(self,
self._athena_kms_key: Optional[str] = athena_kms_key
self._athena_database: Optional[str] = athena_database
self._athena_ctas_approach: bool = athena_ctas_approach
self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path

@property
def profile_name(self):
Expand Down Expand Up @@ -407,6 +418,10 @@ def athena_database(self) -> Optional[str]:
def athena_ctas_approach(self) -> bool:
return self._athena_ctas_approach

@property
def redshift_temp_s3_path(self) -> Optional[str]:
return self._redshift_temp_s3_path

@property
def session(self):
"""
Expand All @@ -427,4 +442,5 @@ def session(self):
athena_encryption=self._athena_encryption,
athena_kms_key=self._athena_kms_key,
athena_database=self._athena_database,
athena_ctas_approach=self._athena_ctas_approach)
athena_ctas_approach=self._athena_ctas_approach,
redshift_temp_s3_path=self._redshift_temp_s3_path)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
numpy~=1.17.4
pandas~=0.25.3
pyarrow~=0.15.1
botocore~=1.13.36
boto3~=1.10.36
botocore~=1.13.37
boto3~=1.10.37
s3fs~=0.4.0
tenacity~=6.0.0
pg8000~=1.13.2
6 changes: 4 additions & 2 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,7 @@ def test_read_table(session, bucket, database):
preserve_index=False,
procs_cpu_bound=1)
df2 = session.pandas.read_table(database=database, table="test")
session.s3.delete_objects(path=path)
assert len(list(df.columns)) == len(list(df2.columns))
assert len(df.index) == len(df2.index)

Expand All @@ -1465,7 +1466,7 @@ def test_read_table2(session, bucket, database):
3)]],
"partition": [0, 0, 1]
})
path = f"s3://{bucket}/test_read_table/"
path = f"s3://{bucket}/test_read_table2/"
session.pandas.to_parquet(dataframe=df,
database=database,
table="test",
Expand All @@ -1474,8 +1475,9 @@ def test_read_table2(session, bucket, database):
preserve_index=False,
procs_cpu_bound=4,
partition_cols=["partition"])
sleep(5)
sleep(15)
df2 = session.pandas.read_table(database=database, table="test")
session.s3.delete_objects(path=path)
assert len(list(df.columns)) == len(list(df2.columns))
assert len(df.index) == len(df2.index)

Expand Down
86 changes: 83 additions & 3 deletions testing/test_awswrangler/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,9 @@ def test_to_redshift_spark_decimal(session, bucket, redshift_parameters):
assert row[2] == Decimal((0, (1, 9, 0, 0, 0, 0), -5))


def test_to_parquet(bucket, redshift_parameters):
def test_to_parquet(session, bucket, redshift_parameters):
n: int = 1_000_000
df = pd.DataFrame({"id": list((range(n))), "name": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])})
con = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
Expand All @@ -519,9 +521,87 @@ def test_to_parquet(bucket, redshift_parameters):
password=redshift_parameters.get("RedshiftPassword"),
)
path = f"s3://{bucket}/test_to_parquet/"
session.pandas.to_redshift(
dataframe=df,
path=path,
schema="public",
table="test",
connection=con,
iam_role=redshift_parameters.get("RedshiftRole"),
mode="overwrite",
preserve_index=True,
)
path = f"s3://{bucket}/test_to_parquet2/"
paths = Redshift.to_parquet(sql="SELECT * FROM public.test",
path=path,
iam_role=redshift_parameters.get("RedshiftRole"),
redshift_conn=con,
connection=con,
partition_cols=["name"])
assert len(paths) == 20
assert len(paths) == 4


@pytest.mark.parametrize("sample_name", ["micro", "small", "nano"])
def test_read_sql_redshift_pandas(session, bucket, redshift_parameters, sample_name):
if sample_name == "micro":
dates = ["date"]
elif sample_name == "small":
dates = ["date"]
else:
dates = ["date", "time"]
df = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
df["date"] = df["date"].dt.date
con = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
port=redshift_parameters.get("RedshiftPort"),
user="test",
password=redshift_parameters.get("RedshiftPassword"),
)
path = f"s3://{bucket}/test_read_sql_redshift_pandas/"
session.pandas.to_redshift(
dataframe=df,
path=path,
schema="public",
table="test",
connection=con,
iam_role=redshift_parameters.get("RedshiftRole"),
mode="overwrite",
preserve_index=True,
)
path2 = f"s3://{bucket}/test_read_sql_redshift_pandas2/"
df2 = session.pandas.read_sql_redshift(sql="select * from public.test",
iam_role=redshift_parameters.get("RedshiftRole"),
connection=con,
temp_s3_path=path2)
assert len(df.index) == len(df2.index)
assert len(df.columns) + 1 == len(df2.columns)


def test_read_sql_redshift_pandas2(session, bucket, redshift_parameters):
n: int = 1_000_000
df = pd.DataFrame({"id": list((range(n))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])})
con = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
port=redshift_parameters.get("RedshiftPort"),
user="test",
password=redshift_parameters.get("RedshiftPassword"),
)
path = f"s3://{bucket}/test_read_sql_redshift_pandas2/"
session.pandas.to_redshift(
dataframe=df,
path=path,
schema="public",
table="test",
connection=con,
iam_role=redshift_parameters.get("RedshiftRole"),
mode="overwrite",
preserve_index=True,
)
path2 = f"s3://{bucket}/test_read_sql_redshift_pandas22/"
df2 = session.pandas.read_sql_redshift(sql="select * from public.test",
iam_role=redshift_parameters.get("RedshiftRole"),
connection=con,
temp_s3_path=path2)
assert len(df.index) == len(df2.index)
assert len(df.columns) + 1 == len(df2.columns)