diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index b54de6f319be0..e91b287bfadca 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -88,14 +88,21 @@ def get_connection_form_widgets() -> dict[str, Any]: @classmethod def provider_user_agent(cls) -> str | None: """Construct User-Agent from Airflow core & provider package versions.""" - import airflow + from airflow import __version__ as airflow_version + from airflow.configuration import conf from airflow.providers_manager import ProvidersManager try: manager = ProvidersManager() provider_name = manager.hooks[cls.conn_type].package_name # type: ignore[union-attr] provider = manager.providers[provider_name] - return f"apache-airflow/{airflow.__version__} {provider_name}/{provider.version}" + return " ".join( + ( + conf.get("yandex", "sdk_user_agent_prefix", fallback=""), + f"apache-airflow/{airflow_version}", + f"{provider_name}/{provider.version}", + ) + ).strip() except KeyError: warnings.warn(f"Hook '{cls.hook_name}' info is not initialized in airflow.ProviderManager") return None @@ -115,6 +122,7 @@ def __init__( yandex_conn_id: str | None = None, default_folder_id: str | None = None, default_public_ssh_key: str | None = None, + default_service_account_id: str | None = None, ) -> None: super().__init__() if connection_id: @@ -129,31 +137,44 @@ def __init__( credentials = self._get_credentials() sdk_config = self._get_endpoint() self.sdk = yandexcloud.SDK(user_agent=self.provider_user_agent(), **sdk_config, **credentials) - self.default_folder_id = default_folder_id or self._get_field("folder_id", False) - self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key", False) + self.default_folder_id = default_folder_id or self._get_field("folder_id") + self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key") + self.default_service_account_id = default_service_account_id or self._get_service_account_id() self.client = self.sdk.client - def _get_credentials(self) -> dict[str, Any]: - service_account_json_path = self._get_field("service_account_json_path", False) - service_account_json = self._get_field("service_account_json", False) - oauth_token = self._get_field("oauth", False) - if not (service_account_json or oauth_token or service_account_json_path): - raise AirflowException( - "No credentials are found in connection. Specify either service account " - "authentication JSON or user OAuth token in Yandex.Cloud connection" - ) + def _get_service_account_key(self) -> dict[str, str] | None: + service_account_json = self._get_field("service_account_json") + service_account_json_path = self._get_field("service_account_json_path") if service_account_json_path: with open(service_account_json_path) as infile: service_account_json = infile.read() if service_account_json: - service_account_key = json.loads(service_account_json) - return {"service_account_key": service_account_key} - else: + return json.loads(service_account_json) + return None + + def _get_service_account_id(self) -> str | None: + sa_key = self._get_service_account_key() + if sa_key: + return sa_key.get("service_account_id") + return None + + def _get_credentials(self) -> dict[str, Any]: + oauth_token = self._get_field("oauth") + if oauth_token: return {"token": oauth_token} + service_account_key = self._get_service_account_key() + if service_account_key: + return {"service_account_key": service_account_key} + + raise AirflowException( + "No credentials are found in connection. Specify either service account " + "authentication JSON or user OAuth token in Yandex.Cloud connection" + ) + def _get_endpoint(self) -> dict[str, str]: sdk_config = {} - endpoint = self._get_field("endpoint", None) + endpoint = self._get_field("endpoint") if endpoint: sdk_config["endpoint"] = endpoint return sdk_config diff --git a/airflow/providers/yandex/operators/yandexcloud_dataproc.py b/airflow/providers/yandex/operators/yandexcloud_dataproc.py index dfd3a07fe4f94..de4ea6e9c2bac 100644 --- a/airflow/providers/yandex/operators/yandexcloud_dataproc.py +++ b/airflow/providers/yandex/operators/yandexcloud_dataproc.py @@ -194,7 +194,7 @@ def execute(self, context: Context) -> dict: services=self.services, s3_bucket=self.s3_bucket, zone=self.zone, - service_account_id=self.service_account_id, + service_account_id=self.service_account_id or self.hook.default_service_account_id, masternode_resource_preset=self.masternode_resource_preset, masternode_disk_size=self.masternode_disk_size, masternode_disk_type=self.masternode_disk_type, diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index fcee1ecc61178..5217b6d2b831c 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -70,3 +70,15 @@ hooks: connection-types: - hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook connection-type: yandexcloud + +config: + yandex: + description: This section contains settings for Yandex Cloud integration. + options: + sdk_user_agent_prefix: + description: | + Prefix for User-Agent header in Yandex.Cloud SDK requests + version_added: 3.6.0 + type: string + example: ~ + default: "" diff --git a/tests/providers/yandex/hooks/test_yandex.py b/tests/providers/yandex/hooks/test_yandex.py index f802ba8d870ad..23b460dabfaba 100644 --- a/tests/providers/yandex/hooks/test_yandex.py +++ b/tests/providers/yandex/hooks/test_yandex.py @@ -25,6 +25,7 @@ from airflow.exceptions import AirflowException from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook +from tests.test_utils.config import conf_vars class TestYandexHook: @@ -139,6 +140,17 @@ def test_get_endpoint_unspecified(self, get_credentials_mock, get_connection_moc assert hook._get_endpoint() == {} + @mock.patch("airflow.hooks.base.BaseHook.get_connection") + @mock.patch("airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials") + def test_sdk_user_agent(self, get_credentials_mock, get_connection_mock): + get_connection_mock.return_value = mock.Mock(connection_id="yandexcloud_default", extra_dejson="{}") + get_credentials_mock.return_value = {"token": 122323} + sdk_prefix = "MyAirflow" + + with conf_vars({("yandex", "sdk_user_agent_prefix"): sdk_prefix}): + hook = YandexCloudBaseHook() + assert hook.sdk._channels._client_user_agent.startswith(sdk_prefix) + @pytest.mark.parametrize( "uri", [