From a91ee7ac2fe29f460a4e4b0d8c1346f40672be43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Sun, 20 Aug 2023 18:20:24 +0000 Subject: [PATCH] Refactor: Simplify code in smaller providers (#33234) --- airflow/providers/common/sql/operators/sql.py | 2 +- .../providers/databricks/utils/databricks.py | 10 ++------ airflow/providers/dbt/cloud/hooks/dbt.py | 2 +- airflow/providers/http/hooks/http.py | 24 +++++++++---------- .../oracle/transfers/oracle_to_oracle.py | 2 +- airflow/providers/sendgrid/utils/emailer.py | 4 ++-- 6 files changed, 19 insertions(+), 25 deletions(-) diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 709c831b5aa4c..cf42fa4f57363 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -783,7 +783,7 @@ def execute(self, context: Context): self.log.info("Record: %s", records) if not records: self._raise_exception(f"The following query returned zero rows: {self.sql}") - elif not all(bool(r) for r in records): + elif not all(records): self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}") self.log.info("Success.") diff --git a/airflow/providers/databricks/utils/databricks.py b/airflow/providers/databricks/utils/databricks.py index 050eba9096943..cac38b8fd9a91 100644 --- a/airflow/providers/databricks/utils/databricks.py +++ b/airflow/providers/databricks/utils/databricks.py @@ -36,19 +36,13 @@ def normalise_json_content(content, json_path: str = "json") -> str | bool | lis normalise = normalise_json_content if isinstance(content, (str, bool)): return content - elif isinstance( - content, - ( - int, - float, - ), - ): + elif isinstance(content, (int, float)): # Databricks can tolerate either numeric or string types in the API backend. return str(content) elif isinstance(content, (list, tuple)): return [normalise(e, f"{json_path}[{i}]") for i, e in enumerate(content)] elif isinstance(content, dict): - return {k: normalise(v, f"{json_path}[{k}]") for k, v in list(content.items())} + return {k: normalise(v, f"{json_path}[{k}]") for k, v in content.items()} else: param_type = type(content) msg = f"Type {param_type} used for parameter {json_path} is not a number or a string" diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index fdbffc4b3444f..8c589970ab620 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -291,7 +291,7 @@ def _paginate(self, endpoint: str, payload: dict[str, Any] | None = None) -> lis _paginate_payload = payload.copy() if payload else {} _paginate_payload["offset"] = limit - while not num_current_results >= num_total_results: + while num_current_results < num_total_results: response = self.run(endpoint=endpoint, data=_paginate_payload) resp_json = response.json() results.append(response) diff --git a/airflow/providers/http/hooks/http.py b/airflow/providers/http/hooks/http.py index 0b0443efc484d..6d5884792227b 100644 --- a/airflow/providers/http/hooks/http.py +++ b/airflow/providers/http/hooks/http.py @@ -340,10 +340,9 @@ async def run( if headers: _headers.update(headers) - if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"): - url = self.base_url + "/" + endpoint - else: - url = (self.base_url or "") + (endpoint or "") + base_url = (self.base_url or "").rstrip("/") + endpoint = (endpoint or "").lstrip("/") + url = f"{base_url}/{endpoint}" async with aiohttp.ClientSession() as session: if self.method == "GET": @@ -363,8 +362,7 @@ async def run( else: raise AirflowException(f"Unexpected HTTP Method: {self.method}") - attempt_num = 1 - while True: + for attempt in range(1, 1 + self.retry_limit): response = await request_func( url, json=data if self.method in ("POST", "PATCH") else None, @@ -375,22 +373,24 @@ async def run( ) try: response.raise_for_status() - return response except ClientResponseError as e: self.log.warning( "[Try %d of %d] Request to %s failed.", - attempt_num, + attempt, self.retry_limit, url, ) - if not self._retryable_error_async(e) or attempt_num == self.retry_limit: + if not self._retryable_error_async(e) or attempt == self.retry_limit: self.log.exception("HTTP error with status: %s", e.status) # In this case, the user probably made a mistake. # Don't retry. raise AirflowException(f"{e.status}:{e.message}") - - attempt_num += 1 - await asyncio.sleep(self.retry_delay) + else: + await asyncio.sleep(self.retry_delay) + else: + return response + else: + raise NotImplementedError # should not reach this, but makes mypy happy def _retryable_error_async(self, exception: ClientResponseError) -> bool: """Determine whether an exception may successful on a subsequent attempt. diff --git a/airflow/providers/oracle/transfers/oracle_to_oracle.py b/airflow/providers/oracle/transfers/oracle_to_oracle.py index 9e54a6d756626..6762a5742f10d 100644 --- a/airflow/providers/oracle/transfers/oracle_to_oracle.py +++ b/airflow/providers/oracle/transfers/oracle_to_oracle.py @@ -73,7 +73,7 @@ def _execute(self, src_hook, dest_hook, context) -> None: rows_total = 0 rows = cursor.fetchmany(self.rows_chunk) - while len(rows) > 0: + while rows: rows_total += len(rows) dest_hook.bulk_insert_rows( self.destination_table, rows, target_fields=target_fields, commit_every=self.rows_chunk diff --git a/airflow/providers/sendgrid/utils/emailer.py b/airflow/providers/sendgrid/utils/emailer.py index 818e24f2bd392..66856738b6717 100644 --- a/airflow/providers/sendgrid/utils/emailer.py +++ b/airflow/providers/sendgrid/utils/emailer.py @@ -94,8 +94,8 @@ def send_email( # Add custom_args to personalization if present pers_custom_args = kwargs.get("personalization_custom_args") if isinstance(pers_custom_args, dict): - for key in pers_custom_args.keys(): - personalization.add_custom_arg(CustomArg(key, pers_custom_args[key])) + for key, val in pers_custom_args.items(): + personalization.add_custom_arg(CustomArg(key, val)) mail.add_personalization(personalization) mail.add_content(Content("text/html", html_content))