diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 94cc420bae34d..77fd9394cfc62 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -216,6 +216,9 @@ def get_connection_form_widgets() -> dict[str, Any]: widget=BS3TextFieldWidget(), default=5, ), + "impersonation_chain": StringField( + lazy_gettext("Impersonation Chain"), widget=BS3TextFieldWidget() + ), } @staticmethod @@ -262,6 +265,9 @@ def get_credentials_and_project_id(self) -> tuple[google.auth.credentials.Creden credential_config_file: str | None = self._get_field("credential_config_file", None) + if not self.impersonation_chain: + self.impersonation_chain = self._get_field("impersonation_chain", None) + target_principal, delegates = _get_target_principal_and_delegates(self.impersonation_chain) credentials, project_id = get_credentials_and_project_id( diff --git a/docs/apache-airflow-providers-google/connections/gcp.rst b/docs/apache-airflow-providers-google/connections/gcp.rst index 59a40ec24b643..83bab9ca90b75 100644 --- a/docs/apache-airflow-providers-google/connections/gcp.rst +++ b/docs/apache-airflow-providers-google/connections/gcp.rst @@ -125,6 +125,16 @@ Number of Retries represents the last request. If zero (default), we attempt the request only once. +Impersonation Chain + Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in all requests leveraging this connection. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account. + When specifying the connection in environment variable you should specify it using URI syntax, with the following requirements: @@ -142,6 +152,7 @@ Number of Retries * ``scope`` - Scopes * ``num_retries`` - Number of Retries + Note that all components of the URI should be URL-encoded. For example, with URI format: @@ -165,6 +176,8 @@ Google operators support `direct impersonation of a service account `_ via ``impersonation_chain`` argument (``google_impersonation_chain`` in case of operators that also communicate with services of other cloud providers). +The impersonation chain can also be configured directly on the Google Cloud Connection +as described above, though the ``impersonation_chain`` passed to the operator takes precedence. For example: diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index 6ca5f19fc1a12..feb9a93a661d1 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -661,16 +661,37 @@ def test_authorize_assert_http_timeout_is_present(self, mock_get_credentials): assert http_authorized.timeout is not None @pytest.mark.parametrize( - "impersonation_chain, target_principal, delegates", + "impersonation_chain, impersonation_chain_from_conn, target_principal, delegates", [ - pytest.param("ACCOUNT_1", "ACCOUNT_1", None, id="string"), - pytest.param(["ACCOUNT_1"], "ACCOUNT_1", [], id="single_element_list"), + pytest.param("ACCOUNT_1", None, "ACCOUNT_1", None, id="string"), + pytest.param(None, "ACCOUNT_1", "ACCOUNT_1", None, id="string_in_conn"), + pytest.param("ACCOUNT_2", "ACCOUNT_1", "ACCOUNT_2", None, id="string_with_override"), + pytest.param(["ACCOUNT_1"], None, "ACCOUNT_1", [], id="single_element_list"), + pytest.param(None, ["ACCOUNT_1"], "ACCOUNT_1", [], id="single_element_list_in_conn"), + pytest.param( + ["ACCOUNT_1"], ["ACCOUNT_2"], "ACCOUNT_1", [], id="single_element_list_with_override" + ), pytest.param( ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"], + None, "ACCOUNT_3", ["ACCOUNT_1", "ACCOUNT_2"], id="multiple_elements_list", ), + pytest.param( + None, + ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"], + "ACCOUNT_3", + ["ACCOUNT_1", "ACCOUNT_2"], + id="multiple_elements_list_in_conn", + ), + pytest.param( + ["ACCOUNT_2", "ACCOUNT_3", "ACCOUNT_4"], + ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"], + "ACCOUNT_4", + ["ACCOUNT_2", "ACCOUNT_3"], + id="multiple_elements_list_with_override", + ), ], ) @mock.patch(MODULE_NAME + ".get_credentials_and_project_id") @@ -678,12 +699,14 @@ def test_get_credentials_and_project_id_with_impersonation_chain( self, mock_get_creds_and_proj_id, impersonation_chain, + impersonation_chain_from_conn, target_principal, delegates, ): mock_credentials = mock.MagicMock() mock_get_creds_and_proj_id.return_value = (mock_credentials, PROJECT_ID) self.instance.impersonation_chain = impersonation_chain + self.instance.extras = {"impersonation_chain": impersonation_chain_from_conn} result = self.instance.get_credentials_and_project_id() mock_get_creds_and_proj_id.assert_called_once_with( key_path=None,