Skip to content

Commit

Permalink
support iam token from metadata, simplify code (#38411)
Browse files Browse the repository at this point in the history
  • Loading branch information
uzhastik committed Mar 22, 2024
1 parent fd5fe8d commit 30817a5
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 51 deletions.
54 changes: 17 additions & 37 deletions airflow/providers/yandex/hooks/yq.py
Expand Up @@ -16,16 +16,14 @@
# under the License.
from __future__ import annotations

import time
from datetime import timedelta
from typing import Any

import jwt
import requests
from urllib3.util.retry import Retry
import yandexcloud
import yandexcloud._auth_fabric as auth_fabric
from yandex.cloud.iam.v1.iam_token_service_pb2_grpc import IamTokenServiceStub
from yandex_query_client import YQHttpClient, YQHttpClientConfig

from airflow.exceptions import AirflowException
from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook
from airflow.providers.yandex.utils.user_agent import provider_user_agent

Expand Down Expand Up @@ -98,35 +96,17 @@ def compose_query_web_link(self, query_id: str):
return self.client.compose_query_web_link(query_id)

def _get_iam_token(self) -> str:
if "token" in self.credentials:
return self.credentials["token"]
if "service_account_key" in self.credentials:
return YQHook._resolve_service_account_key(self.credentials["service_account_key"])
raise AirflowException(f"Unknown credentials type, available keys {self.credentials.keys()}")

@staticmethod
def _resolve_service_account_key(sa_info: dict) -> str:
with YQHook._create_session() as session:
api = "https://iam.api.cloud.yandex.net/iam/v1/tokens"
now = int(time.time())
payload = {"aud": api, "iss": sa_info["service_account_id"], "iat": now, "exp": now + 360}

encoded_token = jwt.encode(
payload, sa_info["private_key"], algorithm="PS256", headers={"kid": sa_info["id"]}
)

data = {"jwt": encoded_token}
iam_response = session.post(api, json=data)
iam_response.raise_for_status()

return iam_response.json()["iamToken"]

@staticmethod
def _create_session() -> requests.Session:
session = requests.Session()
session.verify = False
retry = Retry(backoff_factor=0.3, total=10)
session.mount("http://", requests.adapters.HTTPAdapter(max_retries=retry))
session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retry))

return session
iam_token = self.credentials.get("token")
if iam_token is not None:
return iam_token

service_account_key = self.credentials.get("service_account_key")
# if service_account_key is None metadata server will be used
token_requester = auth_fabric.get_auth_token_requester(service_account_key=service_account_key)

if service_account_key is None:
return token_requester.get_token()

sdk = yandexcloud.SDK()
client = sdk.client(IamTokenServiceStub)
return client.Create(token_requester.get_token_request()).iam_token
3 changes: 0 additions & 3 deletions airflow/providers/yandex/provider.yaml
Expand Up @@ -50,9 +50,6 @@ dependencies:
- apache-airflow>=2.6.0
- yandexcloud>=0.228.0
- yandex-query-client>=0.1.2
- python-dateutil>=2.8.0
# Requests 3 if it will be released, will be heavily breaking.
- requests>=2.27.0,<3

integrations:
- integration-name: Yandex.Cloud
Expand Down
2 changes: 0 additions & 2 deletions generated/provider_dependencies.json
Expand Up @@ -1180,8 +1180,6 @@
"yandex": {
"deps": [
"apache-airflow>=2.6.0",
"python-dateutil>=2.8.0",
"requests>=2.27.0,<3",
"yandex-query-client>=0.1.2",
"yandexcloud>=0.228.0"
],
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Expand Up @@ -975,8 +975,6 @@ weaviate = [ # source: airflow/providers/weaviate/provider.yaml
"weaviate-client>=3.24.2",
]
yandex = [ # source: airflow/providers/yandex/provider.yaml
"python-dateutil>=2.8.0",
"requests>=2.27.0,<3",
"yandex-query-client>=0.1.2",
"yandexcloud>=0.228.0",
]
Expand Down
49 changes: 42 additions & 7 deletions tests/providers/yandex/hooks/test_yq.py
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import json
from datetime import timedelta
from unittest import mock

Expand All @@ -26,6 +27,7 @@
from airflow.providers.yandex.hooks.yq import YQHook

OAUTH_TOKEN = "my_oauth_token"
IAM_TOKEN = "my_iam_token"
SERVICE_ACCOUNT_AUTH_KEY_JSON = """{"id":"my_id", "service_account_id":"my_sa1", "private_key":"my_pk"}"""


Expand All @@ -34,6 +36,18 @@ def __init__(self) -> None:
self.client = None


class DummyTokenRequester:
def get_token(self) -> str:
return IAM_TOKEN

def get_token_request(self) -> str:
return "my_dummy_request"


class DummyCreateTokenResponse:
iam_token = "zzz"


class TestYandexCloudYqHook:
def _init_hook(self):
with mock.patch("airflow.hooks.base.BaseHook.get_connection") as mock_get_connection:
Expand Down Expand Up @@ -68,18 +82,33 @@ def test_oauth_token_usage(self):
m.assert_called_once_with("query1")

@responses.activate()
@mock.patch("yandexcloud.SDK")
@mock.patch("jwt.encode")
def test_select_results(self, mock_jwt, mock_sdk):
@mock.patch("yandexcloud._auth_fabric.get_auth_token_requester", return_value=DummyTokenRequester())
def test_metadata_token_usage(self, mock_get_auth_token_requester):
responses.post(
"https://iam.api.cloud.yandex.net/iam/v1/tokens",
json={"iamToken": "super_token"},
"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries",
match=[
matchers.header_matcher(
{"Content-Type": "application/json", "Authorization": f"Bearer {IAM_TOKEN}"}
),
matchers.query_param_matcher({"project": "my_folder_id"}),
],
json={"id": "query1"},
status=200,
)

mock_jwt.return_value = "zzzz"
mock_sdk.return_value = DummySDK()
self.connection = Connection(extra={})
self._init_hook()
query_id = self.hook.create_query(query_text="select 777", name="my query")
assert query_id == "query1"

@mock.patch(
"yandex.cloud.iam.v1.iam_token_service_pb2_grpc.IamTokenServiceStub.Create",
create=True,
new_callable=mock.PropertyMock,
)
@mock.patch("yandexcloud._auth_fabric.__validate_service_account_key")
@mock.patch("yandexcloud._auth_fabric.get_auth_token_requester", return_value=DummyTokenRequester())
def test_select_results(self, mock_get_auth_token_requester, mock_validate, mock_create_token):
with mock.patch.multiple(
"yandex_query_client.YQHttpClient",
create_query=mock.DEFAULT,
Expand All @@ -90,6 +119,12 @@ def test_select_results(self, mock_jwt, mock_sdk):
stop_query=mock.DEFAULT,
) as mocks:
self._init_hook()
mock_validate.assert_called()
mock_create_token.assert_called()
mock_get_auth_token_requester.assert_called_once_with(
service_account_key=json.loads(SERVICE_ACCOUNT_AUTH_KEY_JSON)
)

mocks["create_query"].return_value = "query1"
mocks["wait_query_to_succeed"].return_value = 2
mocks["get_query_all_result_sets"].return_value = {"x": 765}
Expand Down

0 comments on commit 30817a5

Please sign in to comment.