Skip to content

Commit

Permalink
Refactor multiple equals to contains in providers (#34441)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Sep 26, 2023
1 parent 3b30b8f commit 8bea45f
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/redshift_data.py
Expand Up @@ -120,7 +120,7 @@ def wait_for_results(self, statement_id, poll_interval):
if num_rows is not None:
self.log.info("Processed %s rows", num_rows)
return status
elif status == "FAILED" or status == "ABORTED":
elif status in ("FAILED", "ABORTED"):
raise ValueError(
f"Statement {statement_id!r} terminated with status {status}. "
f"Response details: {pformat(resp)}"
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/triggers/s3.py
Expand Up @@ -197,7 +197,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
allow_delete=self.allow_delete,
last_activity_time=self.last_activity_time,
)
if result.get("status") == "success" or result.get("status") == "error":
if result.get("status") in ("success", "error"):
yield TriggerEvent(result)
elif result.get("status") == "pending":
self.previous_objects = result.get("previous_objects", set())
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/kubernetes_engine.py
Expand Up @@ -126,7 +126,7 @@ def wait_for_operation(self, operation: Operation, project_id: str | None = None
self.log.info("Waiting for OPERATION_NAME %s", operation.name)
time.sleep(OPERATIONAL_POLL_INTERVAL)
while operation.status != Operation.Status.DONE:
if operation.status == Operation.Status.RUNNING or operation.status == Operation.Status.PENDING:
if operation.status in (Operation.Status.RUNNING, Operation.Status.PENDING):
time.sleep(OPERATIONAL_POLL_INTERVAL)
else:
raise exceptions.GoogleCloudError(f"Operation has failed with status: {operation.status}")
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/bigquery_dts.py
Expand Up @@ -371,7 +371,7 @@ def _wait_for_transfer_to_be_done(self, run_id: str, transfer_config_id: str, in
state = transfer_run.state

if self._job_is_done(state):
if state == TransferState.FAILED or state == TransferState.CANCELLED:
if state in (TransferState.FAILED, TransferState.CANCELLED):
raise AirflowException(f"Transfer run was finished with {state} status.")

result = TransferRun.to_dict(transfer_run)
Expand All @@ -393,7 +393,7 @@ def _job_is_done(state: TransferState) -> bool:

def execute_completed(self, context: Context, event: dict):
"""Method to be executed after invoked trigger in defer method finishes its job."""
if event["status"] == "failed" or event["status"] == "cancelled":
if event["status"] in ("failed", "cancelled"):
self.log.error("Trigger finished its work with status: %s.", event["status"])
raise AirflowException(event["message"])

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/dataflow.py
Expand Up @@ -715,7 +715,7 @@ def set_current_job(current_job):

def execute_complete(self, context: Context, event: dict[str, Any]):
"""Method which executes after trigger finishes its work."""
if event["status"] == "error" or event["status"] == "stopped":
if event["status"] in ("error", "stopped"):
self.log.info("status: %s, msg: %s", event["status"], event["message"])
raise AirflowException(event["message"])

Expand Down Expand Up @@ -905,7 +905,7 @@ def _append_uuid_to_job_name(self):

def execute_complete(self, context: Context, event: dict):
"""Method which executes after trigger finishes its work."""
if event["status"] == "error" or event["status"] == "stopped":
if event["status"] in ("error", "stopped"):
self.log.info("status: %s, msg: %s", event["status"], event["message"])
raise AirflowException(event["message"])

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -1878,7 +1878,7 @@ def execute_complete(self, context, event=None) -> None:
This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "failed" or event["status"] == "error":
if event["status"] in ("failed", "error"):
self.log.exception("Unexpected error in the operation.")
raise AirflowException(event["message"])

Expand Down Expand Up @@ -2009,7 +2009,7 @@ def execute_complete(self, context, event=None) -> None:
This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "failed" or event["status"] == "error":
if event["status"] in ("failed", "error"):
self.log.exception("Unexpected error in the operation.")
raise AirflowException(event["message"])

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Expand Up @@ -164,7 +164,7 @@ def execute_complete(self, context: Context, event: dict) -> str:
status = event["status"]
message = event["message"]

if status == "failed" or status == "error":
if status in ("failed", "error"):
self.log.exception("Trigger ended with one of the failed statuses.")
raise AirflowException(message)

Expand Down Expand Up @@ -371,7 +371,7 @@ def execute_complete(self, context: Context, event: dict) -> str:
status = event["status"]
message = event["message"]

if status == "failed" or status == "error":
if status in ("failed", "error"):
self.log.exception("Trigger ended with one of the failed statuses.")
raise AirflowException(message)

Expand Down
Expand Up @@ -201,7 +201,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
}
)
return
elif status == Operation.Status.RUNNING or status == Operation.Status.PENDING:
elif status in (Operation.Status.RUNNING, Operation.Status.PENDING):
self.log.info("Operation is still running.")
self.log.info("Sleeping for %ss...", self.poll_interval)
await asyncio.sleep(self.poll_interval)
Expand Down
Expand Up @@ -477,7 +477,7 @@ def _validate_search_targets(self, targets, body):
matches = set()
for target in targets:
print(f"Loop over:::target = {target}")
if target == "_all" or target == "":
if target in ("_all", ""):
matches.update(self.__documents_dict)
elif "*" in target:
matches.update(fnmatch.filter(self.__documents_dict, target))
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/google/cloud/operators/test_bigquery.py
Expand Up @@ -2041,7 +2041,7 @@ def test_bigquery_value_check_empty(self):
)
with pytest.raises(AirflowException) as missing_param:
BigQueryValueCheckOperator(deferrable=True, kwargs={})
assert (missing_param.value.args[0] == expected) or (missing_param.value.args[0] == expected1)
assert missing_param.value.args[0] in (expected, expected1)

def test_bigquery_value_check_operator_execute_complete_success(self):
"""Tests response message in case of success event"""
Expand Down

0 comments on commit 8bea45f

Please sign in to comment.