diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index a6fa2622f..953f68052 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -682,7 +682,7 @@ def to_csv(self, """ if serde not in Pandas.VALID_CSV_SERDES: raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})") - extra_args = {"sep": sep, "serde": serde, "escapechar": escapechar} + extra_args: Dict[str, Optional[str]] = {"sep": sep, "serde": serde, "escapechar": escapechar} return self.to_s3(dataframe=dataframe, path=path, file_format="csv", @@ -767,7 +767,7 @@ def to_s3(self, procs_cpu_bound=None, procs_io_bound=None, cast_columns=None, - extra_args=None, + extra_args: Optional[Dict[str, Optional[str]]] = None, inplace: bool = True, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, @@ -922,7 +922,7 @@ def _data_to_s3_dataset_writer(dataframe: pd.DataFrame, session_primitives: "SessionPrimitives", file_format: str, cast_columns=None, - extra_args=None, + extra_args: Optional[Dict[str, Optional[str]]] = None, isolated_dataframe: bool = False): objects_paths = [] dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns) @@ -980,7 +980,7 @@ def _data_to_s3_dataset_writer_remote(send_pipe, session_primitives: "SessionPrimitives", file_format, cast_columns=None, - extra_args=None): + extra_args: Optional[Dict[str, Optional[str]]] = None): send_pipe.send( Pandas._data_to_s3_dataset_writer(dataframe=dataframe, path=path, @@ -996,35 +996,35 @@ def _data_to_s3_dataset_writer_remote(send_pipe, @staticmethod def _data_to_s3_object_writer(dataframe: pd.DataFrame, - path: "str", + path: str, preserve_index: bool, - compression, + compression: str, session_primitives: "SessionPrimitives", - file_format, - cast_columns=None, - extra_args=None, - isolated_dataframe=False): + file_format: str, + cast_columns: Optional[List[str]] = None, + extra_args: Optional[Dict[str, Optional[str]]] = None, + isolated_dataframe=False) -> str: fs = get_fs(session_primitives=session_primitives) fs = pa.filesystem._ensure_filesystem(fs) mkdir_if_not_exists(fs, path) if compression is None: - compression_end = "" + compression_extension: str = "" elif compression == "snappy": - compression_end = ".snappy" + compression_extension = ".snappy" elif compression == "gzip": - compression_end = ".gz" + compression_extension = ".gz" else: raise InvalidCompression(compression) - guid = pa.compat.guid() + guid: str = pa.compat.guid() if file_format == "parquet": - outfile = f"{guid}.parquet{compression_end}" + outfile: str = f"{guid}{compression_extension}.parquet" elif file_format == "csv": - outfile = f"{guid}.csv{compression_end}" + outfile = f"{guid}{compression_extension}.csv" else: raise UnsupportedFileFormat(file_format) - object_path = "/".join([path, outfile]) + object_path: str = "/".join([path, outfile]) if file_format == "parquet": Pandas.write_parquet_dataframe(dataframe=dataframe, path=object_path, diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index 21ef18d83..c5d318836 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -1927,7 +1927,12 @@ def test_aurora_postgres_load_special(bucket, postgres_parameters): "value": ["foo", "boo", "bar", "abc"], "slashes": ["\\", "\"", "\\\\\\\\", "\"\"\"\""], "floats": [1.0, 2.0, 3.0, 4.0], - "decimals": [Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 0), -2)), Decimal((0, (3, 1, 2), -2))] + "decimals": [ + Decimal((0, (1, 9, 9), -2)), + Decimal((0, (1, 9, 9), -2)), + Decimal((0, (1, 9, 0), -2)), + Decimal((0, (3, 1, 2), -2)) + ] }) path = f"s3://{bucket}/test_aurora_postgres_special" @@ -1977,8 +1982,12 @@ def test_aurora_mysql_load_special(bucket, mysql_parameters): "value": ["foo", "boo", "bar", "abc"], "slashes": ["\\", "\"", "\\\\\\\\", "\"\"\"\""], "floats": [1.0, 2.0, 3.0, 4.0], - "decimals": [Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 0), -2)), - Decimal((0, (3, 1, 2), -2))] + "decimals": [ + Decimal((0, (1, 9, 9), -2)), + Decimal((0, (1, 9, 9), -2)), + Decimal((0, (1, 9, 0), -2)), + Decimal((0, (3, 1, 2), -2)) + ] }) path = f"s3://{bucket}/test_aurora_mysql_special"