diff --git a/airflow/providers/hashicorp/secrets/vault.py b/airflow/providers/hashicorp/secrets/vault.py index 79943aacd7e8b..bf8b78c00a1b2 100644 --- a/airflow/providers/hashicorp/secrets/vault.py +++ b/airflow/providers/hashicorp/secrets/vault.py @@ -175,8 +175,10 @@ def get_response(self, conn_id: str) -> dict | None: mount_point, conn_key = self._parse_path(conn_id) if self.connections_path is None or conn_key is None: return None - - secret_path = self.build_path(self.connections_path, conn_key) + if self.connections_path == "": + secret_path = conn_key + else: + secret_path = self.build_path(self.connections_path, conn_key) return self.vault_client.get_secret( secret_path=(mount_point + "/" if mount_point else "") + secret_path ) @@ -235,12 +237,14 @@ def get_variable(self, key: str) -> str | None: mount_point, variable_key = self._parse_path(key) if self.variables_path is None or variable_key is None: return None + if self.variables_path == "": + secret_path = variable_key else: secret_path = self.build_path(self.variables_path, variable_key) - response = self.vault_client.get_secret( - secret_path=(mount_point + "/" if mount_point else "") + secret_path - ) - return response.get("value") if response else None + response = self.vault_client.get_secret( + secret_path=(mount_point + "/" if mount_point else "") + secret_path + ) + return response.get("value") if response else None def get_config(self, key: str) -> str | None: """ @@ -252,9 +256,11 @@ def get_config(self, key: str) -> str | None: mount_point, config_key = self._parse_path(key) if self.config_path is None or config_key is None: return None + if self.config_path == "": + secret_path = config_key else: secret_path = self.build_path(self.config_path, config_key) - response = self.vault_client.get_secret( - secret_path=(mount_point + "/" if mount_point else "") + secret_path - ) - return response.get("value") if response else None + response = self.vault_client.get_secret( + secret_path=(mount_point + "/" if mount_point else "") + secret_path + ) + return response.get("value") if response else None diff --git a/tests/providers/hashicorp/secrets/test_vault.py b/tests/providers/hashicorp/secrets/test_vault.py index 309dbd9d6a663..4897a73c22334 100644 --- a/tests/providers/hashicorp/secrets/test_vault.py +++ b/tests/providers/hashicorp/secrets/test_vault.py @@ -181,8 +181,39 @@ def test_get_connection_without_predefined_mount_point(self, mock_hvac): connection = test_client.get_connection(conn_id="airflow/test_postgres") assert "postgresql://airflow:airflow@host:5432/airflow?foo=bar&baz=taz" == connection.get_uri() + @pytest.mark.parametrize( + "mount_point, connections_path, conn_id, expected_args", + [ + ( + "airflow", + "connections", + "test_postgres", + {"mount_point": "airflow", "path": "connections/test_postgres"}, + ), + ( + "airflow", + "", + "path/to/connections/test_postgres", + {"mount_point": "airflow", "path": "path/to/connections/test_postgres"}, + ), + ( + None, + "connections", + "airflow/test_postgres", + {"mount_point": "airflow", "path": "connections/test_postgres"}, + ), + ( + None, + "", + "airflow/path/to/connections/test_postgres", + {"mount_point": "airflow", "path": "path/to/connections/test_postgres"}, + ), + ], + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_get_conn_uri_engine_version_1(self, mock_hvac): + def test_get_conn_uri_engine_version_1( + self, mock_hvac, mount_point, connections_path, conn_id, expected_args + ): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_client.secrets.kv.v1.read_secret.return_value = { @@ -197,8 +228,8 @@ def test_get_conn_uri_engine_version_1(self, mock_hvac): } kwargs = { - "connections_path": "connections", - "mount_point": "airflow", + "connections_path": connections_path, + "mount_point": mount_point, "auth_type": "token", "url": "http://127.0.0.1:8200", "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", @@ -206,10 +237,8 @@ def test_get_conn_uri_engine_version_1(self, mock_hvac): } test_client = VaultBackend(**kwargs) - returned_uri = test_client.get_conn_uri(conn_id="test_postgres") - mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point="airflow", path="connections/test_postgres" - ) + returned_uri = test_client.get_conn_uri(conn_id=conn_id) + mock_client.secrets.kv.v1.read_secret.assert_called_once_with(**expected_args) assert "postgresql://airflow:airflow@host:5432/airflow" == returned_uri @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") @@ -347,39 +376,29 @@ def test_get_variable_value_without_predefined_mount_point(self, mock_hvac): returned_uri = test_client.get_variable("airflow/hello") assert "world" == returned_uri + @pytest.mark.parametrize( + "mount_point, variables_path, variable_key, expected_args", + [ + ("airflow", "variables", "hello", {"mount_point": "airflow", "path": "variables/hello"}), + ( + "airflow", + "", + "path/to/variables/hello", + {"mount_point": "airflow", "path": "path/to/variables/hello"}, + ), + (None, "variables", "airflow/hello", {"mount_point": "airflow", "path": "variables/hello"}), + ( + None, + "", + "airflow/path/to/variables/hello", + {"mount_point": "airflow", "path": "path/to/variables/hello"}, + ), + ], + ) @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_get_variable_value_engine_version_1(self, mock_hvac): - mock_client = mock.MagicMock() - mock_hvac.Client.return_value = mock_client - mock_client.secrets.kv.v1.read_secret.return_value = { - "request_id": "182d0673-618c-9889-4cba-4e1f4cfe4b4b", - "lease_id": "", - "renewable": False, - "lease_duration": 2764800, - "data": {"value": "world"}, - "wrap_info": None, - "warnings": None, - "auth": None, - } - - kwargs = { - "variables_path": "variables", - "mount_point": "airflow", - "auth_type": "token", - "url": "http://127.0.0.1:8200", - "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", - "kv_engine_version": 1, - } - - test_client = VaultBackend(**kwargs) - returned_uri = test_client.get_variable("hello") - mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point="airflow", path="variables/hello" - ) - assert "world" == returned_uri - - @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_get_variable_value_engine_version_1_without_predefined_mount_point(self, mock_hvac): + def test_get_variable_value_engine_version_1( + self, mock_hvac, mount_point, variables_path, variable_key, expected_args + ): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_client.secrets.kv.v1.read_secret.return_value = { @@ -394,8 +413,8 @@ def test_get_variable_value_engine_version_1_without_predefined_mount_point(self } kwargs = { - "variables_path": "variables", - "mount_point": None, + "variables_path": variables_path, + "mount_point": mount_point, "auth_type": "token", "url": "http://127.0.0.1:8200", "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", @@ -403,10 +422,8 @@ def test_get_variable_value_engine_version_1_without_predefined_mount_point(self } test_client = VaultBackend(**kwargs) - returned_uri = test_client.get_variable("airflow/hello") - mock_client.secrets.kv.v1.read_secret.assert_called_once_with( - mount_point="airflow", path="variables/hello" - ) + returned_uri = test_client.get_variable(variable_key) + mock_client.secrets.kv.v1.read_secret.assert_called_once_with(**expected_args) assert "world" == returned_uri @mock.patch.dict(