Skip to content

Commit

Permalink
Introduce anonymous credentials in GCP base hook (apache#39695)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 authored and RNHTTR committed Jun 1, 2024
1 parent 6377c5f commit 8822f43
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 49 deletions.
73 changes: 41 additions & 32 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from urllib.parse import urlencode

import google.auth
import google.auth.credentials
import google.oauth2.service_account
from google.auth import impersonated_credentials # type: ignore[attr-defined]
from google.auth.credentials import AnonymousCredentials, Credentials
from google.auth.environment_vars import CREDENTIALS, LEGACY_PROJECT, PROJECT

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -178,6 +178,7 @@ class _CredentialProvider(LoggingMixin):
:param key_secret_name: Keyfile Secret Name in GCP Secret Manager.
:param key_secret_project_id: Project ID to read the secrets from. If not passed, the project ID from
default credentials will be used.
:param credential_config_file: File path to or content of a GCP credential configuration file.
:param scopes: OAuth scopes for the connection
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
Expand All @@ -192,6 +193,8 @@ class to configure Logger.
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account and target_principal
granting the role to the last account from the list.
:param is_anonymous: Provides an anonymous set of credentials,
which is useful for APIs which do not require authentication.
"""

def __init__(
Expand All @@ -206,13 +209,14 @@ def __init__(
disable_logging: bool = False,
target_principal: str | None = None,
delegates: Sequence[str] | None = None,
is_anonymous: bool | None = None,
) -> None:
super().__init__()
key_options = [key_path, key_secret_name, keyfile_dict]
key_options = [key_path, keyfile_dict, credential_config_file, key_secret_name, is_anonymous]
if len([x for x in key_options if x]) > 1:
raise AirflowException(
"The `keyfile_dict`, `key_path`, and `key_secret_name` fields "
"are all mutually exclusive. Please provide only one value."
"The `keyfile_dict`, `key_path`, `credential_config_file`, `is_anonymous` and"
" `key_secret_name` fields are all mutually exclusive. Please provide only one value."
)
self.key_path = key_path
self.keyfile_dict = keyfile_dict
Expand All @@ -224,43 +228,48 @@ def __init__(
self.disable_logging = disable_logging
self.target_principal = target_principal
self.delegates = delegates
self.is_anonymous = is_anonymous

def get_credentials_and_project(self) -> tuple[google.auth.credentials.Credentials, str]:
def get_credentials_and_project(self) -> tuple[Credentials, str]:
"""
Get current credentials and project ID.
Project ID is an empty string when using anonymous credentials.
:return: Google Auth Credentials
"""
if self.key_path:
credentials, project_id = self._get_credentials_using_key_path()
elif self.key_secret_name:
credentials, project_id = self._get_credentials_using_key_secret_name()
elif self.keyfile_dict:
credentials, project_id = self._get_credentials_using_keyfile_dict()
elif self.credential_config_file:
credentials, project_id = self._get_credentials_using_credential_config_file()
if self.is_anonymous:
credentials, project_id = AnonymousCredentials(), ""
else:
credentials, project_id = self._get_credentials_using_adc()

if self.delegate_to:
if hasattr(credentials, "with_subject"):
credentials = credentials.with_subject(self.delegate_to)
if self.key_path:
credentials, project_id = self._get_credentials_using_key_path()
elif self.key_secret_name:
credentials, project_id = self._get_credentials_using_key_secret_name()
elif self.keyfile_dict:
credentials, project_id = self._get_credentials_using_keyfile_dict()
elif self.credential_config_file:
credentials, project_id = self._get_credentials_using_credential_config_file()
else:
raise AirflowException(
"The `delegate_to` parameter cannot be used here as the current "
"authentication method does not support account impersonate. "
"Please use service-account for authorization."
credentials, project_id = self._get_credentials_using_adc()
if self.delegate_to:
if hasattr(credentials, "with_subject"):
credentials = credentials.with_subject(self.delegate_to)
else:
raise AirflowException(
"The `delegate_to` parameter cannot be used here as the current "
"authentication method does not support account impersonate. "
"Please use service-account for authorization."
)

if self.target_principal:
credentials = impersonated_credentials.Credentials(
source_credentials=credentials,
target_principal=self.target_principal,
delegates=self.delegates,
target_scopes=self.scopes,
)

if self.target_principal:
credentials = impersonated_credentials.Credentials(
source_credentials=credentials,
target_principal=self.target_principal,
delegates=self.delegates,
target_scopes=self.scopes,
)

project_id = _get_project_id_from_service_account_email(self.target_principal)
project_id = _get_project_id_from_service_account_email(self.target_principal)

return credentials, project_id

Expand Down Expand Up @@ -357,7 +366,7 @@ def _log_debug(self, *args, **kwargs) -> None:
self.log.debug(*args, **kwargs)


def get_credentials_and_project_id(*args, **kwargs) -> tuple[google.auth.credentials.Credentials, str]:
def get_credentials_and_project_id(*args, **kwargs) -> tuple[Credentials, str]:
"""Return the Credentials object for Google API and the associated project_id."""
return _CredentialProvider(*args, **kwargs).get_credentials_and_project()

Expand Down
16 changes: 11 additions & 5 deletions airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, TypeVar, cast

import google.auth
import google.auth.credentials
import google.oauth2.service_account
import google_auth_httplib2
import requests
Expand Down Expand Up @@ -223,7 +222,7 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Return connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import IntegerField, PasswordField, StringField
from wtforms import BooleanField, IntegerField, PasswordField, StringField
from wtforms.validators import NumberRange

return {
Expand All @@ -249,6 +248,9 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
"impersonation_chain": StringField(
lazy_gettext("Impersonation Chain"), widget=BS3TextFieldWidget()
),
"is_anonymous": BooleanField(
lazy_gettext("Anonymous credentials (ignores all other settings)"), default=False
),
}

@classmethod
Expand All @@ -270,10 +272,10 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.extras: dict = self.get_connection(self.gcp_conn_id).extra_dejson
self._cached_credentials: google.auth.credentials.Credentials | None = None
self._cached_credentials: Credentials | None = None
self._cached_project_id: str | None = None

def get_credentials_and_project_id(self) -> tuple[google.auth.credentials.Credentials, str | None]:
def get_credentials_and_project_id(self) -> tuple[Credentials, str | None]:
"""Return the Credentials object for Google API and the associated project_id."""
if self._cached_credentials is not None:
return self._cached_credentials, self._cached_project_id
Expand Down Expand Up @@ -301,6 +303,7 @@ def get_credentials_and_project_id(self) -> tuple[google.auth.credentials.Creden
self.impersonation_chain = [s.strip() for s in self.impersonation_chain.split(",")]

target_principal, delegates = _get_target_principal_and_delegates(self.impersonation_chain)
is_anonymous = self._get_field("is_anonymous")

credentials, project_id = get_credentials_and_project_id(
key_path=key_path,
Expand All @@ -312,6 +315,7 @@ def get_credentials_and_project_id(self) -> tuple[google.auth.credentials.Creden
delegate_to=self.delegate_to,
target_principal=target_principal,
delegates=delegates,
is_anonymous=is_anonymous,
)

overridden_project_id = self._get_field("project")
Expand All @@ -323,7 +327,7 @@ def get_credentials_and_project_id(self) -> tuple[google.auth.credentials.Creden

return credentials, project_id

def get_credentials(self) -> google.auth.credentials.Credentials:
def get_credentials(self) -> Credentials:
"""Return the Credentials object for Google API."""
credentials, _ = self.get_credentials_and_project_id()
return credentials
Expand Down Expand Up @@ -655,6 +659,8 @@ def download_content_from_request(file_handle, request: dict, chunk_size: int) -
def test_connection(self):
"""Test the Google cloud connectivity from UI."""
status, message = False, ""
if self._get_field("is_anonymous"):
return True, "Credentials are anonymous"
try:
token = self._get_access_token()
url = f"https://www.googleapis.com/oauth2/v3/tokeninfo?access_token={token}"
Expand Down
12 changes: 6 additions & 6 deletions tests/providers/google/cloud/utils/test_credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,14 @@ def test_get_credentials_and_project_id_with_key_secret_name_when_key_is_invalid
get_credentials_and_project_id(key_secret_name="secret name")

def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(self):
with pytest.raises(
AirflowException,
match=re.escape(
"The `keyfile_dict`, `key_path`, and `key_secret_name` fields are all mutually exclusive."
),
):
with pytest.raises(AirflowException, match="mutually exclusive."):
get_credentials_and_project_id(key_path="KEY.json", keyfile_dict={"private_key": "PRIVATE_KEY"})

@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.AnonymousCredentials")
def test_get_credentials_using_anonymous_credentials(self, mock_anonymous_credentials):
result = get_credentials_and_project_id(is_anonymous=True)
assert result == (mock_anonymous_credentials.return_value, "")

@mock.patch("google.auth.default", return_value=("CREDENTIALS", "PROJECT_ID"))
@mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
Expand Down
32 changes: 26 additions & 6 deletions tests/providers/google/common/hooks/test_base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_get_creds_a
delegate_to=None,
target_principal=None,
delegates=None,
is_anonymous=None,
)
assert ("CREDENTIALS", "PROJECT_ID") == result

Expand Down Expand Up @@ -449,6 +450,7 @@ def test_get_credentials_and_project_id_with_service_account_file(self, mock_get
delegate_to=None,
target_principal=None,
delegates=None,
is_anonymous=None,
)
assert (mock_credentials, "PROJECT_ID") == result

Expand Down Expand Up @@ -479,6 +481,7 @@ def test_get_credentials_and_project_id_with_service_account_info(self, mock_get
delegate_to=None,
target_principal=None,
delegates=None,
is_anonymous=None,
)
assert (mock_credentials, "PROJECT_ID") == result

Expand All @@ -499,6 +502,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, moc
delegate_to="USER",
target_principal=None,
delegates=None,
is_anonymous=None,
)
assert (mock_credentials, "PROJECT_ID") == result

Expand Down Expand Up @@ -535,6 +539,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_overridden_project
delegate_to=None,
target_principal=None,
delegates=None,
is_anonymous=None,
)
assert ("CREDENTIALS", "SECOND_PROJECT_ID") == result

Expand All @@ -544,12 +549,7 @@ def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(se
"key_path": "KEY_PATH",
"keyfile_dict": '{"KEY": "VALUE"}',
}
with pytest.raises(
AirflowException,
match=re.escape(
"The `keyfile_dict`, `key_path`, and `key_secret_name` fields are all mutually exclusive. "
),
):
with pytest.raises(AirflowException, match="mutually exclusive"):
self.instance.get_credentials_and_project_id()

def test_get_credentials_and_project_id_with_invalid_keyfile_dict(self):
Expand All @@ -559,6 +559,25 @@ def test_get_credentials_and_project_id_with_invalid_keyfile_dict(self):
with pytest.raises(AirflowException, match=re.escape("Invalid key JSON.")):
self.instance.get_credentials_and_project_id()

@mock.patch(MODULE_NAME + ".get_credentials_and_project_id", return_value=("CREDENTIALS", ""))
def test_get_credentials_and_project_id_with_is_anonymous(self, mock_get_creds_and_proj_id):
self.instance.extras = {
"is_anonymous": True,
}
self.instance.get_credentials_and_project_id()
mock_get_creds_and_proj_id.assert_called_once_with(
key_path=None,
keyfile_dict=None,
credential_config_file=None,
key_secret_name=None,
key_secret_project_id=None,
scopes=self.instance.scopes,
delegate_to=None,
target_principal=None,
delegates=None,
is_anonymous=True,
)

@pytest.mark.skipif(
not default_creds_available, reason="Default Google Cloud credentials not available to run tests"
)
Expand Down Expand Up @@ -764,6 +783,7 @@ def test_get_credentials_and_project_id_with_impersonation_chain(
delegate_to=None,
target_principal=target_principal,
delegates=delegates,
is_anonymous=None,
)
assert (mock_credentials, PROJECT_ID) == result

Expand Down

0 comments on commit 8822f43

Please sign in to comment.