Skip to content

Commit

Permalink
Yandex dataproc deduce default service account (#35059)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Petr Reznikov <prez@yandex-team.ru>
  • Loading branch information
Piatachock and Petr Reznikov committed Nov 3, 2023
1 parent 13865ab commit 0b850a9
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 18 deletions.
55 changes: 38 additions & 17 deletions airflow/providers/yandex/hooks/yandex.py
Expand Up @@ -88,14 +88,21 @@ def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def provider_user_agent(cls) -> str | None:
"""Construct User-Agent from Airflow core & provider package versions."""
import airflow
from airflow import __version__ as airflow_version
from airflow.configuration import conf
from airflow.providers_manager import ProvidersManager

try:
manager = ProvidersManager()
provider_name = manager.hooks[cls.conn_type].package_name # type: ignore[union-attr]
provider = manager.providers[provider_name]
return f"apache-airflow/{airflow.__version__} {provider_name}/{provider.version}"
return " ".join(
(
conf.get("yandex", "sdk_user_agent_prefix", fallback=""),
f"apache-airflow/{airflow_version}",
f"{provider_name}/{provider.version}",
)
).strip()
except KeyError:
warnings.warn(f"Hook '{cls.hook_name}' info is not initialized in airflow.ProviderManager")
return None
Expand All @@ -115,6 +122,7 @@ def __init__(
yandex_conn_id: str | None = None,
default_folder_id: str | None = None,
default_public_ssh_key: str | None = None,
default_service_account_id: str | None = None,
) -> None:
super().__init__()
if connection_id:
Expand All @@ -129,31 +137,44 @@ def __init__(
credentials = self._get_credentials()
sdk_config = self._get_endpoint()
self.sdk = yandexcloud.SDK(user_agent=self.provider_user_agent(), **sdk_config, **credentials)
self.default_folder_id = default_folder_id or self._get_field("folder_id", False)
self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key", False)
self.default_folder_id = default_folder_id or self._get_field("folder_id")
self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key")
self.default_service_account_id = default_service_account_id or self._get_service_account_id()
self.client = self.sdk.client

def _get_credentials(self) -> dict[str, Any]:
service_account_json_path = self._get_field("service_account_json_path", False)
service_account_json = self._get_field("service_account_json", False)
oauth_token = self._get_field("oauth", False)
if not (service_account_json or oauth_token or service_account_json_path):
raise AirflowException(
"No credentials are found in connection. Specify either service account "
"authentication JSON or user OAuth token in Yandex.Cloud connection"
)
def _get_service_account_key(self) -> dict[str, str] | None:
service_account_json = self._get_field("service_account_json")
service_account_json_path = self._get_field("service_account_json_path")
if service_account_json_path:
with open(service_account_json_path) as infile:
service_account_json = infile.read()
if service_account_json:
service_account_key = json.loads(service_account_json)
return {"service_account_key": service_account_key}
else:
return json.loads(service_account_json)
return None

def _get_service_account_id(self) -> str | None:
sa_key = self._get_service_account_key()
if sa_key:
return sa_key.get("service_account_id")
return None

def _get_credentials(self) -> dict[str, Any]:
oauth_token = self._get_field("oauth")
if oauth_token:
return {"token": oauth_token}

service_account_key = self._get_service_account_key()
if service_account_key:
return {"service_account_key": service_account_key}

raise AirflowException(
"No credentials are found in connection. Specify either service account "
"authentication JSON or user OAuth token in Yandex.Cloud connection"
)

def _get_endpoint(self) -> dict[str, str]:
sdk_config = {}
endpoint = self._get_field("endpoint", None)
endpoint = self._get_field("endpoint")
if endpoint:
sdk_config["endpoint"] = endpoint
return sdk_config
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/yandex/operators/yandexcloud_dataproc.py
Expand Up @@ -194,7 +194,7 @@ def execute(self, context: Context) -> dict:
services=self.services,
s3_bucket=self.s3_bucket,
zone=self.zone,
service_account_id=self.service_account_id,
service_account_id=self.service_account_id or self.hook.default_service_account_id,
masternode_resource_preset=self.masternode_resource_preset,
masternode_disk_size=self.masternode_disk_size,
masternode_disk_type=self.masternode_disk_type,
Expand Down
12 changes: 12 additions & 0 deletions airflow/providers/yandex/provider.yaml
Expand Up @@ -70,3 +70,15 @@ hooks:
connection-types:
- hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook
connection-type: yandexcloud

config:
yandex:
description: This section contains settings for Yandex Cloud integration.
options:
sdk_user_agent_prefix:
description: |
Prefix for User-Agent header in Yandex.Cloud SDK requests
version_added: 3.6.0
type: string
example: ~
default: ""
12 changes: 12 additions & 0 deletions tests/providers/yandex/hooks/test_yandex.py
Expand Up @@ -25,6 +25,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook
from tests.test_utils.config import conf_vars


class TestYandexHook:
Expand Down Expand Up @@ -139,6 +140,17 @@ def test_get_endpoint_unspecified(self, get_credentials_mock, get_connection_moc

assert hook._get_endpoint() == {}

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials")
def test_sdk_user_agent(self, get_credentials_mock, get_connection_mock):
get_connection_mock.return_value = mock.Mock(connection_id="yandexcloud_default", extra_dejson="{}")
get_credentials_mock.return_value = {"token": 122323}
sdk_prefix = "MyAirflow"

with conf_vars({("yandex", "sdk_user_agent_prefix"): sdk_prefix}):
hook = YandexCloudBaseHook()
assert hook.sdk._channels._client_user_agent.startswith(sdk_prefix)

@pytest.mark.parametrize(
"uri",
[
Expand Down

0 comments on commit 0b850a9

Please sign in to comment.