Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion awswrangler/s3/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _validate_args(
table: Optional[str],
database: Optional[str],
dataset: bool,
path: str,
path: Optional[str],
partition_cols: Optional[List[str]],
bucketing_info: Optional[Tuple[List[str], int]],
mode: Optional[str],
Expand All @@ -58,6 +58,8 @@ def _validate_args(
if df.empty is True:
raise exceptions.EmptyDataFrame()
if dataset is False:
if path is None:
raise exceptions.InvalidArgumentValue("If dataset is False, the `path` argument must be passed.")
if path.endswith("/"):
raise exceptions.InvalidArgumentValue(
"If <dataset=False>, the argument <path> should be a file path, not a directory."
Expand All @@ -79,6 +81,10 @@ def _validate_args(
"Arguments database and table must be passed together. If you want to store your dataset metadata in "
"the Glue Catalog, please ensure you are passing both."
)
elif all(x is None for x in [path, database, table]):
raise exceptions.InvalidArgumentCombination(
"You must specify a `path` if dataset is True and database/table are not enabled."
)
elif bucketing_info and bucketing_info[1] <= 0:
raise exceptions.InvalidArgumentValue(
"Please pass a value greater than 1 for the number of buckets for bucketing."
Expand Down
22 changes: 18 additions & 4 deletions awswrangler/s3/_write_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _to_parquet(
@apply_configs
def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
df: pd.DataFrame,
path: str,
path: Optional[str] = None,
index: bool = False,
compression: Optional[str] = "snappy",
max_rows_by_file: Optional[int] = None,
Expand Down Expand Up @@ -252,8 +252,9 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
----------
df: pandas.DataFrame
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
path : str
path : str, optional
S3 path (for file e.g. ``s3://bucket/prefix/filename.parquet``) (for dataset e.g. ``s3://bucket/prefix``).
Required if dataset=False or when dataset=True and creating a new dataset
index : bool
True to store the DataFrame index in file, otherwise False to ignore it.
compression: str, optional
Expand Down Expand Up @@ -511,6 +512,19 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access
database=database, table=table, boto3_session=session, catalog_id=catalog_id
)
catalog_path = catalog_table_input["StorageDescriptor"]["Location"] if catalog_table_input else None
if path is None:
if catalog_path:
path = catalog_path
else:
raise exceptions.InvalidArgumentValue(
"Glue table does not exist in the catalog. Please pass the `path` argument to create it."
)
elif path and catalog_path:
if path.rstrip("/") != catalog_path.rstrip("/"):
raise exceptions.InvalidArgumentValue(
f"The specified path: {path}, does not match the existing Glue catalog table path: {catalog_path}"
)
df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode)
schema: pa.Schema = _data_types.pyarrow_schema_from_pandas(
df=df, index=index, ignore_cols=partition_cols, dtype=dtype
Expand Down Expand Up @@ -545,7 +559,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
func=_to_parquet,
concurrent_partitioning=concurrent_partitioning,
df=df,
path_root=path,
path_root=path, # type: ignore
index=index,
compression=compression,
compression_ext=compression_ext,
Expand All @@ -565,7 +579,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
catalog._create_parquet_table( # pylint: disable=protected-access
database=database,
table=table,
path=path,
path=path, # type: ignore
columns_types=columns_types,
partitions_types=partitions_types,
bucketing_info=bucketing_info,
Expand Down
29 changes: 22 additions & 7 deletions awswrangler/s3/_write_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def _to_text(


@apply_configs
def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements
def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches
df: pd.DataFrame,
path: str,
path: Optional[str] = None,
sep: str = ",",
index: bool = True,
columns: Optional[List[str]] = None,
Expand Down Expand Up @@ -137,8 +137,9 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
----------
df: pandas.DataFrame
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
path : str
Amazon S3 path (e.g. s3://bucket/filename.csv).
path : str, optional
Amazon S3 path (e.g. s3://bucket/prefix/filename.csv) (for dataset e.g. ``s3://bucket/prefix``).
Required if dataset=False or when creating a new dataset
sep : str
String of length 1. Field delimiter for the output file.
index : bool
Expand Down Expand Up @@ -414,13 +415,27 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access
database=database, table=table, boto3_session=session, catalog_id=catalog_id
)
catalog_path = catalog_table_input["StorageDescriptor"]["Location"] if catalog_table_input else None
if path is None:
if catalog_path:
path = catalog_path
else:
raise exceptions.InvalidArgumentValue(
"Glue table does not exist in the catalog. Please pass the `path` argument to create it."
)
elif path and catalog_path:
if path.rstrip("/") != catalog_path.rstrip("/"):
raise exceptions.InvalidArgumentValue(
f"The specified path: {path}, does not match the existing Glue catalog table path: {catalog_path}"
)
if pandas_kwargs.get("compression") not in ("gzip", "bz2", None):
raise exceptions.InvalidArgumentCombination(
"If database and table are given, you must use one of these compressions: gzip, bz2 or None."
)

df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode)

paths: List[str] = []
if dataset is False:
pandas_kwargs["sep"] = sep
pandas_kwargs["index"] = index
Expand All @@ -434,7 +449,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
s3_additional_kwargs=s3_additional_kwargs,
**pandas_kwargs,
)
paths = [path]
paths = [path] # type: ignore
else:
if database and table:
quoting: Optional[int] = csv.QUOTE_NONE
Expand All @@ -461,7 +476,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
func=_to_text,
concurrent_partitioning=concurrent_partitioning,
df=df,
path_root=path,
path_root=path, # type: ignore
index=index,
sep=sep,
compression=compression,
Expand All @@ -486,7 +501,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
catalog._create_csv_table( # pylint: disable=protected-access
database=database,
table=table,
path=path,
path=path, # type: ignore
columns_types=columns_types,
partitions_types=partitions_types,
bucketing_info=bucketing_info,
Expand Down
4 changes: 0 additions & 4 deletions tests/test__routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
df = pd.DataFrame({"c1": [None, 1, None]}, dtype="Int16")
wr.s3.to_parquet(
df=df,
path=path,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since path is now optional, it was removed from some tests to check that they still pass even without referencing it

dataset=True,
mode="overwrite",
database=glue_database,
Expand Down Expand Up @@ -101,7 +100,6 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
df = pd.DataFrame({"c2": ["a", None, "b"], "c1": [None, None, None]})
wr.s3.to_parquet(
df=df,
path=path,
dataset=True,
mode="append",
database=glue_database,
Expand Down Expand Up @@ -162,7 +160,6 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
df = pd.DataFrame({"c0": ["foo", None], "c1": [0, 1]})
wr.s3.to_parquet(
df=df,
path=path,
dataset=True,
mode="overwrite",
database=glue_database,
Expand Down Expand Up @@ -223,7 +220,6 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
df = pd.DataFrame({"c0": [1, 2], "c1": ["1", "3"], "c2": [True, False]})
wr.s3.to_parquet(
df=df,
path=path,
dataset=True,
mode="overwrite_partitions",
database=glue_database,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_athena_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def test_to_csv_modes(glue_database, glue_table, path, use_threads, concurrent_p
df = pd.DataFrame({"c1": [0, 1, 2]}, dtype="Int16")
wr.s3.to_csv(
df=df,
path=path,
dataset=True,
mode="overwrite",
database=glue_database,
Expand Down Expand Up @@ -106,7 +105,6 @@ def test_to_csv_modes(glue_database, glue_table, path, use_threads, concurrent_p
df = pd.DataFrame({"c0": ["foo", "boo"], "c1": [0, 1]})
wr.s3.to_csv(
df=df,
path=path,
dataset=True,
mode="overwrite",
database=glue_database,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ def mock_make_api_call(self, operation_name, kwarg):
wr.s3.delete_objects(path=[path])


def test_missing_or_wrong_path(path, glue_database, glue_table):
# Missing path
df = pd.DataFrame({"FooBoo": [1, 2, 3]})
with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.s3.to_parquet(df=df)
with pytest.raises(wr.exceptions.InvalidArgumentCombination):
wr.s3.to_parquet(df=df, dataset=True)
with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.s3.to_parquet(df=df, dataset=True, database=glue_database, table=glue_table)

# Wrong path
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)
wrong_path = "s3://bucket/prefix"
with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.s3.to_parquet(df=df, path=wrong_path, dataset=True, database=glue_database, table=glue_table)
Comment on lines +106 to +120
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In certain cases, Wrangler must throw an error if the path is missing or wrong



def test_s3_empty_dfs():
df = pd.DataFrame()
with pytest.raises(wr.exceptions.EmptyDataFrame):
Expand Down