Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions awswrangler/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pyarrow.lib # type: ignore
import pyarrow.parquet # type: ignore
import s3fs # type: ignore
from pandas.io.common import infer_compression # type: ignore

from awswrangler import _data_types, _utils, catalog, exceptions

Expand Down Expand Up @@ -1450,7 +1451,9 @@ def _read_text_chunksize(
fs: s3fs.S3FileSystem = _utils.get_fs(session=boto3_session, s3_additional_kwargs=s3_additional_kwargs)
for path in paths:
_logger.debug(f"path: {path}")
with fs.open(path, "r") as f:
if pandas_args.get("compression", "infer") == "infer":
pandas_args["compression"] = infer_compression(path, compression="infer")
with fs.open(path, "rb") as f:
reader: pandas.io.parsers.TextFileReader = parser_func(f, chunksize=chunksize, **pandas_args)
for df in reader:
yield df
Expand All @@ -1464,7 +1467,9 @@ def _read_text_full(
s3_additional_kwargs: Optional[Dict[str, str]] = None,
) -> pd.DataFrame:
fs: s3fs.S3FileSystem = _utils.get_fs(session=boto3_session, s3_additional_kwargs=s3_additional_kwargs)
with fs.open(path, "r") as f:
if pandas_args.get("compression", "infer") == "infer":
pandas_args["compression"] = infer_compression(path, compression="infer")
with fs.open(path, "rb") as f:
return parser_func(f, **pandas_args)


Expand Down
70 changes: 69 additions & 1 deletion testing/test_awswrangler/test_data_lake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import bz2
import datetime
import gzip
import logging
import lzma
from io import BytesIO, TextIOWrapper

import boto3
import pandas as pd
Expand Down Expand Up @@ -846,7 +850,7 @@ def test_athena_types(bucket, database):


def test_parquet_catalog_columns(bucket, database):
path = f"s3://{bucket}/test_parquet_catalog_columns /"
path = f"s3://{bucket}/test_parquet_catalog_columns/"
paths = wr.s3.to_parquet(
df=get_df_csv()[["id", "date", "timestamp", "par0", "par1"]],
path=path,
Expand Down Expand Up @@ -889,3 +893,67 @@ def test_parquet_catalog_columns(bucket, database):

wr.s3.delete_objects(path=path)
assert wr.catalog.delete_table_if_exists(database=database, table="test_parquet_catalog_columns") is True


@pytest.mark.parametrize("compression", [None, "gzip", "snappy"])
def test_parquet_compress(bucket, database, compression):
path = f"s3://{bucket}/test_parquet_compress_{compression}/"
paths = wr.s3.to_parquet(
df=get_df(),
path=path,
compression=compression,
dataset=True,
database=database,
table=f"test_parquet_compress_{compression}",
mode="overwrite",
)["paths"]
wr.s3.wait_objects_exist(paths=paths)
df2 = wr.athena.read_sql_table(f"test_parquet_compress_{compression}", database)
ensure_data_types(df2)
df2 = wr.s3.read_parquet(path=path)
wr.s3.delete_objects(path=path)
assert wr.catalog.delete_table_if_exists(database=database, table=f"test_parquet_compress_{compression}") is True
ensure_data_types(df2)


@pytest.mark.parametrize("compression", ["gzip", "bz2", "xz"])
def test_csv_compress(bucket, compression):
path = f"s3://{bucket}/test_csv_compress_{compression}/"
wr.s3.delete_objects(path=path)
df = get_df_csv()
if compression == "gzip":
buffer = BytesIO()
with gzip.GzipFile(mode="w", fileobj=buffer) as zipped_file:
df.to_csv(TextIOWrapper(zipped_file, "utf8"), index=False, header=None)
s3_resource = boto3.resource("s3")
s3_object = s3_resource.Object(bucket, f"test_csv_compress_{compression}/test.csv.gz")
s3_object.put(Body=buffer.getvalue())
file_path = f"s3://{bucket}/test_csv_compress_{compression}/test.csv.gz"
elif compression == "bz2":
buffer = BytesIO()
with bz2.BZ2File(mode="w", filename=buffer) as zipped_file:
df.to_csv(TextIOWrapper(zipped_file, "utf8"), index=False, header=None)
s3_resource = boto3.resource("s3")
s3_object = s3_resource.Object(bucket, f"test_csv_compress_{compression}/test.csv.bz2")
s3_object.put(Body=buffer.getvalue())
file_path = f"s3://{bucket}/test_csv_compress_{compression}/test.csv.bz2"
elif compression == "xz":
buffer = BytesIO()
with lzma.LZMAFile(mode="w", filename=buffer) as zipped_file:
df.to_csv(TextIOWrapper(zipped_file, "utf8"), index=False, header=None)
s3_resource = boto3.resource("s3")
s3_object = s3_resource.Object(bucket, f"test_csv_compress_{compression}/test.csv.xz")
s3_object.put(Body=buffer.getvalue())
file_path = f"s3://{bucket}/test_csv_compress_{compression}/test.csv.xz"
else:
file_path = f"s3://{bucket}/test_csv_compress_{compression}/test.csv"
wr.s3.to_csv(df=df, path=file_path, index=False, header=None)

wr.s3.wait_objects_exist(paths=[file_path])
df2 = wr.s3.read_csv(path=[file_path], names=df.columns)
assert len(df2.index) == 3
assert len(df2.columns) == 10
dfs = wr.s3.read_csv(path=[file_path], names=df.columns, chunksize=1)
for df3 in dfs:
assert len(df3.columns) == 10
wr.s3.delete_objects(path=path)