diff --git a/airflow/providers/hashicorp/_internal_client/vault_client.py b/airflow/providers/hashicorp/_internal_client/vault_client.py index 076a8696667f4..ea8aaf0071230 100644 --- a/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -89,7 +89,7 @@ def __init__( url: str | None = None, auth_type: str = "token", auth_mount_point: str | None = None, - mount_point: str = "secret", + mount_point: str | None = "secret", kv_engine_version: int | None = None, token: str | None = None, token_path: str | None = None, @@ -324,6 +324,15 @@ def _set_token(self, _client: hvac.Client) -> None: else: _client.token = self.token + def _parse_secret_path(self, secret_path: str) -> tuple[str, str]: + if not self.mount_point: + split_secret_path = secret_path.split("/", 1) + if len(split_secret_path) < 2: + raise InvalidPath + return split_secret_path[0], split_secret_path[1] + else: + return self.mount_point, secret_path + def get_secret(self, secret_path: str, secret_version: int | None = None) -> dict | None: """ Get secret value from the KV engine. @@ -337,19 +346,19 @@ def get_secret(self, secret_path: str, secret_version: int | None = None) -> dic :return: secret stored in the vault as a dictionary """ + mount_point = None try: + mount_point, secret_path = self._parse_secret_path(secret_path) if self.kv_engine_version == 1: if secret_version: raise VaultError("Secret version can only be used with version 2 of the KV engine") - response = self.client.secrets.kv.v1.read_secret( - path=secret_path, mount_point=self.mount_point - ) + response = self.client.secrets.kv.v1.read_secret(path=secret_path, mount_point=mount_point) else: response = self.client.secrets.kv.v2.read_secret_version( - path=secret_path, mount_point=self.mount_point, version=secret_version + path=secret_path, mount_point=mount_point, version=secret_version ) except InvalidPath: - self.log.debug("Secret not found %s with mount point %s", secret_path, self.mount_point) + self.log.debug("Secret not found %s with mount point %s", secret_path, mount_point) return None return_data = response["data"] if self.kv_engine_version == 1 else response["data"]["data"] @@ -367,12 +376,12 @@ def get_secret_metadata(self, secret_path: str) -> dict | None: """ if self.kv_engine_version == 1: raise VaultError("Metadata might only be used with version 2 of the KV engine.") + mount_point = None try: - return self.client.secrets.kv.v2.read_secret_metadata( - path=secret_path, mount_point=self.mount_point - ) + mount_point, secret_path = self._parse_secret_path(secret_path) + return self.client.secrets.kv.v2.read_secret_metadata(path=secret_path, mount_point=mount_point) except InvalidPath: - self.log.debug("Secret not found %s with mount point %s", secret_path, self.mount_point) + self.log.debug("Secret not found %s with mount point %s", secret_path, mount_point) return None def get_secret_including_metadata( @@ -391,15 +400,17 @@ def get_secret_including_metadata( """ if self.kv_engine_version == 1: raise VaultError("Metadata might only be used with version 2 of the KV engine.") + mount_point = None try: + mount_point, secret_path = self._parse_secret_path(secret_path) return self.client.secrets.kv.v2.read_secret_version( - path=secret_path, mount_point=self.mount_point, version=secret_version + path=secret_path, mount_point=mount_point, version=secret_version ) except InvalidPath: self.log.debug( "Secret not found %s with mount point %s and version %s", secret_path, - self.mount_point, + mount_point, secret_version, ) return None @@ -429,12 +440,13 @@ def create_or_update_secret( raise VaultError("The method parameter is only valid for version 1") if self.kv_engine_version == 1 and cas: raise VaultError("The cas parameter is only valid for version 2") + mount_point, secret_path = self._parse_secret_path(secret_path) if self.kv_engine_version == 1: response = self.client.secrets.kv.v1.create_or_update_secret( - secret_path=secret_path, secret=secret, mount_point=self.mount_point, method=method + secret_path=secret_path, secret=secret, mount_point=mount_point, method=method ) else: response = self.client.secrets.kv.v2.create_or_update_secret( - secret_path=secret_path, secret=secret, mount_point=self.mount_point, cas=cas + secret_path=secret_path, secret=secret, mount_point=mount_point, cas=cas ) return response diff --git a/airflow/providers/hashicorp/secrets/vault.py b/airflow/providers/hashicorp/secrets/vault.py index 9c22ff71d66b8..79943aacd7e8b 100644 --- a/airflow/providers/hashicorp/secrets/vault.py +++ b/airflow/providers/hashicorp/secrets/vault.py @@ -59,7 +59,8 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin): Default depends on the authentication method used. :param mount_point: The "path" the secret engine was mounted on. Default is "secret". Note that this mount_point is not used for authentication if authentication is done via a - different engine. For authentication mount_points see, auth_mount_point. + different engine. If set to None, the mount secret should be provided as a prefix for each + variable/connection_id. For authentication mount_points see, auth_mount_point. :param kv_engine_version: Select the version of the engine to run (``1`` or ``2``, default: ``2``). :param token: Authentication token to include in requests sent to Vault. (for ``token`` and ``github`` auth_type) @@ -94,7 +95,7 @@ def __init__( url: str | None = None, auth_type: str = "token", auth_mount_point: str | None = None, - mount_point: str = "secret", + mount_point: str | None = "secret", kv_engine_version: int = 2, token: str | None = None, token_path: str | None = None, @@ -156,17 +157,29 @@ def __init__( **kwargs, ) + def _parse_path(self, secret_path: str) -> tuple[str | None, str | None]: + if not self.mount_point: + split_secret_path = secret_path.split("/", 1) + if len(split_secret_path) < 2: + return None, None + return split_secret_path[0], split_secret_path[1] + else: + return "", secret_path + def get_response(self, conn_id: str) -> dict | None: """ Get data from Vault :return: The data from the Vault path if exists """ - if self.connections_path is None: + mount_point, conn_key = self._parse_path(conn_id) + if self.connections_path is None or conn_key is None: return None - secret_path = self.build_path(self.connections_path, conn_id) - return self.vault_client.get_secret(secret_path=secret_path) + secret_path = self.build_path(self.connections_path, conn_key) + return self.vault_client.get_secret( + secret_path=(mount_point + "/" if mount_point else "") + secret_path + ) def get_conn_uri(self, conn_id: str) -> str | None: """ @@ -219,11 +232,14 @@ def get_variable(self, key: str) -> str | None: :param key: Variable Key :return: Variable Value retrieved from the vault """ - if self.variables_path is None: + mount_point, variable_key = self._parse_path(key) + if self.variables_path is None or variable_key is None: return None else: - secret_path = self.build_path(self.variables_path, key) - response = self.vault_client.get_secret(secret_path=secret_path) + secret_path = self.build_path(self.variables_path, variable_key) + response = self.vault_client.get_secret( + secret_path=(mount_point + "/" if mount_point else "") + secret_path + ) return response.get("value") if response else None def get_config(self, key: str) -> str | None: @@ -233,9 +249,12 @@ def get_config(self, key: str) -> str | None: :param key: Configuration Option Key :return: Configuration Option Value retrieved from the vault """ - if self.config_path is None: + mount_point, config_key = self._parse_path(key) + if self.config_path is None or config_key is None: return None else: - secret_path = self.build_path(self.config_path, key) - response = self.vault_client.get_secret(secret_path=secret_path) + secret_path = self.build_path(self.config_path, config_key) + response = self.vault_client.get_secret( + secret_path=(mount_point + "/" if mount_point else "") + secret_path + ) return response.get("value") if response else None diff --git a/tests/providers/hashicorp/_internal_client/test_vault_client.py b/tests/providers/hashicorp/_internal_client/test_vault_client.py index 1bab652dc01d9..6fd8dcf6a76b6 100644 --- a/tests/providers/hashicorp/_internal_client/test_vault_client.py +++ b/tests/providers/hashicorp/_internal_client/test_vault_client.py @@ -661,10 +661,48 @@ def test_get_existing_key_v2(self, mock_hvac): radius_secret="pass", url="http://localhost:8180", ) - secret = vault_client.get_secret(secret_path="missing") + secret = vault_client.get_secret(secret_path="path/to/secret") assert {"secret_key": "secret_value"} == secret mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( - mount_point="secret", path="missing", version=None + mount_point="secret", path="path/to/secret", version=None + ) + + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + def test_get_existing_key_v2_without_preconfigured_mount_point(self, mock_hvac): + mock_client = mock.MagicMock() + mock_hvac.Client.return_value = mock_client + + mock_client.secrets.kv.v2.read_secret_version.return_value = { + "request_id": "94011e25-f8dc-ec29-221b-1f9c1d9ad2ae", + "lease_id": "", + "renewable": False, + "lease_duration": 0, + "data": { + "data": {"secret_key": "secret_value"}, + "metadata": { + "created_time": "2020-03-16T21:01:43.331126Z", + "deletion_time": "", + "destroyed": False, + "version": 1, + }, + }, + "wrap_info": None, + "warnings": None, + "auth": None, + } + + vault_client = _VaultClient( + auth_type="radius", + radius_host="radhost", + radius_port=8110, + radius_secret="pass", + url="http://localhost:8180", + mount_point=None, + ) + secret = vault_client.get_secret(secret_path="mount_point/path/to/secret") + assert {"secret_key": "secret_value"} == secret + mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( + mount_point="mount_point", path="path/to/secret", version=None ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -728,9 +766,42 @@ def test_get_existing_key_v1(self, mock_hvac): kv_engine_version=1, url="http://localhost:8180", ) - secret = vault_client.get_secret(secret_path="missing") + secret = vault_client.get_secret(secret_path="/path/to/secret") assert {"value": "world"} == secret - mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point="secret", path="missing") + mock_client.secrets.kv.v1.read_secret.assert_called_once_with( + mount_point="secret", path="/path/to/secret" + ) + + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + def test_get_existing_key_v1_without_preconfigured_mount_point(self, mock_hvac): + mock_client = mock.MagicMock() + mock_hvac.Client.return_value = mock_client + + mock_client.secrets.kv.v1.read_secret.return_value = { + "request_id": "182d0673-618c-9889-4cba-4e1f4cfe4b4b", + "lease_id": "", + "renewable": False, + "lease_duration": 2764800, + "data": {"value": "world"}, + "wrap_info": None, + "warnings": None, + "auth": None, + } + + vault_client = _VaultClient( + auth_type="radius", + radius_host="radhost", + radius_port=8110, + radius_secret="pass", + kv_engine_version=1, + url="http://localhost:8180", + mount_point=None, + ) + secret = vault_client.get_secret(secret_path="mount_point/path/to/secret") + assert {"value": "world"} == secret + mock_client.secrets.kv.v1.read_secret.assert_called_once_with( + mount_point="mount_point", path="path/to/secret" + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_existing_key_v1_different_auth_mount_point(self, mock_hvac): diff --git a/tests/providers/hashicorp/secrets/test_vault.py b/tests/providers/hashicorp/secrets/test_vault.py index a29e6dc21e8c3..309dbd9d6a663 100644 --- a/tests/providers/hashicorp/secrets/test_vault.py +++ b/tests/providers/hashicorp/secrets/test_vault.py @@ -60,6 +60,41 @@ def test_get_conn_uri(self, mock_hvac): returned_uri = test_client.get_conn_uri(conn_id="test_postgres") assert "postgresql://airflow:airflow@host:5432/airflow" == returned_uri + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + def test_get_conn_uri_without_predefined_mount_point(self, mock_hvac): + mock_client = mock.MagicMock() + mock_hvac.Client.return_value = mock_client + mock_client.secrets.kv.v2.read_secret_version.return_value = { + "request_id": "94011e25-f8dc-ec29-221b-1f9c1d9ad2ae", + "lease_id": "", + "renewable": False, + "lease_duration": 0, + "data": { + "data": {"conn_uri": "postgresql://airflow:airflow@host:5432/airflow"}, + "metadata": { + "created_time": "2020-03-16T21:01:43.331126Z", + "deletion_time": "", + "destroyed": False, + "version": 1, + }, + }, + "wrap_info": None, + "warnings": None, + "auth": None, + } + + kwargs = { + "connections_path": "connections", + "mount_point": None, + "auth_type": "token", + "url": "http://127.0.0.1:8200", + "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", + } + + test_client = VaultBackend(**kwargs) + returned_uri = test_client.get_conn_uri(conn_id="airflow/test_postgres") + assert "postgresql://airflow:airflow@host:5432/airflow" == returned_uri + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_connection(self, mock_hvac): mock_client = mock.MagicMock() @@ -103,6 +138,49 @@ def test_get_connection(self, mock_hvac): connection = test_client.get_connection(conn_id="test_postgres") assert "postgresql://airflow:airflow@host:5432/airflow?foo=bar&baz=taz" == connection.get_uri() + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + def test_get_connection_without_predefined_mount_point(self, mock_hvac): + mock_client = mock.MagicMock() + mock_hvac.Client.return_value = mock_client + mock_client.secrets.kv.v2.read_secret_version.return_value = { + "request_id": "94011e25-f8dc-ec29-221b-1f9c1d9ad2ae", + "lease_id": "", + "renewable": False, + "lease_duration": 0, + "data": { + "data": { + "conn_type": "postgresql", + "login": "airflow", + "password": "airflow", + "host": "host", + "port": "5432", + "schema": "airflow", + "extra": '{"foo":"bar","baz":"taz"}', + }, + "metadata": { + "created_time": "2020-03-16T21:01:43.331126Z", + "deletion_time": "", + "destroyed": False, + "version": 1, + }, + }, + "wrap_info": None, + "warnings": None, + "auth": None, + } + + kwargs = { + "connections_path": "connections", + "mount_point": None, + "auth_type": "token", + "url": "http://127.0.0.1:8200", + "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", + } + + test_client = VaultBackend(**kwargs) + connection = test_client.get_connection(conn_id="airflow/test_postgres") + assert "postgresql://airflow:airflow@host:5432/airflow?foo=bar&baz=taz" == connection.get_uri() + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_conn_uri_engine_version_1(self, mock_hvac): mock_client = mock.MagicMock() @@ -234,6 +312,41 @@ def test_get_variable_value(self, mock_hvac): returned_uri = test_client.get_variable("hello") assert "world" == returned_uri + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + def test_get_variable_value_without_predefined_mount_point(self, mock_hvac): + mock_client = mock.MagicMock() + mock_hvac.Client.return_value = mock_client + mock_client.secrets.kv.v2.read_secret_version.return_value = { + "request_id": "2d48a2ad-6bcb-e5b6-429d-da35fdf31f56", + "lease_id": "", + "renewable": False, + "lease_duration": 0, + "data": { + "data": {"value": "world"}, + "metadata": { + "created_time": "2020-03-28T02:10:54.301784Z", + "deletion_time": "", + "destroyed": False, + "version": 1, + }, + }, + "wrap_info": None, + "warnings": None, + "auth": None, + } + + kwargs = { + "variables_path": "variables", + "mount_point": None, + "auth_type": "token", + "url": "http://127.0.0.1:8200", + "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", + } + + test_client = VaultBackend(**kwargs) + returned_uri = test_client.get_variable("airflow/hello") + assert "world" == returned_uri + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_variable_value_engine_version_1(self, mock_hvac): mock_client = mock.MagicMock() @@ -265,6 +378,37 @@ def test_get_variable_value_engine_version_1(self, mock_hvac): ) assert "world" == returned_uri + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + def test_get_variable_value_engine_version_1_without_predefined_mount_point(self, mock_hvac): + mock_client = mock.MagicMock() + mock_hvac.Client.return_value = mock_client + mock_client.secrets.kv.v1.read_secret.return_value = { + "request_id": "182d0673-618c-9889-4cba-4e1f4cfe4b4b", + "lease_id": "", + "renewable": False, + "lease_duration": 2764800, + "data": {"value": "world"}, + "wrap_info": None, + "warnings": None, + "auth": None, + } + + kwargs = { + "variables_path": "variables", + "mount_point": None, + "auth_type": "token", + "url": "http://127.0.0.1:8200", + "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", + "kv_engine_version": 1, + } + + test_client = VaultBackend(**kwargs) + returned_uri = test_client.get_variable("airflow/hello") + mock_client.secrets.kv.v1.read_secret.assert_called_once_with( + mount_point="airflow", path="variables/hello" + ) + assert "world" == returned_uri + @mock.patch.dict( "os.environ", { @@ -361,6 +505,41 @@ def test_get_config_value(self, mock_hvac): returned_uri = test_client.get_config("sql_alchemy_conn") assert "sqlite:////Users/airflow/airflow/airflow.db" == returned_uri + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + def test_get_config_value_without_predefined_mount_point(self, mock_hvac): + mock_client = mock.MagicMock() + mock_hvac.Client.return_value = mock_client + mock_client.secrets.kv.v2.read_secret_version.return_value = { + "request_id": "2d48a2ad-6bcb-e5b6-429d-da35fdf31f56", + "lease_id": "", + "renewable": False, + "lease_duration": 0, + "data": { + "data": {"value": "sqlite:////Users/airflow/airflow/airflow.db"}, + "metadata": { + "created_time": "2020-03-28T02:10:54.301784Z", + "deletion_time": "", + "destroyed": False, + "version": 1, + }, + }, + "wrap_info": None, + "warnings": None, + "auth": None, + } + + kwargs = { + "configs_path": "configurations", + "mount_point": None, + "auth_type": "token", + "url": "http://127.0.0.1:8200", + "token": "s.FnL7qg0YnHZDpf4zKKuFy0UK", + } + + test_client = VaultBackend(**kwargs) + returned_uri = test_client.get_config("airflow/sql_alchemy_conn") + assert "sqlite:////Users/airflow/airflow/airflow.db" == returned_uri + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_connections_path_none_value(self, mock_hvac): mock_client = mock.MagicMock()