Skip to content

Commit

Permalink
Add ability to pass impersonation_chain to BigQuery triggers (#35629)
Browse files Browse the repository at this point in the history
This PR adds possibiliy to pass impersonation_chain to BigQuery triggers so that customers can execute triggers in a different project by passing dedicated SA.

Co-authored-by: Ulada Zakharava <Vlada_Zakharava@epam.com>
  • Loading branch information
VladaZakharova and Ulada Zakharava committed Nov 15, 2023
1 parent 946e1e0 commit 054904b
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 17 deletions.
24 changes: 24 additions & 0 deletions airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -3086,6 +3086,18 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):

sync_hook_class = BigQueryHook

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
**kwargs,
)

async def get_job_instance(
self, project_id: str | None, job_id: str | None, session: ClientSession
) -> Job:
Expand Down Expand Up @@ -3311,6 +3323,18 @@ class BigQueryTableAsyncHook(GoogleBaseAsyncHook):

sync_hook_class = BigQueryHook

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
**kwargs,
)

async def get_table_client(
self, dataset: str, table_id: str, project_id: str, session: ClientSession
) -> Table_async:
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -301,6 +301,7 @@ def execute(self, context: Context):
else:
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
job = self._submit_job(hook, job_id="")
context["ti"].xcom_push(key="job_id", value=job.job_id)
Expand All @@ -312,6 +313,7 @@ def execute(self, context: Context):
job_id=job.job_id,
project_id=hook.project_id,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -424,7 +426,7 @@ def execute(self, context: Context) -> None: # type: ignore[override]
if not self.deferrable:
super().execute(context=context)
else:
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)

job = self._submit_job(hook, job_id="")
context["ti"].xcom_push(key="job_id", value=job.job_id)
Expand All @@ -439,6 +441,7 @@ def execute(self, context: Context) -> None: # type: ignore[override]
pass_value=self.pass_value,
tolerance=self.tol,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -573,7 +576,7 @@ def execute(self, context: Context):
if not self.deferrable:
super().execute(context)
else:
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
self.log.info("Using ratio formula: %s", self.ratio_formula)

self.log.info("Executing SQL check: %s", self.sql1)
Expand All @@ -596,6 +599,7 @@ def execute(self, context: Context):
ratio_formula=self.ratio_formula,
ignore_zero=self.ignore_zero,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
Expand Down
84 changes: 70 additions & 14 deletions airflow/providers/google/cloud/triggers/bigquery.py
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator, SupportsAbs
from typing import Any, AsyncIterator, Sequence, SupportsAbs

from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
Expand All @@ -35,7 +35,15 @@ class BigQueryInsertJobTrigger(BaseTrigger):
:param project_id: Google Cloud Project where the job is running
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
:param poll_interval: polling period in seconds to check for the status
:param poll_interval: polling period in seconds to check for the status. (templated)
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account. (templated)
"""

def __init__(
Expand All @@ -46,6 +54,7 @@ def __init__(
dataset_id: str | None = None,
table_id: str | None = None,
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__()
self.log.info("Using the connection %s .", conn_id)
Expand All @@ -56,6 +65,7 @@ def __init__(
self.project_id = project_id
self.table_id = table_id
self.poll_interval = poll_interval
self.impersonation_chain = impersonation_chain

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryInsertJobTrigger arguments and classpath."""
Expand All @@ -68,6 +78,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"project_id": self.project_id,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
},
)

Expand Down Expand Up @@ -101,7 +112,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
yield TriggerEvent({"status": "error", "message": str(e)})

def _get_async_hook(self) -> BigQueryAsyncHook:
return BigQueryAsyncHook(gcp_conn_id=self.conn_id)
return BigQueryAsyncHook(gcp_conn_id=self.conn_id, impersonation_chain=self.impersonation_chain)


class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
Expand All @@ -118,6 +129,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"project_id": self.project_id,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
},
)

