diff --git a/awswrangler/s3/_write.py b/awswrangler/s3/_write.py index 9af8d5956..9867c3312 100644 --- a/awswrangler/s3/_write.py +++ b/awswrangler/s3/_write.py @@ -99,3 +99,20 @@ def _sanitize( dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()} _utils.check_duplicated_columns(df=df) return df, dtype, partition_cols + + +def _check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Dict[str, Any]], mode: str) -> None: + if (table_input is not None) and (mode in ("append", "overwrite_partitions")): + catalog_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]} + for c, t in columns_types.items(): + if c not in catalog_cols: + raise exceptions.InvalidArgumentValue( + f"Schema change detected: New column {c} with type {t}. " + "Please pass schema_evolution=True to allow new columns " + "behaviour." + ) + if t != catalog_cols[c]: # Data type change detected! + raise exceptions.InvalidArgumentValue( + f"Schema change detected: Data type change on column {c} " + f"(Old type: {catalog_cols[c]} / New type {t})." + ) diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index 8996ec650..766291f25 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -17,30 +17,13 @@ from awswrangler.s3._delete import delete_objects from awswrangler.s3._fs import open_s3_object from awswrangler.s3._read_parquet import _read_parquet_metadata -from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args +from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _check_schema_changes, _sanitize, _validate_args from awswrangler.s3._write_concurrent import _WriteProxy from awswrangler.s3._write_dataset import _to_dataset _logger: logging.Logger = logging.getLogger(__name__) -def _check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Dict[str, Any]], mode: str) -> None: - if (table_input is not None) and (mode in ("append", "overwrite_partitions")): - catalog_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]} - for c, t in columns_types.items(): - if c not in catalog_cols: - raise exceptions.InvalidArgumentValue( - f"Schema change detected: New column {c} with type {t}. " - "Please pass schema_evolution=True to allow new columns " - "behaviour." - ) - if t != catalog_cols[c]: # Data type change detected! - raise exceptions.InvalidArgumentValue( - f"Schema change detected: Data type change on column {c} " - f"(Old type: {catalog_cols[c]} / New type {t})." - ) - - def _get_file_path(file_counter: int, file_path: str) -> str: slash_index: int = file_path.rfind("/") dot_index: int = file_path.find(".", slash_index) diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index 354e76cdf..9aa8950a3 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -14,7 +14,7 @@ from awswrangler._config import apply_configs from awswrangler.s3._delete import delete_objects from awswrangler.s3._fs import open_s3_object -from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args +from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _check_schema_changes, _sanitize, _validate_args from awswrangler.s3._write_dataset import _to_dataset _logger: logging.Logger = logging.getLogger(__name__) @@ -87,6 +87,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state concurrent_partitioning: bool = False, mode: Optional[str] = None, catalog_versioning: bool = False, + schema_evolution: bool = False, database: Optional[str] = None, table: Optional[str] = None, dtype: Optional[Dict[str, str]] = None, @@ -182,6 +183,11 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state https://aws-data-wrangler.readthedocs.io/en/2.9.0/stubs/awswrangler.s3.to_parquet.html#awswrangler.s3.to_parquet catalog_versioning : bool If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it. + schema_evolution : bool + If True allows schema evolution (new or missing columns), otherwise a exception will be raised. + (Only considered if dataset=True and mode in ("append", "overwrite_partitions")) + Related tutorial: + https://aws-data-wrangler.readthedocs.io/en/2.9.0/tutorials/014%20-%20Schema%20Evolution.html database : str, optional Glue/Athena catalog: Database name. table : str, optional @@ -474,6 +480,16 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state pd_kwargs.pop("compression", None) df = df[columns] if columns else df + + columns_types: Dict[str, str] = {} + partitions_types: Dict[str, str] = {} + if (database is not None) and (table is not None): + columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( + df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True + ) + if schema_evolution is False: + _check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode) + paths, partitions_values = _to_dataset( func=_to_text, concurrent_partitioning=concurrent_partitioning, @@ -498,9 +514,6 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state ) if database and table: try: - columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( - df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True - ) serde_info: Dict[str, Any] = {} if catalog_table_input: serde_info = catalog_table_input["StorageDescriptor"]["SerdeInfo"] diff --git a/tests/test_s3_text.py b/tests/test_s3_text.py index e730caa93..18dcc7ad2 100644 --- a/tests/test_s3_text.py +++ b/tests/test_s3_text.py @@ -331,3 +331,14 @@ def test_read_csv_versioned(path) -> None: df_temp = wr.s3.read_csv(path_file, version_id=version_id) assert df_temp.equals(df) assert version_id == wr.s3.describe_objects(path=path_file, version_id=version_id)[path_file]["VersionId"] + + +def test_to_csv_schema_evolution(path, glue_database, glue_table) -> None: + path_file = f"{path}0.csv" + df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]}) + wr.s3.to_csv(df=df, path=path_file, dataset=True, database=glue_database, table=glue_table) + df["test"] = 1 + with pytest.raises(wr.exceptions.InvalidArgumentValue): + wr.s3.to_csv( + df=df, path=path_file, dataset=True, database=glue_database, table=glue_table, schema_evolution=True + )