Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions awswrangler/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, reg
"SELECT aws_s3.table_import_from_s3(\n"
f"'{schema_name}.{table_name}',\n"
"'',\n"
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\\'')',\n"
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\"'')',\n"
f"'({bucket},{key},{region})')")
elif "mysql" in engine.lower():
sql = ("-- AWS DATA WRANGLER\n"
f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
"REPLACE\n"
f"INTO TABLE {schema_name}.{table_name}\n"
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\\\\'\n"
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\"'\n"
"LINES TERMINATED BY '\\n'")
else:
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
Expand Down
196 changes: 112 additions & 84 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,17 +637,18 @@ def _apply_dates_to_generator(generator, parse_dates):
yield df

def to_csv(self,
dataframe,
path,
sep=",",
serde="OpenCSVSerDe",
dataframe: pd.DataFrame,
path: str,
sep: str = ",",
escapechar: Optional[str] = None,
serde: str = "OpenCSVSerDe",
database: Optional[str] = None,
table=None,
partition_cols=None,
preserve_index=True,
mode="append",
procs_cpu_bound=None,
procs_io_bound=None,
table: Optional[str] = None,
partition_cols: Optional[List[str]] = None,
preserve_index: bool = True,
mode: str = "append",
procs_cpu_bound: Optional[int] = None,
procs_io_bound: Optional[int] = None,
inplace=True,
description: Optional[str] = None,
parameters: Optional[Dict[str, str]] = None,
Expand All @@ -659,6 +660,7 @@ def to_csv(self,
:param dataframe: Pandas Dataframe
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
:param sep: Same as pandas.to_csv()
:param escapechar: Same as pandas.to_csv()
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe)
:param database: AWS Glue Database name
:param table: AWS Glue table name
Expand All @@ -675,7 +677,7 @@ def to_csv(self,
"""
if serde not in Pandas.VALID_CSV_SERDES:
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
extra_args = {"sep": sep, "serde": serde}
extra_args = {"sep": sep, "serde": serde, "escapechar": escapechar}
return self.to_s3(dataframe=dataframe,
path=path,
file_format="csv",
Expand Down Expand Up @@ -1041,8 +1043,13 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_
sep = extra_args.get("sep")
if sep is not None:
csv_extra_args["sep"] = sep

serde = extra_args.get("serde")
if serde is not None:
if serde is None:
escapechar = extra_args.get("escapechar")
if escapechar is not None:
csv_extra_args["escapechar"] = escapechar
else:
if serde == "OpenCSVSerDe":
csv_extra_args["quoting"] = csv.QUOTE_ALL
csv_extra_args["escapechar"] = "\\"
Expand Down Expand Up @@ -1511,7 +1518,7 @@ def to_aurora(self,
Load Pandas Dataframe as a Table on Aurora

:param dataframe: Pandas Dataframe
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
:param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
:param schema: The Redshift Schema for the table
:param table: The name of the desired Redshift table
:param engine: "mysql" or "postgres"
Expand All @@ -1523,58 +1530,66 @@ def to_aurora(self,
:param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact
:return: None
"""
if temp_s3_path is None:
if self._session.aurora_temp_s3_path is not None:
temp_s3_path = self._session.aurora_temp_s3_path
else:
guid: str = pa.compat.guid()
temp_directory = f"temp_aurora_{guid}"
temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/"
temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/"
logger.debug(f"temp_s3_path: {temp_s3_path}")

paths: List[str] = self.to_csv(dataframe=dataframe,
path=temp_s3_path,
sep=",",
preserve_index=preserve_index,
mode="overwrite",
procs_cpu_bound=procs_cpu_bound,
procs_io_bound=procs_io_bound,
inplace=inplace)

load_paths: List[str]
region: str = "us-east-1"
if "postgres" in engine.lower():
load_paths = paths.copy()
bucket, _ = Pandas._parse_path(path=load_paths[0])
region = self._session.s3.get_bucket_region(bucket=bucket)
elif "mysql" in engine.lower():
manifest_path: str = f"{temp_s3_path}manifest_{pa.compat.guid()}.json"
self._session.aurora.write_load_manifest(manifest_path=manifest_path, objects_paths=paths)
load_paths = [manifest_path]
else:
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
logger.debug(f"load_paths: {load_paths}")

Aurora.load_table(dataframe=dataframe,
dataframe_type="pandas",
load_paths=load_paths,
schema_name=schema,
table_name=table,
connection=connection,
num_files=len(paths),
mode=mode,
preserve_index=preserve_index,
engine=engine,
region=region)

if "postgres" in engine.lower():
self._session.s3.delete_listed_objects(objects_paths=load_paths, procs_io_bound=procs_io_bound)
elif "mysql" in engine.lower():
self._session.s3.delete_listed_objects(objects_paths=load_paths + [manifest_path],
procs_io_bound=procs_io_bound)
else:
if ("postgres" not in engine.lower()) and ("mysql" not in engine.lower()):
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
generated_conn: bool = False
if type(connection) == str:
logger.debug("Glue connection (str) provided.")
connection = self._session.glue.get_connection(name=connection)
generated_conn = True
try:
if temp_s3_path is None:
if self._session.aurora_temp_s3_path is not None:
temp_s3_path = self._session.aurora_temp_s3_path
else:
guid: str = pa.compat.guid()
temp_directory = f"temp_aurora_{guid}"
temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/"
temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/"
logger.debug(f"temp_s3_path: {temp_s3_path}")
paths: List[str] = self.to_csv(dataframe=dataframe,
path=temp_s3_path,
sep=",",
escapechar="\"",
preserve_index=preserve_index,
mode="overwrite",
procs_cpu_bound=procs_cpu_bound,
procs_io_bound=procs_io_bound,
inplace=inplace)
load_paths: List[str]
region: str = "us-east-1"
if "postgres" in engine.lower():
load_paths = paths.copy()
bucket, _ = Pandas._parse_path(path=load_paths[0])
region = self._session.s3.get_bucket_region(bucket=bucket)
elif "mysql" in engine.lower():
manifest_path: str = f"{temp_s3_path}manifest_{pa.compat.guid()}.json"
self._session.aurora.write_load_manifest(manifest_path=manifest_path, objects_paths=paths)
load_paths = [manifest_path]
logger.debug(f"load_paths: {load_paths}")
Aurora.load_table(dataframe=dataframe,
dataframe_type="pandas",
load_paths=load_paths,
schema_name=schema,
table_name=table,
connection=connection,
num_files=len(paths),
mode=mode,
preserve_index=preserve_index,
engine=engine,
region=region)
if "postgres" in engine.lower():
self._session.s3.delete_listed_objects(objects_paths=load_paths, procs_io_bound=procs_io_bound)
elif "mysql" in engine.lower():
self._session.s3.delete_listed_objects(objects_paths=load_paths + [manifest_path],
procs_io_bound=procs_io_bound)
except Exception as ex:
connection.rollback()
if generated_conn is True:
connection.close()
raise ex
if generated_conn is True:
connection.close()

def read_sql_aurora(self,
sql: str,
Expand All @@ -1587,7 +1602,7 @@ def read_sql_aurora(self,
Convert a query result in a Pandas Dataframe.

:param sql: SQL Query
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
:param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
:param col_names: List of column names. Default (None) is use columns IDs as column names.
:param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket)
:param engine: Only "mysql" by now
Expand All @@ -1596,25 +1611,38 @@ def read_sql_aurora(self,
"""
if "mysql" not in engine.lower():
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql'!")
guid: str = pa.compat.guid()
name: str = f"temp_aurora_{guid}"
if temp_s3_path is None:
if self._session.aurora_temp_s3_path is not None:
temp_s3_path = self._session.aurora_temp_s3_path
else:
temp_s3_path = self._session.athena.create_athena_bucket()
temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path
temp_s3_path = f"{temp_s3_path}/{name}"
logger.debug(f"temp_s3_path: {temp_s3_path}")
manifest_path: str = self._session.aurora.to_s3(sql=sql,
path=temp_s3_path,
connection=connection,
engine=engine)
paths: List[str] = self._session.aurora.extract_manifest_paths(path=manifest_path)
logger.debug(f"paths: {paths}")
ret: Union[pd.DataFrame, Iterator[pd.DataFrame]]
ret = self.read_csv_list(paths=paths, max_result_size=max_result_size, header=None, names=col_names)
self._session.s3.delete_listed_objects(objects_paths=paths + [manifest_path])
generated_conn: bool = False
if type(connection) == str:
logger.debug("Glue connection (str) provided.")
connection = self._session.glue.get_connection(name=connection)
generated_conn = True
try:
guid: str = pa.compat.guid()
name: str = f"temp_aurora_{guid}"
if temp_s3_path is None:
if self._session.aurora_temp_s3_path is not None:
temp_s3_path = self._session.aurora_temp_s3_path
else:
temp_s3_path = self._session.athena.create_athena_bucket()
temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path
temp_s3_path = f"{temp_s3_path}/{name}"
logger.debug(f"temp_s3_path: {temp_s3_path}")
manifest_path: str = self._session.aurora.to_s3(sql=sql,
path=temp_s3_path,
connection=connection,
engine=engine)
paths: List[str] = self._session.aurora.extract_manifest_paths(path=manifest_path)
logger.debug(f"paths: {paths}")
ret: Union[pd.DataFrame, Iterator[pd.DataFrame]]
ret = self.read_csv_list(paths=paths, max_result_size=max_result_size, header=None, names=col_names)
self._session.s3.delete_listed_objects(objects_paths=paths + [manifest_path])
except Exception as ex:
connection.rollback()
if generated_conn is True:
connection.close()
raise ex
if generated_conn is True:
connection.close()
return ret

def read_csv_list(
Expand Down
84 changes: 84 additions & 0 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,3 +1917,87 @@ def test_to_csv_metadata(
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))
assert len(session.glue.tables(database=database, search_text="boo bar").index) == 1
assert len(session.glue.tables(database=database, search_text="value").index) > 0


def test_aurora_postgres_load_special(bucket, postgres_parameters):
df = pd.DataFrame({
"id": [1, 2, 3, 4],
"value": ["foo", "boo", "bar", "abc"],
"special": ["\\", "\"", "\\\\\\\\", "\"\"\"\""]
})

path = f"s3://{bucket}/test_aurora_postgres_slash"
wr.pandas.to_aurora(
dataframe=df,
connection="aws-data-wrangler-postgres",
schema="public",
table="test_aurora_postgres_special",
mode="overwrite",
temp_s3_path=path,
engine="postgres",
procs_cpu_bound=4
)
conn = Aurora.generate_connection(database="postgres",
host=postgres_parameters["PostgresAddress"],
port=3306,
user="test",
password=postgres_parameters["Password"],
engine="postgres")
with conn.cursor() as cursor:
cursor.execute("SELECT * FROM public.test_aurora_postgres_special")
rows = cursor.fetchall()
assert len(rows) == len(df.index)
assert rows[0][0] == 1
assert rows[1][0] == 2
assert rows[2][0] == 3
assert rows[0][1] == "foo"
assert rows[1][1] == "boo"
assert rows[2][1] == "bar"
assert rows[3][1] == "abc"
assert rows[0][2] == "\\"
assert rows[1][2] == "\""
assert rows[2][2] == "\\\\\\\\"
assert rows[3][2] == "\"\"\"\""
conn.close()


def test_aurora_mysql_load_special(bucket, mysql_parameters):
df = pd.DataFrame({
"id": [1, 2, 3, 4],
"value": ["foo", "boo", "bar", "abc"],
"special": ["\\", "\"", "\\\\\\\\", "\"\"\"\""]
})

path = f"s3://{bucket}/test_aurora_mysql_special"
wr.pandas.to_aurora(
dataframe=df,
connection="aws-data-wrangler-mysql",
schema="test",
table="test_aurora_mysql_special",
mode="overwrite",
temp_s3_path=path,
engine="mysql",
procs_cpu_bound=1
)
conn = Aurora.generate_connection(database="mysql",
host=mysql_parameters["MysqlAddress"],
port=3306,
user="test",
password=mysql_parameters["Password"],
engine="mysql")
with conn.cursor() as cursor:
cursor.execute("SELECT * FROM test.test_aurora_mysql_special")
rows = cursor.fetchall()
assert len(rows) == len(df.index)
assert rows[0][0] == 1
assert rows[1][0] == 2
assert rows[2][0] == 3
assert rows[0][1] == "foo"
assert rows[1][1] == "boo"
assert rows[2][1] == "bar"
assert rows[3][1] == "abc"
assert rows[0][2] == "\\"
assert rows[1][2] == "\""
assert rows[2][2] == "\\\\\\\\"
assert rows[3][2] == "\"\"\"\""
conn.close()