Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions awswrangler/aurora.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union, List, Dict, Tuple, Any
import logging
import json
import warnings

import pg8000 # type: ignore
import pymysql # type: ignore
Expand Down Expand Up @@ -158,7 +159,6 @@ def load_table(dataframe: pd.DataFrame,
table_name=table_name,
preserve_index=preserve_index,
engine=engine)

for path in load_paths:
sql = Aurora._get_load_sql(path=path,
schema_name=schema_name,
Expand All @@ -167,22 +167,21 @@ def load_table(dataframe: pd.DataFrame,
region=region)
logger.debug(sql)
cursor.execute(sql)

if "mysql" in engine.lower():
sql = ("-- AWS DATA WRANGLER\n"
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
f"WHERE load_prefix = '{path}'")
logger.debug(sql)
cursor.execute(sql)
num_files_loaded = cursor.fetchall()[0][0]
if num_files_loaded != (num_files + 1):
connection.rollback()
raise AuroraLoadError(
f"Aurora load rolled back. {num_files_loaded} files counted. {num_files} expected.")

connection.commit()
logger.debug("Load committed.")

if "mysql" in engine.lower():
with connection.cursor() as cursor:
sql = ("-- AWS DATA WRANGLER\n"
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
f"WHERE load_prefix = '{path}'")
logger.debug(sql)
cursor.execute(sql)
num_files_loaded = cursor.fetchall()[0][0]
if num_files_loaded != (num_files + 1):
raise AuroraLoadError(
f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.")

@staticmethod
def _parse_path(path):
path2 = path.replace("s3://", "")
Expand Down Expand Up @@ -233,7 +232,14 @@ def _create_table(cursor,
sql: str = f"-- AWS DATA WRANGLER\n" \
f"DROP TABLE IF EXISTS {schema_name}.{table_name}"
logger.debug(f"Drop table query:\n{sql}")
cursor.execute(sql)
if "postgres" in engine.lower():
cursor.execute(sql)
elif "mysql" in engine.lower():
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", message=".*Unknown table.*")
cursor.execute(sql)
else:
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
schema = Aurora._get_schema(dataframe=dataframe,
dataframe_type=dataframe_type,
preserve_index=preserve_index,
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ def to_aurora(self,
:param engine: "mysql" or "postgres"
:param temp_s3_path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/)
:param preserve_index: Should we preserve the Dataframe index?
:param mode: append, overwrite or upsert
:param mode: append or overwrite
:param procs_cpu_bound: Number of cores used for CPU bound tasks
:param procs_io_bound: Number of cores used for I/O bound tasks
:param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact
Expand Down
102 changes: 102 additions & 0 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,108 @@ def test_read_csv_list_iterator(bucket, sample, row_num):
assert total_count == row_num * n


def test_aurora_mysql_load_append(bucket, mysql_parameters):
n: int = 10_000
df = pd.DataFrame({"id": list((range(n))), "value": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])})
conn = Aurora.generate_connection(database="mysql",
host=mysql_parameters["MysqlAddress"],
port=3306,
user="test",
password=mysql_parameters["Password"],
engine="mysql")
path = f"s3://{bucket}/test_aurora_mysql_load_append"

# LOAD
wr.pandas.to_aurora(dataframe=df,
connection=conn,
schema="test",
table="test_aurora_mysql_load_append",
mode="overwrite",
temp_s3_path=path)
with conn.cursor() as cursor:
cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append")
count = cursor.fetchall()[0][0]
assert count == len(df.index)

# APPEND
wr.pandas.to_aurora(dataframe=df,
connection=conn,
schema="test",
table="test_aurora_mysql_load_append",
mode="append",
temp_s3_path=path)
with conn.cursor() as cursor:
cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append")
count = cursor.fetchall()[0][0]
assert count == len(df.index) * 2

# RESET
wr.pandas.to_aurora(dataframe=df,
connection=conn,
schema="test",
table="test_aurora_mysql_load_append",
mode="overwrite",
temp_s3_path=path)
with conn.cursor() as cursor:
cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append")
count = cursor.fetchall()[0][0]
assert count == len(df.index)

conn.close()


def test_aurora_postgres_load_append(bucket, postgres_parameters):
df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"]})
conn = Aurora.generate_connection(database="postgres",
host=postgres_parameters["PostgresAddress"],
port=3306,
user="test",
password=postgres_parameters["Password"],
engine="postgres")
path = f"s3://{bucket}/test_aurora_postgres_load_append"

# LOAD
wr.pandas.to_aurora(dataframe=df,
connection=conn,
schema="public",
table="test_aurora_postgres_load_append",
mode="overwrite",
temp_s3_path=path,
engine="postgres")
with conn.cursor() as cursor:
cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append")
count = cursor.fetchall()[0][0]
assert count == len(df.index)

# APPEND
wr.pandas.to_aurora(dataframe=df,
connection=conn,
schema="public",
table="test_aurora_postgres_load_append",
mode="append",
temp_s3_path=path,
engine="postgres")
with conn.cursor() as cursor:
cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append")
count = cursor.fetchall()[0][0]
assert count == len(df.index) * 2

# RESET
wr.pandas.to_aurora(dataframe=df,
connection=conn,
schema="public",
table="test_aurora_postgres_load_append",
mode="overwrite",
temp_s3_path=path,
engine="postgres")
with conn.cursor() as cursor:
cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append")
count = cursor.fetchall()[0][0]
assert count == len(df.index)

conn.close()


def test_to_csv_metadata(
session,
bucket,
Expand Down