Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions awswrangler/s3/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,20 @@ def _sanitize(
dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()}
_utils.check_duplicated_columns(df=df)
return df, dtype, partition_cols


def _check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Dict[str, Any]], mode: str) -> None:
if (table_input is not None) and (mode in ("append", "overwrite_partitions")):
catalog_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
for c, t in columns_types.items():
if c not in catalog_cols:
raise exceptions.InvalidArgumentValue(
f"Schema change detected: New column {c} with type {t}. "
"Please pass schema_evolution=True to allow new columns "
"behaviour."
)
if t != catalog_cols[c]: # Data type change detected!
raise exceptions.InvalidArgumentValue(
f"Schema change detected: Data type change on column {c} "
f"(Old type: {catalog_cols[c]} / New type {t})."
)
19 changes: 1 addition & 18 deletions awswrangler/s3/_write_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,13 @@
from awswrangler.s3._delete import delete_objects
from awswrangler.s3._fs import open_s3_object
from awswrangler.s3._read_parquet import _read_parquet_metadata
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _check_schema_changes, _sanitize, _validate_args
from awswrangler.s3._write_concurrent import _WriteProxy
from awswrangler.s3._write_dataset import _to_dataset

_logger: logging.Logger = logging.getLogger(__name__)


def _check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Dict[str, Any]], mode: str) -> None:
if (table_input is not None) and (mode in ("append", "overwrite_partitions")):
catalog_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
for c, t in columns_types.items():
if c not in catalog_cols:
raise exceptions.InvalidArgumentValue(
f"Schema change detected: New column {c} with type {t}. "
"Please pass schema_evolution=True to allow new columns "
"behaviour."
)
if t != catalog_cols[c]: # Data type change detected!
raise exceptions.InvalidArgumentValue(
f"Schema change detected: Data type change on column {c} "
f"(Old type: {catalog_cols[c]} / New type {t})."
)


def _get_file_path(file_counter: int, file_path: str) -> str:
slash_index: int = file_path.rfind("/")
dot_index: int = file_path.find(".", slash_index)
Expand Down
21 changes: 17 additions & 4 deletions awswrangler/s3/_write_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from awswrangler._config import apply_configs
from awswrangler.s3._delete import delete_objects
from awswrangler.s3._fs import open_s3_object
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _check_schema_changes, _sanitize, _validate_args
from awswrangler.s3._write_dataset import _to_dataset

_logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,6 +87,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
concurrent_partitioning: bool = False,
mode: Optional[str] = None,
catalog_versioning: bool = False,
schema_evolution: bool = False,
database: Optional[str] = None,
table: Optional[str] = None,
dtype: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -182,6 +183,11 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
https://aws-data-wrangler.readthedocs.io/en/2.9.0/stubs/awswrangler.s3.to_parquet.html#awswrangler.s3.to_parquet
catalog_versioning : bool
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
schema_evolution : bool
If True allows schema evolution (new or missing columns), otherwise a exception will be raised.
(Only considered if dataset=True and mode in ("append", "overwrite_partitions"))
Related tutorial:
https://aws-data-wrangler.readthedocs.io/en/2.9.0/tutorials/014%20-%20Schema%20Evolution.html
database : str, optional
Glue/Athena catalog: Database name.
table : str, optional
Expand Down Expand Up @@ -474,6 +480,16 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
pd_kwargs.pop("compression", None)

df = df[columns] if columns else df

columns_types: Dict[str, str] = {}
partitions_types: Dict[str, str] = {}
if (database is not None) and (table is not None):
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True
)
if schema_evolution is False:
_check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)

Comment on lines +483 to +492
Copy link
Contributor

@jaidisido jaidisido Jul 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was it necessary to move this block outside the existing if condition for database and table on line 499? Mostly asking because it's likely to create a conflict with my Governed table branch which I would prefer to avoid :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hahah, yes, I wanted to do it before _to_dataset call which in some cases will send a request to delete the files; and right after df = df[columns] if columns else df which is right when we form a "final" dataframe that we should check for schema evolution.

paths, partitions_values = _to_dataset(
func=_to_text,
concurrent_partitioning=concurrent_partitioning,
Expand All @@ -498,9 +514,6 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
)
if database and table:
try:
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True
)
serde_info: Dict[str, Any] = {}
if catalog_table_input:
serde_info = catalog_table_input["StorageDescriptor"]["SerdeInfo"]
Expand Down
11 changes: 11 additions & 0 deletions tests/test_s3_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,14 @@ def test_read_csv_versioned(path) -> None:
df_temp = wr.s3.read_csv(path_file, version_id=version_id)
assert df_temp.equals(df)
assert version_id == wr.s3.describe_objects(path=path_file, version_id=version_id)[path_file]["VersionId"]


def test_to_csv_schema_evolution(path, glue_database, glue_table) -> None:
path_file = f"{path}0.csv"
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]})
wr.s3.to_csv(df=df, path=path_file, dataset=True, database=glue_database, table=glue_table)
df["test"] = 1
with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.s3.to_csv(
df=df, path=path_file, dataset=True, database=glue_database, table=glue_table, schema_evolution=True
)