Skip to content

Commit

Permalink
Added impersonation_chain for DataflowStartFlexTemplateOperator and D…
Browse files Browse the repository at this point in the history
…ataflowStartSqlJobOperator (#24046)
  • Loading branch information
Łukasz Wyszomirski committed Jun 4, 2022
1 parent 5d5598e commit 90233bc
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 19 deletions.
23 changes: 19 additions & 4 deletions airflow/providers/google/cloud/hooks/dataflow.py
Expand Up @@ -972,16 +972,31 @@ def start_sql_job(
:param on_new_job_callback: Callback called when the job is known.
:return: the new job object
"""
gcp_options = [
f"--project={project_id}",
"--format=value(job.id)",
f"--job-name={job_name}",
f"--region={location}",
]

if self.impersonation_chain:
if isinstance(self.impersonation_chain, str):
impersonation_account = self.impersonation_chain
elif len(self.impersonation_chain) == 1:
impersonation_account = self.impersonation_chain[0]
else:
raise AirflowException(
"Chained list of accounts is not supported, please specify only one service account"
)
gcp_options.append(f"--impersonate-service-account={impersonation_account}")

cmd = [
"gcloud",
"dataflow",
"sql",
"query",
query,
f"--project={project_id}",
"--format=value(job.id)",
f"--job-name={job_name}",
f"--region={location}",
*gcp_options,
*(beam_options_to_args(options)),
]
self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd))
Expand Down
22 changes: 22 additions & 0 deletions airflow/providers/google/cloud/operators/dataflow.py
Expand Up @@ -727,6 +727,14 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False
to the operator, the second loop will check once is job not in terminal state and exit the loop.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields: Sequence[str] = ("body", "location", "project_id", "gcp_conn_id")
Expand All @@ -742,6 +750,7 @@ def __init__(
drain_pipeline: bool = False,
cancel_timeout: Optional[int] = 10 * 60,
wait_until_finished: Optional[bool] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
*args,
**kwargs,
) -> None:
Expand All @@ -756,6 +765,7 @@ def __init__(
self.wait_until_finished = wait_until_finished
self.job = None
self.hook: Optional[DataflowHook] = None
self.impersonation_chain = impersonation_chain

def execute(self, context: 'Context'):
self.hook = DataflowHook(
Expand All @@ -764,6 +774,7 @@ def execute(self, context: 'Context'):
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
impersonation_chain=self.impersonation_chain,
)

def set_current_job(current_job):
Expand Down Expand Up @@ -821,6 +832,14 @@ class DataflowStartSqlJobOperator(BaseOperator):
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields: Sequence[str] = (
Expand All @@ -843,6 +862,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
drain_pipeline: bool = False,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
*args,
**kwargs,
) -> None:
Expand All @@ -855,6 +875,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.drain_pipeline = drain_pipeline
self.impersonation_chain = impersonation_chain
self.job = None
self.hook: Optional[DataflowHook] = None

Expand All @@ -863,6 +884,7 @@ def execute(self, context: 'Context'):
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
drain_pipeline=self.drain_pipeline,
impersonation_chain=self.impersonation_chain,
)

def set_current_job(current_job):
Expand Down
17 changes: 3 additions & 14 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Expand Up @@ -173,20 +173,10 @@ def test_fn(self, *args, **kwargs):
FixtureFallback().test_fn({'project': "TEST"}, "TEST2")


def mock_init(
self,
gcp_conn_id,
delegate_to=None,
impersonation_chain=None,
):
pass


class TestDataflowHook(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init):
self.dataflow_hook = DataflowHook(gcp_conn_id='test')
self.dataflow_hook.beam_hook = MagicMock()
self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default')
self.dataflow_hook.beam_hook = MagicMock()

@mock.patch("airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize")
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.build")
Expand Down Expand Up @@ -792,8 +782,7 @@ def test_wait_for_done(self, mock_conn, mock_dataflowjob):

class TestDataflowTemplateHook(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init):
self.dataflow_hook = DataflowHook(gcp_conn_id='test')
self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default')

@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
Expand Down
13 changes: 12 additions & 1 deletion tests/providers/google/cloud/operators/test_dataflow.py
Expand Up @@ -496,6 +496,14 @@ def test_execute(self, mock_dataflow):
location=TEST_LOCATION,
)
start_flex_template.execute(mock.MagicMock())
mock_dataflow.assert_called_once_with(
gcp_conn_id='google_cloud_default',
delegate_to=None,
drain_pipeline=False,
cancel_timeout=600,
wait_until_finished=None,
impersonation_chain=None,
)
mock_dataflow.return_value.start_flex_template.assert_called_once_with(
body={"launchParameter": TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
Expand Down Expand Up @@ -533,7 +541,10 @@ def test_execute(self, mock_hook):

start_sql.execute(mock.MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id='google_cloud_default', delegate_to=None, drain_pipeline=False
gcp_conn_id='google_cloud_default',
delegate_to=None,
drain_pipeline=False,
impersonation_chain=None,
)
mock_hook.return_value.start_sql_job.assert_called_once_with(
job_name=TEST_SQL_JOB_NAME,
Expand Down

0 comments on commit 90233bc

Please sign in to comment.