diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 6a56e3538..c3f4f8b0f 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -74,6 +74,9 @@ def read_csv( escapechar=None, parse_dates: Union[bool, Dict, List] = False, infer_datetime_format=False, + na_values: Optional[Union[str, List[str]]] = None, + keep_default_na: bool = True, + na_filter: bool = True, encoding="utf-8", converters=None, ): @@ -98,6 +101,9 @@ def read_csv( :param escapechar: Same as pandas.read_csv() :param parse_dates: Same as pandas.read_csv() :param infer_datetime_format: Same as pandas.read_csv() + :param na_values: Same as pandas.read_csv() + :param keep_default_na: Same as pandas.read_csv() + :param na_filter: Same as pandas.read_csv() :param encoding: Same as pandas.read_csv() :param converters: Same as pandas.read_csv() :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None @@ -120,6 +126,9 @@ def read_csv( escapechar=escapechar, parse_dates=parse_dates, infer_datetime_format=infer_datetime_format, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, encoding=encoding, converters=converters) else: @@ -139,6 +148,9 @@ def read_csv( escapechar=escapechar, parse_dates=parse_dates, infer_datetime_format=infer_datetime_format, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, encoding=encoding, converters=converters) return ret @@ -161,6 +173,9 @@ def _read_csv_iterator( escapechar=None, parse_dates: Union[bool, Dict, List] = False, infer_datetime_format=False, + na_values: Optional[Union[str, List[str]]] = None, + keep_default_na: bool = True, + na_filter: bool = True, encoding="utf-8", converters=None, ): @@ -185,6 +200,9 @@ def _read_csv_iterator( :param escapechar: Same as pandas.read_csv() :param parse_dates: Same as pandas.read_csv() :param infer_datetime_format: Same as pandas.read_csv() + :param na_values: Same as pandas.read_csv() + :param keep_default_na: Same as pandas.read_csv() + :param na_filter: Same as pandas.read_csv() :param encoding: Same as pandas.read_csv() :param converters: Same as pandas.read_csv() :return: Pandas Dataframe @@ -211,6 +229,9 @@ def _read_csv_iterator( escapechar=escapechar, parse_dates=parse_dates, infer_datetime_format=infer_datetime_format, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, encoding=encoding, converters=converters) else: @@ -251,6 +272,9 @@ def _read_csv_iterator( header=header, names=names, usecols=usecols, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, sep=sep, thousands=thousands, decimal=decimal, @@ -371,6 +395,9 @@ def _read_csv_once( escapechar=None, parse_dates: Union[bool, Dict, List] = False, infer_datetime_format=False, + na_values: Optional[Union[str, List[str]]] = None, + keep_default_na: bool = True, + na_filter: bool = True, encoding=None, converters=None, ): @@ -395,6 +422,9 @@ def _read_csv_once( :param escapechar: Same as pandas.read_csv() :param parse_dates: Same as pandas.read_csv() :param infer_datetime_format: Same as pandas.read_csv() + :param na_values: Same as pandas.read_csv() + :param keep_default_na: Same as pandas.read_csv() + :param na_filter: Same as pandas.read_csv() :param encoding: Same as pandas.read_csv() :param converters: Same as pandas.read_csv() :return: Pandas Dataframe @@ -409,6 +439,9 @@ def _read_csv_once( header=header, names=names, usecols=usecols, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, sep=sep, thousands=thousands, decimal=decimal, @@ -443,6 +476,9 @@ def _read_csv_once_remote(send_pipe: mp.connection.Connection, escapechar=None, parse_dates: Union[bool, Dict, List] = False, infer_datetime_format=False, + na_values: Optional[Union[str, List[str]]] = None, + keep_default_na: bool = True, + na_filter: bool = True, encoding=None, converters=None): df: pd.DataFrame = Pandas._read_csv_once(session_primitives=session_primitives, @@ -461,6 +497,9 @@ def _read_csv_once_remote(send_pipe: mp.connection.Connection, escapechar=escapechar, parse_dates=parse_dates, infer_datetime_format=infer_datetime_format, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, encoding=encoding, converters=converters) send_pipe.send(df) @@ -869,7 +908,7 @@ def to_s3(self, logger.debug(f"cast_columns: {cast_columns}") partition_cols = [Athena.normalize_column_name(x) for x in partition_cols] logger.debug(f"partition_cols: {partition_cols}") - if extra_args is not None and "columns" in extra_args: + if extra_args is not None and "columns" in extra_args and extra_args["columns"] is not None: extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]] # type: ignore dataframe = Pandas.drop_duplicated_columns(dataframe=dataframe, inplace=inplace) if compression is not None: @@ -1754,7 +1793,12 @@ def read_sql_aurora(self, paths: List[str] = self._session.aurora.extract_manifest_paths(path=manifest_path) logger.debug(f"paths: {paths}") ret: Union[pd.DataFrame, Iterator[pd.DataFrame]] - ret = self.read_csv_list(paths=paths, max_result_size=max_result_size, header=None, names=col_names) + ret = self.read_csv_list(paths=paths, + max_result_size=max_result_size, + header=None, + names=col_names, + na_values=["\\N"], + keep_default_na=False) self._session.s3.delete_listed_objects(objects_paths=paths + [manifest_path]) except Exception as ex: connection.rollback() @@ -1782,6 +1826,9 @@ def read_csv_list( escapechar=None, parse_dates: Union[bool, Dict, List] = False, infer_datetime_format=False, + na_values: Optional[Union[str, List[str]]] = None, + keep_default_na: bool = True, + na_filter: bool = True, encoding="utf-8", converters=None, procs_cpu_bound: Optional[int] = None, @@ -1807,6 +1854,9 @@ def read_csv_list( :param escapechar: Same as pandas.read_csv() :param parse_dates: Same as pandas.read_csv() :param infer_datetime_format: Same as pandas.read_csv() + :param na_values: Same as pandas.read_csv() + :param keep_default_na: Same as pandas.read_csv() + :param na_filter: Same as pandas.read_csv() :param encoding: Same as pandas.read_csv() :param converters: Same as pandas.read_csv() :param procs_cpu_bound: Number of cores used for CPU bound tasks @@ -1828,35 +1878,40 @@ def read_csv_list( escapechar=escapechar, parse_dates=parse_dates, infer_datetime_format=infer_datetime_format, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, encoding=encoding, converters=converters) else: procs_cpu_bound = procs_cpu_bound if procs_cpu_bound is not None else self._session.procs_cpu_bound if self._session.procs_cpu_bound is not None else 1 logger.debug(f"procs_cpu_bound: {procs_cpu_bound}") - df: Optional[pd.DataFrame] = None session_primitives = self._session.primitives if len(paths) == 1: path = paths[0] bucket_name, key_path = Pandas._parse_path(path) logger.debug(f"path: {path}") - df = self._read_csv_once(session_primitives=self._session.primitives, - bucket_name=bucket_name, - key_path=key_path, - header=header, - names=names, - usecols=usecols, - dtype=dtype, - sep=sep, - thousands=thousands, - decimal=decimal, - lineterminator=lineterminator, - quotechar=quotechar, - quoting=quoting, - escapechar=escapechar, - parse_dates=parse_dates, - infer_datetime_format=infer_datetime_format, - encoding=encoding, - converters=converters) + df: pd.DataFrame = self._read_csv_once(session_primitives=self._session.primitives, + bucket_name=bucket_name, + key_path=key_path, + header=header, + names=names, + usecols=usecols, + dtype=dtype, + sep=sep, + thousands=thousands, + decimal=decimal, + lineterminator=lineterminator, + quotechar=quotechar, + quoting=quoting, + escapechar=escapechar, + parse_dates=parse_dates, + infer_datetime_format=infer_datetime_format, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, + encoding=encoding, + converters=converters) else: procs = [] receive_pipes = [] @@ -1869,7 +1924,7 @@ def read_csv_list( target=self._read_csv_once_remote, args=(send_pipe, session_primitives, bucket_name, key_path, header, names, usecols, dtype, sep, thousands, decimal, lineterminator, quotechar, quoting, escapechar, parse_dates, - infer_datetime_format, encoding, converters), + infer_datetime_format, na_values, keep_default_na, na_filter, encoding, converters), ) proc.daemon = False proc.start() @@ -1906,6 +1961,9 @@ def _read_csv_list_iterator( escapechar=None, parse_dates: Union[bool, Dict, List] = False, infer_datetime_format=False, + na_values: Optional[Union[str, List[str]]] = None, + keep_default_na: bool = True, + na_filter: bool = True, encoding="utf-8", converters=None, ): @@ -1930,6 +1988,9 @@ def _read_csv_list_iterator( :param escapechar: Same as pandas.read_csv() :param parse_dates: Same as pandas.read_csv() :param infer_datetime_format: Same as pandas.read_csv() + :param na_values: Same as pandas.read_csv() + :param keep_default_na: Same as pandas.read_csv() + :param na_filter: Same as pandas.read_csv() :param encoding: Same as pandas.read_csv() :param converters: Same as pandas.read_csv() :return: Iterator of iterators of Pandas Dataframes @@ -1953,6 +2014,9 @@ def _read_csv_list_iterator( escapechar=escapechar, parse_dates=parse_dates, infer_datetime_format=infer_datetime_format, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, encoding=encoding, converters=converters) @@ -1973,6 +2037,9 @@ def read_csv_prefix( escapechar=None, parse_dates: Union[bool, Dict, List] = False, infer_datetime_format=False, + na_values: Optional[Union[str, List[str]]] = None, + keep_default_na: bool = True, + na_filter: bool = True, encoding="utf-8", converters=None, procs_cpu_bound: Optional[int] = None, @@ -1998,6 +2065,9 @@ def read_csv_prefix( :param escapechar: Same as pandas.read_csv() :param parse_dates: Same as pandas.read_csv() :param infer_datetime_format: Same as pandas.read_csv() + :param na_values: Same as pandas.read_csv() + :param keep_default_na: Same as pandas.read_csv() + :param na_filter: Same as pandas.read_csv() :param encoding: Same as pandas.read_csv() :param converters: Same as pandas.read_csv() :param procs_cpu_bound: Number of cores used for CPU bound tasks @@ -2020,5 +2090,9 @@ def read_csv_prefix( escapechar=escapechar, parse_dates=parse_dates, infer_datetime_format=infer_datetime_format, + na_values=na_values, + keep_default_na=keep_default_na, + na_filter=na_filter, encoding=encoding, - converters=converters) + converters=converters, + procs_cpu_bound=procs_cpu_bound) diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index ade8ed9cd..c4680ee43 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -2309,3 +2309,35 @@ def test_aurora_mysql_load_columns(bucket, mysql_parameters): assert rows[4][1] == "boo" assert rows[5][1] == "bar" conn.close() + + +def test_aurora_mysql_unload_null(bucket, mysql_parameters): + df = pd.DataFrame({ + "id": [1, 2, 3, 4, 5], + "c_str": ["foo", "", None, "bar", None], + "c_float": [1.1, None, 3.3, None, 5.5], + "c_int": [1, 2, None, 3, 4], + }) + df["c_int"] = df["c_int"].astype("Int64") + print(df) + conn = Aurora.generate_connection(database="mysql", + host=mysql_parameters["MysqlAddress"], + port=3306, + user="test", + password=mysql_parameters["Password"], + engine="mysql") + path = f"s3://{bucket}/test_aurora_mysql_unload_complex" + wr.pandas.to_aurora(dataframe=df, + connection=conn, + schema="test", + table="test_aurora_mysql_unload_complex", + mode="overwrite", + temp_s3_path=path) + path2 = f"s3://{bucket}/test_aurora_mysql_unload_complex2" + df2 = wr.pandas.read_sql_aurora(sql="SELECT * FROM test.test_aurora_mysql_unload_complex", + connection=conn, + col_names=["id", "c_str", "c_float", "c_int"], + temp_s3_path=path2) + df2["c_int"] = df2["c_int"].astype("Int64") + assert df.equals(df2) + conn.close()