Skip to content

Commit

Permalink
Fix empty paths in Vault secrets backend (#29908)
Browse files Browse the repository at this point in the history
* fix empty variables, config and connections paths

* add tests for empty paths with and without mount point
  • Loading branch information
hussein-awala committed Mar 4, 2023
1 parent 9710394 commit 4fa91d7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 55 deletions.
26 changes: 16 additions & 10 deletions airflow/providers/hashicorp/secrets/vault.py
Expand Up @@ -175,8 +175,10 @@ def get_response(self, conn_id: str) -> dict | None:
mount_point, conn_key = self._parse_path(conn_id)
if self.connections_path is None or conn_key is None:
return None

secret_path = self.build_path(self.connections_path, conn_key)
if self.connections_path == "":
secret_path = conn_key
else:
secret_path = self.build_path(self.connections_path, conn_key)
return self.vault_client.get_secret(
secret_path=(mount_point + "/" if mount_point else "") + secret_path
)
Expand Down Expand Up @@ -235,12 +237,14 @@ def get_variable(self, key: str) -> str | None:
mount_point, variable_key = self._parse_path(key)
if self.variables_path is None or variable_key is None:
return None
if self.variables_path == "":
secret_path = variable_key
else:
secret_path = self.build_path(self.variables_path, variable_key)
response = self.vault_client.get_secret(
secret_path=(mount_point + "/" if mount_point else "") + secret_path
)
return response.get("value") if response else None
response = self.vault_client.get_secret(
secret_path=(mount_point + "/" if mount_point else "") + secret_path
)
return response.get("value") if response else None

def get_config(self, key: str) -> str | None:
"""
Expand All @@ -252,9 +256,11 @@ def get_config(self, key: str) -> str | None:
mount_point, config_key = self._parse_path(key)
if self.config_path is None or config_key is None:
return None
if self.config_path == "":
secret_path = config_key
else:
secret_path = self.build_path(self.config_path, config_key)
response = self.vault_client.get_secret(
secret_path=(mount_point + "/" if mount_point else "") + secret_path
)
return response.get("value") if response else None
response = self.vault_client.get_secret(
secret_path=(mount_point + "/" if mount_point else "") + secret_path
)
return response.get("value") if response else None
107 changes: 62 additions & 45 deletions tests/providers/hashicorp/secrets/test_vault.py
Expand Up @@ -181,8 +181,39 @@ def test_get_connection_without_predefined_mount_point(self, mock_hvac):
connection = test_client.get_connection(conn_id="airflow/test_postgres")
assert "postgresql://airflow:airflow@host:5432/airflow?foo=bar&baz=taz" == connection.get_uri()

@pytest.mark.parametrize(
"mount_point, connections_path, conn_id, expected_args",
[
(
"airflow",
"connections",
"test_postgres",
{"mount_point": "airflow", "path": "connections/test_postgres"},
),
(
"airflow",
"",
"path/to/connections/test_postgres",
{"mount_point": "airflow", "path": "path/to/connections/test_postgres"},
),
(
None,
"connections",
"airflow/test_postgres",
{"mount_point": "airflow", "path": "connections/test_postgres"},
),
(
None,
"",
"airflow/path/to/connections/test_postgres",
{"mount_point": "airflow", "path": "path/to/connections/test_postgres"},
),
],
)
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_conn_uri_engine_version_1(self, mock_hvac):
def test_get_conn_uri_engine_version_1(
self, mock_hvac, mount_point, connections_path, conn_id, expected_args
):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_client.secrets.kv.v1.read_secret.return_value = {
Expand All @@ -197,19 +228,17 @@ def test_get_conn_uri_engine_version_1(self, mock_hvac):
}

kwargs = {
"connections_path": "connections",
"mount_point": "airflow",
"connections_path": connections_path,
"mount_point": mount_point,
"auth_type": "token",
"url": "http://127.0.0.1:8200",
"token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS",
"kv_engine_version": 1,
}

test_client = VaultBackend(**kwargs)
returned_uri = test_client.get_conn_uri(conn_id="test_postgres")
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point="airflow", path="connections/test_postgres"
)
returned_uri = test_client.get_conn_uri(conn_id=conn_id)
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(**expected_args)
assert "postgresql://airflow:airflow@host:5432/airflow" == returned_uri

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
Expand Down Expand Up @@ -347,39 +376,29 @@ def test_get_variable_value_without_predefined_mount_point(self, mock_hvac):
returned_uri = test_client.get_variable("airflow/hello")
assert "world" == returned_uri

@pytest.mark.parametrize(
"mount_point, variables_path, variable_key, expected_args",
[
("airflow", "variables", "hello", {"mount_point": "airflow", "path": "variables/hello"}),
(
"airflow",
"",
"path/to/variables/hello",
{"mount_point": "airflow", "path": "path/to/variables/hello"},
),
(None, "variables", "airflow/hello", {"mount_point": "airflow", "path": "variables/hello"}),
(
None,
"",
"airflow/path/to/variables/hello",
{"mount_point": "airflow", "path": "path/to/variables/hello"},
),
],
)
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_variable_value_engine_version_1(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_client.secrets.kv.v1.read_secret.return_value = {
"request_id": "182d0673-618c-9889-4cba-4e1f4cfe4b4b",
"lease_id": "",
"renewable": False,
"lease_duration": 2764800,
"data": {"value": "world"},
"wrap_info": None,
"warnings": None,
"auth": None,
}

kwargs = {
"variables_path": "variables",
"mount_point": "airflow",
"auth_type": "token",
"url": "http://127.0.0.1:8200",
"token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS",
"kv_engine_version": 1,
}

test_client = VaultBackend(**kwargs)
returned_uri = test_client.get_variable("hello")
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point="airflow", path="variables/hello"
)
assert "world" == returned_uri

@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_variable_value_engine_version_1_without_predefined_mount_point(self, mock_hvac):
def test_get_variable_value_engine_version_1(
self, mock_hvac, mount_point, variables_path, variable_key, expected_args
):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_client.secrets.kv.v1.read_secret.return_value = {
Expand All @@ -394,19 +413,17 @@ def test_get_variable_value_engine_version_1_without_predefined_mount_point(self
}

kwargs = {
"variables_path": "variables",
"mount_point": None,
"variables_path": variables_path,
"mount_point": mount_point,
"auth_type": "token",
"url": "http://127.0.0.1:8200",
"token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS",
"kv_engine_version": 1,
}

test_client = VaultBackend(**kwargs)
returned_uri = test_client.get_variable("airflow/hello")
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point="airflow", path="variables/hello"
)
returned_uri = test_client.get_variable(variable_key)
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(**expected_args)
assert "world" == returned_uri

@mock.patch.dict(
Expand Down

0 comments on commit 4fa91d7

Please sign in to comment.