diff --git a/awswrangler/aurora.py b/awswrangler/aurora.py index bc4e2047c..a7305efe7 100644 --- a/awswrangler/aurora.py +++ b/awswrangler/aurora.py @@ -1,12 +1,14 @@ from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any -from logging import getLogger, Logger +from logging import getLogger, Logger, INFO import json import warnings import pg8000 # type: ignore +from pg8000 import ProgrammingError # type: ignore import pymysql # type: ignore import pandas as pd # type: ignore from boto3 import client # type: ignore +import tenacity # type: ignore from awswrangler import data_types from awswrangler.exceptions import InvalidEngine, InvalidDataframeType, AuroraLoadError @@ -134,7 +136,7 @@ def load_table(dataframe: pd.DataFrame, schema_name: str, table_name: str, connection: Any, - num_files, + num_files: int, mode: str = "append", preserve_index: bool = False, engine: str = "mysql", @@ -156,6 +158,54 @@ def load_table(dataframe: pd.DataFrame, :param region: AWS S3 bucket region (Required only for postgres engine) :return: None """ + if "postgres" in engine.lower(): + Aurora.load_table_postgres(dataframe=dataframe, + dataframe_type=dataframe_type, + load_paths=load_paths, + schema_name=schema_name, + table_name=table_name, + connection=connection, + mode=mode, + preserve_index=preserve_index, + region=region) + elif "mysql" in engine.lower(): + Aurora.load_table_mysql(dataframe=dataframe, + dataframe_type=dataframe_type, + manifest_path=load_paths[0], + schema_name=schema_name, + table_name=table_name, + connection=connection, + mode=mode, + preserve_index=preserve_index, + num_files=num_files) + else: + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") + + @staticmethod + def load_table_postgres(dataframe: pd.DataFrame, + dataframe_type: str, + load_paths: List[str], + schema_name: str, + table_name: str, + connection: Any, + mode: str = "append", + preserve_index: bool = False, + region: str = "us-east-1"): + """ + Load text/CSV files into a Aurora table using a manifest file. + Creates the table if necessary. + + :param dataframe: Pandas or Spark Dataframe + :param dataframe_type: "pandas" or "spark" + :param load_paths: S3 paths to be loaded (E.g. S3://...) + :param schema_name: Aurora schema + :param table_name: Aurora table name + :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection()) + :param mode: append or overwrite + :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) + :param region: AWS S3 bucket region (Required only for postgres engine) + :return: None + """ with connection.cursor() as cursor: if mode == "overwrite": Aurora._create_table(cursor=cursor, @@ -164,30 +214,94 @@ def load_table(dataframe: pd.DataFrame, schema_name=schema_name, table_name=table_name, preserve_index=preserve_index, - engine=engine) - for path in load_paths: - sql = Aurora._get_load_sql(path=path, - schema_name=schema_name, - table_name=table_name, - engine=engine, - region=region) - logger.debug(sql) + engine="postgres") + connection.commit() + logger.debug("CREATE TABLE committed.") + for path in load_paths: + Aurora._load_object_postgres_with_retry(connection=connection, + schema_name=schema_name, + table_name=table_name, + path=path, + region=region) + + @staticmethod + @tenacity.retry(retry=tenacity.retry_if_exception_type(exception_types=ProgrammingError), + wait=tenacity.wait_random_exponential(multiplier=0.5), + stop=tenacity.stop_after_attempt(max_attempt_number=5), + reraise=True, + after=tenacity.after_log(logger, INFO)) + def _load_object_postgres_with_retry(connection: Any, schema_name: str, table_name: str, path: str, + region: str) -> None: + with connection.cursor() as cursor: + sql = Aurora._get_load_sql(path=path, + schema_name=schema_name, + table_name=table_name, + engine="postgres", + region=region) + logger.debug(sql) + try: cursor.execute(sql) + except ProgrammingError as ex: + if "The file has been modified" in str(ex): + connection.rollback() + raise ex + connection.commit() + logger.debug(f"Load committed for: {path}.") - connection.commit() - logger.debug("Load committed.") + @staticmethod + def load_table_mysql(dataframe: pd.DataFrame, + dataframe_type: str, + manifest_path: str, + schema_name: str, + table_name: str, + connection: Any, + num_files: int, + mode: str = "append", + preserve_index: bool = False): + """ + Load text/CSV files into a Aurora table using a manifest file. + Creates the table if necessary. - if "mysql" in engine.lower(): - with connection.cursor() as cursor: - sql = ("-- AWS DATA WRANGLER\n" - f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history " - f"WHERE load_prefix = '{path}'") - logger.debug(sql) - cursor.execute(sql) - num_files_loaded = cursor.fetchall()[0][0] - if num_files_loaded != (num_files + 1): - raise AuroraLoadError( - f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.") + :param dataframe: Pandas or Spark Dataframe + :param dataframe_type: "pandas" or "spark" + :param manifest_path: S3 manifest path to be loaded (E.g. S3://...) + :param schema_name: Aurora schema + :param table_name: Aurora table name + :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection()) + :param num_files: Number of files to be loaded + :param mode: append or overwrite + :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) + :return: None + """ + with connection.cursor() as cursor: + if mode == "overwrite": + Aurora._create_table(cursor=cursor, + dataframe=dataframe, + dataframe_type=dataframe_type, + schema_name=schema_name, + table_name=table_name, + preserve_index=preserve_index, + engine="mysql") + sql = Aurora._get_load_sql(path=manifest_path, + schema_name=schema_name, + table_name=table_name, + engine="mysql") + logger.debug(sql) + cursor.execute(sql) + logger.debug(f"Load done for: {manifest_path}") + connection.commit() + logger.debug("Load committed.") + + with connection.cursor() as cursor: + sql = ("-- AWS DATA WRANGLER\n" + f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history " + f"WHERE load_prefix = '{manifest_path}'") + logger.debug(sql) + cursor.execute(sql) + num_files_loaded = cursor.fetchall()[0][0] + if num_files_loaded != (num_files + 1): + raise AuroraLoadError( + f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.") @staticmethod def _parse_path(path): diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index c998111ef..b9093676b 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -644,9 +644,11 @@ def _apply_dates_to_generator(generator, parse_dates): def to_csv(self, dataframe: pd.DataFrame, path: str, - sep: str = ",", + sep: Optional[str] = None, + na_rep: Optional[str] = None, + quoting: Optional[int] = None, escapechar: Optional[str] = None, - serde: str = "OpenCSVSerDe", + serde: Optional[str] = "OpenCSVSerDe", database: Optional[str] = None, table: Optional[str] = None, partition_cols: Optional[List[str]] = None, @@ -665,8 +667,10 @@ def to_csv(self, :param dataframe: Pandas Dataframe :param path: AWS S3 path (E.g. s3://bucket-name/folder_name/ :param sep: Same as pandas.to_csv() + :param na_rep: Same as pandas.to_csv() + :param quoting: Same as pandas.to_csv() :param escapechar: Same as pandas.to_csv() - :param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe) + :param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe) (For Athena/Glue Catalog only) :param database: AWS Glue Database name :param table: AWS Glue table name :param partition_cols: List of columns names that will be partitions on S3 @@ -680,9 +684,17 @@ def to_csv(self, :param columns_comments: Columns names and the related comments (Optional[Dict[str, str]]) :return: List of objects written on S3 """ - if serde not in Pandas.VALID_CSV_SERDES: + if (serde is not None) and (serde not in Pandas.VALID_CSV_SERDES): raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})") - extra_args: Dict[str, Optional[str]] = {"sep": sep, "serde": serde, "escapechar": escapechar} + if (database is not None) and (serde is None): + raise InvalidParameters(f"It is not possible write to a Glue Database without a SerDe.") + extra_args: Dict[str, Optional[Union[str, int]]] = { + "sep": sep, + "na_rep": na_rep, + "serde": serde, + "escapechar": escapechar, + "quoting": quoting + } return self.to_s3(dataframe=dataframe, path=path, file_format="csv", @@ -767,7 +779,7 @@ def to_s3(self, procs_cpu_bound=None, procs_io_bound=None, cast_columns=None, - extra_args: Optional[Dict[str, Optional[str]]] = None, + extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None, inplace: bool = True, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, @@ -1053,9 +1065,15 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_ serde = extra_args.get("serde") if serde is None: - escapechar = extra_args.get("escapechar") + escapechar: Optional[str] = extra_args.get("escapechar") if escapechar is not None: csv_extra_args["escapechar"] = escapechar + quoting: Optional[str] = extra_args.get("quoting") + if escapechar is not None: + csv_extra_args["quoting"] = quoting + na_rep: Optional[str] = extra_args.get("na_rep") + if na_rep is not None: + csv_extra_args["na_rep"] = na_rep else: if serde == "OpenCSVSerDe": csv_extra_args["quoting"] = csv.QUOTE_ALL @@ -1063,7 +1081,8 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_ elif serde == "LazySimpleSerDe": csv_extra_args["quoting"] = csv.QUOTE_NONE csv_extra_args["escapechar"] = "\\" - csv_buffer = bytes( + logger.debug(f"csv_extra_args: {csv_extra_args}") + csv_buffer: bytes = bytes( dataframe.to_csv(None, header=False, index=preserve_index, compression=compression, **csv_extra_args), "utf-8") Pandas._write_csv_to_s3_retrying(fs=fs, path=path, buffer=csv_buffer) @@ -1554,9 +1573,13 @@ def to_aurora(self, temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/" temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/" logger.debug(f"temp_s3_path: {temp_s3_path}") + na_rep: str = "NULL" if "mysql" in engine.lower() else "" paths: List[str] = self.to_csv(dataframe=dataframe, path=temp_s3_path, + serde=None, sep=",", + na_rep=na_rep, + quoting=csv.QUOTE_MINIMAL, escapechar="\"", preserve_index=preserve_index, mode="overwrite", diff --git a/awswrangler/s3.py b/awswrangler/s3.py index 3741a2cf4..2a5bbfb8e 100644 --- a/awswrangler/s3.py +++ b/awswrangler/s3.py @@ -308,8 +308,13 @@ def get_objects_sizes(self, objects_paths: List[str], procs_io_bound: Optional[i receive_pipes[i].close() return objects_sizes - def copy_listed_objects(self, objects_paths, source_path, target_path, mode="append", procs_io_bound=None): - if not procs_io_bound: + def copy_listed_objects(self, + objects_paths: List[str], + source_path: str, + target_path: str, + mode: str = "append", + procs_io_bound: Optional[int] = None): + if procs_io_bound is None: procs_io_bound = self._session.procs_io_bound logger.debug(f"procs_io_bound: {procs_io_bound}") logger.debug(f"len(objects_paths): {len(objects_paths)}") diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index 544a787e1..47a984ed3 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -1417,7 +1417,9 @@ def test_read_parquet_dataset(session, bucket): preserve_index=False, procs_cpu_bound=4, partition_cols=["partition"]) + sleep(15) df2 = session.pandas.read_parquet(path=path) + wr.s3.delete_objects(path=path) assert len(list(df.columns)) == len(list(df2.columns)) assert len(df.index) == len(df2.index) @@ -1935,7 +1937,7 @@ def test_aurora_postgres_load_special(bucket, postgres_parameters): Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 0), -2)), - Decimal((0, (3, 1, 2), -2)) + None ] }) @@ -1976,7 +1978,7 @@ def test_aurora_postgres_load_special(bucket, postgres_parameters): assert rows[0][4] == Decimal((0, (1, 9, 9), -2)) assert rows[1][4] == Decimal((0, (1, 9, 9), -2)) assert rows[2][4] == Decimal((0, (1, 9, 0), -2)) - assert rows[3][4] == Decimal((0, (3, 1, 2), -2)) + assert rows[3][4] is None conn.close() @@ -1990,7 +1992,7 @@ def test_aurora_mysql_load_special(bucket, mysql_parameters): Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 0), -2)), - Decimal((0, (3, 1, 2), -2)) + None ] }) @@ -2002,7 +2004,7 @@ def test_aurora_mysql_load_special(bucket, mysql_parameters): mode="overwrite", temp_s3_path=path, engine="mysql", - procs_cpu_bound=1) + procs_cpu_bound=4) conn = Aurora.generate_connection(database="mysql", host=mysql_parameters["MysqlAddress"], port=3306, @@ -2031,7 +2033,7 @@ def test_aurora_mysql_load_special(bucket, mysql_parameters): assert rows[0][4] == Decimal((0, (1, 9, 9), -2)) assert rows[1][4] == Decimal((0, (1, 9, 9), -2)) assert rows[2][4] == Decimal((0, (1, 9, 0), -2)) - assert rows[3][4] == Decimal((0, (3, 1, 2), -2)) + assert rows[3][4] is None conn.close() @@ -2068,3 +2070,112 @@ def test_read_sql_athena_empty(ctas_approach): """ df = wr.pandas.read_sql_athena(sql=sql, ctas_approach=ctas_approach) print(df) + + +def test_aurora_postgres_load_special2(bucket, postgres_parameters): + dt = lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f") # noqa + df = pd.DataFrame({ + "integer1": [0, 1, np.NaN, 3], + "integer2": [8986, 9735, 9918, 9150], + "string1": ["O", "P", "P", "O"], + "string2": ["050100", "010101", "010101", "050100"], + "string3": ["A", "R", "A", "R"], + "string4": ["SGD", "SGD", "SGD", "SGD"], + "float1": [0.0, 1800000.0, np.NaN, 0.0], + "string5": ["0000296722", "0000199396", "0000298592", "0000196380"], + "string6": [None, "C", "C", None], + "timestamp1": + [dt("2020-01-07 00:00:00.000"), None, + dt("2020-01-07 00:00:00.000"), + dt("2020-01-07 00:00:00.000")], + "string7": ["XXX", "XXX", "XXX", "XXX"], + "timestamp2": [ + dt("2020-01-10 10:34:55.863"), + dt("2020-01-10 10:34:55.864"), + dt("2020-01-10 10:34:55.865"), + dt("2020-01-10 10:34:55.866") + ], + }) + df = pd.concat([df for _ in range(10_000)]) + path = f"s3://{bucket}/test_aurora_postgres_special" + wr.pandas.to_aurora(dataframe=df, + connection="aws-data-wrangler-postgres", + schema="public", + table="test_aurora_postgres_load_special2", + mode="overwrite", + temp_s3_path=path, + engine="postgres") + conn = Aurora.generate_connection(database="postgres", + host=postgres_parameters["PostgresAddress"], + port=3306, + user="test", + password=postgres_parameters["Password"], + engine="postgres") + with conn.cursor() as cursor: + cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_special2") + assert cursor.fetchall()[0][0] == len(df.index) + cursor.execute("SELECT timestamp2 FROM public.test_aurora_postgres_load_special2 limit 4") + rows = cursor.fetchall() + assert rows[0][0] == dt("2020-01-10 10:34:55.863") + assert rows[1][0] == dt("2020-01-10 10:34:55.864") + assert rows[2][0] == dt("2020-01-10 10:34:55.865") + assert rows[3][0] == dt("2020-01-10 10:34:55.866") + cursor.execute( + "SELECT integer1, float1, string6, timestamp1 FROM public.test_aurora_postgres_load_special2 limit 4") + rows = cursor.fetchall() + assert rows[2][0] is None + assert rows[2][1] is None + assert rows[0][2] is None + assert rows[1][3] is None + conn.close() + + +def test_aurora_mysql_load_special2(bucket, mysql_parameters): + dt = lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f") # noqa + df = pd.DataFrame({ + "integer1": [0, 1, np.NaN, 3], + "integer2": [8986, 9735, 9918, 9150], + "string1": ["O", "P", "P", "O"], + "string2": ["050100", "010101", "010101", "050100"], + "string3": ["A", "R", "A", "R"], + "string4": ["SGD", "SGD", "SGD", "SGD"], + "float1": [0.0, 1800000.0, np.NaN, 0.0], + "string5": ["0000296722", "0000199396", "0000298592", "0000196380"], + "string6": [None, "C", "C", None], + "timestamp1": + [dt("2020-01-07 00:00:00.000"), None, + dt("2020-01-07 00:00:00.000"), + dt("2020-01-07 00:00:00.000")], + "string7": ["XXX", "XXX", "XXX", "XXX"], + "timestamp2": [ + dt("2020-01-10 10:34:55.863"), + dt("2020-01-10 10:34:55.864"), + dt("2020-01-10 10:34:55.865"), + dt("2020-01-10 10:34:55.866") + ], + }) + df = pd.concat([df for _ in range(10_000)]) + path = f"s3://{bucket}/test_aurora_mysql_load_special2" + wr.pandas.to_aurora(dataframe=df, + connection="aws-data-wrangler-mysql", + schema="test", + table="test_aurora_mysql_load_special2", + mode="overwrite", + temp_s3_path=path, + engine="mysql") + conn = Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="mysql") + with conn.cursor() as cursor: + cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_special2") + assert cursor.fetchall()[0][0] == len(df.index) + cursor.execute("SELECT integer1, float1, string6, timestamp1 FROM test.test_aurora_mysql_load_special2 limit 4") + rows = cursor.fetchall() + assert rows[2][0] is None + assert rows[2][1] is None + assert rows[0][2] is None + assert rows[1][3] is None + conn.close()