From 3fa9d46ec74ef8453fcf17fbd49280cb6fb37cef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Tue, 12 Sep 2023 21:24:17 +0000 Subject: [PATCH] Refactor: Simplify code in providers/google (#33229) Co-authored-by: Elad Kalif <45845474+eladkal@users.noreply.github.com> Co-authored-by: Tzu-ping Chung --- .../providers/google/cloud/hooks/bigquery.py | 7 +-- airflow/providers/google/cloud/hooks/gcs.py | 48 +++++++++---------- .../google/cloud/operators/bigquery.py | 2 +- .../google/cloud/operators/cloud_build.py | 2 +- .../cloud/transfers/facebook_ads_to_gcs.py | 4 +- .../google/cloud/transfers/gcs_to_gcs.py | 4 +- .../google/cloud/utils/bigquery_get_data.py | 10 ++-- .../google/cloud/utils/field_validator.py | 18 ++++--- 8 files changed, 40 insertions(+), 55 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index c98e7e92ab22e..be486c6c8d2bb 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -2874,12 +2874,7 @@ def fetchall(self) -> list[list]: A sequence of sequences (e.g. a list of tuples) is returned. """ - result = [] - while True: - one = self.fetchone() - if one is None: - break - result.append(one) + result = list(iter(self.fetchone, None)) return result def get_arraysize(self) -> int: diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index acc4ad688af09..d8bb36037fa10 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -330,11 +330,16 @@ def download( # TODO: future improvement check file size before downloading, # to check for local space availability - num_file_attempts = 0 + if num_max_attempts is None: + num_max_attempts = 3 + + for attempt in range(num_max_attempts): + if attempt: + # Wait with exponential backoff scheme before retrying. + timeout_seconds = 2**attempt + time.sleep(timeout_seconds) - while True: try: - num_file_attempts += 1 client = self.get_conn() bucket = client.bucket(bucket_name, user_project=user_project) blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size) @@ -347,19 +352,17 @@ def download( return blob.download_as_bytes() except GoogleCloudError: - if num_file_attempts == num_max_attempts: + if attempt == num_max_attempts - 1: self.log.error( "Download attempt of object: %s from %s has failed. Attempt: %s, max %s.", object_name, bucket_name, - num_file_attempts, + attempt, num_max_attempts, ) raise - - # Wait with exponential backoff scheme before retrying. - timeout_seconds = 2 ** (num_file_attempts - 1) - time.sleep(timeout_seconds) + else: + raise NotImplementedError # should not reach this, but makes mypy happy def download_as_byte_array( self, @@ -826,15 +829,10 @@ def _list( versions=versions, ) - blob_names = [] - for blob in blobs: - blob_names.append(blob.name) - - prefixes = blobs.prefixes - if prefixes: - ids += list(prefixes) + if blobs.prefixes: + ids.extend(blobs.prefixes) else: - ids += blob_names + ids.extend(blob.name for blob in blobs) page_token = blobs.next_page_token if page_token is None: @@ -942,16 +940,14 @@ def list_by_timespan( versions=versions, ) - blob_names = [] - for blob in blobs: - if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) < timespan_end: - blob_names.append(blob.name) - - prefixes = blobs.prefixes - if prefixes: - ids += list(prefixes) + if blobs.prefixes: + ids.extend(blobs.prefixes) else: - ids += blob_names + ids.extend( + blob.name + for blob in blobs + if timespan_start <= blob.updated.replace(tzinfo=timezone.utc) < timespan_end + ) page_token = blobs.next_page_token if page_token is None: diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 4a17c41ba2cc8..1ed333d9dc4b6 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -329,7 +329,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: records = event["records"] if not records: raise AirflowException("The query returned empty results") - elif not all(bool(r) for r in records): + elif not all(records): self._raise_exception( # type: ignore[attr-defined] f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}" ) diff --git a/airflow/providers/google/cloud/operators/cloud_build.py b/airflow/providers/google/cloud/operators/cloud_build.py index 9daacefa72139..690a64f349cb5 100644 --- a/airflow/providers/google/cloud/operators/cloud_build.py +++ b/airflow/providers/google/cloud/operators/cloud_build.py @@ -202,7 +202,7 @@ def prepare_template(self) -> None: if not isinstance(self.build_raw, str): return with open(self.build_raw) as file: - if any(self.build_raw.endswith(ext) for ext in [".yaml", ".yml"]): + if self.build_raw.endswith((".yaml", ".yml")): self.build = yaml.safe_load(file.read()) if self.build_raw.endswith(".json"): self.build = json.loads(file.read()) diff --git a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py index bc0dae153a2dc..758bd818ca88d 100644 --- a/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py @@ -226,7 +226,5 @@ def _flush_rows(self, converted_rows: list[Any] | None, object_name: str): def _transform_object_name_with_account_id(self, account_id: str): directory_parts = self.object_name.split("/") - directory_parts[len(directory_parts) - 1] = ( - account_id + "_" + directory_parts[len(directory_parts) - 1] - ) + directory_parts[-1] = f"{account_id}_{directory_parts[-1]}" return "/".join(directory_parts) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index b262a75e6afc6..ad7a41c23d61a 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -207,7 +207,7 @@ def __init__( stacklevel=2, ) self.source_object = source_object - if source_objects and any([WILDCARD in obj for obj in source_objects]): + if source_objects and any(WILDCARD in obj for obj in source_objects): warnings.warn( "Usage of wildcard (*) in 'source_objects' is deprecated, utilize 'match_glob' instead", AirflowProviderDeprecationWarning, @@ -429,7 +429,7 @@ def _copy_multiple_objects(self, hook, source_objects, prefix): # Check whether the prefix is a root directory for all the rest of objects. _pref = prefix.rstrip("/") is_directory = prefix.endswith("/") or all( - [obj.replace(_pref, "", 1).startswith("/") for obj in source_objects] + obj.replace(_pref, "", 1).startswith("/") for obj in source_objects ) if is_directory: diff --git a/airflow/providers/google/cloud/utils/bigquery_get_data.py b/airflow/providers/google/cloud/utils/bigquery_get_data.py index 8fb61fc52ccc5..d178aee963f12 100644 --- a/airflow/providers/google/cloud/utils/bigquery_get_data.py +++ b/airflow/providers/google/cloud/utils/bigquery_get_data.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import itertools from typing import TYPE_CHECKING from google.cloud.bigquery.table import Row, RowIterator @@ -38,14 +39,13 @@ def bigquery_get_data( logger.info("Fetching Data from:") logger.info("Dataset: %s ; Table: %s", dataset_id, table_id) - i = 0 - while True: + for start_index in itertools.count(step=batch_size): rows: list[Row] | RowIterator = big_query_hook.list_rows( dataset_id=dataset_id, table_id=table_id, max_results=batch_size, selected_fields=selected_fields, - start_index=i * batch_size, + start_index=start_index, ) if isinstance(rows, RowIterator): @@ -55,8 +55,6 @@ def bigquery_get_data( logger.info("Job Finished") return - logger.info("Total Extracted rows: %s", len(rows) + i * batch_size) + logger.info("Total Extracted rows: %s", len(rows) + start_index) yield [row.values() for row in rows] - - i += 1 diff --git a/airflow/providers/google/cloud/utils/field_validator.py b/airflow/providers/google/cloud/utils/field_validator.py index 87aee5d7af027..415351c69c21f 100644 --- a/airflow/providers/google/cloud/utils/field_validator.py +++ b/airflow/providers/google/cloud/utils/field_validator.py @@ -257,7 +257,7 @@ def _validate_dict(self, children_validation_specs: dict, full_field_path: str, self._validate_field( validation_spec=child_validation_spec, dictionary_to_validate=value, parent=full_field_path ) - all_dict_keys = [spec["name"] for spec in children_validation_specs] + all_dict_keys = {spec["name"] for spec in children_validation_specs} for field_name in value.keys(): if field_name not in all_dict_keys: self.log.warning( @@ -428,20 +428,18 @@ def validate(self, body_to_validate: dict) -> None: raise GcpFieldValidationException( f"There was an error when validating: body '{body_to_validate}': '{e}'" ) - all_field_names = [ + all_field_names = { spec["name"] for spec in self._validation_specs if spec.get("type") != "union" and spec.get("api_version") != self._api_version - ] + } all_union_fields = [spec for spec in self._validation_specs if spec.get("type") == "union"] for union_field in all_union_fields: - all_field_names.extend( - [ - nested_union_spec["name"] - for nested_union_spec in union_field["fields"] - if nested_union_spec.get("type") != "union" - and nested_union_spec.get("api_version") != self._api_version - ] + all_field_names.update( + nested_union_spec["name"] + for nested_union_spec in union_field["fields"] + if nested_union_spec.get("type") != "union" + and nested_union_spec.get("api_version") != self._api_version ) for field_name in body_to_validate.keys(): if field_name not in all_field_names: