Skip to content

Commit

Permalink
Support multiple mount points in Vault backend secret (#29734)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Feb 24, 2023
1 parent d078374 commit dff425b
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 29 deletions.
40 changes: 26 additions & 14 deletions airflow/providers/hashicorp/_internal_client/vault_client.py
Expand Up @@ -89,7 +89,7 @@ def __init__(
url: str | None = None,
auth_type: str = "token",
auth_mount_point: str | None = None,
mount_point: str = "secret",
mount_point: str | None = "secret",
kv_engine_version: int | None = None,
token: str | None = None,
token_path: str | None = None,
Expand Down Expand Up @@ -324,6 +324,15 @@ def _set_token(self, _client: hvac.Client) -> None:
else:
_client.token = self.token

def _parse_secret_path(self, secret_path: str) -> tuple[str, str]:
if not self.mount_point:
split_secret_path = secret_path.split("/", 1)
if len(split_secret_path) < 2:
raise InvalidPath
return split_secret_path[0], split_secret_path[1]
else:
return self.mount_point, secret_path

def get_secret(self, secret_path: str, secret_version: int | None = None) -> dict | None:
"""
Get secret value from the KV engine.
Expand All @@ -337,19 +346,19 @@ def get_secret(self, secret_path: str, secret_version: int | None = None) -> dic
:return: secret stored in the vault as a dictionary
"""
mount_point = None
try:
mount_point, secret_path = self._parse_secret_path(secret_path)
if self.kv_engine_version == 1:
if secret_version:
raise VaultError("Secret version can only be used with version 2 of the KV engine")
response = self.client.secrets.kv.v1.read_secret(
path=secret_path, mount_point=self.mount_point
)
response = self.client.secrets.kv.v1.read_secret(path=secret_path, mount_point=mount_point)
else:
response = self.client.secrets.kv.v2.read_secret_version(
path=secret_path, mount_point=self.mount_point, version=secret_version
path=secret_path, mount_point=mount_point, version=secret_version
)
except InvalidPath:
self.log.debug("Secret not found %s with mount point %s", secret_path, self.mount_point)
self.log.debug("Secret not found %s with mount point %s", secret_path, mount_point)
return None

return_data = response["data"] if self.kv_engine_version == 1 else response["data"]["data"]
Expand All @@ -367,12 +376,12 @@ def get_secret_metadata(self, secret_path: str) -> dict | None:
"""
if self.kv_engine_version == 1:
raise VaultError("Metadata might only be used with version 2 of the KV engine.")
mount_point = None
try:
return self.client.secrets.kv.v2.read_secret_metadata(
path=secret_path, mount_point=self.mount_point
)
mount_point, secret_path = self._parse_secret_path(secret_path)
return self.client.secrets.kv.v2.read_secret_metadata(path=secret_path, mount_point=mount_point)
except InvalidPath:
self.log.debug("Secret not found %s with mount point %s", secret_path, self.mount_point)
self.log.debug("Secret not found %s with mount point %s", secret_path, mount_point)
return None

def get_secret_including_metadata(
Expand All @@ -391,15 +400,17 @@ def get_secret_including_metadata(
"""
if self.kv_engine_version == 1:
raise VaultError("Metadata might only be used with version 2 of the KV engine.")
mount_point = None
try:
mount_point, secret_path = self._parse_secret_path(secret_path)
return self.client.secrets.kv.v2.read_secret_version(
path=secret_path, mount_point=self.mount_point, version=secret_version
path=secret_path, mount_point=mount_point, version=secret_version
)
except InvalidPath:
self.log.debug(
"Secret not found %s with mount point %s and version %s",
secret_path,
self.mount_point,
mount_point,
secret_version,
)
return None
Expand Down Expand Up @@ -429,12 +440,13 @@ def create_or_update_secret(
raise VaultError("The method parameter is only valid for version 1")
if self.kv_engine_version == 1 and cas:
raise VaultError("The cas parameter is only valid for version 2")
mount_point, secret_path = self._parse_secret_path(secret_path)
if self.kv_engine_version == 1:
response = self.client.secrets.kv.v1.create_or_update_secret(
secret_path=secret_path, secret=secret, mount_point=self.mount_point, method=method
secret_path=secret_path, secret=secret, mount_point=mount_point, method=method
)
else:
response = self.client.secrets.kv.v2.create_or_update_secret(
secret_path=secret_path, secret=secret, mount_point=self.mount_point, cas=cas
secret_path=secret_path, secret=secret, mount_point=mount_point, cas=cas
)
return response
41 changes: 30 additions & 11 deletions airflow/providers/hashicorp/secrets/vault.py
Expand Up @@ -59,7 +59,8 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin):
Default depends on the authentication method used.
:param mount_point: The "path" the secret engine was mounted on. Default is "secret". Note that
this mount_point is not used for authentication if authentication is done via a
different engine. For authentication mount_points see, auth_mount_point.
different engine. If set to None, the mount secret should be provided as a prefix for each
variable/connection_id. For authentication mount_points see, auth_mount_point.
:param kv_engine_version: Select the version of the engine to run (``1`` or ``2``, default: ``2``).
:param token: Authentication token to include in requests sent to Vault.
(for ``token`` and ``github`` auth_type)
Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(
url: str | None = None,
auth_type: str = "token",
auth_mount_point: str | None = None,
mount_point: str = "secret",
mount_point: str | None = "secret",
kv_engine_version: int = 2,
token: str | None = None,
token_path: str | None = None,
Expand Down Expand Up @@ -156,17 +157,29 @@ def __init__(
**kwargs,
)

def _parse_path(self, secret_path: str) -> tuple[str | None, str | None]:
if not self.mount_point:
split_secret_path = secret_path.split("/", 1)
if len(split_secret_path) < 2:
return None, None
return split_secret_path[0], split_secret_path[1]
else:
return "", secret_path

def get_response(self, conn_id: str) -> dict | None:
"""
Get data from Vault
:return: The data from the Vault path if exists
"""
if self.connections_path is None:
mount_point, conn_key = self._parse_path(conn_id)
if self.connections_path is None or conn_key is None:
return None

secret_path = self.build_path(self.connections_path, conn_id)
return self.vault_client.get_secret(secret_path=secret_path)
secret_path = self.build_path(self.connections_path, conn_key)
return self.vault_client.get_secret(
secret_path=(mount_point + "/" if mount_point else "") + secret_path
)

def get_conn_uri(self, conn_id: str) -> str | None:
"""
Expand Down Expand Up @@ -219,11 +232,14 @@ def get_variable(self, key: str) -> str | None:
:param key: Variable Key
:return: Variable Value retrieved from the vault
"""
if self.variables_path is None:
mount_point, variable_key = self._parse_path(key)
if self.variables_path is None or variable_key is None:
return None
else:
secret_path = self.build_path(self.variables_path, key)
response = self.vault_client.get_secret(secret_path=secret_path)
secret_path = self.build_path(self.variables_path, variable_key)
response = self.vault_client.get_secret(
secret_path=(mount_point + "/" if mount_point else "") + secret_path
)
return response.get("value") if response else None

def get_config(self, key: str) -> str | None:
Expand All @@ -233,9 +249,12 @@ def get_config(self, key: str) -> str | None:
:param key: Configuration Option Key
:return: Configuration Option Value retrieved from the vault
"""
if self.config_path is None:
mount_point, config_key = self._parse_path(key)
if self.config_path is None or config_key is None:
return None
else:
secret_path = self.build_path(self.config_path, key)
response = self.vault_client.get_secret(secret_path=secret_path)
secret_path = self.build_path(self.config_path, config_key)
response = self.vault_client.get_secret(
secret_path=(mount_point + "/" if mount_point else "") + secret_path
)
return response.get("value") if response else None
79 changes: 75 additions & 4 deletions tests/providers/hashicorp/_internal_client/test_vault_client.py
Expand Up @@ -661,10 +661,48 @@ def test_get_existing_key_v2(self, mock_hvac):
radius_secret="pass",
url="http://localhost:8180",
)
secret = vault_client.get_secret(secret_path="missing")
secret = vault_client.get_secret(secret_path="path/to/secret")
assert {"secret_key": "secret_value"} == secret
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point="secret", path="missing", version=None
mount_point="secret", path="path/to/secret", version=None
)

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_existing_key_v2_without_preconfigured_mount_point(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client

mock_client.secrets.kv.v2.read_secret_version.return_value = {
"request_id": "94011e25-f8dc-ec29-221b-1f9c1d9ad2ae",
"lease_id": "",
"renewable": False,
"lease_duration": 0,
"data": {
"data": {"secret_key": "secret_value"},
"metadata": {
"created_time": "2020-03-16T21:01:43.331126Z",
"deletion_time": "",
"destroyed": False,
"version": 1,
},
},
"wrap_info": None,
"warnings": None,
"auth": None,
}

vault_client = _VaultClient(
auth_type="radius",
radius_host="radhost",
radius_port=8110,
radius_secret="pass",
url="http://localhost:8180",
mount_point=None,
)
secret = vault_client.get_secret(secret_path="mount_point/path/to/secret")
assert {"secret_key": "secret_value"} == secret
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point="mount_point", path="path/to/secret", version=None
)

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
Expand Down Expand Up @@ -728,9 +766,42 @@ def test_get_existing_key_v1(self, mock_hvac):
kv_engine_version=1,
url="http://localhost:8180",
)
secret = vault_client.get_secret(secret_path="missing")
secret = vault_client.get_secret(secret_path="/path/to/secret")
assert {"value": "world"} == secret
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point="secret", path="missing")
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point="secret", path="/path/to/secret"
)

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_existing_key_v1_without_preconfigured_mount_point(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client

mock_client.secrets.kv.v1.read_secret.return_value = {
"request_id": "182d0673-618c-9889-4cba-4e1f4cfe4b4b",
"lease_id": "",
"renewable": False,
"lease_duration": 2764800,
"data": {"value": "world"},
"wrap_info": None,
"warnings": None,
"auth": None,
}

vault_client = _VaultClient(
auth_type="radius",
radius_host="radhost",
radius_port=8110,
radius_secret="pass",
kv_engine_version=1,
url="http://localhost:8180",
mount_point=None,
)
secret = vault_client.get_secret(secret_path="mount_point/path/to/secret")
assert {"value": "world"} == secret
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point="mount_point", path="path/to/secret"
)

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_existing_key_v1_different_auth_mount_point(self, mock_hvac):
Expand Down

0 comments on commit dff425b

Please sign in to comment.