Skip to content

Commit

Permalink
Add a new param for BigQuery operators to support additional actions …
Browse files Browse the repository at this point in the history
…when resource exists (#29394)

* Add a new param to support additional actions when resource exists and depracte old one
---------

Co-authored-by: eladkal <45845474+eladkal@users.noreply.github.com>
  • Loading branch information
hussein-awala and eladkal committed Feb 26, 2023
1 parent 228d79c commit a5adb87
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 32 deletions.
101 changes: 70 additions & 31 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from google.api_core.retry import Retry
from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import BaseOperator, BaseOperatorLink
from airflow.models.xcom import XCom
from airflow.providers.common.sql.operators.sql import (
Expand Down Expand Up @@ -68,6 +68,15 @@ class BigQueryUIColors(enum.Enum):
DATASET = "#5F86FF"


class IfExistAction(enum.Enum):
"""Action to take if the resource exist"""

IGNORE = "ignore"
LOG = "log"
FAIL = "fail"
SKIP = "skip"


class BigQueryConsoleLink(BaseOperatorLink):
"""Helper class for constructing BigQuery link."""

Expand Down Expand Up @@ -248,7 +257,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
if not records:
raise AirflowException("The query returned empty results")
elif not all(bool(r) for r in records):
self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
self._raise_exception( # type: ignore[attr-defined]
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
)
self.log.info("Record: %s", event["records"])
self.log.info("Success.")

Expand Down Expand Up @@ -773,9 +784,6 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
:param selected_fields: List of fields to return (comma-separated). If
unspecified, all fields are returned.
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled. Deprecated.
:param location: The location used for the operation.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand All @@ -786,6 +794,9 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled. Deprecated.
"""

template_fields: Sequence[str] = (
Expand All @@ -807,10 +818,10 @@ def __init__(
max_results: int = 100,
selected_fields: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
delegate_to: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -1253,7 +1264,10 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param exists_ok: If ``True``, ignore "already exists" errors when creating the table.
:param if_exists: What should Airflow do if the table exists. If set to `log`, the TI will be passed to
success and an error message will be logged. Set to `ignore` to ignore the error, set to `fail` to
fail the TI, and set to `skip` to skip it.
:param exists_ok: Deprecated - use `if_exists="ignore"` instead.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -1282,17 +1296,18 @@ def __init__(
gcs_schema_object: str | None = None,
time_partitioning: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
bigquery_conn_id: str | None = None,
google_cloud_storage_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
labels: dict | None = None,
view: dict | None = None,
materialized_view: dict | None = None,
encryption_configuration: dict | None = None,
location: str | None = None,
cluster_fields: list[str] | None = None,
impersonation_chain: str | Sequence[str] | None = None,
exists_ok: bool = False,
if_exists: str = "log",
delegate_to: str | None = None,
bigquery_conn_id: str | None = None,
exists_ok: bool | None = None,
**kwargs,
) -> None:
if bigquery_conn_id:
Expand Down Expand Up @@ -1326,7 +1341,11 @@ def __init__(
self.cluster_fields = cluster_fields
self.table_resource = table_resource
self.impersonation_chain = impersonation_chain
self.exists_ok = exists_ok
if exists_ok is not None:
warnings.warn("`exists_ok` parameter is deprecated, please use `if_exists`", DeprecationWarning)
self.if_exists = IfExistAction.IGNORE if exists_ok else IfExistAction.LOG
else:
self.if_exists = IfExistAction(if_exists)

def execute(self, context: Context) -> None:
bq_hook = BigQueryHook(
Expand Down Expand Up @@ -1362,7 +1381,7 @@ def execute(self, context: Context) -> None:
materialized_view=self.materialized_view,
encryption_configuration=self.encryption_configuration,
table_resource=self.table_resource,
exists_ok=self.exists_ok,
exists_ok=self.if_exists == IfExistAction.IGNORE,
)
BigQueryTableLink.persist(
context=context,
Expand All @@ -1375,7 +1394,13 @@ def execute(self, context: Context) -> None:
"Table %s.%s.%s created successfully", table.project, table.dataset_id, table.table_id
)
except Conflict:
self.log.info("Table %s.%s already exists.", self.dataset_id, self.table_id)
error_msg = f"Table {self.dataset_id}.{self.table_id} already exists."
if self.if_exists == IfExistAction.LOG:
self.log.info(error_msg)
elif self.if_exists == IfExistAction.FAIL:
raise AirflowException(error_msg)
else:
raise AirflowSkipException(error_msg)


class BigQueryCreateExternalTableOperator(GoogleCloudBaseOperator):
Expand Down Expand Up @@ -1490,14 +1515,14 @@ def __init__(
allow_quoted_newlines: bool = False,
allow_jagged_rows: bool = False,
gcp_conn_id: str = "google_cloud_default",
bigquery_conn_id: str | None = None,
google_cloud_storage_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
src_fmt_configs: dict | None = None,
labels: dict | None = None,
encryption_configuration: dict | None = None,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
bigquery_conn_id: str | None = None,
**kwargs,
) -> None:
if bigquery_conn_id:
Expand Down Expand Up @@ -1721,8 +1746,8 @@ def __init__(
project_id: str | None = None,
delete_contents: bool = False,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
self.dataset_id = dataset_id
Expand Down Expand Up @@ -1779,7 +1804,9 @@ class BigQueryCreateEmptyDatasetOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param exists_ok: If ``True``, ignore "already exists" errors when creating the dataset.
:param if_exists: What should Airflow do if the dataset exists. If set to `log`, the TI will be passed to
success and an error message will be logged. Set to `ignore` to ignore the error, set to `fail` to
fail the TI, and set to `skip` to skip it.
**Example**: ::
create_new_dataset = BigQueryCreateEmptyDatasetOperator(
Expand All @@ -1789,6 +1816,7 @@ class BigQueryCreateEmptyDatasetOperator(GoogleCloudBaseOperator):
gcp_conn_id='_my_gcp_conn_',
task_id='newDatasetCreator',
dag=dag)
:param exists_ok: Deprecated - use `if_exists="ignore"` instead.
"""

template_fields: Sequence[str] = (
Expand All @@ -1809,9 +1837,10 @@ def __init__(
dataset_reference: dict | None = None,
location: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
exists_ok: bool = False,
if_exists: str = "log",
delegate_to: str | None = None,
exists_ok: bool | None = None,
**kwargs,
) -> None:

Expand All @@ -1826,7 +1855,11 @@ def __init__(
)
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.exists_ok = exists_ok
if exists_ok is not None:
warnings.warn("`exists_ok` parameter is deprecated, please use `if_exists`", DeprecationWarning)
self.if_exists = IfExistAction.IGNORE if exists_ok else IfExistAction.LOG
else:
self.if_exists = IfExistAction(if_exists)

super().__init__(**kwargs)

Expand All @@ -1844,7 +1877,7 @@ def execute(self, context: Context) -> None:
dataset_id=self.dataset_id,
dataset_reference=self.dataset_reference,
location=self.location,
exists_ok=self.exists_ok,
exists_ok=self.if_exists == IfExistAction.IGNORE,
)
BigQueryDatasetLink.persist(
context=context,
Expand All @@ -1854,7 +1887,13 @@ def execute(self, context: Context) -> None:
)
except Conflict:
dataset_id = self.dataset_reference.get("datasetReference", {}).get("datasetId", self.dataset_id)
self.log.info("Dataset %s already exists.", dataset_id)
error_msg = f"Dataset {dataset_id} already exists."
if self.if_exists == IfExistAction.LOG:
self.log.info(error_msg)
elif self.if_exists == IfExistAction.FAIL:
raise AirflowException(error_msg)
else:
raise AirflowSkipException(error_msg)


class BigQueryGetDatasetOperator(GoogleCloudBaseOperator):
Expand Down Expand Up @@ -1897,8 +1936,8 @@ def __init__(
dataset_id: str,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
self.dataset_id = dataset_id
Expand Down Expand Up @@ -1972,8 +2011,8 @@ def __init__(
project_id: str | None = None,
max_results: int | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
self.dataset_id = dataset_id
Expand Down Expand Up @@ -2045,8 +2084,8 @@ def __init__(
dataset_resource: dict,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
warnings.warn(
Expand Down Expand Up @@ -2133,8 +2172,8 @@ def __init__(
table_id: str | None = None,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
self.dataset_id = dataset_id
Expand Down Expand Up @@ -2227,8 +2266,8 @@ def __init__(
dataset_id: str | None = None,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
self.dataset_id = dataset_id
Expand Down Expand Up @@ -2308,10 +2347,10 @@ def __init__(
*,
deletion_dataset_table: str,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
ignore_if_missing: bool = False,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -2385,9 +2424,9 @@ def __init__(
table_resource: dict,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -2496,8 +2535,8 @@ def __init__(
include_policy_tags: bool = False,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
**kwargs,
) -> None:
self.schema_fields_updates = schema_fields_updates
Expand Down Expand Up @@ -2616,12 +2655,12 @@ def __init__(
force_rerun: bool = True,
reattach_states: set[str] | None = None,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
cancel_on_kill: bool = True,
result_retry: Retry = DEFAULT_RETRY,
result_timeout: float | None = None,
deferrable: bool = False,
delegate_to: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down
Loading

0 comments on commit a5adb87

Please sign in to comment.