Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 97 additions & 23 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def read_csv(
escapechar=None,
parse_dates: Union[bool, Dict, List] = False,
infer_datetime_format=False,
na_values: Optional[Union[str, List[str]]] = None,
keep_default_na: bool = True,
na_filter: bool = True,
encoding="utf-8",
converters=None,
):
Expand All @@ -98,6 +101,9 @@ def read_csv(
:param escapechar: Same as pandas.read_csv()
:param parse_dates: Same as pandas.read_csv()
:param infer_datetime_format: Same as pandas.read_csv()
:param na_values: Same as pandas.read_csv()
:param keep_default_na: Same as pandas.read_csv()
:param na_filter: Same as pandas.read_csv()
:param encoding: Same as pandas.read_csv()
:param converters: Same as pandas.read_csv()
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
Expand All @@ -120,6 +126,9 @@ def read_csv(
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
encoding=encoding,
converters=converters)
else:
Expand All @@ -139,6 +148,9 @@ def read_csv(
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
encoding=encoding,
converters=converters)
return ret
Expand All @@ -161,6 +173,9 @@ def _read_csv_iterator(
escapechar=None,
parse_dates: Union[bool, Dict, List] = False,
infer_datetime_format=False,
na_values: Optional[Union[str, List[str]]] = None,
keep_default_na: bool = True,
na_filter: bool = True,
encoding="utf-8",
converters=None,
):
Expand All @@ -185,6 +200,9 @@ def _read_csv_iterator(
:param escapechar: Same as pandas.read_csv()
:param parse_dates: Same as pandas.read_csv()
:param infer_datetime_format: Same as pandas.read_csv()
:param na_values: Same as pandas.read_csv()
:param keep_default_na: Same as pandas.read_csv()
:param na_filter: Same as pandas.read_csv()
:param encoding: Same as pandas.read_csv()
:param converters: Same as pandas.read_csv()
:return: Pandas Dataframe
Expand All @@ -211,6 +229,9 @@ def _read_csv_iterator(
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
encoding=encoding,
converters=converters)
else:
Expand Down Expand Up @@ -251,6 +272,9 @@ def _read_csv_iterator(
header=header,
names=names,
usecols=usecols,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
sep=sep,
thousands=thousands,
decimal=decimal,
Expand Down Expand Up @@ -371,6 +395,9 @@ def _read_csv_once(
escapechar=None,
parse_dates: Union[bool, Dict, List] = False,
infer_datetime_format=False,
na_values: Optional[Union[str, List[str]]] = None,
keep_default_na: bool = True,
na_filter: bool = True,
encoding=None,
converters=None,
):
Expand All @@ -395,6 +422,9 @@ def _read_csv_once(
:param escapechar: Same as pandas.read_csv()
:param parse_dates: Same as pandas.read_csv()
:param infer_datetime_format: Same as pandas.read_csv()
:param na_values: Same as pandas.read_csv()
:param keep_default_na: Same as pandas.read_csv()
:param na_filter: Same as pandas.read_csv()
:param encoding: Same as pandas.read_csv()
:param converters: Same as pandas.read_csv()
:return: Pandas Dataframe
Expand All @@ -409,6 +439,9 @@ def _read_csv_once(
header=header,
names=names,
usecols=usecols,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
sep=sep,
thousands=thousands,
decimal=decimal,
Expand Down Expand Up @@ -443,6 +476,9 @@ def _read_csv_once_remote(send_pipe: mp.connection.Connection,
escapechar=None,
parse_dates: Union[bool, Dict, List] = False,
infer_datetime_format=False,
na_values: Optional[Union[str, List[str]]] = None,
keep_default_na: bool = True,
na_filter: bool = True,
encoding=None,
converters=None):
df: pd.DataFrame = Pandas._read_csv_once(session_primitives=session_primitives,
Expand All @@ -461,6 +497,9 @@ def _read_csv_once_remote(send_pipe: mp.connection.Connection,
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
encoding=encoding,
converters=converters)
send_pipe.send(df)
Expand Down Expand Up @@ -869,7 +908,7 @@ def to_s3(self,
logger.debug(f"cast_columns: {cast_columns}")
partition_cols = [Athena.normalize_column_name(x) for x in partition_cols]
logger.debug(f"partition_cols: {partition_cols}")
if extra_args is not None and "columns" in extra_args:
if extra_args is not None and "columns" in extra_args and extra_args["columns"] is not None:
extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]] # type: ignore
dataframe = Pandas.drop_duplicated_columns(dataframe=dataframe, inplace=inplace)
if compression is not None:
Expand Down Expand Up @@ -1754,7 +1793,12 @@ def read_sql_aurora(self,
paths: List[str] = self._session.aurora.extract_manifest_paths(path=manifest_path)
logger.debug(f"paths: {paths}")
ret: Union[pd.DataFrame, Iterator[pd.DataFrame]]
ret = self.read_csv_list(paths=paths, max_result_size=max_result_size, header=None, names=col_names)
ret = self.read_csv_list(paths=paths,
max_result_size=max_result_size,
header=None,
names=col_names,
na_values=["\\N"],
keep_default_na=False)
self._session.s3.delete_listed_objects(objects_paths=paths + [manifest_path])
except Exception as ex:
connection.rollback()
Expand Down Expand Up @@ -1782,6 +1826,9 @@ def read_csv_list(
escapechar=None,
parse_dates: Union[bool, Dict, List] = False,
infer_datetime_format=False,
na_values: Optional[Union[str, List[str]]] = None,
keep_default_na: bool = True,
na_filter: bool = True,
encoding="utf-8",
converters=None,
procs_cpu_bound: Optional[int] = None,
Expand All @@ -1807,6 +1854,9 @@ def read_csv_list(
:param escapechar: Same as pandas.read_csv()
:param parse_dates: Same as pandas.read_csv()
:param infer_datetime_format: Same as pandas.read_csv()
:param na_values: Same as pandas.read_csv()
:param keep_default_na: Same as pandas.read_csv()
:param na_filter: Same as pandas.read_csv()
:param encoding: Same as pandas.read_csv()
:param converters: Same as pandas.read_csv()
:param procs_cpu_bound: Number of cores used for CPU bound tasks
Expand All @@ -1828,35 +1878,40 @@ def read_csv_list(
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
encoding=encoding,
converters=converters)
else:
procs_cpu_bound = procs_cpu_bound if procs_cpu_bound is not None else self._session.procs_cpu_bound if self._session.procs_cpu_bound is not None else 1
logger.debug(f"procs_cpu_bound: {procs_cpu_bound}")
df: Optional[pd.DataFrame] = None
session_primitives = self._session.primitives
if len(paths) == 1:
path = paths[0]
bucket_name, key_path = Pandas._parse_path(path)
logger.debug(f"path: {path}")
df = self._read_csv_once(session_primitives=self._session.primitives,
bucket_name=bucket_name,
key_path=key_path,
header=header,
names=names,
usecols=usecols,
dtype=dtype,
sep=sep,
thousands=thousands,
decimal=decimal,
lineterminator=lineterminator,
quotechar=quotechar,
quoting=quoting,
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
encoding=encoding,
converters=converters)
df: pd.DataFrame = self._read_csv_once(session_primitives=self._session.primitives,
bucket_name=bucket_name,
key_path=key_path,
header=header,
names=names,
usecols=usecols,
dtype=dtype,
sep=sep,
thousands=thousands,
decimal=decimal,
lineterminator=lineterminator,
quotechar=quotechar,
quoting=quoting,
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
encoding=encoding,
converters=converters)
else:
procs = []
receive_pipes = []
Expand All @@ -1869,7 +1924,7 @@ def read_csv_list(
target=self._read_csv_once_remote,
args=(send_pipe, session_primitives, bucket_name, key_path, header, names, usecols, dtype, sep,
thousands, decimal, lineterminator, quotechar, quoting, escapechar, parse_dates,
infer_datetime_format, encoding, converters),
infer_datetime_format, na_values, keep_default_na, na_filter, encoding, converters),
)
proc.daemon = False
proc.start()
Expand Down Expand Up @@ -1906,6 +1961,9 @@ def _read_csv_list_iterator(
escapechar=None,
parse_dates: Union[bool, Dict, List] = False,
infer_datetime_format=False,
na_values: Optional[Union[str, List[str]]] = None,
keep_default_na: bool = True,
na_filter: bool = True,
encoding="utf-8",
converters=None,
):
Expand All @@ -1930,6 +1988,9 @@ def _read_csv_list_iterator(
:param escapechar: Same as pandas.read_csv()
:param parse_dates: Same as pandas.read_csv()
:param infer_datetime_format: Same as pandas.read_csv()
:param na_values: Same as pandas.read_csv()
:param keep_default_na: Same as pandas.read_csv()
:param na_filter: Same as pandas.read_csv()
:param encoding: Same as pandas.read_csv()
:param converters: Same as pandas.read_csv()
:return: Iterator of iterators of Pandas Dataframes
Expand All @@ -1953,6 +2014,9 @@ def _read_csv_list_iterator(
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
encoding=encoding,
converters=converters)

Expand All @@ -1973,6 +2037,9 @@ def read_csv_prefix(
escapechar=None,
parse_dates: Union[bool, Dict, List] = False,
infer_datetime_format=False,
na_values: Optional[Union[str, List[str]]] = None,
keep_default_na: bool = True,
na_filter: bool = True,
encoding="utf-8",
converters=None,
procs_cpu_bound: Optional[int] = None,
Expand All @@ -1998,6 +2065,9 @@ def read_csv_prefix(
:param escapechar: Same as pandas.read_csv()
:param parse_dates: Same as pandas.read_csv()
:param infer_datetime_format: Same as pandas.read_csv()
:param na_values: Same as pandas.read_csv()
:param keep_default_na: Same as pandas.read_csv()
:param na_filter: Same as pandas.read_csv()
:param encoding: Same as pandas.read_csv()
:param converters: Same as pandas.read_csv()
:param procs_cpu_bound: Number of cores used for CPU bound tasks
Expand All @@ -2020,5 +2090,9 @@ def read_csv_prefix(
escapechar=escapechar,
parse_dates=parse_dates,
infer_datetime_format=infer_datetime_format,
na_values=na_values,
keep_default_na=keep_default_na,
na_filter=na_filter,
encoding=encoding,
converters=converters)
converters=converters,
procs_cpu_bound=procs_cpu_bound)
32 changes: 32 additions & 0 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2309,3 +2309,35 @@ def test_aurora_mysql_load_columns(bucket, mysql_parameters):
assert rows[4][1] == "boo"
assert rows[5][1] == "bar"
conn.close()


def test_aurora_mysql_unload_null(bucket, mysql_parameters):
df = pd.DataFrame({
"id": [1, 2, 3, 4, 5],
"c_str": ["foo", "", None, "bar", None],
"c_float": [1.1, None, 3.3, None, 5.5],
"c_int": [1, 2, None, 3, 4],
})
df["c_int"] = df["c_int"].astype("Int64")
print(df)
conn = Aurora.generate_connection(database="mysql",
host=mysql_parameters["MysqlAddress"],
port=3306,
user="test",
password=mysql_parameters["Password"],
engine="mysql")
path = f"s3://{bucket}/test_aurora_mysql_unload_complex"
wr.pandas.to_aurora(dataframe=df,
connection=conn,
schema="test",
table="test_aurora_mysql_unload_complex",
mode="overwrite",
temp_s3_path=path)
path2 = f"s3://{bucket}/test_aurora_mysql_unload_complex2"
df2 = wr.pandas.read_sql_aurora(sql="SELECT * FROM test.test_aurora_mysql_unload_complex",
connection=conn,
col_names=["id", "c_str", "c_float", "c_int"],
temp_s3_path=path2)
df2["c_int"] = df2["c_int"].astype("Int64")
assert df.equals(df2)
conn.close()