Expand Down Expand Up @@ -191,6 +203,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"project_id": self.project_id,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
"as_dict": self.as_dict,
},
)
Expand Down Expand Up @@ -240,13 +253,20 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
:param dataset_id: The dataset ID of the requested table. (templated)
:param table: table name
:param metrics_thresholds: dictionary of ratios indexed by metrics
:param date_filter_column: column name
:param days_back: number of days between ds and the ds we want to check
against
:param ratio_formula: ration formula
:param ignore_zero: boolean value to consider zero or not
:param date_filter_column: column name. (templated)
:param days_back: number of days between ds and the ds we want to check against. (templated)
:param ratio_formula: ration formula. (templated)
:param ignore_zero: boolean value to consider zero or not. (templated)
:param table_id: The table ID of the requested table. (templated)
:param poll_interval: polling period in seconds to check for the status
:param poll_interval: polling period in seconds to check for the status. (templated)
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account. (templated)
"""

def __init__(
Expand All @@ -264,6 +284,7 @@ def __init__(
dataset_id: str | None = None,
table_id: str | None = None,
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__(
conn_id=conn_id,
Expand All @@ -272,6 +293,7 @@ def __init__(
dataset_id=dataset_id,
table_id=table_id,
poll_interval=poll_interval,
impersonation_chain=impersonation_chain,
)
self.conn_id = conn_id
self.first_job_id = first_job_id
Expand Down Expand Up @@ -299,6 +321,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"days_back": self.days_back,
"ratio_formula": self.ratio_formula,
"ignore_zero": self.ignore_zero,
"impersonation_chain": self.impersonation_chain,
},
)

Expand Down Expand Up @@ -386,12 +409,20 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
:param conn_id: Reference to google cloud connection id
:param sql: the sql to be executed
:param pass_value: pass value
:param job_id: The ID of the job
:param job_id: The ID of the job
:param project_id: Google Cloud Project where the job is running
:param tolerance: certain metrics for tolerance
:param tolerance: certain metrics for tolerance. (templated)
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
:param poll_interval: polling period in seconds to check for the status
:param poll_interval: polling period in seconds to check for the status. (templated)
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

def __init__(
Expand All @@ -405,6 +436,7 @@ def __init__(
dataset_id: str | None = None,
table_id: str | None = None,
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__(
conn_id=conn_id,
Expand All @@ -413,6 +445,7 @@ def __init__(
dataset_id=dataset_id,
table_id=table_id,
poll_interval=poll_interval,
impersonation_chain=impersonation_chain,
)
self.sql = sql
self.pass_value = pass_value
Expand All @@ -432,6 +465,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"table_id": self.table_id,
"tolerance": self.tolerance,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
},
)

Expand Down Expand Up @@ -473,6 +507,14 @@ class BigQueryTableExistenceTrigger(BaseTrigger):
:param gcp_conn_id: Reference to google cloud connection id
:param hook_params: params for hook
:param poll_interval: polling period in seconds to check for the status
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account. (templated)
"""

def __init__(
Expand All @@ -483,13 +525,15 @@ def __init__(
gcp_conn_id: str,
hook_params: dict[str, Any],
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
):
self.dataset_id = dataset_id
self.project_id = project_id
self.table_id = table_id
self.gcp_conn_id: str = gcp_conn_id
self.poll_interval = poll_interval
self.hook_params = hook_params
self.impersonation_chain = impersonation_chain

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryTableExistenceTrigger arguments and classpath."""
Expand All @@ -502,11 +546,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"poll_interval": self.poll_interval,
"hook_params": self.hook_params,
"impersonation_chain": self.impersonation_chain,
},
)

def _get_async_hook(self) -> BigQueryTableAsyncHook:
return BigQueryTableAsyncHook(gcp_conn_id=self.gcp_conn_id)
return BigQueryTableAsyncHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Will run until the table exists in the Google Big Query."""
Expand Down Expand Up @@ -561,6 +608,14 @@ class BigQueryTablePartitionExistenceTrigger(BigQueryTableExistenceTrigger):
:param gcp_conn_id: Reference to google cloud connection id
:param hook_params: params for hook
:param poll_interval: polling period in seconds to check for the status
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account. (templated)
"""

def __init__(self, partition_id: str, **kwargs):
Expand All @@ -578,13 +633,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"table_id": self.table_id,
"gcp_conn_id": self.gcp_conn_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
"hook_params": self.hook_params,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Will run until the table exists in the Google Big Query."""
hook = BigQueryAsyncHook(gcp_conn_id=self.gcp_conn_id)
hook = BigQueryAsyncHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
job_id = None
while True:
if job_id is not None:
Expand Down
2 changes: 1 addition & 1 deletion dev/breeze/README.md
Expand Up @@ -66,6 +66,6 @@ PLEASE DO NOT MODIFY THE HASH BELOW! IT IS AUTOMATICALLY UPDATED BY PRE-COMMIT.

---------------------------------------------------------------------------------------------------------

Package config hash: 51d9c2ec8af90c2941d58cf28397e9972d31718bc5d74538eb0614ed9418310e7b1d14bb3ee11f4df6e8403390869838217dc641cdb1416a223b7cf69adf1b20
Package config hash: Missing file /usr/local/google/home/uladaz/Documents/Project/Airflow/airflow/dev/breeze/setup.py

---------------------------------------------------------------------------------------------------------

0 comments on commit 054904b

Please sign in to comment.