Skip to content

Commit

Permalink
Pass kwargs from vault hook to hvac client (#26680)
Browse files Browse the repository at this point in the history
Co-authored-by: phil <1bluek1te@gmail.com>
  • Loading branch information
bluek1te and phil committed Nov 2, 2022
1 parent 9c6d2ab commit 1a3f785
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 23 deletions.
61 changes: 38 additions & 23 deletions airflow/providers/hashicorp/hooks/vault.py
Expand Up @@ -30,6 +30,7 @@
DEFAULT_KV_ENGINE_VERSION,
_VaultClient,
)
from airflow.utils.helpers import merge_dicts


class VaultHook(BaseHook):
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
azure_resource: str | None = None,
radius_host: str | None = None,
radius_port: int | None = None,
**kwargs,
):
super().__init__()
self.connection = self.get_connection(vault_conn_id)
Expand All @@ -135,6 +137,11 @@ def __init__(
except ValueError:
raise VaultError(f"The version is not an int: {conn_version}. ")

client_kwargs = self.connection.extra_dejson.get("client_kwargs", {})

if kwargs:
client_kwargs = merge_dicts(client_kwargs, kwargs)

if auth_type == "approle":
if role_id:
warnings.warn(
Expand Down Expand Up @@ -179,6 +186,10 @@ def __init__(
else (None, None)
)

key_id = self.connection.extra_dejson.get("key_id")
if not key_id:
key_id = self.connection.login

if self.connection.conn_type == "vault":
conn_protocol = "http"
elif self.connection.conn_type == "vaults":
Expand All @@ -197,31 +208,35 @@ def __init__(
# Schema is really path in the Connection definition. This is pretty confusing because of URL schema
mount_point = self.connection.schema if self.connection.schema else "secret"

self.vault_client = _VaultClient(
url=url,
auth_type=auth_type,
auth_mount_point=auth_mount_point,
mount_point=mount_point,
kv_engine_version=kv_engine_version,
token=self.connection.password,
token_path=token_path,
username=self.connection.login,
password=self.connection.password,
key_id=self.connection.login,
secret_id=self.connection.password,
role_id=role_id,
kubernetes_role=kubernetes_role,
kubernetes_jwt_path=kubernetes_jwt_path,
gcp_key_path=gcp_key_path,
gcp_keyfile_dict=gcp_keyfile_dict,
gcp_scopes=gcp_scopes,
azure_tenant_id=azure_tenant_id,
azure_resource=azure_resource,
radius_host=radius_host,
radius_secret=self.connection.password,
radius_port=radius_port,
client_kwargs.update(
**dict(
url=url,
auth_type=auth_type,
auth_mount_point=auth_mount_point,
mount_point=mount_point,
kv_engine_version=kv_engine_version,
token=self.connection.password,
token_path=token_path,
username=self.connection.login,
password=self.connection.password,
key_id=self.connection.login,
secret_id=self.connection.password,
role_id=role_id,
kubernetes_role=kubernetes_role,
kubernetes_jwt_path=kubernetes_jwt_path,
gcp_key_path=gcp_key_path,
gcp_keyfile_dict=gcp_keyfile_dict,
gcp_scopes=gcp_scopes,
azure_tenant_id=azure_tenant_id,
azure_resource=azure_resource,
radius_host=radius_host,
radius_secret=self.connection.password,
radius_port=radius_port,
)
)

self.vault_client = _VaultClient(**client_kwargs)

def _get_kubernetes_parameters_from_connection(
self, kubernetes_jwt_path: str | None, kubernetes_role: str | None
) -> tuple[str, str | None]:
Expand Down
27 changes: 27 additions & 0 deletions tests/providers/hashicorp/hooks/test_vault.py
Expand Up @@ -626,6 +626,33 @@ def test_kubernetes_dejson(self, mock_kubernetes, mock_hvac, mock_get_connection
test_client.is_authenticated.assert_called_with()
assert 2 == test_hook.vault_client.kv_engine_version

@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_client_kwargs(self, mock_hvac, mock_get_connection):
"""This test checks that values in connection extras keyed with 'client_kwargs' will be
consumed by the underlying Hashicorp Vault client init. The order of precedence should
be kwargs (passed through the hook init) > client_kwargs (found in connection extras).
"""
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_connection = self.get_mock_connection()
mock_get_connection.return_value = mock_connection

connection_dict = {
"client_kwargs": {"namespace": "name", "timeout": 50, "generic_arg": "generic_val1"}
}

mock_connection.extra_dejson.get.side_effect = connection_dict.get
kwargs = {"vault_conn_id": "vault_conn_id", "generic_arg": "generic_val0"}
test_hook = VaultHook(**kwargs)
test_client = test_hook.get_conn()
mock_get_connection.assert_called_with("vault_conn_id")
mock_hvac.Client.assert_called_with(
url="http://localhost:8180", namespace="name", timeout=50, generic_arg="generic_val0"
)
test_client.is_authenticated.assert_called_with()
assert 2 == test_hook.vault_client.kv_engine_version

@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_ldap_init_params(self, mock_hvac, mock_get_connection):
Expand Down

0 comments on commit 1a3f785

Please sign in to comment.