diff --git a/README.md b/README.md index 8d63a8749..9fd8867b1 100644 --- a/README.md +++ b/README.md @@ -27,13 +27,14 @@ * Pandas -> Glue Catalog Table * Pandas -> Athena (Parallel) * Pandas -> Redshift (Append/Overwrite/Upsert) (Parallel) -* Parquet (S3) -> Pandas (Parallel) (NEW :star:) +* Pandas -> Aurora (MySQL/PostgreSQL) (Append/Overwrite) (Via S3) (NEW :star:) +* Parquet (S3) -> Pandas (Parallel) * CSV (S3) -> Pandas (One shot or Batching) -* Glue Catalog Table -> Pandas (Parallel) (NEW :star:) -* Athena -> Pandas (One shot, Batching or Parallel (NEW :star:)) -* Redshift -> Pandas (Parallel) (NEW :star:) -* Redshift -> Parquet (S3) (NEW :star:) +* Glue Catalog Table -> Pandas (Parallel) +* Athena -> Pandas (One shot, Batching or Parallel) +* Redshift -> Pandas (Parallel) * CloudWatch Logs Insights -> Pandas +* Aurora -> Pandas (MySQL) (Via S3) (NEW :star:) * Encrypt Pandas Dataframes on S3 with KMS keys ### PySpark @@ -60,6 +61,8 @@ * Get EMR step state * Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*) * Load and Unzip SageMaker jobs outputs +* Redshift -> Parquet (S3) +* Aurora -> CSV (S3) (MySQL) (NEW :star:) ## Installation @@ -147,6 +150,22 @@ df = sess.pandas.read_sql_athena( ) ``` +#### Reading from Glue Catalog (Parquet) to Pandas + +```py3 +import awswrangler as wr + +df = wr.pandas.read_table(database="DATABASE_NAME", table="TABLE_NAME") +``` + +#### Reading from S3 (Parquet) to Pandas + +```py3 +import awswrangler as wr + +df = wr.pandas.read_parquet(path="s3://...", columns=["c1", "c3"], filters=[("c5", "=", 0)]) +``` + #### Reading from S3 (CSV) to Pandas ```py3 @@ -227,6 +246,30 @@ df = wr.pandas.read_sql_redshift( temp_s3_path="s3://temp_path") ``` +#### Loading Pandas Dataframe to Aurora (MySQL/PostgreSQL) + +```py3 +import awswrangler as wr + +wr.pandas.to_aurora( + dataframe=df, + connection=con, + schema="...", + table="..." +) +``` + +#### Extract Aurora query to Pandas DataFrame (MySQL) + +```py3 +import awswrangler as wr + +df = wr.pandas.read_sql_aurora( + sql="SELECT ...", + connection=con +) +``` + ### PySpark #### Loading PySpark Dataframe to Redshift diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index e07a214da..fce152e1b 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -9,6 +9,7 @@ from awswrangler.cloudwatchlogs import CloudWatchLogs # noqa from awswrangler.glue import Glue # noqa from awswrangler.redshift import Redshift # noqa +from awswrangler.aurora import Aurora # noqa from awswrangler.emr import EMR # noqa from awswrangler.sagemaker import SageMaker # noqa import awswrangler.utils # noqa @@ -38,6 +39,7 @@ def __getattr__(self, name): pandas = DynamicInstantiate("pandas") athena = DynamicInstantiate("athena") redshift = DynamicInstantiate("redshift") +aurora = DynamicInstantiate("aurora") sagemaker = DynamicInstantiate("sagemaker") cloudwatchlogs = DynamicInstantiate("cloudwatchlogs") diff --git a/awswrangler/aurora.py b/awswrangler/aurora.py new file mode 100644 index 000000000..4e3a80a3f --- /dev/null +++ b/awswrangler/aurora.py @@ -0,0 +1,297 @@ +from typing import Union, List, Dict, Tuple, Any +import logging +import json + +import pg8000 # type: ignore +import pymysql # type: ignore +import pandas as pd # type: ignore + +from awswrangler import data_types +from awswrangler.exceptions import InvalidEngine, InvalidDataframeType, AuroraLoadError + +logger = logging.getLogger(__name__) + + +class Aurora: + def __init__(self, session): + self._session = session + self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config) + + @staticmethod + def _validate_connection(database: str, + host: str, + port: Union[str, int], + user: str, + password: str, + engine: str = "mysql", + tcp_keepalive: bool = True, + application_name: str = "aws-data-wrangler-validation", + validation_timeout: int = 5) -> None: + if "postgres" in engine.lower(): + conn = pg8000.connect(database=database, + host=host, + port=int(port), + user=user, + password=password, + ssl=True, + application_name=application_name, + tcp_keepalive=tcp_keepalive, + timeout=validation_timeout) + elif "mysql" in engine.lower(): + conn = pymysql.connect(database=database, + host=host, + port=int(port), + user=user, + password=password, + program_name=application_name, + connect_timeout=validation_timeout) + else: + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") + conn.close() + + @staticmethod + def generate_connection(database: str, + host: str, + port: Union[str, int], + user: str, + password: str, + engine: str = "mysql", + tcp_keepalive: bool = True, + application_name: str = "aws-data-wrangler", + connection_timeout: int = 1_200_000, + validation_timeout: int = 5): + """ + Generates a valid connection object + + :param database: The name of the database instance to connect with. + :param host: The hostname of the Aurora server to connect with. + :param port: The TCP/IP port of the Aurora server instance. + :param user: The username to connect to the Aurora database with. + :param password: The user password to connect to the server with. + :param engine: "mysql" or "postgres" + :param tcp_keepalive: If True then use TCP keepalive + :param application_name: Application name + :param connection_timeout: Connection Timeout + :param validation_timeout: Timeout to try to validate the connection + :return: PEP 249 compatible connection + """ + Aurora._validate_connection(database=database, + host=host, + port=port, + user=user, + password=password, + engine=engine, + tcp_keepalive=tcp_keepalive, + application_name=application_name, + validation_timeout=validation_timeout) + if "postgres" in engine.lower(): + conn = pg8000.connect(database=database, + host=host, + port=int(port), + user=user, + password=password, + ssl=True, + application_name=application_name, + tcp_keepalive=tcp_keepalive, + timeout=connection_timeout) + elif "mysql" in engine.lower(): + conn = pymysql.connect(database=database, + host=host, + port=int(port), + user=user, + password=password, + program_name=application_name, + connect_timeout=connection_timeout) + else: + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") + return conn + + def write_load_manifest(self, manifest_path: str, + objects_paths: List[str]) -> Dict[str, List[Dict[str, Union[str, bool]]]]: + manifest: Dict[str, List[Dict[str, Union[str, bool]]]] = {"entries": []} + path: str + for path in objects_paths: + entry: Dict[str, Union[str, bool]] = {"url": path, "mandatory": True} + manifest["entries"].append(entry) + payload: str = json.dumps(manifest) + bucket: str + bucket, path = manifest_path.replace("s3://", "").split("/", 1) + logger.info(f"payload: {payload}") + self._client_s3.put_object(Body=payload, Bucket=bucket, Key=path) + return manifest + + @staticmethod + def load_table(dataframe: pd.DataFrame, + dataframe_type: str, + load_paths: List[str], + schema_name: str, + table_name: str, + connection: Any, + num_files, + mode: str = "append", + preserve_index: bool = False, + engine: str = "mysql", + region: str = "us-east-1"): + """ + Load text/CSV files into a Aurora table using a manifest file. + Creates the table if necessary. + + :param dataframe: Pandas or Spark Dataframe + :param dataframe_type: "pandas" or "spark" + :param load_paths: S3 paths to be loaded (E.g. S3://...) + :param schema_name: Aurora schema + :param table_name: Aurora table name + :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection()) + :param num_files: Number of files to be loaded + :param mode: append or overwrite + :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) + :param engine: "mysql" or "postgres" + :param region: AWS S3 bucket region (Required only for postgres engine) + :return: None + """ + with connection.cursor() as cursor: + if mode == "overwrite": + Aurora._create_table(cursor=cursor, + dataframe=dataframe, + dataframe_type=dataframe_type, + schema_name=schema_name, + table_name=table_name, + preserve_index=preserve_index, + engine=engine) + + for path in load_paths: + sql = Aurora._get_load_sql(path=path, + schema_name=schema_name, + table_name=table_name, + engine=engine, + region=region) + logger.debug(sql) + cursor.execute(sql) + + if "mysql" in engine.lower(): + sql = ("-- AWS DATA WRANGLER\n" + f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history " + f"WHERE load_prefix = '{path}'") + logger.debug(sql) + cursor.execute(sql) + num_files_loaded = cursor.fetchall()[0][0] + if num_files_loaded != (num_files + 1): + connection.rollback() + raise AuroraLoadError( + f"Aurora load rolled back. {num_files_loaded} files counted. {num_files} expected.") + + connection.commit() + logger.debug("Load committed.") + + @staticmethod + def _parse_path(path): + path2 = path.replace("s3://", "") + parts = path2.partition("/") + return parts[0], parts[2] + + @staticmethod + def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, region: str = "us-east-1") -> str: + if "postgres" in engine.lower(): + bucket, key = Aurora._parse_path(path=path) + sql: str = ("-- AWS DATA WRANGLER\n" + "SELECT aws_s3.table_import_from_s3(\n" + f"'{schema_name}.{table_name}',\n" + "'',\n" + "'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\\'')',\n" + f"'({bucket},{key},{region})')") + elif "mysql" in engine.lower(): + sql = ("-- AWS DATA WRANGLER\n" + f"LOAD DATA FROM S3 MANIFEST '{path}'\n" + "REPLACE\n" + f"INTO TABLE {schema_name}.{table_name}\n" + "FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\\\\'\n" + "LINES TERMINATED BY '\\n'") + else: + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") + return sql + + @staticmethod + def _create_table(cursor, + dataframe, + dataframe_type, + schema_name, + table_name, + preserve_index=False, + engine: str = "mysql"): + """ + Creates Aurora table. + + :param cursor: A PEP 249 compatible cursor + :param dataframe: Pandas or Spark Dataframe + :param dataframe_type: "pandas" or "spark" + :param schema_name: Redshift schema + :param table_name: Redshift table name + :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) + :param engine: "mysql" or "postgres" + :return: None + """ + sql: str = f"-- AWS DATA WRANGLER\n" \ + f"DROP TABLE IF EXISTS {schema_name}.{table_name}" + logger.debug(f"Drop table query:\n{sql}") + cursor.execute(sql) + schema = Aurora._get_schema(dataframe=dataframe, + dataframe_type=dataframe_type, + preserve_index=preserve_index, + engine=engine) + cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2] + sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})" + logger.debug(f"Create table query:\n{sql}") + cursor.execute(sql) + + @staticmethod + def _get_schema(dataframe, + dataframe_type: str, + preserve_index: bool, + engine: str = "mysql") -> List[Tuple[str, str]]: + schema_built: List[Tuple[str, str]] = [] + if "postgres" in engine.lower(): + convert_func = data_types.pyarrow2postgres + elif "mysql" in engine.lower(): + convert_func = data_types.pyarrow2mysql + else: + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") + if dataframe_type.lower() == "pandas": + pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas( + dataframe=dataframe, preserve_index=preserve_index, indexes_position="right") + for name, dtype in pyarrow_schema: + aurora_type: str = convert_func(dtype) + schema_built.append((name, aurora_type)) + else: + raise InvalidDataframeType(f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!") + return schema_built + + def to_s3(self, sql: str, path: str, connection: Any, engine: str = "mysql") -> str: + """ + Write a query result on S3 + + :param sql: SQL Query + :param path: AWS S3 path to write the data (e.g. s3://...) + :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) + :param engine: Only "mysql" by now + :return: Manifest S3 path + """ + if "mysql" not in engine.lower(): + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql'!") + path = path[-1] if path[-1] == "/" else path + self._session.s3.delete_objects(path=path) + sql = f"{sql}\n" \ + f"INTO OUTFILE S3 '{path}'\n" \ + "FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\\\\'\n" \ + "LINES TERMINATED BY '\\n'\n" \ + "MANIFEST ON\n" \ + "OVERWRITE ON" + with connection.cursor() as cursor: + logger.debug(sql) + cursor.execute(sql) + connection.commit() + return path + ".manifest" + + def extract_manifest_paths(self, path: str) -> List[str]: + bucket_name, key_path = Aurora._parse_path(path) + body: bytes = self._client_s3.get_object(Bucket=bucket_name, Key=key_path)["Body"].read() + return [x["url"] for x in json.loads(body.decode('utf-8'))["entries"]] diff --git a/awswrangler/data_types.py b/awswrangler/data_types.py index 1fdd02034..b8eaf9fed 100644 --- a/awswrangler/data_types.py +++ b/awswrangler/data_types.py @@ -203,6 +203,58 @@ def pyarrow2redshift(dtype: pa.types) -> str: raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}") +def pyarrow2postgres(dtype: pa.types) -> str: + dtype_str = str(dtype).lower() + if dtype_str == "int16": + return "SMALLINT" + elif dtype_str == "int32": + return "INT" + elif dtype_str == "int64": + return "BIGINT" + elif dtype_str == "float": + return "FLOAT4" + elif dtype_str == "double": + return "FLOAT8" + elif dtype_str == "bool": + return "BOOLEAN" + elif dtype_str == "string": + return "VARCHAR(256)" + elif dtype_str.startswith("timestamp"): + return "TIMESTAMP" + elif dtype_str.startswith("date"): + return "DATE" + elif dtype_str.startswith("decimal"): + return dtype_str.replace(" ", "").upper() + else: + raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}") + + +def pyarrow2mysql(dtype: pa.types) -> str: + dtype_str = str(dtype).lower() + if dtype_str == "int16": + return "SMALLINT" + elif dtype_str == "int32": + return "INT" + elif dtype_str == "int64": + return "BIGINT" + elif dtype_str == "float": + return "FLOAT" + elif dtype_str == "double": + return "DOUBLE" + elif dtype_str == "bool": + return "BOOLEAN" + elif dtype_str == "string": + return "VARCHAR(256)" + elif dtype_str.startswith("timestamp"): + return "TIMESTAMP" + elif dtype_str.startswith("date"): + return "DATE" + elif dtype_str.startswith("decimal"): + return dtype_str.replace(" ", "").upper() + else: + raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}") + + def python2athena(python_type: type) -> str: python_type_str: str = str(python_type) if python_type_str == "": diff --git a/awswrangler/exceptions.py b/awswrangler/exceptions.py index d66dd8ee0..a9bf91d4f 100644 --- a/awswrangler/exceptions.py +++ b/awswrangler/exceptions.py @@ -26,6 +26,10 @@ class RedshiftLoadError(Exception): pass +class AuroraLoadError(Exception): + pass + + class AthenaQueryError(Exception): pass @@ -96,3 +100,7 @@ class InvalidParameters(Exception): class AWSCredentialsNotFound(Exception): pass + + +class InvalidEngine(Exception): + pass diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 92e0a6ff0..6224d8c63 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Optional, Any, Union +from typing import Dict, List, Tuple, Optional, Any, Union, Iterator from io import BytesIO, StringIO import multiprocessing as mp import logging @@ -19,10 +19,11 @@ from awswrangler import data_types from awswrangler.exceptions import (UnsupportedWriteMode, UnsupportedFileFormat, AthenaQueryError, EmptyS3Object, LineTerminatorNotFound, EmptyDataframe, InvalidSerDe, InvalidCompression, - InvalidParameters) + InvalidParameters, InvalidEngine) from awswrangler.utils import calculate_bounders from awswrangler import s3 from awswrangler.athena import Athena +from awswrangler.aurora import Aurora logger = logging.getLogger(__name__) @@ -41,6 +42,7 @@ class Pandas: def __init__(self, session): self._session = session + self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config) @staticmethod def _parse_path(path): @@ -94,53 +96,47 @@ def read_csv( :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None """ bucket_name, key_path = self._parse_path(path) - client_s3 = self._session.boto3_session.client(service_name="s3", - use_ssl=True, - config=self._session.botocore_config) - if max_result_size: - ret = Pandas._read_csv_iterator(client_s3=client_s3, - bucket_name=bucket_name, - key_path=key_path, - max_result_size=max_result_size, - header=header, - names=names, - usecols=usecols, - dtype=dtype, - sep=sep, - thousands=thousands, - decimal=decimal, - lineterminator=lineterminator, - quotechar=quotechar, - quoting=quoting, - escapechar=escapechar, - parse_dates=parse_dates, - infer_datetime_format=infer_datetime_format, - encoding=encoding, - converters=converters) + if max_result_size is not None: + ret = self._read_csv_iterator(bucket_name=bucket_name, + key_path=key_path, + max_result_size=max_result_size, + header=header, + names=names, + usecols=usecols, + dtype=dtype, + sep=sep, + thousands=thousands, + decimal=decimal, + lineterminator=lineterminator, + quotechar=quotechar, + quoting=quoting, + escapechar=escapechar, + parse_dates=parse_dates, + infer_datetime_format=infer_datetime_format, + encoding=encoding, + converters=converters) else: - ret = Pandas._read_csv_once(client_s3=client_s3, - bucket_name=bucket_name, - key_path=key_path, - header=header, - names=names, - usecols=usecols, - dtype=dtype, - sep=sep, - thousands=thousands, - decimal=decimal, - lineterminator=lineterminator, - quotechar=quotechar, - quoting=quoting, - escapechar=escapechar, - parse_dates=parse_dates, - infer_datetime_format=infer_datetime_format, - encoding=encoding, - converters=converters) + ret = self._read_csv_once(bucket_name=bucket_name, + key_path=key_path, + header=header, + names=names, + usecols=usecols, + dtype=dtype, + sep=sep, + thousands=thousands, + decimal=decimal, + lineterminator=lineterminator, + quotechar=quotechar, + quoting=quoting, + escapechar=escapechar, + parse_dates=parse_dates, + infer_datetime_format=infer_datetime_format, + encoding=encoding, + converters=converters) return ret - @staticmethod def _read_csv_iterator( - client_s3, + self, bucket_name, key_path, max_result_size=200_000_000, # 200 MB @@ -165,7 +161,6 @@ def _read_csv_iterator( Try to mimic as most as possible pandas.read_csv() https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html - :param client_s3: Boto3 S3 client object :param bucket_name: S3 bucket name :param key_path: S3 key path (W/o bucket) :param max_result_size: Max number of bytes on each request to S3 @@ -186,37 +181,35 @@ def _read_csv_iterator( :param converters: Same as pandas.read_csv() :return: Pandas Dataframe """ - metadata = s3.S3.head_object_with_retry(client=client_s3, bucket=bucket_name, key=key_path) - logger.debug(f"metadata: {metadata}") + metadata = s3.S3.head_object_with_retry(client=self._client_s3, bucket=bucket_name, key=key_path) total_size = metadata["ContentLength"] logger.debug(f"total_size: {total_size}") if total_size <= 0: raise EmptyS3Object(metadata) elif total_size <= max_result_size: - yield Pandas._read_csv_once(client_s3=client_s3, - bucket_name=bucket_name, - key_path=key_path, - header=header, - names=names, - usecols=usecols, - dtype=dtype, - sep=sep, - thousands=thousands, - decimal=decimal, - lineterminator=lineterminator, - quotechar=quotechar, - quoting=quoting, - escapechar=escapechar, - parse_dates=parse_dates, - infer_datetime_format=infer_datetime_format, - encoding=encoding, - converters=converters) + yield self._read_csv_once(bucket_name=bucket_name, + key_path=key_path, + header=header, + names=names, + usecols=usecols, + dtype=dtype, + sep=sep, + thousands=thousands, + decimal=decimal, + lineterminator=lineterminator, + quotechar=quotechar, + quoting=quoting, + escapechar=escapechar, + parse_dates=parse_dates, + infer_datetime_format=infer_datetime_format, + encoding=encoding, + converters=converters) else: bounders = calculate_bounders(num_items=total_size, max_size=max_result_size) logger.debug(f"bounders: {bounders}") - bounders_len = len(bounders) - count = 0 - forgotten_bytes = 0 + bounders_len: int = len(bounders) + count: int = 0 + forgotten_bytes: int = 0 for ini, end in bounders: count += 1 @@ -224,7 +217,7 @@ def _read_csv_iterator( end -= 1 # Range is inclusive, contrary from Python's List bytes_range = "bytes={}-{}".format(ini, end) logger.debug(f"bytes_range: {bytes_range}") - body = client_s3.get_object(Bucket=bucket_name, Key=key_path, Range=bytes_range)["Body"].read() + body = self._client_s3.get_object(Bucket=bucket_name, Key=key_path, Range=bytes_range)["Body"].read() chunk_size = len(body) logger.debug(f"chunk_size (bytes): {chunk_size}") @@ -351,9 +344,8 @@ def _find_terminator(body, sep, quoting, quotechar, lineterminator): raise LineTerminatorNotFound() return index - @staticmethod def _read_csv_once( - client_s3, + self, bucket_name, key_path, header="infer", @@ -377,7 +369,6 @@ def _read_csv_once( Try to mimic as most as possible pandas.read_csv() https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html - :param client_s3: Boto3 S3 client object :param bucket_name: S3 bucket name :param key_path: S3 key path (W/o bucket) :param header: Same as pandas.read_csv() @@ -398,7 +389,7 @@ def _read_csv_once( :return: Pandas Dataframe """ buff = BytesIO() - client_s3.download_fileobj(Bucket=bucket_name, Key=key_path, Fileobj=buff) + self._client_s3.download_fileobj(Bucket=bucket_name, Key=key_path, Fileobj=buff) buff.seek(0), dataframe = pd.read_csv( buff, @@ -834,9 +825,9 @@ def data_to_s3(self, procs_io_bound=None, cast_columns=None, extra_args=None): - if not procs_cpu_bound: + if procs_cpu_bound is None: procs_cpu_bound = self._session.procs_cpu_bound - if not procs_io_bound: + if procs_io_bound is None: procs_io_bound = self._session.procs_io_bound logger.debug(f"procs_cpu_bound: {procs_cpu_bound}") logger.debug(f"procs_io_bound: {procs_io_bound}") @@ -1220,7 +1211,7 @@ def drop_duplicated_columns(dataframe: pd.DataFrame, inplace: bool = True) -> pd def read_parquet(self, path: Union[str, List[str]], columns: Optional[List[str]] = None, - filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None, + filters: Optional[Union[List[Tuple[Any]], List[List[Tuple[Any]]]]] = None, procs_cpu_bound: Optional[int] = None) -> pd.DataFrame: """ Read parquet data from S3 @@ -1283,7 +1274,7 @@ def _read_parquet_paths_remote(send_pipe: mp.connection.Connection, session_primitives: Any, path: Union[str, List[str]], columns: Optional[List[str]] = None, - filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None, + filters: Optional[Union[List[Tuple[Any]], List[List[Tuple[Any]]]]] = None, procs_cpu_bound: Optional[int] = None): df: pd.DataFrame = Pandas._read_parquet_paths(session_primitives=session_primitives, path=path, @@ -1297,7 +1288,7 @@ def _read_parquet_paths_remote(send_pipe: mp.connection.Connection, def _read_parquet_paths(session_primitives: Any, path: Union[str, List[str]], columns: Optional[List[str]] = None, - filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None, + filters: Optional[Union[List[Tuple[Any]], List[List[Tuple[Any]]]]] = None, procs_cpu_bound: Optional[int] = None) -> pd.DataFrame: """ Read parquet data from S3 @@ -1336,7 +1327,7 @@ def _read_parquet_paths(session_primitives: Any, def _read_parquet_path(session_primitives: Any, path: str, columns: Optional[List[str]] = None, - filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None, + filters: Optional[Union[List[Tuple[Any]], List[List[Tuple[Any]]]]] = None, procs_cpu_bound: Optional[int] = None) -> pd.DataFrame: """ Read parquet data from S3 @@ -1378,7 +1369,7 @@ def read_table(self, database: str, table: str, columns: Optional[List[str]] = None, - filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None, + filters: Optional[Union[List[Tuple[Any]], List[List[Tuple[Any]]]]] = None, procs_cpu_bound: Optional[int] = None) -> pd.DataFrame: """ Read PARQUET table from S3 using the Glue Catalog location skipping Athena's necessity @@ -1410,13 +1401,14 @@ def read_sql_redshift(self, guid: str = pa.compat.guid() name: str = f"temp_redshift_{guid}" if temp_s3_path is None: - if self._session.athena_s3_output is not None: + if self._session.redshift_temp_s3_path is not None: temp_s3_path = self._session.redshift_temp_s3_path else: temp_s3_path = self._session.athena.create_athena_bucket() temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path temp_s3_path = f"{temp_s3_path}/{name}" logger.debug(f"temp_s3_path: {temp_s3_path}") + self._session.s3.delete_objects(path=temp_s3_path) paths: Optional[List[str]] = None try: paths = self._session.redshift.to_parquet(sql=sql, @@ -1425,11 +1417,285 @@ def read_sql_redshift(self, connection=connection) logger.debug(f"paths: {paths}") df: pd.DataFrame = self.read_parquet(path=paths, procs_cpu_bound=procs_cpu_bound) # type: ignore - self._session.s3.delete_listed_objects(objects_paths=paths) + self._session.s3.delete_listed_objects(objects_paths=paths + [temp_s3_path + "/manifest"]) # type: ignore return df except Exception as e: if paths is not None: - self._session.s3.delete_listed_objects(objects_paths=paths) + self._session.s3.delete_listed_objects(objects_paths=paths + [temp_s3_path + "/manifest"]) else: self._session.s3.delete_objects(path=temp_s3_path) raise e + + def to_aurora(self, + dataframe: pd.DataFrame, + connection: Any, + schema: str, + table: str, + engine: str = "mysql", + temp_s3_path: Optional[str] = None, + preserve_index: bool = False, + mode: str = "append", + procs_cpu_bound: Optional[int] = None, + procs_io_bound: Optional[int] = None, + inplace=True) -> None: + """ + Load Pandas Dataframe as a Table on Aurora + + :param dataframe: Pandas Dataframe + :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) + :param schema: The Redshift Schema for the table + :param table: The name of the desired Redshift table + :param engine: "mysql" or "postgres" + :param temp_s3_path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/) + :param preserve_index: Should we preserve the Dataframe index? + :param mode: append, overwrite or upsert + :param procs_cpu_bound: Number of cores used for CPU bound tasks + :param procs_io_bound: Number of cores used for I/O bound tasks + :param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact + :return: None + """ + if temp_s3_path is None: + if self._session.aurora_temp_s3_path is not None: + temp_s3_path = self._session.aurora_temp_s3_path + else: + guid: str = pa.compat.guid() + temp_directory = f"temp_aurora_{guid}" + temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/" + temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/" + logger.debug(f"temp_s3_path: {temp_s3_path}") + + paths: List[str] = self.to_csv(dataframe=dataframe, + path=temp_s3_path, + sep=",", + preserve_index=preserve_index, + mode="overwrite", + procs_cpu_bound=procs_cpu_bound, + procs_io_bound=procs_io_bound, + inplace=inplace) + + load_paths: List[str] + region: str = "us-east-1" + if "postgres" in engine.lower(): + load_paths = paths.copy() + bucket, _ = Pandas._parse_path(path=load_paths[0]) + region = self._session.s3.get_bucket_region(bucket=bucket) + elif "mysql" in engine.lower(): + manifest_path: str = f"{temp_s3_path}manifest_{pa.compat.guid()}.json" + self._session.aurora.write_load_manifest(manifest_path=manifest_path, objects_paths=paths) + load_paths = [manifest_path] + else: + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") + logger.debug(f"load_paths: {load_paths}") + + Aurora.load_table(dataframe=dataframe, + dataframe_type="pandas", + load_paths=load_paths, + schema_name=schema, + table_name=table, + connection=connection, + num_files=len(paths), + mode=mode, + preserve_index=preserve_index, + engine=engine, + region=region) + + self._session.s3.delete_objects(path=temp_s3_path, procs_io_bound=procs_io_bound) + + def read_sql_aurora(self, + sql: str, + connection: Any, + col_names: Optional[List[str]] = None, + temp_s3_path: Optional[str] = None, + engine: str = "mysql", + max_result_size: Optional[int] = None) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """ + Convert a query result in a Pandas Dataframe. + + :param sql: SQL Query + :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection()) + :param col_names: List of column names. Default (None) is use columns IDs as column names. + :param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket) + :param engine: Only "mysql" by now + :param max_result_size: Max number of bytes on each request to S3 + :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None + """ + if "mysql" not in engine.lower(): + raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql'!") + guid: str = pa.compat.guid() + name: str = f"temp_aurora_{guid}" + if temp_s3_path is None: + if self._session.aurora_temp_s3_path is not None: + temp_s3_path = self._session.aurora_temp_s3_path + else: + temp_s3_path = self._session.athena.create_athena_bucket() + temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path + temp_s3_path = f"{temp_s3_path}/{name}" + logger.debug(f"temp_s3_path: {temp_s3_path}") + manifest_path: str = self._session.aurora.to_s3(sql=sql, + path=temp_s3_path, + connection=connection, + engine=engine) + paths: List[str] = self._session.aurora.extract_manifest_paths(path=manifest_path) + logger.debug(f"paths: {paths}") + ret: Union[pd.DataFrame, Iterator[pd.DataFrame]] + ret = self.read_csv_list(paths=paths, max_result_size=max_result_size, header=None, names=col_names) + self._session.s3.delete_listed_objects(objects_paths=paths + [manifest_path]) + return ret + + def read_csv_list( + self, + paths, + max_result_size=None, + header: Optional[str] = "infer", + names=None, + usecols=None, + dtype=None, + sep=",", + thousands=None, + decimal=".", + lineterminator="\n", + quotechar='"', + quoting=csv.QUOTE_MINIMAL, + escapechar=None, + parse_dates: Union[bool, Dict, List] = False, + infer_datetime_format=False, + encoding="utf-8", + converters=None, + ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """ + Read CSV files from AWS S3 using optimized strategies. + Try to mimic as most as possible pandas.read_csv() + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html + P.S. max_result_size != None tries to mimic the chunksize behaviour in pandas.read_sql() + + :param paths: AWS S3 paths (E.g. S3://BUCKET_NAME/KEY_NAME) + :param max_result_size: Max number of bytes on each request to S3 + :param header: Same as pandas.read_csv() + :param names: Same as pandas.read_csv() + :param usecols: Same as pandas.read_csv() + :param dtype: Same as pandas.read_csv() + :param sep: Same as pandas.read_csv() + :param thousands: Same as pandas.read_csv() + :param decimal: Same as pandas.read_csv() + :param lineterminator: Same as pandas.read_csv() + :param quotechar: Same as pandas.read_csv() + :param quoting: Same as pandas.read_csv() + :param escapechar: Same as pandas.read_csv() + :param parse_dates: Same as pandas.read_csv() + :param infer_datetime_format: Same as pandas.read_csv() + :param encoding: Same as pandas.read_csv() + :param converters: Same as pandas.read_csv() + :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None + """ + if max_result_size is not None: + return self._read_csv_list_iterator(paths=paths, + max_result_size=max_result_size, + header=header, + names=names, + usecols=usecols, + dtype=dtype, + sep=sep, + thousands=thousands, + decimal=decimal, + lineterminator=lineterminator, + quotechar=quotechar, + quoting=quoting, + escapechar=escapechar, + parse_dates=parse_dates, + infer_datetime_format=infer_datetime_format, + encoding=encoding, + converters=converters) + else: + df_full: Optional[pd.DataFrame] = None + for path in paths: + logger.debug(f"path: {path}") + bucket_name, key_path = Pandas._parse_path(path) + df = self._read_csv_once(bucket_name=bucket_name, + key_path=key_path, + header=header, + names=names, + usecols=usecols, + dtype=dtype, + sep=sep, + thousands=thousands, + decimal=decimal, + lineterminator=lineterminator, + quotechar=quotechar, + quoting=quoting, + escapechar=escapechar, + parse_dates=parse_dates, + infer_datetime_format=infer_datetime_format, + encoding=encoding, + converters=converters) + if df_full is None: + df_full = df + else: + df_full = pd.concat(objs=[df_full, df], ignore_index=True) + return df_full + + def _read_csv_list_iterator( + self, + paths, + max_result_size=None, + header="infer", + names=None, + usecols=None, + dtype=None, + sep=",", + thousands=None, + decimal=".", + lineterminator="\n", + quotechar='"', + quoting=csv.QUOTE_MINIMAL, + escapechar=None, + parse_dates: Union[bool, Dict, List] = False, + infer_datetime_format=False, + encoding="utf-8", + converters=None, + ): + """ + Read CSV files from AWS S3 using optimized strategies. + Try to mimic as most as possible pandas.read_csv() + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html + P.S. Try to mimic the chunksize behaviour in pandas.read_sql() + + :param paths: AWS S3 paths (E.g. S3://BUCKET_NAME/KEY_NAME) + :param max_result_size: Max number of bytes on each request to S3 + :param header: Same as pandas.read_csv() + :param names: Same as pandas.read_csv() + :param usecols: Same as pandas.read_csv() + :param dtype: Same as pandas.read_csv() + :param sep: Same as pandas.read_csv() + :param thousands: Same as pandas.read_csv() + :param decimal: Same as pandas.read_csv() + :param lineterminator: Same as pandas.read_csv() + :param quotechar: Same as pandas.read_csv() + :param quoting: Same as pandas.read_csv() + :param escapechar: Same as pandas.read_csv() + :param parse_dates: Same as pandas.read_csv() + :param infer_datetime_format: Same as pandas.read_csv() + :param encoding: Same as pandas.read_csv() + :param converters: Same as pandas.read_csv() + :return: Iterator of iterators of Pandas Dataframes + """ + for path in paths: + logger.debug(f"path: {path}") + bucket_name, key_path = Pandas._parse_path(path) + yield from self._read_csv_iterator(bucket_name=bucket_name, + key_path=key_path, + max_result_size=max_result_size, + header=header, + names=names, + usecols=usecols, + dtype=dtype, + sep=sep, + thousands=thousands, + decimal=decimal, + lineterminator=lineterminator, + quotechar=quotechar, + quoting=quoting, + escapechar=escapechar, + parse_dates=parse_dates, + infer_datetime_format=infer_datetime_format, + encoding=encoding, + converters=converters) diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 1019c01e9..abc730cda 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union, Optional, Any +from typing import Dict, List, Union, Optional, Any, Tuple import json import logging @@ -28,6 +28,7 @@ class Redshift: def __init__(self, session): self._session = session + self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config) @staticmethod def _validate_connection(database, @@ -38,19 +39,16 @@ def _validate_connection(database, tcp_keepalive=True, application_name="aws-data-wrangler-validation", validation_timeout=5): - try: - conn = pg8000.connect(database=database, - host=host, - port=int(port), - user=user, - password=password, - ssl=True, - application_name=application_name, - tcp_keepalive=tcp_keepalive, - timeout=validation_timeout) - conn.close() - except pg8000.core.InterfaceError as e: - raise e + conn = pg8000.connect(database=database, + host=host, + port=int(port), + user=user, + password=password, + ssl=True, + application_name=application_name, + tcp_keepalive=tcp_keepalive, + timeout=validation_timeout) + conn.close() @staticmethod def generate_connection(database, @@ -86,8 +84,6 @@ def generate_connection(database, tcp_keepalive=tcp_keepalive, application_name=application_name, validation_timeout=validation_timeout) - if isinstance(type(port), str) or isinstance(type(port), float): - port = int(port) conn = pg8000.connect(database=database, host=host, port=int(port), @@ -133,11 +129,10 @@ def write_load_manifest( } manifest["entries"].append(entry) payload: str = json.dumps(manifest) - client_s3 = self._session.boto3_session.client(service_name="s3", config=self._session.botocore_config) bucket: str bucket, path = manifest_path.replace("s3://", "").split("/", 1) logger.info(f"payload: {payload}") - client_s3.put_object(Body=payload, Bucket=bucket, Key=path) + self._client_s3.put_object(Body=payload, Bucket=bucket, Key=path) return manifest @staticmethod @@ -189,69 +184,67 @@ def load_table(dataframe, """ final_table_name: Optional[str] = None temp_table_name: Optional[str] = None - cursor = redshift_conn.cursor() - if mode == "overwrite": - Redshift._create_table(cursor=cursor, - dataframe=dataframe, - dataframe_type=dataframe_type, - schema_name=schema_name, - table_name=table_name, - diststyle=diststyle, - distkey=distkey, - sortstyle=sortstyle, - sortkey=sortkey, - primary_keys=primary_keys, - preserve_index=preserve_index, - cast_columns=cast_columns) - table_name = f"{schema_name}.{table_name}" - elif mode == "upsert": - guid: str = pa.compat.guid() - temp_table_name = f"temp_redshift_{guid}" - final_table_name = table_name - table_name = temp_table_name - sql: str = f"CREATE TEMPORARY TABLE {temp_table_name} (LIKE {schema_name}.{final_table_name})" - logger.debug(sql) - cursor.execute(sql) - else: - table_name = f"{schema_name}.{table_name}" + with redshift_conn.cursor() as cursor: + if mode == "overwrite": + Redshift._create_table(cursor=cursor, + dataframe=dataframe, + dataframe_type=dataframe_type, + schema_name=schema_name, + table_name=table_name, + diststyle=diststyle, + distkey=distkey, + sortstyle=sortstyle, + sortkey=sortkey, + primary_keys=primary_keys, + preserve_index=preserve_index, + cast_columns=cast_columns) + table_name = f"{schema_name}.{table_name}" + elif mode == "upsert": + guid: str = pa.compat.guid() + temp_table_name = f"temp_redshift_{guid}" + final_table_name = table_name + table_name = temp_table_name + sql: str = f"CREATE TEMPORARY TABLE {temp_table_name} (LIKE {schema_name}.{final_table_name})" + logger.debug(sql) + cursor.execute(sql) + else: + table_name = f"{schema_name}.{table_name}" - sql = ("-- AWS DATA WRANGLER\n" - f"COPY {table_name} FROM '{manifest_path}'\n" - f"IAM_ROLE '{iam_role}'\n" - "MANIFEST\n" - "FORMAT AS PARQUET") - logger.debug(sql) - cursor.execute(sql) - cursor.execute("-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id") - query_id = cursor.fetchall()[0][0] - sql = ("-- AWS DATA WRANGLER\n" - f"SELECT COUNT(*) as num_files_loaded FROM STL_LOAD_COMMITS WHERE query = {query_id}") - cursor.execute(sql) - num_files_loaded = cursor.fetchall()[0][0] - if num_files_loaded != num_files: - redshift_conn.rollback() - cursor.close() - raise RedshiftLoadError( - f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected.") - - if (mode == "upsert") and (final_table_name is not None): - if not primary_keys: - primary_keys = Redshift.get_primary_keys(connection=redshift_conn, - schema=schema_name, - table=final_table_name) - if not primary_keys: - raise InvalidRedshiftPrimaryKeys() - equals_clause = f"{final_table_name}.%s = {temp_table_name}.%s" - join_clause = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys]) - sql = f"DELETE FROM {schema_name}.{final_table_name} USING {temp_table_name} WHERE {join_clause}" + sql = ("-- AWS DATA WRANGLER\n" + f"COPY {table_name} FROM '{manifest_path}'\n" + f"IAM_ROLE '{iam_role}'\n" + "MANIFEST\n" + "FORMAT AS PARQUET") logger.debug(sql) cursor.execute(sql) - sql = f"INSERT INTO {schema_name}.{final_table_name} SELECT * FROM {temp_table_name}" - logger.debug(sql) + cursor.execute("-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id") + query_id = cursor.fetchall()[0][0] + sql = ("-- AWS DATA WRANGLER\n" + f"SELECT COUNT(*) as num_files_loaded FROM STL_LOAD_COMMITS WHERE query = {query_id}") cursor.execute(sql) + num_files_loaded = cursor.fetchall()[0][0] + if num_files_loaded != num_files: + redshift_conn.rollback() + raise RedshiftLoadError( + f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected.") + + if (mode == "upsert") and (final_table_name is not None): + if not primary_keys: + primary_keys = Redshift.get_primary_keys(connection=redshift_conn, + schema=schema_name, + table=final_table_name) + if not primary_keys: + raise InvalidRedshiftPrimaryKeys() + equals_clause = f"{final_table_name}.%s = {temp_table_name}.%s" + join_clause = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys]) + sql = f"DELETE FROM {schema_name}.{final_table_name} USING {temp_table_name} WHERE {join_clause}" + logger.debug(sql) + cursor.execute(sql) + sql = f"INSERT INTO {schema_name}.{final_table_name} SELECT * FROM {temp_table_name}" + logger.debug(sql) + cursor.execute(sql) redshift_conn.commit() - cursor.close() @staticmethod def _create_table(cursor, @@ -376,11 +369,14 @@ def _validate_parameters(schema, diststyle, distkey, sortstyle, sortkey): f"Currently value: {key}") @staticmethod - def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_columns=None): + def _get_redshift_schema(dataframe, + dataframe_type: str, + preserve_index: bool = False, + cast_columns=None) -> List[Tuple[str, str]]: if cast_columns is None: cast_columns = {} - schema_built = [] - if dataframe_type == "pandas": + schema_built: List[Tuple[str, str]] = [] + if dataframe_type.lower() == "pandas": pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe, preserve_index=preserve_index, indexes_position="right") @@ -390,7 +386,7 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c else: redshift_type = data_types.pyarrow2redshift(dtype) schema_built.append((name, redshift_type)) - elif dataframe_type == "spark": + elif dataframe_type.lower() == "spark": for name, dtype in dataframe.dtypes: if name in cast_columns.keys(): redshift_type = data_types.athena2redshift(cast_columns[name]) @@ -398,7 +394,8 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c redshift_type = data_types.spark2redshift(dtype) schema_built.append((name, redshift_type)) else: - raise InvalidDataframeType(dataframe_type) + raise InvalidDataframeType( + f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas' or 'spark'!") return schema_built def to_parquet(self, diff --git a/awswrangler/s3.py b/awswrangler/s3.py index 344a4ef30..f4af2384b 100644 --- a/awswrangler/s3.py +++ b/awswrangler/s3.py @@ -354,3 +354,10 @@ def copy_objects_batch(session_primitives, batch): copy_source = {"Bucket": source_bucket, "Key": source_key} target_bucket, target_key = S3.parse_object_path(path=target_obj) resource.meta.client.copy(copy_source, target_bucket, target_key) + + def get_bucket_region(self, bucket: str) -> str: + logger.debug(f"bucket: {bucket}") + region: str = self._client_s3.get_bucket_location(Bucket=bucket)["LocationConstraint"] + region = "us-east-1" if region is None else region + logger.debug(f"region: {region}") + return region diff --git a/awswrangler/sagemaker.py b/awswrangler/sagemaker.py index dc8ab5c66..9d8d4b07f 100644 --- a/awswrangler/sagemaker.py +++ b/awswrangler/sagemaker.py @@ -1,8 +1,8 @@ +from typing import Any import pickle import tarfile import logging -from typing import Any from awswrangler.exceptions import InvalidParameters logger = logging.getLogger(__name__) @@ -12,7 +12,9 @@ class SageMaker: def __init__(self, session): self._session = session self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config) - self._client_sagemaker = session.boto3_session.client(service_name="sagemaker") + self._client_sagemaker = session.boto3_session.client(service_name="sagemaker", + use_ssl=True, + config=session.botocore_config) @staticmethod def _parse_path(path): diff --git a/awswrangler/session.py b/awswrangler/session.py index 2e05f828e..3e627d4fc 100644 --- a/awswrangler/session.py +++ b/awswrangler/session.py @@ -12,6 +12,7 @@ from awswrangler.pandas import Pandas from awswrangler.glue import Glue from awswrangler.redshift import Redshift +from awswrangler.aurora import Aurora from awswrangler.emr import EMR from awswrangler.sagemaker import SageMaker from awswrangler.exceptions import AWSCredentialsNotFound @@ -52,7 +53,8 @@ def __init__(self, athena_kms_key: Optional[str] = None, athena_database: str = "default", athena_ctas_approach: bool = False, - redshift_temp_s3_path: Optional[str] = None): + redshift_temp_s3_path: Optional[str] = None, + aurora_temp_s3_path: Optional[str] = None): """ Most parameters inherit from Boto3 or Pyspark. https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html @@ -76,7 +78,8 @@ def __init__(self, :param athena_s3_output: AWS S3 path :param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS' :param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. - :param redshift_temp_s3_path: redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) + :param redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) + :param aurora_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) """ self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name) self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key @@ -100,6 +103,7 @@ def __init__(self, self._athena_database: str = athena_database self._athena_ctas_approach: bool = athena_ctas_approach self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path + self._aurora_temp_s3_path: Optional[str] = aurora_temp_s3_path self._primitives = None self._load_new_primitives() if boto3_session: @@ -113,6 +117,7 @@ def __init__(self, self._pandas = None self._glue = None self._redshift = None + self._aurora = None self._spark = None self._sagemaker = None @@ -160,7 +165,8 @@ def _load_new_primitives(self): athena_kms_key=self._athena_kms_key, athena_database=self._athena_database, athena_ctas_approach=self._athena_ctas_approach, - redshift_temp_s3_path=self._redshift_temp_s3_path) + redshift_temp_s3_path=self._redshift_temp_s3_path, + aurora_temp_s3_path=self._aurora_temp_s3_path) @property def profile_name(self): @@ -238,6 +244,10 @@ def athena_ctas_approach(self) -> bool: def redshift_temp_s3_path(self) -> Optional[str]: return self._redshift_temp_s3_path + @property + def aurora_temp_s3_path(self) -> Optional[str]: + return self._aurora_temp_s3_path + @property def boto3_session(self): return self._boto3_session @@ -288,6 +298,12 @@ def redshift(self): self._redshift = Redshift(session=self) return self._redshift + @property + def aurora(self): + if not self._aurora: + self._aurora = Aurora(session=self) + return self._aurora + @property def sagemaker(self): if not self._sagemaker: @@ -326,7 +342,8 @@ def __init__(self, athena_kms_key: Optional[str] = None, athena_database: Optional[str] = None, athena_ctas_approach: bool = False, - redshift_temp_s3_path: Optional[str] = None): + redshift_temp_s3_path: Optional[str] = None, + aurora_temp_s3_path: Optional[str] = None): """ Most parameters inherit from Boto3. https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html @@ -348,6 +365,7 @@ def __init__(self, :param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS' :param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. :param redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) + :param aurora_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) """ self._profile_name: Optional[str] = profile_name self._aws_access_key_id: Optional[str] = aws_access_key_id @@ -366,6 +384,7 @@ def __init__(self, self._athena_database: Optional[str] = athena_database self._athena_ctas_approach: bool = athena_ctas_approach self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path + self._aurora_temp_s3_path: Optional[str] = aurora_temp_s3_path @property def profile_name(self): @@ -435,6 +454,10 @@ def athena_ctas_approach(self) -> bool: def redshift_temp_s3_path(self) -> Optional[str]: return self._redshift_temp_s3_path + @property + def aurora_temp_s3_path(self) -> Optional[str]: + return self._aurora_temp_s3_path + @property def session(self): """ @@ -456,4 +479,5 @@ def session(self): athena_kms_key=self._athena_kms_key, athena_database=self._athena_database, athena_ctas_approach=self._athena_ctas_approach, - redshift_temp_s3_path=self._redshift_temp_s3_path) + redshift_temp_s3_path=self._redshift_temp_s3_path, + aurora_temp_s3_path=self._aurora_temp_s3_path) diff --git a/building/Dockerfile b/building/Dockerfile index f2509f1a9..161055296 100644 --- a/building/Dockerfile +++ b/building/Dockerfile @@ -6,7 +6,7 @@ RUN yum install -y \ bison \ flex \ autoconf \ - python37-devel + python36-devel RUN pip3 install --upgrade pip diff --git a/building/build-lambda-layer.sh b/building/build-lambda-layer.sh index 6c0248de0..e8d6a8ae4 100755 --- a/building/build-lambda-layer.sh +++ b/building/build-lambda-layer.sh @@ -1,18 +1,15 @@ #!/usr/bin/env bash set -e - # Go back to AWSWRANGLER directory cd /aws-data-wrangler/ rm -rf dist/*.zip -# Build PyArrow files if necessary -if [ ! -d "dist/pyarrow_files" ] ; then - cd building - ./build-pyarrow.sh - cd .. -fi +# Build PyArrow files +cd building +./build-pyarrow.sh +cd .. # Preparing directories mkdir -p dist diff --git a/building/build-pyarrow.sh b/building/build-pyarrow.sh index 0251af8fb..71373acf7 100755 --- a/building/build-pyarrow.sh +++ b/building/build-pyarrow.sh @@ -7,7 +7,6 @@ rm -rf \ dist \ /aws-data-wrangler/dist/pyarrow_wheels \ /aws-data-wrangler/dist/pyarrow_files \ - /aws-data-wrangler/dist/pyarrow_wheels/ # Clone desired Arrow version git clone \ diff --git a/docs/source/examples.rst b/docs/source/examples.rst index dfaeccbc6..04812dcf4 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -83,6 +83,23 @@ Reading from AWS Athena to Pandas with the blazing fast CTAS approach database="database" ) +Reading from Glue Catalog (Parquet) to Pandas +````````````````````````````````````````````` + +.. code-block:: python + + import awswrangler as wr + + df = wr.pandas.read_table(database="DATABASE_NAME", table="TABLE_NAME") + +Reading from S3 (Parquet) to Pandas +``````````````````````````````````` + +.. code-block:: python + + import awswrangler as wr + + df = wr.pandas.read_parquet(path="s3://...", columns=["c1", "c3"], filters=[("c5", "=", 0)]) Reading from S3 (CSV) to Pandas ``````````````````````````````` @@ -174,6 +191,32 @@ Extract Redshift query to Pandas DataFrame connection=con, temp_s3_path="s3://temp_path") +Loading Pandas Dataframe to Aurora (MySQL/PostgreSQL) +````````````````````````````````````````````````````` + +.. code-block:: python + + import awswrangler as wr + + wr.pandas.to_aurora( + dataframe=df, + connection=con, + schema="...", + table="..." + ) + + +Extract Aurora query to Pandas DataFrame (MySQL) +```````````````````````````````````````````````` + +.. code-block:: python + + import awswrangler as wr + + df = wr.pandas.read_sql_aurora( + sql="SELECT ...", + connection=con + ) PySpark ------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 3550961f2..86b3af6f7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,13 +20,14 @@ Pandas * Pandas -> Glue Catalog Table * Pandas -> Athena (Parallel) * Pandas -> Redshift (Append/Overwrite/Upsert) (Parallel) +* Pandas -> Aurora (MySQL/PostgreSQL) (Append/Overwrite) (Via S3) (NEW) * Parquet (S3) -> Pandas (Parallel) * CSV (S3) -> Pandas (One shot or Batching) * Glue Catalog Table -> Pandas (Parallel) * Athena -> Pandas (One shot, Batching or Parallel) * Redshift -> Pandas (Parallel) -* Redshift -> Parquet (S3) * CloudWatch Logs Insights -> Pandas +* Aurora -> Pandas (MySQL) (Via S3) (NEW) * Encrypt Pandas Dataframes on S3 with KMS keys PySpark @@ -45,13 +46,16 @@ General * Get the size of S3 objects (Parallel) * Get CloudWatch Logs Insights query results * Load partitions on Athena/Glue table (repair table) -* Create EMR cluster (For humans) (NEW) -* Terminate EMR cluster (NEW) -* Get EMR cluster state (NEW) -* Submit EMR step(s) (For humans) (NEW) -* Get EMR step state (NEW) -* Athena query to receive the result as python primitives (Iterable[Dict[str, Any]) (NEW) +* Create EMR cluster (For humans) +* Terminate EMR cluster +* Get EMR cluster state +* Submit EMR step(s) (For humans) +* Get EMR step state +* Get EMR step state +* Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*) * Load and Unzip SageMaker jobs outputs +* Redshift -> Parquet (S3) +* Aurora -> CSV (S3) (MySQL) (NEW :star:) Table Of Contents diff --git a/requirements.txt b/requirements.txt index 7db12a62b..c32f44bdf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ botocore~=1.13.45 boto3~=1.10.45 s3fs~=0.4.0 tenacity~=6.0.0 -pg8000~=1.13.2 \ No newline at end of file +pg8000~=1.13.2 +pymysql~=0.9.3 \ No newline at end of file diff --git a/testing/deploy-cloudformation.sh b/testing/deploy-cloudformation.sh index 0d3b05080..c9c20628f 100755 --- a/testing/deploy-cloudformation.sh +++ b/testing/deploy-cloudformation.sh @@ -3,6 +3,6 @@ set -e aws cloudformation deploy \ --template-file template.yaml \ ---stack-name aws-data-wrangler-test-arena \ +--stack-name aws-data-wrangler-test \ --capabilities CAPABILITY_IAM \ --parameter-overrides $(cat parameters.properties) diff --git a/testing/parameters.properties b/testing/parameters.properties index 6e245de14..30f9c0b47 100644 --- a/testing/parameters.properties +++ b/testing/parameters.properties @@ -1,4 +1,5 @@ VpcId=VPC_ID SubnetId=SUBNET_ID +SubnetId2=SUBNET_ID2 Password=REDSHIFT_PASSWORD TestUser=AWS_USER_THAT_WILL_RUN_THE_TESTS_ON_CLI \ No newline at end of file diff --git a/testing/template.yaml b/testing/template.yaml index b9e002d3e..47575efb3 100644 --- a/testing/template.yaml +++ b/testing/template.yaml @@ -10,6 +10,9 @@ Parameters: SubnetId: Type: String Description: Redshift Subnet ID + SubnetId2: + Type: String + Description: Redshift Subnet ID Password: Type: String Description: Redshift Password @@ -97,6 +100,7 @@ Resources: Action: - "s3:Get*" - "s3:List*" + - "s3:Put*" Resource: - !Join ['', ['arn:aws:s3:::', !Ref Bucket]] - !Join ['', ['arn:aws:s3:::', !Ref Bucket, /*]] @@ -108,7 +112,7 @@ Resources: SubnetIds: - Ref: SubnetId - RedshiftSecurityGroup: + DatabaseSecurityGroup: Type: AWS::EC2::SecurityGroup Properties: VpcId: !Ref VpcId @@ -117,7 +121,7 @@ Resources: Redshift: Type: AWS::Redshift::Cluster DependsOn: - - RedshiftSecurityGroup + - DatabaseSecurityGroup - RedshiftSubnetGroup - RedshiftRole Properties: @@ -127,7 +131,7 @@ Resources: NodeType: dc2.large ClusterType: single-node VpcSecurityGroupIds: - - !Ref RedshiftSecurityGroup + - !Ref DatabaseSecurityGroup ClusterSubnetGroupName: !Ref RedshiftSubnetGroup PubliclyAccessible: true Port: 5439 @@ -152,6 +156,121 @@ Resources: Properties: LogGroupName: !Ref LogGroup + RdsSubnetGroup: + Type: AWS::RDS::DBSubnetGroup + Properties: + DBSubnetGroupDescription: RDS Database Subnet Group + SubnetIds: + - Ref: SubnetId + - Ref: SubnetId2 + + AuroraRole: + Type: AWS::IAM::Role + Properties: + AssumeRolePolicyDocument: + Version: 2012-10-17 + Statement: + - Effect: Allow + Principal: + Service: + - rds.amazonaws.com + Action: + - sts:AssumeRole + Path: "/" + Policies: + - PolicyName: S3GetAndList + PolicyDocument: + Version: 2012-10-17 + Statement: + - Effect: Allow + Action: + - "s3:Get*" + - "s3:List*" + - "s3:Put*" + Resource: + - !Join ['', ['arn:aws:s3:::', !Ref Bucket]] + - !Join ['', ['arn:aws:s3:::', !Ref Bucket, /*]] + + PostgresParameterGroup: + Type: AWS::RDS::DBClusterParameterGroup + Properties: + Description: Postgres 10 + Family: aurora-postgresql10 + Parameters: + apg_plan_mgmt.capture_plan_baselines: "off" + + AuroraClusterPostgres: + Type: AWS::RDS::DBCluster + DependsOn: + - PostgresParameterGroup + - AuroraRole + - RdsSubnetGroup + - DatabaseSecurityGroup + Properties: + Engine: aurora-postgresql + DBClusterIdentifier : postgres-cluster-wrangler + MasterUsername: test + MasterUserPassword: !Ref Password + BackupRetentionPeriod: 1 + DBSubnetGroupName: !Ref RdsSubnetGroup + VpcSecurityGroupIds: + - !Ref DatabaseSecurityGroup + DBClusterParameterGroupName: !Ref PostgresParameterGroup + AssociatedRoles: + - FeatureName: s3Import + RoleArn: !GetAtt AuroraRole.Arn + + AuroraInstancePostgres: + Type: AWS::RDS::DBInstance + Properties: + Engine: aurora-postgresql + DBInstanceIdentifier: postgres-instance-wrangler + DBClusterIdentifier: !Ref AuroraClusterPostgres + DBInstanceClass: db.t3.medium + DBSubnetGroupName: !Ref RdsSubnetGroup + PubliclyAccessible: true + + MysqlParameterGroup: + Type: AWS::RDS::DBClusterParameterGroup + Properties: + Description: Mysql 5.7 + Family: aurora-mysql5.7 + Parameters: + aurora_load_from_s3_role: !GetAtt AuroraRole.Arn + aws_default_s3_role: !GetAtt AuroraRole.Arn + aurora_select_into_s3_role: !GetAtt AuroraRole.Arn + + AuroraClusterMysql: + Type: AWS::RDS::DBCluster + DependsOn: + - MysqlParameterGroup + - AuroraRole + - RdsSubnetGroup + - DatabaseSecurityGroup + Properties: + Engine: aurora-mysql + DBClusterIdentifier: mysql-cluster-wrangler + MasterUsername: test + MasterUserPassword: !Ref Password + BackupRetentionPeriod: 1 + DBSubnetGroupName: !Ref RdsSubnetGroup + VpcSecurityGroupIds: + - !Ref DatabaseSecurityGroup + DBClusterParameterGroupName: !Ref MysqlParameterGroup + AssociatedRoles: + - RoleArn: !GetAtt AuroraRole.Arn + + + AuroraInstanceMysql: + Type: AWS::RDS::DBInstance + Properties: + Engine: aurora-mysql + DBInstanceIdentifier: mysql-instance-wrangler + DBClusterIdentifier: !Ref AuroraClusterMysql + DBInstanceClass: db.t3.medium + DBSubnetGroupName: !Ref RdsSubnetGroup + PubliclyAccessible: true + Outputs: BucketName: Value: !Ref Bucket @@ -162,9 +281,9 @@ Outputs: RedshiftPort: Value: !GetAtt Redshift.Endpoint.Port Description: Redshift Endpoint Port. - RedshiftPassword: + Password: Value: !Ref Password - Description: Redshift Password. + Description: Password. RedshiftRole: Value: !GetAtt RedshiftRole.Arn Description: Redshift IAM role. @@ -183,3 +302,12 @@ Outputs: SubnetId: Value: !Ref SubnetId Description: Subnet ID + SubnetId2: + Value: !Ref SubnetId2 + Description: Subnet ID 2 + PostgresAddress: + Value: !GetAtt AuroraInstancePostgres.Endpoint.Address + Description: Postgres Address + MysqlAddress: + Value: !GetAtt AuroraInstanceMysql.Endpoint.Address + Description: Mysql Address \ No newline at end of file diff --git a/testing/test_awswrangler/test_athena.py b/testing/test_awswrangler/test_athena.py index 92d9f38ff..5fbfcfeca 100644 --- a/testing/test_awswrangler/test_athena.py +++ b/testing/test_awswrangler/test_athena.py @@ -15,7 +15,7 @@ @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue") diff --git a/testing/test_awswrangler/test_aurora.py b/testing/test_awswrangler/test_aurora.py new file mode 100644 index 000000000..90f35701f --- /dev/null +++ b/testing/test_awswrangler/test_aurora.py @@ -0,0 +1,89 @@ +import logging + +import pytest +import boto3 + +from awswrangler import Aurora +from awswrangler.exceptions import InvalidEngine + +logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +@pytest.fixture(scope="module") +def cloudformation_outputs(): + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") + outputs = {} + for output in response.get("Stacks")[0].get("Outputs"): + outputs[output.get("OutputKey")] = output.get("OutputValue") + yield outputs + + +@pytest.fixture(scope="module") +def postgres_parameters(cloudformation_outputs): + postgres_parameters = {} + if "PostgresAddress" in cloudformation_outputs: + postgres_parameters["PostgresAddress"] = cloudformation_outputs.get("PostgresAddress") + else: + raise Exception("You must deploy the test infrastructure using SAM!") + if "Password" in cloudformation_outputs: + postgres_parameters["Password"] = cloudformation_outputs.get("Password") + else: + raise Exception("You must deploy the test infrastructure using SAM!") + yield postgres_parameters + + +@pytest.fixture(scope="module") +def mysql_parameters(cloudformation_outputs): + mysql_parameters = {} + if "MysqlAddress" in cloudformation_outputs: + mysql_parameters["MysqlAddress"] = cloudformation_outputs.get("MysqlAddress") + else: + raise Exception("You must deploy the test infrastructure using SAM!") + if "Password" in cloudformation_outputs: + mysql_parameters["Password"] = cloudformation_outputs.get("Password") + else: + raise Exception("You must deploy the test infrastructure using SAM!") + yield mysql_parameters + + +def test_postgres_connection(postgres_parameters): + conn = Aurora.generate_connection(database="postgres", + host=postgres_parameters["PostgresAddress"], + port=3306, + user="test", + password=postgres_parameters["Password"], + engine="postgres") + cursor = conn.cursor() + cursor.execute("SELECT 1 + 2, 3 + 4") + first_row = cursor.fetchall()[0] + assert first_row[0] == 3 + assert first_row[1] == 7 + cursor.close() + conn.close() + + +def test_mysql_connection(mysql_parameters): + conn = Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="mysql") + cursor = conn.cursor() + cursor.execute("SELECT 1 + 2, 3 + 4") + first_row = cursor.fetchall()[0] + assert first_row[0] == 3 + assert first_row[1] == 7 + cursor.close() + conn.close() + + +def test_invalid_engine(mysql_parameters): + with pytest.raises(InvalidEngine): + Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="foo") diff --git a/testing/test_awswrangler/test_cloudwatchlogs.py b/testing/test_awswrangler/test_cloudwatchlogs.py index 383ecbb49..920e4f6d6 100644 --- a/testing/test_awswrangler/test_cloudwatchlogs.py +++ b/testing/test_awswrangler/test_cloudwatchlogs.py @@ -13,7 +13,7 @@ @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue") diff --git a/testing/test_awswrangler/test_emr.py b/testing/test_awswrangler/test_emr.py index 8a1d1e93a..559f73326 100644 --- a/testing/test_awswrangler/test_emr.py +++ b/testing/test_awswrangler/test_emr.py @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue") diff --git a/testing/test_awswrangler/test_glue.py b/testing/test_awswrangler/test_glue.py index 58c52e68a..2b13e44cb 100644 --- a/testing/test_awswrangler/test_glue.py +++ b/testing/test_awswrangler/test_glue.py @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue") diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index 81daf398c..f41605a89 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -3,13 +3,15 @@ import csv from datetime import datetime, date from decimal import Decimal +import warnings import pytest import boto3 import pandas as pd import numpy as np -from awswrangler import Session, Pandas +import awswrangler as wr +from awswrangler import Session, Pandas, Aurora from awswrangler.exceptions import LineTerminatorNotFound, EmptyDataframe, InvalidSerDe, UndetectedType logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") @@ -18,7 +20,7 @@ @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue") @@ -87,6 +89,58 @@ def logstream(cloudformation_outputs, loggroup): yield logstream +@pytest.fixture(scope="module") +def postgres_parameters(cloudformation_outputs): + postgres_parameters = {} + if "PostgresAddress" in cloudformation_outputs: + postgres_parameters["PostgresAddress"] = cloudformation_outputs.get("PostgresAddress") + else: + raise Exception("You must deploy the test infrastructure using SAM!") + if "Password" in cloudformation_outputs: + postgres_parameters["Password"] = cloudformation_outputs.get("Password") + else: + raise Exception("You must deploy the test infrastructure using SAM!") + conn = Aurora.generate_connection(database="postgres", + host=postgres_parameters["PostgresAddress"], + port=3306, + user="test", + password=postgres_parameters["Password"], + engine="postgres") + with conn.cursor() as cursor: + sql = "CREATE EXTENSION IF NOT EXISTS aws_s3 CASCADE" + cursor.execute(sql) + conn.commit() + conn.close() + yield postgres_parameters + + +@pytest.fixture(scope="module") +def mysql_parameters(cloudformation_outputs): + mysql_parameters = {} + if "MysqlAddress" in cloudformation_outputs: + mysql_parameters["MysqlAddress"] = cloudformation_outputs.get("MysqlAddress") + else: + raise Exception("You must deploy the test infrastructure using SAM!") + if "Password" in cloudformation_outputs: + mysql_parameters["Password"] = cloudformation_outputs.get("Password") + else: + raise Exception("You must deploy the test infrastructure using SAM!") + conn = Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="mysql") + with conn.cursor() as cursor: + sql = "CREATE DATABASE IF NOT EXISTS test" + with warnings.catch_warnings(): + warnings.filterwarnings(action="ignore", message=".*database exists.*") + cursor.execute(sql) + conn.commit() + conn.close() + yield mysql_parameters + + @pytest.mark.parametrize("sample, row_num", [("data_samples/micro.csv", 30), ("data_samples/small.csv", 100)]) def test_read_csv(session, bucket, sample, row_num): boto3.client("s3").upload_file(sample, bucket, sample) @@ -1592,3 +1646,134 @@ def test_to_csv_single_file(session, bucket, database): assert len(list(df.columns)) + 1 == len(list(df2.columns)) assert len(df.index) == len(df2.index) print(df2) + + +def test_aurora_mysql_load_simple(bucket, mysql_parameters): + df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"]}) + conn = Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="mysql") + path = f"s3://{bucket}/test_aurora_mysql_load_simple" + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="test", + table="test_aurora_mysql_load_simple", + mode="overwrite", + temp_s3_path=path) + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM test.test_aurora_mysql_load_simple") + rows = cursor.fetchall() + assert len(rows) == len(df.index) + assert rows[0][0] == 1 + assert rows[1][0] == 2 + assert rows[2][0] == 3 + assert rows[0][1] == "foo" + assert rows[1][1] == "boo" + assert rows[2][1] == "bar" + conn.close() + + +def test_aurora_postgres_load_simple(bucket, postgres_parameters): + df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"]}) + conn = Aurora.generate_connection(database="postgres", + host=postgres_parameters["PostgresAddress"], + port=3306, + user="test", + password=postgres_parameters["Password"], + engine="postgres") + path = f"s3://{bucket}/test_aurora_postgres_load_simple" + wr.pandas.to_aurora( + dataframe=df, + connection=conn, + schema="public", + table="test_aurora_postgres_load_simple", + mode="overwrite", + temp_s3_path=path, + engine="postgres", + ) + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM public.test_aurora_postgres_load_simple") + rows = cursor.fetchall() + assert len(rows) == len(df.index) + assert rows[0][0] == 1 + assert rows[1][0] == 2 + assert rows[2][0] == 3 + assert rows[0][1] == "foo" + assert rows[1][1] == "boo" + assert rows[2][1] == "bar" + conn.close() + + +def test_aurora_mysql_unload_simple(bucket, mysql_parameters): + df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"]}) + conn = Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="mysql") + path = f"s3://{bucket}/test_aurora_mysql_unload_simple" + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="test", + table="test_aurora_mysql_load_simple", + mode="overwrite", + temp_s3_path=path) + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM test.test_aurora_mysql_load_simple") + rows = cursor.fetchall() + assert len(rows) == len(df.index) + assert rows[0][0] == 1 + assert rows[1][0] == 2 + assert rows[2][0] == 3 + assert rows[0][1] == "foo" + assert rows[1][1] == "boo" + assert rows[2][1] == "bar" + path2 = f"s3://{bucket}/test_aurora_mysql_unload_simple2" + df2 = wr.pandas.read_sql_aurora(sql="SELECT * FROM test.test_aurora_mysql_load_simple", + connection=conn, + col_names=["id", "value"], + temp_s3_path=path2) + assert len(df.index) == len(df2.index) + assert len(df.columns) == len(df2.columns) + df2 = wr.pandas.read_sql_aurora(sql="SELECT * FROM test.test_aurora_mysql_load_simple", + connection=conn, + temp_s3_path=path2) + assert len(df.index) == len(df2.index) + assert len(df.columns) == len(df2.columns) + conn.close() + + +@pytest.mark.parametrize("sample, row_num", [("data_samples/micro.csv", 30), ("data_samples/small.csv", 100)]) +def test_read_csv_list(bucket, sample, row_num): + n = 10 + paths = [] + for i in range(n): + key = f"{sample}_{i}" + boto3.client("s3").upload_file(sample, bucket, key) + paths.append(f"s3://{bucket}/{key}") + dataframe = wr.pandas.read_csv_list(paths=paths) + wr.s3.delete_listed_objects(objects_paths=paths) + assert len(dataframe.index) == row_num * n + + +@pytest.mark.parametrize("sample, row_num", [("data_samples/micro.csv", 30), ("data_samples/small.csv", 100)]) +def test_read_csv_list_iterator(bucket, sample, row_num): + n = 10 + paths = [] + for i in range(n): + key = f"{sample}_{i}" + boto3.client("s3").upload_file(sample, bucket, key) + paths.append(f"s3://{bucket}/{key}") + + df_iter = wr.pandas.read_csv_list(paths=paths, max_result_size=200) + total_count = 0 + for df in df_iter: + count = len(df.index) + print(f"count: {count}") + total_count += count + wr.s3.delete_listed_objects(objects_paths=paths) + assert total_count == row_num * n diff --git a/testing/test_awswrangler/test_redshift.py b/testing/test_awswrangler/test_redshift.py index 3ec1ed60d..9e5945641 100644 --- a/testing/test_awswrangler/test_redshift.py +++ b/testing/test_awswrangler/test_redshift.py @@ -9,6 +9,7 @@ from pyspark.sql import SparkSession import pg8000 +import awswrangler as wr from awswrangler import Session, Redshift from awswrangler.exceptions import InvalidRedshiftDiststyle, InvalidRedshiftDistkey, InvalidRedshiftSortstyle, InvalidRedshiftSortkey @@ -18,7 +19,7 @@ @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue") @@ -48,8 +49,8 @@ def redshift_parameters(cloudformation_outputs): redshift_parameters["RedshiftAddress"] = cloudformation_outputs.get("RedshiftAddress") else: raise Exception("You must deploy the test infrastructure using SAM!") - if "RedshiftPassword" in cloudformation_outputs: - redshift_parameters["RedshiftPassword"] = cloudformation_outputs.get("RedshiftPassword") + if "Password" in cloudformation_outputs: + redshift_parameters["Password"] = cloudformation_outputs.get("Password") else: raise Exception("You must deploy the test infrastructure using SAM!") if "RedshiftPort" in cloudformation_outputs: @@ -90,7 +91,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) path = f"s3://{bucket}/redshift-load/" session.pandas.to_redshift( @@ -130,7 +131,7 @@ def test_to_redshift_pandas_cast(session, bucket, redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) path = f"s3://{bucket}/redshift-load/" session.pandas.to_redshift(dataframe=df, @@ -169,7 +170,7 @@ def test_to_redshift_pandas_exceptions(session, bucket, redshift_parameters, sam host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) path = f"s3://{bucket}/redshift-load/" with pytest.raises(exc): @@ -223,7 +224,7 @@ def test_to_redshift_spark(session, bucket, redshift_parameters, sample_name, mo host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) session.spark.to_redshift( dataframe=dataframe, @@ -260,7 +261,7 @@ def test_to_redshift_spark_big(session, bucket, redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) session.spark.to_redshift( dataframe=dataframe, @@ -288,7 +289,7 @@ def test_to_redshift_spark_bool(session, bucket, redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) session.spark.to_redshift( dataframe=dataframe, @@ -325,7 +326,7 @@ def test_stress_to_redshift_spark_big(session, bucket, redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) session.spark.to_redshift( dataframe=dataframe, @@ -360,7 +361,7 @@ def test_to_redshift_spark_exceptions(session, bucket, redshift_parameters, samp host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) with pytest.raises(exc): assert session.spark.to_redshift( @@ -399,7 +400,7 @@ def test_connection_timeout(redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=12345, user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) @@ -409,7 +410,7 @@ def test_connection_with_different_port_types(redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=str(redshift_parameters.get("RedshiftPort")), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) conn.close() conn = Redshift.generate_connection( @@ -417,7 +418,7 @@ def test_connection_with_different_port_types(redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=float(redshift_parameters.get("RedshiftPort")), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) conn.close() @@ -434,7 +435,7 @@ def test_to_redshift_pandas_decimal(session, bucket, redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) path = f"s3://{bucket}/redshift-load/" session.pandas.to_redshift( @@ -479,7 +480,7 @@ def test_to_redshift_spark_decimal(session, bucket, redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) path = f"s3://{bucket}/redshift-load2/" session.spark.to_redshift( @@ -518,7 +519,7 @@ def test_to_parquet(session, bucket, redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) path = f"s3://{bucket}/test_to_parquet/" session.pandas.to_redshift( @@ -555,7 +556,7 @@ def test_read_sql_redshift_pandas(session, bucket, redshift_parameters, sample_n host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) path = f"s3://{bucket}/test_read_sql_redshift_pandas/" session.pandas.to_redshift( @@ -585,7 +586,7 @@ def test_read_sql_redshift_pandas2(session, bucket, redshift_parameters): host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) path = f"s3://{bucket}/test_read_sql_redshift_pandas2/" session.pandas.to_redshift( @@ -603,61 +604,68 @@ def test_read_sql_redshift_pandas2(session, bucket, redshift_parameters): iam_role=redshift_parameters.get("RedshiftRole"), connection=con, temp_s3_path=path2) + wr.s3.delete_objects(path=f"s3://{bucket}/") assert len(df.index) == len(df2.index) assert len(df.columns) + 1 == len(df2.columns) def test_to_redshift_pandas_upsert(session, bucket, redshift_parameters): + wr.s3.delete_objects(path=f"s3://{bucket}/") con = Redshift.generate_connection( database="test", host=redshift_parameters.get("RedshiftAddress"), port=redshift_parameters.get("RedshiftPort"), user="test", - password=redshift_parameters.get("RedshiftPassword"), + password=redshift_parameters.get("Password"), ) - # CREATE - df = pd.DataFrame({ - "id": list((range(1_000_000))), - "val": list(["foo" if i % 2 == 0 else "boo" for i in range(1_000_000)]) - }) - path = f"s3://{bucket}/test_to_redshift_pandas_upsert/" - session.pandas.to_redshift(dataframe=df, - path=path, - schema="public", - table="test_upsert", - connection=con, - iam_role=redshift_parameters.get("RedshiftRole"), - mode="overwrite", - preserve_index=True, - primary_keys=["id"]) - path = f"s3://{bucket}/test_to_redshift_pandas_upsert2/" - df2 = session.pandas.read_sql_redshift(sql="select * from public.test_upsert", - iam_role=redshift_parameters.get("RedshiftRole"), - connection=con, - temp_s3_path=path) - assert len(df.index) == len(df2.index) - assert len(df.columns) + 1 == len(df2.columns) + df = pd.DataFrame({"id": list((range(1_000))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(1_000)])}) - # UPSERT df3 = pd.DataFrame({ - "id": list((range(1_000_000, 1_500_000))), - "val": list(["foo" if i % 2 == 0 else "boo" for i in range(500_000)]) + "id": list((range(1_000, 1_500))), + "val": list(["foo" if i % 2 == 0 else "boo" for i in range(500)]) }) - path = f"s3://{bucket}/test_to_redshift_pandas_upsert3/" - session.pandas.to_redshift(dataframe=df3, - path=path, - schema="public", - table="test_upsert", - connection=con, - iam_role=redshift_parameters.get("RedshiftRole"), - mode="upsert", - preserve_index=True, - primary_keys=["id"]) - path = f"s3://{bucket}/test_to_redshift_pandas_upsert4/" - df4 = session.pandas.read_sql_redshift(sql="select * from public.test_upsert", - iam_role=redshift_parameters.get("RedshiftRole"), - connection=con, - temp_s3_path=path) - assert len(df.index) + len(df3.index) == len(df4.index) - assert len(df.columns) + 1 == len(df2.columns) + + for i in range(10): + print(f"run: {i}") + + # CREATE + path = f"s3://{bucket}/test_to_redshift_pandas_upsert/" + session.pandas.to_redshift(dataframe=df, + path=path, + schema="public", + table="test_upsert", + connection=con, + iam_role=redshift_parameters.get("RedshiftRole"), + mode="overwrite", + preserve_index=True, + primary_keys=["id"]) + path = f"s3://{bucket}/test_to_redshift_pandas_upsert2/" + df2 = session.pandas.read_sql_redshift(sql="select * from public.test_upsert", + iam_role=redshift_parameters.get("RedshiftRole"), + connection=con, + temp_s3_path=path) + assert len(df.index) == len(df2.index) + assert len(df.columns) + 1 == len(df2.columns) + + # UPSERT + path = f"s3://{bucket}/test_to_redshift_pandas_upsert3/" + session.pandas.to_redshift(dataframe=df3, + path=path, + schema="public", + table="test_upsert", + connection=con, + iam_role=redshift_parameters.get("RedshiftRole"), + mode="upsert", + preserve_index=True, + primary_keys=["id"]) + path = f"s3://{bucket}/test_to_redshift_pandas_upsert4/" + df4 = session.pandas.read_sql_redshift(sql="select * from public.test_upsert", + iam_role=redshift_parameters.get("RedshiftRole"), + connection=con, + temp_s3_path=path) + assert len(df.index) + len(df3.index) == len(df4.index) + assert len(df.columns) + 1 == len(df2.columns) + + wr.s3.delete_objects(path=f"s3://{bucket}/") + con.close() diff --git a/testing/test_awswrangler/test_s3.py b/testing/test_awswrangler/test_s3.py index 1ffc10912..d47dc27c4 100644 --- a/testing/test_awswrangler/test_s3.py +++ b/testing/test_awswrangler/test_s3.py @@ -56,7 +56,7 @@ def write_fake_objects(bucket, path, num, size=10): @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue") diff --git a/testing/test_awswrangler/test_sagemaker.py b/testing/test_awswrangler/test_sagemaker.py index 568210306..5a6bc210e 100644 --- a/testing/test_awswrangler/test_sagemaker.py +++ b/testing/test_awswrangler/test_sagemaker.py @@ -20,7 +20,7 @@ def session(): @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue") diff --git a/testing/test_awswrangler/test_spark.py b/testing/test_awswrangler/test_spark.py index 9836e850b..c3688d092 100644 --- a/testing/test_awswrangler/test_spark.py +++ b/testing/test_awswrangler/test_spark.py @@ -17,7 +17,7 @@ @pytest.fixture(scope="module") def cloudformation_outputs(): - response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena") + response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test") outputs = {} for output in response.get("Stacks")[0].get("Outputs"): outputs[output.get("OutputKey")] = output.get("OutputValue")