diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 53b0b4159..323fb6216 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -869,7 +869,7 @@ def unload( sql: str, path: str, con: redshift_connector.Connection, - iam_role: Optional[str], + iam_role: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, diff --git a/tests/test_redshift.py b/tests/test_redshift.py index 23d23aa6f..340171f23 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -10,6 +10,7 @@ import redshift_connector import awswrangler as wr +from awswrangler import _utils from ._utils import dt, ensure_data_types, ensure_data_types_category, get_df, get_df_category, ts @@ -776,3 +777,97 @@ def test_connect_secret_manager(dbname): df = wr.redshift.read_sql_query("SELECT 1", con=con) con.close() assert df.shape == (1, 1) + + +def test_copy_unload_session(path, redshift_table): + df = pd.DataFrame({"col0": [1, 2, 3]}) + con = wr.redshift.connect("aws-data-wrangler-redshift") + wr.redshift.copy(df=df, path=path, con=con, schema="public", table=redshift_table, mode="overwrite") + df2 = wr.redshift.unload( + sql=f"SELECT * FROM public.{redshift_table}", + con=con, + path=path, + keep_files=False, + ) + assert df2.shape == (3, 1) + wr.redshift.copy( + df=df, + path=path, + con=con, + schema="public", + table=redshift_table, + mode="append", + ) + df2 = wr.redshift.unload( + sql=f"SELECT * FROM public.{redshift_table}", + con=con, + path=path, + keep_files=False, + ) + assert df2.shape == (6, 1) + dfs = wr.redshift.unload( + sql=f"SELECT * FROM public.{redshift_table}", + con=con, + path=path, + keep_files=False, + chunked=True, + ) + for chunk in dfs: + assert len(chunk.columns) == 1 + con.close() + + +def test_copy_unload_creds(path, redshift_table): + credentials = _utils.get_credentials_from_session() + df = pd.DataFrame({"col0": [1, 2, 3]}) + con = wr.redshift.connect("aws-data-wrangler-redshift") + wr.redshift.copy( + df=df, + path=path, + con=con, + schema="public", + table=redshift_table, + mode="overwrite", + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + ) + df2 = wr.redshift.unload( + sql=f"SELECT * FROM public.{redshift_table}", + con=con, + path=path, + keep_files=False, + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + ) + assert df2.shape == (3, 1) + wr.redshift.copy( + df=df, + path=path, + con=con, + schema="public", + table=redshift_table, + mode="append", + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + ) + df2 = wr.redshift.unload( + sql=f"SELECT * FROM public.{redshift_table}", + con=con, + path=path, + keep_files=False, + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + ) + assert df2.shape == (6, 1) + dfs = wr.redshift.unload( + sql=f"SELECT * FROM public.{redshift_table}", + con=con, + path=path, + keep_files=False, + chunked=True, + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + ) + for chunk in dfs: + assert len(chunk.columns) == 1 + con.close()