diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 66b259838e522..60ffb1e8eb44b 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -972,16 +972,31 @@ def start_sql_job( :param on_new_job_callback: Callback called when the job is known. :return: the new job object """ + gcp_options = [ + f"--project={project_id}", + "--format=value(job.id)", + f"--job-name={job_name}", + f"--region={location}", + ] + + if self.impersonation_chain: + if isinstance(self.impersonation_chain, str): + impersonation_account = self.impersonation_chain + elif len(self.impersonation_chain) == 1: + impersonation_account = self.impersonation_chain[0] + else: + raise AirflowException( + "Chained list of accounts is not supported, please specify only one service account" + ) + gcp_options.append(f"--impersonate-service-account={impersonation_account}") + cmd = [ "gcloud", "dataflow", "sql", "query", query, - f"--project={project_id}", - "--format=value(job.id)", - f"--job-name={job_name}", - f"--region={location}", + *gcp_options, *(beam_options_to_args(options)), ] self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd)) diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 98a17d800cbe5..677bc94940dba 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -727,6 +727,14 @@ class DataflowStartFlexTemplateOperator(BaseOperator): If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False to the operator, the second loop will check once is job not in terminal state and exit the loop. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). """ template_fields: Sequence[str] = ("body", "location", "project_id", "gcp_conn_id") @@ -742,6 +750,7 @@ def __init__( drain_pipeline: bool = False, cancel_timeout: Optional[int] = 10 * 60, wait_until_finished: Optional[bool] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, *args, **kwargs, ) -> None: @@ -756,6 +765,7 @@ def __init__( self.wait_until_finished = wait_until_finished self.job = None self.hook: Optional[DataflowHook] = None + self.impersonation_chain = impersonation_chain def execute(self, context: 'Context'): self.hook = DataflowHook( @@ -764,6 +774,7 @@ def execute(self, context: 'Context'): drain_pipeline=self.drain_pipeline, cancel_timeout=self.cancel_timeout, wait_until_finished=self.wait_until_finished, + impersonation_chain=self.impersonation_chain, ) def set_current_job(current_job): @@ -821,6 +832,14 @@ class DataflowStartSqlJobOperator(BaseOperator): :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it instead of canceling during killing task instance. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). """ template_fields: Sequence[str] = ( @@ -843,6 +862,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, drain_pipeline: bool = False, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, *args, **kwargs, ) -> None: @@ -855,6 +875,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.drain_pipeline = drain_pipeline + self.impersonation_chain = impersonation_chain self.job = None self.hook: Optional[DataflowHook] = None @@ -863,6 +884,7 @@ def execute(self, context: 'Context'): gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, drain_pipeline=self.drain_pipeline, + impersonation_chain=self.impersonation_chain, ) def set_current_job(current_job): diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index 98aeebe2fb1e4..b24eeb566b777 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -173,20 +173,10 @@ def test_fn(self, *args, **kwargs): FixtureFallback().test_fn({'project': "TEST"}, "TEST2") -def mock_init( - self, - gcp_conn_id, - delegate_to=None, - impersonation_chain=None, -): - pass - - class TestDataflowHook(unittest.TestCase): def setUp(self): - with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init): - self.dataflow_hook = DataflowHook(gcp_conn_id='test') - self.dataflow_hook.beam_hook = MagicMock() + self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default') + self.dataflow_hook.beam_hook = MagicMock() @mock.patch("airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize") @mock.patch("airflow.providers.google.cloud.hooks.dataflow.build") @@ -792,8 +782,7 @@ def test_wait_for_done(self, mock_conn, mock_dataflowjob): class TestDataflowTemplateHook(unittest.TestCase): def setUp(self): - with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init): - self.dataflow_hook = DataflowHook(gcp_conn_id='test') + self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default') @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 333bf7c2725cf..2cc37cb589629 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -496,6 +496,14 @@ def test_execute(self, mock_dataflow): location=TEST_LOCATION, ) start_flex_template.execute(mock.MagicMock()) + mock_dataflow.assert_called_once_with( + gcp_conn_id='google_cloud_default', + delegate_to=None, + drain_pipeline=False, + cancel_timeout=600, + wait_until_finished=None, + impersonation_chain=None, + ) mock_dataflow.return_value.start_flex_template.assert_called_once_with( body={"launchParameter": TEST_FLEX_PARAMETERS}, location=TEST_LOCATION, @@ -533,7 +541,10 @@ def test_execute(self, mock_hook): start_sql.execute(mock.MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', delegate_to=None, drain_pipeline=False + gcp_conn_id='google_cloud_default', + delegate_to=None, + drain_pipeline=False, + impersonation_chain=None, ) mock_hook.return_value.start_sql_job.assert_called_once_with( job_name=TEST_SQL_JOB_NAME,