Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion awswrangler/_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _extract_partitions_from_path(path_root: str, path: str) -> Dict[str, str]:
raise Exception(f"Object {path} is not under the root path ({path_root}).")
path_wo_filename: str = path.rpartition("/")[0] + "/"
path_wo_prefix: str = path_wo_filename.replace(f"{path_root}/", "")
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if (x != "") and (x.count("=") > 0))
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if x and (x.count("=") > 0))
if not dirs:
return {}
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=", maxsplit=1)[:2]) for x in dirs))
Expand Down
22 changes: 10 additions & 12 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,7 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-retur
raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}")


def athena2pandas(
dtype: str, dtype_backend: Optional[str] = None
) -> str: # pylint: disable=too-many-branches,too-many-return-statements
def athena2pandas(dtype: str, dtype_backend: Optional[str] = None) -> str: # pylint: disable=too-many-return-statements
"""Athena to Pandas data types conversion."""
dtype = dtype.lower()
if dtype == "tinyint":
Expand Down Expand Up @@ -493,24 +491,24 @@ def pyarrow_types_from_pandas( # pylint: disable=too-many-branches,too-many-sta
cols: List[str] = []
cols_dtypes: Dict[str, Optional[pa.DataType]] = {}
for name, dtype in df.dtypes.to_dict().items():
dtype = str(dtype)
dtype_str = str(dtype)
if name in ignore_cols:
cols_dtypes[name] = None
elif dtype == "Int8":
elif dtype_str == "Int8":
cols_dtypes[name] = pa.int8()
elif dtype == "Int16":
elif dtype_str == "Int16":
cols_dtypes[name] = pa.int16()
elif dtype == "Int32":
elif dtype_str == "Int32":
cols_dtypes[name] = pa.int32()
elif dtype == "Int64":
elif dtype_str == "Int64":
cols_dtypes[name] = pa.int64()
elif dtype == "float32":
elif dtype_str == "float32":
cols_dtypes[name] = pa.float32()
elif dtype == "float64":
elif dtype_str == "float64":
cols_dtypes[name] = pa.float64()
elif dtype == "string":
elif dtype_str == "string":
cols_dtypes[name] = pa.string()
elif dtype == "boolean":
elif dtype_str == "boolean":
cols_dtypes[name] = pa.bool_()
else:
cols.append(name)
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _records2df(
for col_values, col_name in zip(tuple(zip(*records)), cols_names): # Transposing
if (dtype is None) or (col_name not in dtype):
if _oracledb_found:
col_values = oracle.handle_oracle_objects(col_values, col_name)
col_values = oracle.handle_oracle_objects(col_values, col_name) # ruff: noqa: PLW2901
try:
array: pa.Array = pa.array(obj=col_values, safe=safe) # Creating Arrow array
except pa.ArrowInvalid as ex:
Expand Down
4 changes: 2 additions & 2 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _extract_ctas_manifest_paths(path: str, boto3_session: Optional[boto3.Sessio
bucket_name, key_path = _utils.parse_path(path)
client_s3 = _utils.client(service_name="s3", session=boto3_session)
body: bytes = client_s3.get_object(Bucket=bucket_name, Key=key_path)["Body"].read()
paths = [x for x in body.decode("utf-8").split("\n") if x != ""]
paths = [x for x in body.decode("utf-8").split("\n") if x]
_logger.debug("Read %d paths from manifest file in: %s", len(paths), path)
return paths

Expand All @@ -58,7 +58,7 @@ def _add_query_metadata_generator(
) -> Iterator[pd.DataFrame]:
"""Add Query Execution metadata to every DF in iterator."""
for df in dfs:
df = _apply_query_metadata(df=df, query_metadata=query_metadata)
df = _apply_query_metadata(df=df, query_metadata=query_metadata) # ruff: noqa: PLW2901
yield df


Expand Down
13 changes: 6 additions & 7 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,10 @@ def _start_query_execution(
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": wg_config.encryption}
if wg_config.kms_key is not None:
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = wg_config.kms_key
else:
if encryption is not None:
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": encryption}
if kms_key is not None:
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key
elif encryption is not None:
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": encryption}
if kms_key is not None:
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key

# database
if database is not None:
Expand Down Expand Up @@ -187,8 +186,8 @@ def _parse_describe_table(df: pd.DataFrame) -> pd.DataFrame:
origin_df_dict = df.to_dict()
target_df_dict: Dict[str, List[Union[str, bool]]] = {"Column Name": [], "Type": [], "Partition": [], "Comment": []}
for index, col_name in origin_df_dict["col_name"].items():
col_name = col_name.strip()
if col_name.startswith("#") or col_name == "":
col_name = col_name.strip() # ruff: noqa: PLW2901
if col_name.startswith("#") or not col_name:
pass
elif col_name in target_df_dict["Column Name"]:
index_col_name = target_df_dict["Column Name"].index(col_name)
Expand Down
4 changes: 3 additions & 1 deletion awswrangler/data_api/_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame:
pass

@staticmethod
def _get_column_value(column_value: Dict[str, Any], col_type: Optional[str] = None) -> Any:
def _get_column_value( # pylint: disable=too-many-return-statements
column_value: Dict[str, Any], col_type: Optional[str] = None
) -> Any:
"""Return the first non-null key value for a given dictionary.

The key names for a given record depend on the column type: stringValue, longValue, etc.
Expand Down
6 changes: 4 additions & 2 deletions awswrangler/data_api/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def _create_table(
con.execute(sql, database=database, transaction_id=transaction_id)


def _create_value_dict(value: Any) -> Tuple[Dict[str, Any], Optional[str]]:
def _create_value_dict( # pylint: disable=too-many-return-statements
value: Any,
) -> Tuple[Dict[str, Any], Optional[str]]:
if value is None or pd.isnull(value):
return {"isNull": True}, None

Expand Down Expand Up @@ -351,7 +353,7 @@ def _generate_parameters(columns: List[str], values: List[Any]) -> List[Dict[str
parameter_list = []

for col, value in zip(columns, values):
value, type_hint = _create_value_dict(value)
value, type_hint = _create_value_dict(value) # ruff: noqa: PLW2901

parameter = {
"name": col,
Expand Down
6 changes: 3 additions & 3 deletions awswrangler/data_api/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ def rollback_transaction(self, transaction_id: str) -> str:
raise NotImplementedError("Redshift Data API does not support transactions.")

def _validate_redshift_target(self) -> None:
if self.database == "":
if not self.database:
raise ValueError("`database` must be set for connection")
if self.cluster_id == "" and self.workgroup_name == "":
if not self.cluster_id and not self.workgroup_name:
raise ValueError("Either `cluster_id` or `workgroup_name`(Redshift Serverless) must be set for connection")

def _validate_auth_method(self) -> None:
if self.workgroup_name == "" and self.secret_arn == "" and self.db_user == "":
if not self.workgroup_name and not self.secret_arn and not self.db_user:
raise ValueError("Either `secret_arn` or `db_user` must be set for authentication")

def _execute_statement(
Expand Down
10 changes: 4 additions & 6 deletions awswrangler/dynamodb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,17 @@ def _serialize_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:

if "ExpressionAttributeNames" in kwargs:
kwargs["ExpressionAttributeNames"].update(names)
else:
if names:
kwargs["ExpressionAttributeNames"] = names
elif names:
kwargs["ExpressionAttributeNames"] = names

values = {k: serializer.serialize(v) for k, v in values.items()}
if "ExpressionAttributeValues" in kwargs:
kwargs["ExpressionAttributeValues"] = {
k: serializer.serialize(v) for k, v in kwargs["ExpressionAttributeValues"].items()
}
kwargs["ExpressionAttributeValues"].update(values)
else:
if values:
kwargs["ExpressionAttributeValues"] = values
elif values:
kwargs["ExpressionAttributeValues"] = values

return kwargs

Expand Down
2 changes: 1 addition & 1 deletion awswrangler/neptune/_gremlin_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _parse_dict(data: Any) -> Any:
for k, v in data.items():
# If the key is a Vertex or an Edge do special processing
if isinstance(k, (gremlin.Vertex, gremlin.Edge)):
k = k.id
k = k.id # ruff: noqa: PLW2901

# If the value is a list do special processing to make it a scalar if the list is of length 1
if isinstance(v, list) and len(v) == 1:
Expand Down
15 changes: 7 additions & 8 deletions awswrangler/neptune/_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,16 +511,15 @@ def _set_properties(
if column not in ["~id", "~label", "~to", "~from"]:
if ignore_cardinality and pd.notna(value):
g = g.property(_get_column_name(column), value)
else:
elif use_header_cardinality:
# If the column header is specifying the cardinality then use it
if use_header_cardinality:
if column.lower().find("(single)") > 0 and pd.notna(value):
g = g.property(gremlin.Cardinality.single, _get_column_name(column), value)
else:
g = _expand_properties(g, _get_column_name(column), value)
if column.lower().find("(single)") > 0 and pd.notna(value):
g = g.property(gremlin.Cardinality.single, _get_column_name(column), value)
else:
# If not using header cardinality then use the default of set
g = _expand_properties(g, column, value)
g = _expand_properties(g, _get_column_name(column), value)
else:
# If not using header cardinality then use the default of set
g = _expand_properties(g, column, value)
return g


Expand Down
7 changes: 3 additions & 4 deletions awswrangler/opensearch/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,10 @@ def create_index(
if mappings:
if _get_distribution(client) == "opensearch" or _get_version_major(client) >= 7:
body["mappings"] = mappings # doc type deprecated
elif doc_type:
body["mappings"] = {doc_type: mappings}
else:
if doc_type:
body["mappings"] = {doc_type: mappings}
else:
body["mappings"] = {index: mappings}
body["mappings"] = {index: mappings}
if settings:
body["settings"] = settings
if not body:
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def to_sql(
df=df, column_placeholders=column_placeholders, chunksize=chunksize
)
for _, parameters in placeholder_parameter_pair_generator:
parameters = list(zip(*[iter(parameters)] * len(df.columns)))
parameters = list(zip(*[iter(parameters)] * len(df.columns))) # ruff: noqa: PLW2901
_logger.debug("sql: %s", sql)
cursor.executemany(sql, parameters)

Expand Down
2 changes: 1 addition & 1 deletion awswrangler/s3/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _extract_partitions_metadata_from_paths(
path_wo_filename: str = p.rpartition("/")[0] + "/"
if path_wo_filename not in partitions_values:
path_wo_prefix: str = path_wo_filename.replace(f"{path}", "")
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if (x != "") and (x.count("=") > 0))
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if x and (x.count("=") > 0))
if dirs:
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=", maxsplit=1)[:2]) for x in dirs))
values_dics: Dict[str, str] = dict(values_tups)
Expand Down
4 changes: 2 additions & 2 deletions awswrangler/s3/_read_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def _ensure_locations_are_valid(paths: Iterable[str]) -> Iterator[str]:
for path in paths:
suffix: str = path.rpartition("/")[2]
# If the suffix looks like a partition,
if (suffix != "") and (suffix.count("=") == 1):
if suffix and (suffix.count("=") == 1):
# the path should end in a '/' character.
path = f"{path}/"
path = f"{path}/" # ruff: noqa: PLW2901
yield path


Expand Down
4 changes: 2 additions & 2 deletions awswrangler/s3/_write_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def _to_partitions(
s3_client = client(service_name="s3", session=boto3_session)
for keys, subgroup in df.groupby(by=partition_cols, observed=True):
# Keys are either a primitive type or a tuple if partitioning by multiple cols
keys = (keys,) if not isinstance(keys, tuple) else keys
subgroup = subgroup.drop(partition_cols, axis="columns")
keys = (keys,) if not isinstance(keys, tuple) else keys # ruff: noqa: PLW2901
subgroup = subgroup.drop(partition_cols, axis="columns") # ruff: noqa: PLW2901
prefix = _delete_objects(
keys=keys,
path_root=path_root,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ extend_exclude = '''

[tool.ruff]
select = ["D", "E", "F", "I001", "I002", "PL", "W"]
ignore = ["E501", "PLR2004", "PLR0913", "PLR0915"]
ignore = ["E501", "PLR2004", "PLR0911", "PLR0912", "PLR0913", "PLR0915"]
fixable = ["I001"]
exclude = [
".eggs",
Expand Down
2 changes: 1 addition & 1 deletion validate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ black --check .
ruff . --ignore "PL" --ignore "D"
ruff awswrangler
mypy --install-types --non-interactive awswrangler
pylint -j 0 --disable=all --enable=R0913,R0915 awswrangler
pylint -j 0 --disable=all --enable=R0911,R0912,R0913,R0915 awswrangler
doc8 --ignore-path docs/source/stubs --max-line-length 120 docs/source