diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 6c9575e2f..4e15dbe71 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -523,12 +523,12 @@ def read_sql_athena(self, :param workgroup: The name of the workgroup in which the query is being started. (By default uses de Session() workgroup) :param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS' :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. - :param ctas_approach: Wraps the query with a CTAS (Session's deafult is False) + :param ctas_approach: Wraps the query with a CTAS (Session's default is False) :param procs_cpu_bound: Number of cores used for CPU bound tasks :param max_result_size: Max number of bytes on each request to S3 (VALID ONLY FOR ctas_approach=False) :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size was passed """ - ctas_approach = ctas_approach if ctas_approach is not None else self._session.ctas_approach if self._session.ctas_approach is not None else False + ctas_approach = ctas_approach if ctas_approach is not None else self._session.athena_ctas_approach if self._session.athena_ctas_approach is not None else False if ctas_approach is True and max_result_size is not None: raise InvalidParameters("ctas_approach can't use max_result_size!") if s3_output is None: @@ -1376,3 +1376,45 @@ def read_table(self, """ path: str = self._session.glue.get_table_location(database=database, table=table) return self.read_parquet(path=path, columns=columns, filters=filters, procs_cpu_bound=procs_cpu_bound) + + def read_sql_redshift(self, + sql: str, + iam_role: str, + connection: Any, + temp_s3_path: Optional[str] = None, + procs_cpu_bound: Optional[int] = None) -> pd.DataFrame: + """ + Convert a query result in a Pandas Dataframe. + + :param sql: SQL Query + :param iam_role: AWS IAM role with the related permissions + :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) + :param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket) + :param procs_cpu_bound: Number of cores used for CPU bound tasks + """ + guid: str = pa.compat.guid() + name: str = f"temp_redshift_{guid}" + if temp_s3_path is None: + if self._session.athena_s3_output is not None: + temp_s3_path = self._session.redshift_temp_s3_path + else: + temp_s3_path = self._session.athena.create_athena_bucket() + temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path + temp_s3_path = f"{temp_s3_path}/{name}" + logger.debug(f"temp_s3_path: {temp_s3_path}") + paths: Optional[List[str]] = None + try: + paths = self._session.redshift.to_parquet(sql=sql, + path=temp_s3_path, + iam_role=iam_role, + connection=connection) + logger.debug(f"paths: {paths}") + df: pd.DataFrame = self.read_parquet(path=paths, procs_cpu_bound=procs_cpu_bound) # type: ignore + self._session.s3.delete_listed_objects(objects_paths=paths) + return df + except Exception as e: + if paths is not None: + self._session.s3.delete_listed_objects(objects_paths=paths) + else: + self._session.s3.delete_objects(path=temp_s3_path) + raise e diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index d41dbb0e1..bb04f6553 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -351,7 +351,7 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c def to_parquet(sql: str, path: str, iam_role: str, - redshift_conn: Any, + connection: Any, partition_cols: Optional[List] = None) -> List[str]: """ Write a query result as parquet files on S3 @@ -359,12 +359,12 @@ def to_parquet(sql: str, :param sql: SQL Query :param path: AWS S3 path to write the data (e.g. s3://...) :param iam_role: AWS IAM role with the related permissions - :param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) + :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) :param partition_cols: Specifies the partition keys for the unload operation. """ sql = sql.replace("'", "\'").replace(";", "") # escaping single quote path = path if path[-1] == "/" else path + "/" - cursor: Any = redshift_conn.cursor() + cursor: Any = connection.cursor() partition_str: str = "" if partition_cols is not None: partition_str = f"PARTITION BY ({','.join([x for x in partition_cols])})\n" @@ -389,6 +389,6 @@ def to_parquet(sql: str, cursor.execute(query) paths: List[str] = [row[0].replace(" ", "") for row in cursor.fetchall()] logger.debug(f"paths: {paths}") - redshift_conn.commit() + connection.commit() cursor.close() return paths diff --git a/awswrangler/session.py b/awswrangler/session.py index e7faaac11..16d1536f3 100644 --- a/awswrangler/session.py +++ b/awswrangler/session.py @@ -49,7 +49,8 @@ def __init__(self, athena_encryption: Optional[str] = "SSE_S3", athena_kms_key: Optional[str] = None, athena_database: str = "default", - athena_ctas_approach: bool = False): + athena_ctas_approach: bool = False, + redshift_temp_s3_path: Optional[str] = None): """ Most parameters inherit from Boto3 or Pyspark. https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html @@ -73,6 +74,7 @@ def __init__(self, :param athena_s3_output: AWS S3 path :param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS' :param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. + :param redshift_temp_s3_path: redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) """ self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name) self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key @@ -95,6 +97,7 @@ def __init__(self, self._athena_kms_key: Optional[str] = athena_kms_key self._athena_database: str = athena_database self._athena_ctas_approach: bool = athena_ctas_approach + self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path self._primitives = None self._load_new_primitives() if boto3_session: @@ -149,7 +152,8 @@ def _load_new_primitives(self): athena_encryption=self._athena_encryption, athena_kms_key=self._athena_kms_key, athena_database=self._athena_database, - athena_ctas_approach=self._athena_ctas_approach) + athena_ctas_approach=self._athena_ctas_approach, + redshift_temp_s3_path=self._redshift_temp_s3_path) @property def profile_name(self): @@ -223,6 +227,10 @@ def athena_database(self) -> str: def athena_ctas_approach(self) -> bool: return self._athena_ctas_approach + @property + def redshift_temp_s3_path(self) -> Optional[str]: + return self._redshift_temp_s3_path + @property def boto3_session(self): return self._boto3_session @@ -304,7 +312,8 @@ def __init__(self, athena_encryption: Optional[str] = None, athena_kms_key: Optional[str] = None, athena_database: Optional[str] = None, - athena_ctas_approach: bool = False): + athena_ctas_approach: bool = False, + redshift_temp_s3_path: Optional[str] = None): """ Most parameters inherit from Boto3. https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html @@ -325,6 +334,7 @@ def __init__(self, :param athena_s3_output: AWS S3 path :param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS' :param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. + :param redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) """ self._profile_name: Optional[str] = profile_name self._aws_access_key_id: Optional[str] = aws_access_key_id @@ -342,6 +352,7 @@ def __init__(self, self._athena_kms_key: Optional[str] = athena_kms_key self._athena_database: Optional[str] = athena_database self._athena_ctas_approach: bool = athena_ctas_approach + self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path @property def profile_name(self): @@ -407,6 +418,10 @@ def athena_database(self) -> Optional[str]: def athena_ctas_approach(self) -> bool: return self._athena_ctas_approach + @property + def redshift_temp_s3_path(self) -> Optional[str]: + return self._redshift_temp_s3_path + @property def session(self): """ @@ -427,4 +442,5 @@ def session(self): athena_encryption=self._athena_encryption, athena_kms_key=self._athena_kms_key, athena_database=self._athena_database, - athena_ctas_approach=self._athena_ctas_approach) + athena_ctas_approach=self._athena_ctas_approach, + redshift_temp_s3_path=self._redshift_temp_s3_path) diff --git a/requirements.txt b/requirements.txt index 9b5c21e18..6eabe9565 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ numpy~=1.17.4 pandas~=0.25.3 pyarrow~=0.15.1 -botocore~=1.13.36 -boto3~=1.10.36 +botocore~=1.13.37 +boto3~=1.10.37 s3fs~=0.4.0 tenacity~=6.0.0 pg8000~=1.13.2 \ No newline at end of file diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index b717b636f..c15b24fdc 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -1442,6 +1442,7 @@ def test_read_table(session, bucket, database): preserve_index=False, procs_cpu_bound=1) df2 = session.pandas.read_table(database=database, table="test") + session.s3.delete_objects(path=path) assert len(list(df.columns)) == len(list(df2.columns)) assert len(df.index) == len(df2.index) @@ -1465,7 +1466,7 @@ def test_read_table2(session, bucket, database): 3)]], "partition": [0, 0, 1] }) - path = f"s3://{bucket}/test_read_table/" + path = f"s3://{bucket}/test_read_table2/" session.pandas.to_parquet(dataframe=df, database=database, table="test", @@ -1474,8 +1475,9 @@ def test_read_table2(session, bucket, database): preserve_index=False, procs_cpu_bound=4, partition_cols=["partition"]) - sleep(5) + sleep(15) df2 = session.pandas.read_table(database=database, table="test") + session.s3.delete_objects(path=path) assert len(list(df.columns)) == len(list(df2.columns)) assert len(df.index) == len(df2.index) diff --git a/testing/test_awswrangler/test_redshift.py b/testing/test_awswrangler/test_redshift.py index e227b99bb..692aac648 100644 --- a/testing/test_awswrangler/test_redshift.py +++ b/testing/test_awswrangler/test_redshift.py @@ -510,7 +510,9 @@ def test_to_redshift_spark_decimal(session, bucket, redshift_parameters): assert row[2] == Decimal((0, (1, 9, 0, 0, 0, 0), -5)) -def test_to_parquet(bucket, redshift_parameters): +def test_to_parquet(session, bucket, redshift_parameters): + n: int = 1_000_000 + df = pd.DataFrame({"id": list((range(n))), "name": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])}) con = Redshift.generate_connection( database="test", host=redshift_parameters.get("RedshiftAddress"), @@ -519,9 +521,87 @@ def test_to_parquet(bucket, redshift_parameters): password=redshift_parameters.get("RedshiftPassword"), ) path = f"s3://{bucket}/test_to_parquet/" + session.pandas.to_redshift( + dataframe=df, + path=path, + schema="public", + table="test", + connection=con, + iam_role=redshift_parameters.get("RedshiftRole"), + mode="overwrite", + preserve_index=True, + ) + path = f"s3://{bucket}/test_to_parquet2/" paths = Redshift.to_parquet(sql="SELECT * FROM public.test", path=path, iam_role=redshift_parameters.get("RedshiftRole"), - redshift_conn=con, + connection=con, partition_cols=["name"]) - assert len(paths) == 20 + assert len(paths) == 4 + + +@pytest.mark.parametrize("sample_name", ["micro", "small", "nano"]) +def test_read_sql_redshift_pandas(session, bucket, redshift_parameters, sample_name): + if sample_name == "micro": + dates = ["date"] + elif sample_name == "small": + dates = ["date"] + else: + dates = ["date", "time"] + df = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True) + df["date"] = df["date"].dt.date + con = Redshift.generate_connection( + database="test", + host=redshift_parameters.get("RedshiftAddress"), + port=redshift_parameters.get("RedshiftPort"), + user="test", + password=redshift_parameters.get("RedshiftPassword"), + ) + path = f"s3://{bucket}/test_read_sql_redshift_pandas/" + session.pandas.to_redshift( + dataframe=df, + path=path, + schema="public", + table="test", + connection=con, + iam_role=redshift_parameters.get("RedshiftRole"), + mode="overwrite", + preserve_index=True, + ) + path2 = f"s3://{bucket}/test_read_sql_redshift_pandas2/" + df2 = session.pandas.read_sql_redshift(sql="select * from public.test", + iam_role=redshift_parameters.get("RedshiftRole"), + connection=con, + temp_s3_path=path2) + assert len(df.index) == len(df2.index) + assert len(df.columns) + 1 == len(df2.columns) + + +def test_read_sql_redshift_pandas2(session, bucket, redshift_parameters): + n: int = 1_000_000 + df = pd.DataFrame({"id": list((range(n))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])}) + con = Redshift.generate_connection( + database="test", + host=redshift_parameters.get("RedshiftAddress"), + port=redshift_parameters.get("RedshiftPort"), + user="test", + password=redshift_parameters.get("RedshiftPassword"), + ) + path = f"s3://{bucket}/test_read_sql_redshift_pandas2/" + session.pandas.to_redshift( + dataframe=df, + path=path, + schema="public", + table="test", + connection=con, + iam_role=redshift_parameters.get("RedshiftRole"), + mode="overwrite", + preserve_index=True, + ) + path2 = f"s3://{bucket}/test_read_sql_redshift_pandas22/" + df2 = session.pandas.read_sql_redshift(sql="select * from public.test", + iam_role=redshift_parameters.get("RedshiftRole"), + connection=con, + temp_s3_path=path2) + assert len(df.index) == len(df2.index) + assert len(df.columns) + 1 == len(df2.columns)