Skip to content

Commit

Permalink
Refactor: Simplify code in smaller providers (#33234)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Aug 20, 2023
1 parent 4390524 commit a91ee7a
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 25 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/common/sql/operators/sql.py
Expand Up @@ -783,7 +783,7 @@ def execute(self, context: Context):
self.log.info("Record: %s", records)
if not records:
self._raise_exception(f"The following query returned zero rows: {self.sql}")
elif not all(bool(r) for r in records):
elif not all(records):
self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")

self.log.info("Success.")
Expand Down
10 changes: 2 additions & 8 deletions airflow/providers/databricks/utils/databricks.py
Expand Up @@ -36,19 +36,13 @@ def normalise_json_content(content, json_path: str = "json") -> str | bool | lis
normalise = normalise_json_content
if isinstance(content, (str, bool)):
return content
elif isinstance(
content,
(
int,
float,
),
):
elif isinstance(content, (int, float)):
# Databricks can tolerate either numeric or string types in the API backend.
return str(content)
elif isinstance(content, (list, tuple)):
return [normalise(e, f"{json_path}[{i}]") for i, e in enumerate(content)]
elif isinstance(content, dict):
return {k: normalise(v, f"{json_path}[{k}]") for k, v in list(content.items())}
return {k: normalise(v, f"{json_path}[{k}]") for k, v in content.items()}
else:
param_type = type(content)
msg = f"Type {param_type} used for parameter {json_path} is not a number or a string"
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/dbt/cloud/hooks/dbt.py
Expand Up @@ -291,7 +291,7 @@ def _paginate(self, endpoint: str, payload: dict[str, Any] | None = None) -> lis
_paginate_payload = payload.copy() if payload else {}
_paginate_payload["offset"] = limit

while not num_current_results >= num_total_results:
while num_current_results < num_total_results:
response = self.run(endpoint=endpoint, data=_paginate_payload)
resp_json = response.json()
results.append(response)
Expand Down
24 changes: 12 additions & 12 deletions airflow/providers/http/hooks/http.py
Expand Up @@ -340,10 +340,9 @@ async def run(
if headers:
_headers.update(headers)

if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
url = self.base_url + "/" + endpoint
else:
url = (self.base_url or "") + (endpoint or "")
base_url = (self.base_url or "").rstrip("/")
endpoint = (endpoint or "").lstrip("/")
url = f"{base_url}/{endpoint}"

async with aiohttp.ClientSession() as session:
if self.method == "GET":
Expand All @@ -363,8 +362,7 @@ async def run(
else:
raise AirflowException(f"Unexpected HTTP Method: {self.method}")

attempt_num = 1
while True:
for attempt in range(1, 1 + self.retry_limit):
response = await request_func(
url,
json=data if self.method in ("POST", "PATCH") else None,
Expand All @@ -375,22 +373,24 @@ async def run(
)
try:
response.raise_for_status()
return response
except ClientResponseError as e:
self.log.warning(
"[Try %d of %d] Request to %s failed.",
attempt_num,
attempt,
self.retry_limit,
url,
)
if not self._retryable_error_async(e) or attempt_num == self.retry_limit:
if not self._retryable_error_async(e) or attempt == self.retry_limit:
self.log.exception("HTTP error with status: %s", e.status)
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(f"{e.status}:{e.message}")

attempt_num += 1
await asyncio.sleep(self.retry_delay)
else:
await asyncio.sleep(self.retry_delay)
else:
return response
else:
raise NotImplementedError # should not reach this, but makes mypy happy

def _retryable_error_async(self, exception: ClientResponseError) -> bool:
"""Determine whether an exception may successful on a subsequent attempt.
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/oracle/transfers/oracle_to_oracle.py
Expand Up @@ -73,7 +73,7 @@ def _execute(self, src_hook, dest_hook, context) -> None:

rows_total = 0
rows = cursor.fetchmany(self.rows_chunk)
while len(rows) > 0:
while rows:
rows_total += len(rows)
dest_hook.bulk_insert_rows(
self.destination_table, rows, target_fields=target_fields, commit_every=self.rows_chunk
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/sendgrid/utils/emailer.py
Expand Up @@ -94,8 +94,8 @@ def send_email(
# Add custom_args to personalization if present
pers_custom_args = kwargs.get("personalization_custom_args")
if isinstance(pers_custom_args, dict):
for key in pers_custom_args.keys():
personalization.add_custom_arg(CustomArg(key, pers_custom_args[key]))
for key, val in pers_custom_args.items():
personalization.add_custom_arg(CustomArg(key, val))

mail.add_personalization(personalization)
mail.add_content(Content("text/html", html_content))
Expand Down

0 comments on commit a91ee7a

Please sign in to comment.