Skip to content

Commit

Permalink
Refactor: Simplify code in providers/google (#33229)
Browse files Browse the repository at this point in the history

Co-authored-by: Elad Kalif <45845474+eladkal@users.noreply.github.com>
Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
3 people committed Sep 12, 2023
1 parent 3df1af4 commit 3fa9d46
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 55 deletions.
7 changes: 1 addition & 6 deletions airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -2874,12 +2874,7 @@ def fetchall(self) -> list[list]:
A sequence of sequences (e.g. a list of tuples) is returned.
"""
result = []
while True:
one = self.fetchone()
if one is None:
break
result.append(one)
result = list(iter(self.fetchone, None))
return result

def get_arraysize(self) -> int:
Expand Down
48 changes: 22 additions & 26 deletions airflow/providers/google/cloud/hooks/gcs.py
Expand Up @@ -330,11 +330,16 @@ def download(
# TODO: future improvement check file size before downloading,
# to check for local space availability

num_file_attempts = 0
if num_max_attempts is None:
num_max_attempts = 3

for attempt in range(num_max_attempts):
if attempt:
# Wait with exponential backoff scheme before retrying.
timeout_seconds = 2**attempt
time.sleep(timeout_seconds)

while True:
try:
num_file_attempts += 1
client = self.get_conn()
bucket = client.bucket(bucket_name, user_project=user_project)
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)
Expand All @@ -347,19 +352,17 @@ def download(
return blob.download_as_bytes()

except GoogleCloudError:
if num_file_attempts == num_max_attempts:
if attempt == num_max_attempts - 1:
self.log.error(
"Download attempt of object: %s from %s has failed. Attempt: %s, max %s.",
object_name,
bucket_name,
num_file_attempts,
attempt,
num_max_attempts,
)
raise

# Wait with exponential backoff scheme before retrying.
timeout_seconds = 2 ** (num_file_attempts - 1)
time.sleep(timeout_seconds)
else:
raise NotImplementedError # should not reach this, but makes mypy happy

def download_as_byte_array(
self,
Expand Down Expand Up @@ -826,15 +829,10 @@ def _list(
versions=versions,
)

blob_names = []
for blob in blobs:
blob_names.append(blob.name)

prefixes = blobs.prefixes
if prefixes:
ids += list(prefixes)
if blobs.prefixes:
ids.extend(blobs.prefixes)
else:
ids += blob_names
ids.extend(blob.name for blob in blobs)

page_token = blobs.next_page_token
if page_token is None:
Expand Down Expand Up @@ -942,16 +940,14 @@ def list_by_timespan(
versions=versions,
)

blob_names = []
for blob in blobs:
if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) < timespan_end:
blob_names.append(blob.name)

prefixes = blobs.prefixes
if prefixes:
ids += list(prefixes)
if blobs.prefixes:
ids.extend(blobs.prefixes)
else:
ids += blob_names
ids.extend(
blob.name
for blob in blobs
if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) < timespan_end
)

page_token = blobs.next_page_token
if page_token is None:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -329,7 +329,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
records = event["records"]
if not records:
raise AirflowException("The query returned empty results")
elif not all(bool(r) for r in records):
elif not all(records):
self._raise_exception( # type: ignore[attr-defined]
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/cloud_build.py
Expand Up @@ -202,7 +202,7 @@ def prepare_template(self) -> None:
if not isinstance(self.build_raw, str):
return
with open(self.build_raw) as file:
if any(self.build_raw.endswith(ext) for ext in [".yaml", ".yml"]):
if self.build_raw.endswith((".yaml", ".yml")):
self.build = yaml.safe_load(file.read())
if self.build_raw.endswith(".json"):
self.build = json.loads(file.read())
Expand Down
Expand Up @@ -226,7 +226,5 @@ def _flush_rows(self, converted_rows: list[Any] | None, object_name: str):

def _transform_object_name_with_account_id(self, account_id: str):
directory_parts = self.object_name.split("/")
directory_parts[len(directory_parts) - 1] = (
account_id + "_" + directory_parts[len(directory_parts) - 1]
)
directory_parts[-1] = f"{account_id}_{directory_parts[-1]}"
return "/".join(directory_parts)
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/transfers/gcs_to_gcs.py
Expand Up @@ -207,7 +207,7 @@ def __init__(
stacklevel=2,
)
self.source_object = source_object
if source_objects and any([WILDCARD in obj for obj in source_objects]):
if source_objects and any(WILDCARD in obj for obj in source_objects):
warnings.warn(
"Usage of wildcard (*) in 'source_objects' is deprecated, utilize 'match_glob' instead",
AirflowProviderDeprecationWarning,
Expand Down Expand Up @@ -429,7 +429,7 @@ def _copy_multiple_objects(self, hook, source_objects, prefix):
# Check whether the prefix is a root directory for all the rest of objects.
_pref = prefix.rstrip("/")
is_directory = prefix.endswith("/") or all(
[obj.replace(_pref, "", 1).startswith("/") for obj in source_objects]
obj.replace(_pref, "", 1).startswith("/") for obj in source_objects
)

if is_directory:
Expand Down
10 changes: 4 additions & 6 deletions airflow/providers/google/cloud/utils/bigquery_get_data.py
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING

from google.cloud.bigquery.table import Row, RowIterator
Expand All @@ -38,14 +39,13 @@ def bigquery_get_data(
logger.info("Fetching Data from:")
logger.info("Dataset: %s ; Table: %s", dataset_id, table_id)

i = 0
while True:
for start_index in itertools.count(step=batch_size):
rows: list[Row] | RowIterator = big_query_hook.list_rows(
dataset_id=dataset_id,
table_id=table_id,
max_results=batch_size,
selected_fields=selected_fields,
start_index=i * batch_size,
start_index=start_index,
)

if isinstance(rows, RowIterator):
Expand All @@ -55,8 +55,6 @@ def bigquery_get_data(
logger.info("Job Finished")
return

logger.info("Total Extracted rows: %s", len(rows) + i * batch_size)
logger.info("Total Extracted rows: %s", len(rows) + start_index)

yield [row.values() for row in rows]

i += 1
18 changes: 8 additions & 10 deletions airflow/providers/google/cloud/utils/field_validator.py
Expand Up @@ -257,7 +257,7 @@ def _validate_dict(self, children_validation_specs: dict, full_field_path: str,
self._validate_field(
validation_spec=child_validation_spec, dictionary_to_validate=value, parent=full_field_path
)
all_dict_keys = [spec["name"] for spec in children_validation_specs]
all_dict_keys = {spec["name"] for spec in children_validation_specs}
for field_name in value.keys():
if field_name not in all_dict_keys:
self.log.warning(
Expand Down Expand Up @@ -428,20 +428,18 @@ def validate(self, body_to_validate: dict) -> None:
raise GcpFieldValidationException(
f"There was an error when validating: body '{body_to_validate}': '{e}'"
)
all_field_names = [
all_field_names = {
spec["name"]
for spec in self._validation_specs
if spec.get("type") != "union" and spec.get("api_version") != self._api_version
]
}
all_union_fields = [spec for spec in self._validation_specs if spec.get("type") == "union"]
for union_field in all_union_fields:
all_field_names.extend(
[
nested_union_spec["name"]
for nested_union_spec in union_field["fields"]
if nested_union_spec.get("type") != "union"
and nested_union_spec.get("api_version") != self._api_version
]
all_field_names.update(
nested_union_spec["name"]
for nested_union_spec in union_field["fields"]
if nested_union_spec.get("type") != "union"
and nested_union_spec.get("api_version") != self._api_version
)
for field_name in body_to_validate.keys():
if field_name not in all_field_names:
Expand Down

0 comments on commit 3fa9d46

Please sign in to comment.