diff --git a/airflow/providers/hashicorp/_internal_client/vault_client.py b/airflow/providers/hashicorp/_internal_client/vault_client.py index 0012d9580262c..e170eda7870a6 100644 --- a/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import os from functools import cached_property import hvac @@ -125,7 +126,7 @@ def __init__( raise VaultError( f"The auth_type is not supported: {auth_type}. It should be one of {VALID_AUTH_TYPES}" ) - if auth_type == "token" and not token and not token_path: + if auth_type == "token" and not token and not token_path and "VAULT_TOKEN" not in os.environ: raise VaultError("The 'token' authentication type requires 'token' or 'token_path'") if auth_type == "github" and not token and not token_path: raise VaultError("The 'github' authentication type requires 'token' or 'token_path'") @@ -151,7 +152,7 @@ def __init__( self.url = url self.auth_type = auth_type self.kwargs = kwargs - self.token = token + self.token = token or os.getenv("VAULT_TOKEN", None) self.token_path = token_path self.auth_mount_point = auth_mount_point self.mount_point = mount_point diff --git a/tests/providers/hashicorp/_internal_client/test_vault_client.py b/tests/providers/hashicorp/_internal_client/test_vault_client.py index 28c6944fa6b85..ba8e0a0cd715e 100644 --- a/tests/providers/hashicorp/_internal_client/test_vault_client.py +++ b/tests/providers/hashicorp/_internal_client/test_vault_client.py @@ -551,6 +551,19 @@ def test_token(self, mock_hvac): assert 2 == vault_client.kv_engine_version assert "secret" == vault_client.mount_point + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") + def test_token_in_env(self, mock_hvac, monkeypatch): + monkeypatch.setenv("VAULT_TOKEN", "s.7AU0I51yv1Q1lxOIg1F3ZRAS") + mock_client = mock.MagicMock() + mock_hvac.Client.return_value = mock_client + vault_client = _VaultClient(auth_type="token", url="http://localhost:8180", session=None) + client = vault_client.client + mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None) + client.is_authenticated.assert_called_with() + assert "s.7AU0I51yv1Q1lxOIg1F3ZRAS" == client.token + assert 2 == vault_client.kv_engine_version + assert "secret" == vault_client.mount_point + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_token_path(self, mock_hvac): mock_client = mock.MagicMock()