From a42e68c2a2c55f5fe2e782f757d95981d080c8e4 Mon Sep 17 00:00:00 2001 From: igorborgest Date: Wed, 22 Jan 2020 21:23:02 -0300 Subject: [PATCH 1/3] Add columns parameters to Pandas.to_csv() --- awswrangler/data_types.py | 8 +-- awswrangler/glue.py | 92 ++++++++++++++----------- awswrangler/pandas.py | 32 ++++++--- testing/test_awswrangler/test_pandas.py | 28 ++++++++ 4 files changed, 108 insertions(+), 52 deletions(-) diff --git a/awswrangler/data_types.py b/awswrangler/data_types.py index 83b45d9c6..aadd62ead 100644 --- a/awswrangler/data_types.py +++ b/awswrangler/data_types.py @@ -370,8 +370,8 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame, :param indexes_position: "right" or "left" :return: Pyarrow schema (e.g. [("col name": "bigint"), ("col2 name": "int")] """ - cols = [] - cols_dtypes = {} + cols: List[str] = [] + cols_dtypes: Dict[str, str] = {} if indexes_position not in ("right", "left"): raise ValueError(f"indexes_position must be \"right\" or \"left\"") @@ -384,10 +384,10 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame, cols.append(name) # Filling cols_dtypes and indexes - indexes = [] + indexes: List[str] = [] for field in pa.Schema.from_pandas(df=dataframe[cols], preserve_index=preserve_index): name = str(field.name) - dtype = field.type + dtype = str(field.type) cols_dtypes[name] = dtype if name not in dataframe.columns: indexes.append(name) diff --git a/awswrangler/glue.py b/awswrangler/glue.py index e1f34f121..8548f4887 100644 --- a/awswrangler/glue.py +++ b/awswrangler/glue.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, Optional, Any, Iterator, List, Union +from typing import TYPE_CHECKING, Dict, Optional, Any, Iterator, List, Union, Tuple from math import ceil from itertools import islice import re @@ -55,16 +55,16 @@ def get_table_python_types(self, database: str, table: str) -> Dict[str, Optiona def metadata_to_glue(self, dataframe, path: str, - objects_paths, - file_format, - database=None, - table=None, - partition_cols=None, - preserve_index=True, + objects_paths: List[str], + file_format: str, + database: str, + table: Optional[str], + partition_cols: Optional[List[str]] = None, + preserve_index: bool = True, mode: str = "append", - compression=None, - cast_columns=None, - extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None, + compression: Optional[str] = None, + cast_columns: Optional[Dict[str, str]] = None, + extra_args: Optional[Dict[str, Optional[Union[str, int, List[str]]]]] = None, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, columns_comments: Optional[Dict[str, str]] = None) -> None: @@ -88,6 +88,8 @@ def metadata_to_glue(self, :return: None """ indexes_position = "left" if file_format == "csv" else "right" + schema: List[Tuple[str, str]] + partition_cols_schema: List[Tuple[str, str]] schema, partition_cols_schema = Glue._build_schema(dataframe=dataframe, partition_cols=partition_cols, preserve_index=preserve_index, @@ -138,14 +140,14 @@ def does_table_exists(self, database, table): return False def create_table(self, - database, - table, - schema, - path, - file_format, - compression, - partition_cols_schema=None, - extra_args=None, + database: str, + table: str, + schema: List[Tuple[str, str]], + path: str, + file_format: str, + compression: Optional[str], + partition_cols_schema: List[Tuple[str, str]], + extra_args: Optional[Dict[str, Union[str, int, List[str], None]]] = None, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, columns_comments: Optional[Dict[str, str]] = None) -> None: @@ -166,13 +168,17 @@ def create_table(self, :return: None """ if file_format == "parquet": - table_input = Glue.parquet_table_definition(table, partition_cols_schema, schema, path, compression) + table_input: Dict[str, Any] = Glue.parquet_table_definition(table=table, + partition_cols_schema=partition_cols_schema, + schema=schema, + path=path, + compression=compression) elif file_format == "csv": - table_input = Glue.csv_table_definition(table, - partition_cols_schema, - schema, - path, - compression, + table_input = Glue.csv_table_definition(table=table, + partition_cols_schema=partition_cols_schema, + schema=schema, + path=path, + compression=compression, extra_args=extra_args) else: raise UnsupportedFileFormat(file_format) @@ -223,19 +229,23 @@ def get_connection_details(self, name): return self._client_glue.get_connection(Name=name, HidePassword=False)["Connection"] @staticmethod - def _build_schema(dataframe, partition_cols, preserve_index, indexes_position, cast_columns=None): + def _build_schema( + dataframe, + partition_cols: Optional[List[str]], + preserve_index: bool, + indexes_position: str, + cast_columns: Optional[Dict[str, str]] = None) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]: if cast_columns is None: cast_columns = {} logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}") - if not partition_cols: + if partition_cols is None: partition_cols = [] - pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe, - preserve_index=preserve_index, - indexes_position=indexes_position) + pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas( + dataframe=dataframe, preserve_index=preserve_index, indexes_position=indexes_position) - schema_built = [] - partition_cols_types = {} + schema_built: List[Tuple[str, str]] = [] + partition_cols_types: Dict[str, str] = {} for name, dtype in pyarrow_schema: if (cast_columns is not None) and (name in cast_columns.keys()): if name in partition_cols: @@ -256,7 +266,7 @@ def _build_schema(dataframe, partition_cols, preserve_index, indexes_position, c else: schema_built.append((name, athena_type)) - partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols] + partition_cols_schema_built: List = [(name, partition_cols_types[name]) for name in partition_cols] logger.debug(f"schema_built:\n{schema_built}") logger.debug(f"partition_cols_schema_built:\n{partition_cols_schema_built}") @@ -269,12 +279,12 @@ def parse_table_name(path): return path.rpartition("/")[2] @staticmethod - def csv_table_definition(table, - partition_cols_schema, - schema, - path, - compression, - extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None): + def csv_table_definition(table: str, + partition_cols_schema: List[Tuple[str, str]], + schema: List[Tuple[str, str]], + path: str, + compression: Optional[str], + extra_args: Optional[Dict[str, Optional[Union[str, int, List[str]]]]] = None): if extra_args is None: extra_args = {"sep": ","} if partition_cols_schema is None: @@ -301,6 +311,9 @@ def csv_table_definition(table, refined_schema = [(name, dtype) if dtype in dtypes_allowed else (name, "string") for name, dtype in schema] else: raise InvalidSerDe(f"{serde} in not in the valid SerDe list.") + if "columns" in extra_args: + refined_schema = [(name, dtype) for name, dtype in refined_schema + if name in extra_args["columns"]] # type: ignore return { "Name": table, "PartitionKeys": [{ @@ -378,7 +391,8 @@ def csv_partition_definition(partition, compression, extra_args=None): } @staticmethod - def parquet_table_definition(table, partition_cols_schema, schema, path, compression): + def parquet_table_definition(table: str, partition_cols_schema: List[Tuple[str, str]], + schema: List[Tuple[str, str]], path: str, compression: Optional[str]): if not partition_cols_schema: partition_cols_schema = [] compressed = False if compression is None else True diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 1e36b1c00..776735268 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -696,6 +696,7 @@ def to_csv(self, path: str, sep: Optional[str] = None, na_rep: Optional[str] = None, + columns: Optional[List[str]] = None, quoting: Optional[int] = None, escapechar: Optional[str] = None, serde: Optional[str] = "OpenCSVSerDe", @@ -718,6 +719,7 @@ def to_csv(self, :param path: AWS S3 path (E.g. s3://bucket-name/folder_name/ :param sep: Same as pandas.to_csv() :param na_rep: Same as pandas.to_csv() + :param columns: Same as pandas.to_csv() :param quoting: Same as pandas.to_csv() :param escapechar: Same as pandas.to_csv() :param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe) (For Athena/Glue Catalog only) @@ -738,9 +740,10 @@ def to_csv(self, raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})") if (database is not None) and (serde is None): raise InvalidParameters(f"It is not possible write to a Glue Database without a SerDe.") - extra_args: Dict[str, Optional[Union[str, int]]] = { + extra_args: Dict[str, Optional[Union[str, int, List[str]]]] = { "sep": sep, "na_rep": na_rep, + "columns": columns, "serde": serde, "escapechar": escapechar, "quoting": quoting @@ -822,14 +825,14 @@ def to_s3(self, file_format: str, database: Optional[str] = None, table: Optional[str] = None, - partition_cols=None, - preserve_index=True, + partition_cols: Optional[List[str]] = None, + preserve_index: bool = True, mode: str = "append", - compression=None, - procs_cpu_bound=None, - procs_io_bound=None, - cast_columns=None, - extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None, + compression: Optional[str] = None, + procs_cpu_bound: Optional[int] = None, + procs_io_bound: Optional[int] = None, + cast_columns: Optional[Dict[str, str]] = None, + extra_args: Optional[Dict[str, Optional[Union[str, int, List[str]]]]] = None, inplace: bool = True, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, @@ -866,6 +869,8 @@ def to_s3(self, logger.debug(f"cast_columns: {cast_columns}") partition_cols = [Athena.normalize_column_name(x) for x in partition_cols] logger.debug(f"partition_cols: {partition_cols}") + if extra_args is not None and "columns" in extra_args: + extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]] dataframe = Pandas.drop_duplicated_columns(dataframe=dataframe, inplace=inplace) if compression is not None: compression = compression.lower() @@ -1112,6 +1117,9 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_ sep = extra_args.get("sep") if sep is not None: csv_extra_args["sep"] = sep + columns = extra_args.get("columns") + if columns is not None: + csv_extra_args["columns"] = columns serde = extra_args.get("serde") if serde is None: @@ -1519,7 +1527,10 @@ def _read_parquet_path(session_primitives: "SessionPrimitives", fs.invalidate_cache() table = pq.read_table(source=path, columns=columns, filters=filters, filesystem=fs, use_threads=use_threads) # Check if we lose some integer during the conversion (Happens when has some null value) - integers = [field.name for field in table.schema if str(field.type).startswith("int") and field.name != "__index_level_0__"] + integers = [ + field.name for field in table.schema + if str(field.type).startswith("int") and field.name != "__index_level_0__" + ] logger.debug(f"Converting to Pandas: {path}") df = table.to_pandas(use_threads=use_threads, integer_object_nulls=True) logger.debug(f"Casting Int64 columns: {path}") @@ -1612,6 +1623,7 @@ def to_aurora(self, temp_s3_path: Optional[str] = None, preserve_index: bool = False, mode: str = "append", + columns: Optional[List[str]] = None, procs_cpu_bound: Optional[int] = None, procs_io_bound: Optional[int] = None, inplace=True) -> None: @@ -1626,6 +1638,7 @@ def to_aurora(self, :param temp_s3_path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/) :param preserve_index: Should we preserve the Dataframe index? :param mode: append or overwrite + :param columns: List of columns to load :param procs_cpu_bound: Number of cores used for CPU bound tasks :param procs_io_bound: Number of cores used for I/O bound tasks :param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact @@ -1654,6 +1667,7 @@ def to_aurora(self, serde=None, sep=",", na_rep=na_rep, + columns=columns, quoting=csv.QUOTE_MINIMAL, escapechar="\"", preserve_index=preserve_index, diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index 3c917b3ab..fec3f14ab 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -2185,6 +2185,7 @@ def test_to_parquet_categorical_partitions(bucket): x['Year'] = x['Year'].astype('category') wr.pandas.to_parquet(x[x.Year == 1990], path=path, partition_cols=["Year"]) y = wr.pandas.read_parquet(path=path) + wr.s3.delete_objects(path=path) assert len(x[x.Year == 1990].index) == len(y.index) @@ -2197,5 +2198,32 @@ def test_range_index(bucket, database): print(x) wr.pandas.to_parquet(dataframe=x, path=path, database=database) df = wr.pandas.read_parquet(path=path) + wr.s3.delete_objects(path=path) assert len(x.columns) == len(df.columns) assert len(x.index) == len(df.index) + + +def test_to_csv_columns(bucket, database): + path = f"s3://{bucket}/test_to_csv_columns" + wr.s3.delete_objects(path=path) + df = pd.DataFrame({ + "A": [1, 2, 3], + "B": [4, 5, 6], + "C": ["foo", "boo", "bar"] + }) + wr.s3.delete_objects(path=path) + wr.pandas.to_csv( + dataframe=df, + database=database, + path=path, + columns=["A", "B"], + mode="overwrite", + preserve_index=False, + procs_cpu_bound=1, + inplace=False + ) + sleep(10) + df2 = wr.pandas.read_sql_athena(database=database, sql="SELECT * FROM test_to_csv_columns") + wr.s3.delete_objects(path=path) + assert len(df.columns) == len(df2.columns) + 1 + assert len(df.index) == len(df2.index) From 71849773a013e43ae28e1a8ca7434cf1784c9b8d Mon Sep 17 00:00:00 2001 From: igorborgest Date: Thu, 23 Jan 2020 15:52:26 -0300 Subject: [PATCH 2/3] Bumping dependencies versions --- requirements-dev.txt | 4 ++-- requirements.txt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 8a6725e13..9caa920ce 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,11 +3,11 @@ mypy~=0.761 flake8~=3.7.9 pytest-cov~=2.8.1 scikit-learn~=0.22.1 -cfn-lint~=0.26.3 +cfn-lint~=0.27.1 twine~=3.1.1 wheel~=0.33.6 sphinx~=2.3.1 pyspark~=2.4.4 pyspark-stubs~=2.4.0.post7 jupyter~=1.0.0 -jupyterlab~=1.2.4 \ No newline at end of file +jupyterlab~=1.2.5 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2bb8b8bb0..f22262cc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ numpy~=1.18.1 pandas~=0.25.3 pyarrow~=0.15.1 -botocore~=1.14.2 -boto3~=1.11.2 +botocore~=1.14.7 +boto3~=1.11.7 s3fs~=0.4.0 tenacity~=6.0.0 pg8000~=1.13.2 From 2dcec02b4bc54bb1a71a4e4cba4b9922cdc4c26e Mon Sep 17 00:00:00 2001 From: igorborgest Date: Thu, 23 Jan 2020 17:00:39 -0300 Subject: [PATCH 3/3] Add columns parameters for Pandas.to_aurora() and Aurora.load_table() --- awswrangler/aurora.py | 69 +++++++++++---- awswrangler/pandas.py | 3 +- testing/test_awswrangler/test_pandas.py | 112 ++++++++++++++++++++---- 3 files changed, 151 insertions(+), 33 deletions(-) diff --git a/awswrangler/aurora.py b/awswrangler/aurora.py index 68f0802f5..30fa2c3bf 100644 --- a/awswrangler/aurora.py +++ b/awswrangler/aurora.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any +from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any, Optional from logging import getLogger, Logger, INFO import json import warnings @@ -137,6 +137,7 @@ def load_table(dataframe: pd.DataFrame, table_name: str, connection: Any, num_files: int, + columns: Optional[List[str]] = None, mode: str = "append", preserve_index: bool = False, engine: str = "mysql", @@ -152,6 +153,7 @@ def load_table(dataframe: pd.DataFrame, :param table_name: Aurora table name :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection()) :param num_files: Number of files to be loaded + :param columns: List of columns to load :param mode: append or overwrite :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) :param engine: "mysql" or "postgres" @@ -167,7 +169,8 @@ def load_table(dataframe: pd.DataFrame, connection=connection, mode=mode, preserve_index=preserve_index, - region=region) + region=region, + columns=columns) elif "mysql" in engine.lower(): Aurora.load_table_mysql(dataframe=dataframe, dataframe_type=dataframe_type, @@ -177,7 +180,8 @@ def load_table(dataframe: pd.DataFrame, connection=connection, mode=mode, preserve_index=preserve_index, - num_files=num_files) + num_files=num_files, + columns=columns) else: raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") @@ -190,7 +194,8 @@ def load_table_postgres(dataframe: pd.DataFrame, connection: Any, mode: str = "append", preserve_index: bool = False, - region: str = "us-east-1"): + region: str = "us-east-1", + columns: Optional[List[str]] = None): """ Load text/CSV files into a Aurora table using a manifest file. Creates the table if necessary. @@ -204,6 +209,7 @@ def load_table_postgres(dataframe: pd.DataFrame, :param mode: append or overwrite :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) :param region: AWS S3 bucket region (Required only for postgres engine) + :param columns: List of columns to load :return: None """ with connection.cursor() as cursor: @@ -214,7 +220,8 @@ def load_table_postgres(dataframe: pd.DataFrame, schema_name=schema_name, table_name=table_name, preserve_index=preserve_index, - engine="postgres") + engine="postgres", + columns=columns) connection.commit() logger.debug("CREATE TABLE committed.") for path in load_paths: @@ -222,7 +229,8 @@ def load_table_postgres(dataframe: pd.DataFrame, schema_name=schema_name, table_name=table_name, engine="postgres", - region=region) + region=region, + columns=columns) Aurora._load_object_postgres_with_retry(connection=connection, sql=sql) logger.debug(f"Load committed for: {path}.") @@ -257,7 +265,8 @@ def load_table_mysql(dataframe: pd.DataFrame, connection: Any, num_files: int, mode: str = "append", - preserve_index: bool = False): + preserve_index: bool = False, + columns: Optional[List[str]] = None): """ Load text/CSV files into a Aurora table using a manifest file. Creates the table if necessary. @@ -271,6 +280,7 @@ def load_table_mysql(dataframe: pd.DataFrame, :param num_files: Number of files to be loaded :param mode: append or overwrite :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) + :param columns: List of columns to load :return: None """ with connection.cursor() as cursor: @@ -281,11 +291,13 @@ def load_table_mysql(dataframe: pd.DataFrame, schema_name=schema_name, table_name=table_name, preserve_index=preserve_index, - engine="mysql") + engine="mysql", + columns=columns) sql = Aurora._get_load_sql(path=manifest_path, schema_name=schema_name, table_name=table_name, - engine="mysql") + engine="mysql", + columns=columns) logger.debug(sql) cursor.execute(sql) logger.debug(f"Load done for: {manifest_path}") @@ -310,22 +322,40 @@ def _parse_path(path): return parts[0], parts[2] @staticmethod - def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, region: str = "us-east-1") -> str: + def _get_load_sql(path: str, + schema_name: str, + table_name: str, + engine: str, + region: str = "us-east-1", + columns: Optional[List[str]] = None) -> str: if "postgres" in engine.lower(): bucket, key = Aurora._parse_path(path=path) + if columns is None: + cols_str: str = "" + else: + cols_str = ",".join(columns) sql: str = ("-- AWS DATA WRANGLER\n" "SELECT aws_s3.table_import_from_s3(\n" f"'{schema_name}.{table_name}',\n" - "'',\n" + f"'{cols_str}',\n" "'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\"'')',\n" f"'({bucket},{key},{region})')") elif "mysql" in engine.lower(): + if columns is None: + cols_str = "" + else: + # building something like: (@col1,@col2) set col1=@col1,col2=@col2 + col_str = [f"@{x}" for x in columns] + set_str = [f"{x}=@{x}" for x in columns] + cols_str = f"({','.join(col_str)}) SET {','.join(set_str)}" + logger.debug(f"cols_str: {cols_str}") sql = ("-- AWS DATA WRANGLER\n" f"LOAD DATA FROM S3 MANIFEST '{path}'\n" "REPLACE\n" f"INTO TABLE {schema_name}.{table_name}\n" "FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\"'\n" - "LINES TERMINATED BY '\\n'") + "LINES TERMINATED BY '\\n'" + f"{cols_str}") else: raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!") return sql @@ -337,7 +367,8 @@ def _create_table(cursor, schema_name, table_name, preserve_index=False, - engine: str = "mysql"): + engine: str = "mysql", + columns: Optional[List[str]] = None): """ Creates Aurora table. @@ -348,6 +379,7 @@ def _create_table(cursor, :param table_name: Redshift table name :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) :param engine: "mysql" or "postgres" + :param columns: List of columns to load :return: None """ sql: str = f"-- AWS DATA WRANGLER\n" \ @@ -364,7 +396,8 @@ def _create_table(cursor, schema = Aurora._get_schema(dataframe=dataframe, dataframe_type=dataframe_type, preserve_index=preserve_index, - engine=engine) + engine=engine, + columns=columns) cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2] sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})" logger.debug(f"Create table query:\n{sql}") @@ -374,7 +407,8 @@ def _create_table(cursor, def _get_schema(dataframe, dataframe_type: str, preserve_index: bool, - engine: str = "mysql") -> List[Tuple[str, str]]: + engine: str = "mysql", + columns: Optional[List[str]] = None) -> List[Tuple[str, str]]: schema_built: List[Tuple[str, str]] = [] if "postgres" in engine.lower(): convert_func = data_types.pyarrow2postgres @@ -386,8 +420,9 @@ def _get_schema(dataframe, pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas( dataframe=dataframe, preserve_index=preserve_index, indexes_position="right") for name, dtype in pyarrow_schema: - aurora_type: str = convert_func(dtype) - schema_built.append((name, aurora_type)) + if columns is None or name in columns: + aurora_type: str = convert_func(dtype) + schema_built.append((name, aurora_type)) else: raise InvalidDataframeType(f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!") return schema_built diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 776735268..6a56e3538 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -870,7 +870,7 @@ def to_s3(self, partition_cols = [Athena.normalize_column_name(x) for x in partition_cols] logger.debug(f"partition_cols: {partition_cols}") if extra_args is not None and "columns" in extra_args: - extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]] + extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]] # type: ignore dataframe = Pandas.drop_duplicated_columns(dataframe=dataframe, inplace=inplace) if compression is not None: compression = compression.lower() @@ -1691,6 +1691,7 @@ def to_aurora(self, load_paths=load_paths, schema_name=schema, table_name=table, + columns=columns, connection=connection, num_files=len(paths), mode=mode, diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index fec3f14ab..ade8ed9cd 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -2206,24 +2206,106 @@ def test_range_index(bucket, database): def test_to_csv_columns(bucket, database): path = f"s3://{bucket}/test_to_csv_columns" wr.s3.delete_objects(path=path) - df = pd.DataFrame({ - "A": [1, 2, 3], - "B": [4, 5, 6], - "C": ["foo", "boo", "bar"] - }) + df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": ["foo", "boo", "bar"]}) wr.s3.delete_objects(path=path) - wr.pandas.to_csv( - dataframe=df, - database=database, - path=path, - columns=["A", "B"], - mode="overwrite", - preserve_index=False, - procs_cpu_bound=1, - inplace=False - ) + wr.pandas.to_csv(dataframe=df, + database=database, + path=path, + columns=["A", "B"], + mode="overwrite", + preserve_index=False, + procs_cpu_bound=1, + inplace=False) sleep(10) df2 = wr.pandas.read_sql_athena(database=database, sql="SELECT * FROM test_to_csv_columns") wr.s3.delete_objects(path=path) assert len(df.columns) == len(df2.columns) + 1 assert len(df.index) == len(df2.index) + + +def test_aurora_postgres_load_columns(bucket, postgres_parameters): + df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"], "value2": [4, 5, 6]}) + conn = Aurora.generate_connection(database="postgres", + host=postgres_parameters["PostgresAddress"], + port=3306, + user="test", + password=postgres_parameters["Password"], + engine="postgres") + path = f"s3://{bucket}/test_aurora_postgres_load_columns" + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="public", + table="test_aurora_postgres_load_columns", + mode="overwrite", + temp_s3_path=path, + engine="postgres", + columns=["id", "value"]) + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="public", + table="test_aurora_postgres_load_columns", + mode="append", + temp_s3_path=path, + engine="postgres", + columns=["value"]) + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM public.test_aurora_postgres_load_columns") + rows = cursor.fetchall() + assert len(rows) == len(df.index) * 2 + assert rows[0][0] == 1 + assert rows[1][0] == 2 + assert rows[2][0] == 3 + assert rows[3][0] is None + assert rows[4][0] is None + assert rows[5][0] is None + assert rows[0][1] == "foo" + assert rows[1][1] == "boo" + assert rows[2][1] == "bar" + assert rows[3][1] == "foo" + assert rows[4][1] == "boo" + assert rows[5][1] == "bar" + conn.close() + + +def test_aurora_mysql_load_columns(bucket, mysql_parameters): + df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"], "value2": [4, 5, 6]}) + conn = Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="mysql") + path = f"s3://{bucket}/test_aurora_mysql_load_columns" + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="test", + table="test_aurora_mysql_load_columns", + mode="overwrite", + temp_s3_path=path, + engine="mysql", + columns=["id", "value"]) + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="test", + table="test_aurora_mysql_load_columns", + mode="append", + temp_s3_path=path, + engine=" mysql", + columns=["value"]) + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM test.test_aurora_mysql_load_columns") + rows = cursor.fetchall() + assert len(rows) == len(df.index) * 2 + assert rows[0][0] == 1 + assert rows[1][0] == 2 + assert rows[2][0] == 3 + assert rows[3][0] is None + assert rows[4][0] is None + assert rows[5][0] is None + assert rows[0][1] == "foo" + assert rows[1][1] == "boo" + assert rows[2][1] == "bar" + assert rows[3][1] == "foo" + assert rows[4][1] == "boo" + assert rows[5][1] == "bar" + conn.close()