diff --git a/awswrangler/aurora.py b/awswrangler/aurora.py index 4e3a80a3f..8e8c9484c 100644 --- a/awswrangler/aurora.py +++ b/awswrangler/aurora.py @@ -1,6 +1,7 @@ from typing import Union, List, Dict, Tuple, Any import logging import json +import warnings import pg8000 # type: ignore import pymysql # type: ignore @@ -158,7 +159,6 @@ def load_table(dataframe: pd.DataFrame, table_name=table_name, preserve_index=preserve_index, engine=engine) - for path in load_paths: sql = Aurora._get_load_sql(path=path, schema_name=schema_name, @@ -167,22 +167,21 @@ def load_table(dataframe: pd.DataFrame, region=region) logger.debug(sql) cursor.execute(sql) - - if "mysql" in engine.lower(): - sql = ("-- AWS DATA WRANGLER\n" - f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history " - f"WHERE load_prefix = '{path}'") - logger.debug(sql) - cursor.execute(sql) - num_files_loaded = cursor.fetchall()[0][0] - if num_files_loaded != (num_files + 1): - connection.rollback() - raise AuroraLoadError( - f"Aurora load rolled back. {num_files_loaded} files counted. {num_files} expected.") - connection.commit() logger.debug("Load committed.") + if "mysql" in engine.lower(): + with connection.cursor() as cursor: + sql = ("-- AWS DATA WRANGLER\n" + f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history " + f"WHERE load_prefix = '{path}'") + logger.debug(sql) + cursor.execute(sql) + num_files_loaded = cursor.fetchall()[0][0] + if num_files_loaded != (num_files + 1): + raise AuroraLoadError( + f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.") + @staticmethod def _parse_path(path): path2 = path.replace("s3://", "") @@ -233,7 +232,14 @@ def _create_table(cursor, sql: str = f"-- AWS DATA WRANGLER\n" \ f"DROP TABLE IF EXISTS {schema_name}.{table_name}" logger.debug(f"Drop table query:\n{sql}") - cursor.execute(sql) + if "postgres" in engine.lower(): + cursor.execute(sql) + elif "mysql" in engine.lower(): + with warnings.catch_warnings(): + warnings.filterwarnings(action="ignore", message=".*Unknown table.*") + cursor.execute(sql) + else: + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") schema = Aurora._get_schema(dataframe=dataframe, dataframe_type=dataframe_type, preserve_index=preserve_index, diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index eda9569fd..162aafddb 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -1473,7 +1473,7 @@ def to_aurora(self, :param engine: "mysql" or "postgres" :param temp_s3_path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/) :param preserve_index: Should we preserve the Dataframe index? - :param mode: append, overwrite or upsert + :param mode: append or overwrite :param procs_cpu_bound: Number of cores used for CPU bound tasks :param procs_io_bound: Number of cores used for I/O bound tasks :param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index 381e80c5f..a16312555 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -1779,6 +1779,108 @@ def test_read_csv_list_iterator(bucket, sample, row_num): assert total_count == row_num * n +def test_aurora_mysql_load_append(bucket, mysql_parameters): + n: int = 10_000 + df = pd.DataFrame({"id": list((range(n))), "value": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])}) + conn = Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="mysql") + path = f"s3://{bucket}/test_aurora_mysql_load_append" + + # LOAD + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="test", + table="test_aurora_mysql_load_append", + mode="overwrite", + temp_s3_path=path) + with conn.cursor() as cursor: + cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append") + count = cursor.fetchall()[0][0] + assert count == len(df.index) + + # APPEND + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="test", + table="test_aurora_mysql_load_append", + mode="append", + temp_s3_path=path) + with conn.cursor() as cursor: + cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append") + count = cursor.fetchall()[0][0] + assert count == len(df.index) * 2 + + # RESET + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="test", + table="test_aurora_mysql_load_append", + mode="overwrite", + temp_s3_path=path) + with conn.cursor() as cursor: + cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append") + count = cursor.fetchall()[0][0] + assert count == len(df.index) + + conn.close() + + +def test_aurora_postgres_load_append(bucket, postgres_parameters): + df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"]}) + conn = Aurora.generate_connection(database="postgres", + host=postgres_parameters["PostgresAddress"], + port=3306, + user="test", + password=postgres_parameters["Password"], + engine="postgres") + path = f"s3://{bucket}/test_aurora_postgres_load_append" + + # LOAD + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="public", + table="test_aurora_postgres_load_append", + mode="overwrite", + temp_s3_path=path, + engine="postgres") + with conn.cursor() as cursor: + cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append") + count = cursor.fetchall()[0][0] + assert count == len(df.index) + + # APPEND + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="public", + table="test_aurora_postgres_load_append", + mode="append", + temp_s3_path=path, + engine="postgres") + with conn.cursor() as cursor: + cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append") + count = cursor.fetchall()[0][0] + assert count == len(df.index) * 2 + + # RESET + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="public", + table="test_aurora_postgres_load_append", + mode="overwrite", + temp_s3_path=path, + engine="postgres") + with conn.cursor() as cursor: + cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append") + count = cursor.fetchall()[0][0] + assert count == len(df.index) + + conn.close() + + def test_to_csv_metadata( session, bucket,