Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 137 additions & 23 deletions awswrangler/aurora.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any
from logging import getLogger, Logger
from logging import getLogger, Logger, INFO
import json
import warnings

import pg8000 # type: ignore
from pg8000 import ProgrammingError # type: ignore
import pymysql # type: ignore
import pandas as pd # type: ignore
from boto3 import client # type: ignore
import tenacity # type: ignore

from awswrangler import data_types
from awswrangler.exceptions import InvalidEngine, InvalidDataframeType, AuroraLoadError
Expand Down Expand Up @@ -134,7 +136,7 @@ def load_table(dataframe: pd.DataFrame,
schema_name: str,
table_name: str,
connection: Any,
num_files,
num_files: int,
mode: str = "append",
preserve_index: bool = False,
engine: str = "mysql",
Expand All @@ -156,6 +158,54 @@ def load_table(dataframe: pd.DataFrame,
:param region: AWS S3 bucket region (Required only for postgres engine)
:return: None
"""
if "postgres" in engine.lower():
Aurora.load_table_postgres(dataframe=dataframe,
dataframe_type=dataframe_type,
load_paths=load_paths,
schema_name=schema_name,
table_name=table_name,
connection=connection,
mode=mode,
preserve_index=preserve_index,
region=region)
elif "mysql" in engine.lower():
Aurora.load_table_mysql(dataframe=dataframe,
dataframe_type=dataframe_type,
manifest_path=load_paths[0],
schema_name=schema_name,
table_name=table_name,
connection=connection,
mode=mode,
preserve_index=preserve_index,
num_files=num_files)
else:
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")

@staticmethod
def load_table_postgres(dataframe: pd.DataFrame,
dataframe_type: str,
load_paths: List[str],
schema_name: str,
table_name: str,
connection: Any,
mode: str = "append",
preserve_index: bool = False,
region: str = "us-east-1"):
"""
Load text/CSV files into a Aurora table using a manifest file.
Creates the table if necessary.

:param dataframe: Pandas or Spark Dataframe
:param dataframe_type: "pandas" or "spark"
:param load_paths: S3 paths to be loaded (E.g. S3://...)
:param schema_name: Aurora schema
:param table_name: Aurora table name
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
:param mode: append or overwrite
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
:param region: AWS S3 bucket region (Required only for postgres engine)
:return: None
"""
with connection.cursor() as cursor:
if mode == "overwrite":
Aurora._create_table(cursor=cursor,
Expand All @@ -164,30 +214,94 @@ def load_table(dataframe: pd.DataFrame,
schema_name=schema_name,
table_name=table_name,
preserve_index=preserve_index,
engine=engine)
for path in load_paths:
sql = Aurora._get_load_sql(path=path,
schema_name=schema_name,
table_name=table_name,
engine=engine,
region=region)
logger.debug(sql)
engine="postgres")
connection.commit()
logger.debug("CREATE TABLE committed.")
for path in load_paths:
Aurora._load_object_postgres_with_retry(connection=connection,
schema_name=schema_name,
table_name=table_name,
path=path,
region=region)

@staticmethod
@tenacity.retry(retry=tenacity.retry_if_exception_type(exception_types=ProgrammingError),
wait=tenacity.wait_random_exponential(multiplier=0.5),
stop=tenacity.stop_after_attempt(max_attempt_number=5),
reraise=True,
after=tenacity.after_log(logger, INFO))
def _load_object_postgres_with_retry(connection: Any, schema_name: str, table_name: str, path: str,
region: str) -> None:
with connection.cursor() as cursor:
sql = Aurora._get_load_sql(path=path,
schema_name=schema_name,
table_name=table_name,
engine="postgres",
region=region)
logger.debug(sql)
try:
cursor.execute(sql)
except ProgrammingError as ex:
if "The file has been modified" in str(ex):
connection.rollback()
raise ex
connection.commit()
logger.debug(f"Load committed for: {path}.")

connection.commit()
logger.debug("Load committed.")
@staticmethod
def load_table_mysql(dataframe: pd.DataFrame,
dataframe_type: str,
manifest_path: str,
schema_name: str,
table_name: str,
connection: Any,
num_files: int,
mode: str = "append",
preserve_index: bool = False):
"""
Load text/CSV files into a Aurora table using a manifest file.
Creates the table if necessary.

if "mysql" in engine.lower():
with connection.cursor() as cursor:
sql = ("-- AWS DATA WRANGLER\n"
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
f"WHERE load_prefix = '{path}'")
logger.debug(sql)
cursor.execute(sql)
num_files_loaded = cursor.fetchall()[0][0]
if num_files_loaded != (num_files + 1):
raise AuroraLoadError(
f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.")
:param dataframe: Pandas or Spark Dataframe
:param dataframe_type: "pandas" or "spark"
:param manifest_path: S3 manifest path to be loaded (E.g. S3://...)
:param schema_name: Aurora schema
:param table_name: Aurora table name
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
:param num_files: Number of files to be loaded
:param mode: append or overwrite
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
:return: None
"""
with connection.cursor() as cursor:
if mode == "overwrite":
Aurora._create_table(cursor=cursor,
dataframe=dataframe,
dataframe_type=dataframe_type,
schema_name=schema_name,
table_name=table_name,
preserve_index=preserve_index,
engine="mysql")
sql = Aurora._get_load_sql(path=manifest_path,
schema_name=schema_name,
table_name=table_name,
engine="mysql")
logger.debug(sql)
cursor.execute(sql)
logger.debug(f"Load done for: {manifest_path}")
connection.commit()
logger.debug("Load committed.")

with connection.cursor() as cursor:
sql = ("-- AWS DATA WRANGLER\n"
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
f"WHERE load_prefix = '{manifest_path}'")
logger.debug(sql)
cursor.execute(sql)
num_files_loaded = cursor.fetchall()[0][0]
if num_files_loaded != (num_files + 1):
raise AuroraLoadError(
f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.")

@staticmethod
def _parse_path(path):
Expand Down
39 changes: 31 additions & 8 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,11 @@ def _apply_dates_to_generator(generator, parse_dates):
def to_csv(self,
dataframe: pd.DataFrame,
path: str,
sep: str = ",",
sep: Optional[str] = None,
na_rep: Optional[str] = None,
quoting: Optional[int] = None,
escapechar: Optional[str] = None,
serde: str = "OpenCSVSerDe",
serde: Optional[str] = "OpenCSVSerDe",
database: Optional[str] = None,
table: Optional[str] = None,
partition_cols: Optional[List[str]] = None,
Expand All @@ -665,8 +667,10 @@ def to_csv(self,
:param dataframe: Pandas Dataframe
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
:param sep: Same as pandas.to_csv()
:param na_rep: Same as pandas.to_csv()
:param quoting: Same as pandas.to_csv()
:param escapechar: Same as pandas.to_csv()
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe)
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe) (For Athena/Glue Catalog only)
:param database: AWS Glue Database name
:param table: AWS Glue table name
:param partition_cols: List of columns names that will be partitions on S3
Expand All @@ -680,9 +684,17 @@ def to_csv(self,
:param columns_comments: Columns names and the related comments (Optional[Dict[str, str]])
:return: List of objects written on S3
"""
if serde not in Pandas.VALID_CSV_SERDES:
if (serde is not None) and (serde not in Pandas.VALID_CSV_SERDES):
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
extra_args: Dict[str, Optional[str]] = {"sep": sep, "serde": serde, "escapechar": escapechar}
if (database is not None) and (serde is None):
raise InvalidParameters(f"It is not possible write to a Glue Database without a SerDe.")
extra_args: Dict[str, Optional[Union[str, int]]] = {
"sep": sep,
"na_rep": na_rep,
"serde": serde,
"escapechar": escapechar,
"quoting": quoting
}
return self.to_s3(dataframe=dataframe,
path=path,
file_format="csv",
Expand Down Expand Up @@ -767,7 +779,7 @@ def to_s3(self,
procs_cpu_bound=None,
procs_io_bound=None,
cast_columns=None,
extra_args: Optional[Dict[str, Optional[str]]] = None,
extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None,
inplace: bool = True,
description: Optional[str] = None,
parameters: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -1053,17 +1065,24 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_

serde = extra_args.get("serde")
if serde is None:
escapechar = extra_args.get("escapechar")
escapechar: Optional[str] = extra_args.get("escapechar")
if escapechar is not None:
csv_extra_args["escapechar"] = escapechar
quoting: Optional[str] = extra_args.get("quoting")
if escapechar is not None:
csv_extra_args["quoting"] = quoting
na_rep: Optional[str] = extra_args.get("na_rep")
if na_rep is not None:
csv_extra_args["na_rep"] = na_rep
else:
if serde == "OpenCSVSerDe":
csv_extra_args["quoting"] = csv.QUOTE_ALL
csv_extra_args["escapechar"] = "\\"
elif serde == "LazySimpleSerDe":
csv_extra_args["quoting"] = csv.QUOTE_NONE
csv_extra_args["escapechar"] = "\\"
csv_buffer = bytes(
logger.debug(f"csv_extra_args: {csv_extra_args}")
csv_buffer: bytes = bytes(
dataframe.to_csv(None, header=False, index=preserve_index, compression=compression, **csv_extra_args),
"utf-8")
Pandas._write_csv_to_s3_retrying(fs=fs, path=path, buffer=csv_buffer)
Expand Down Expand Up @@ -1554,9 +1573,13 @@ def to_aurora(self,
temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/"
temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/"
logger.debug(f"temp_s3_path: {temp_s3_path}")
na_rep: str = "NULL" if "mysql" in engine.lower() else ""
paths: List[str] = self.to_csv(dataframe=dataframe,
path=temp_s3_path,
serde=None,
sep=",",
na_rep=na_rep,
quoting=csv.QUOTE_MINIMAL,
escapechar="\"",
preserve_index=preserve_index,
mode="overwrite",
Expand Down
9 changes: 7 additions & 2 deletions awswrangler/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,13 @@ def get_objects_sizes(self, objects_paths: List[str], procs_io_bound: Optional[i
receive_pipes[i].close()
return objects_sizes

def copy_listed_objects(self, objects_paths, source_path, target_path, mode="append", procs_io_bound=None):
if not procs_io_bound:
def copy_listed_objects(self,
objects_paths: List[str],
source_path: str,
target_path: str,
mode: str = "append",
procs_io_bound: Optional[int] = None):
if procs_io_bound is None:
procs_io_bound = self._session.procs_io_bound
logger.debug(f"procs_io_bound: {procs_io_bound}")
logger.debug(f"len(objects_paths): {len(objects_paths)}")
Expand Down
Loading