Skip to content

Commit

Permalink
Add Secrets backend for Microsoft Azure Key Vault (#10898)
Browse files Browse the repository at this point in the history
(cherry picked from commit f77a11d)
  • Loading branch information
kaxil authored and potiuk committed Sep 18, 2020
1 parent 118faaf commit ea36166
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 15 deletions.
12 changes: 6 additions & 6 deletions CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -557,12 +557,12 @@ This is the full list of those extras:
.. START EXTRAS HERE
all, all_dbs, async, atlas, aws, azure, azure_blob_storage, azure_container_instances, azure_cosmos,
azure_data_lake, cassandra, celery, cgroups, cloudant, crypto, dask, databricks, datadog, devel,
devel_azure, devel_ci, devel_hadoop, doc, docker, druid, elasticsearch, emr, gcp, gcp_api,
github_enterprise, google_auth, grpc, hashicorp, hdfs, hive, jdbc, jira, kerberos, kubernetes, ldap,
mongo, mssql, mysql, oracle, papermill, password, pinot, postgres, presto, qds, rabbitmq, redis, s3,
salesforce, samba, segment, sendgrid, sentry, slack, snowflake, ssh, statsd, vertica, virtualenv,
webhdfs, winrm
azure_data_lake, azure_secrets, cassandra, celery, cgroups, cloudant, crypto, dask, databricks,
datadog, devel, devel_azure, devel_ci, devel_hadoop, doc, docker, druid, elasticsearch, emr, gcp,
gcp_api, github_enterprise, google_auth, grpc, hashicorp, hdfs, hive, jdbc, jira, kerberos,
kubernetes, ldap, mongo, mssql, mysql, oracle, papermill, password, pinot, postgres, presto, qds,
rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, slack, snowflake, ssh, statsd,
vertica, virtualenv, webhdfs, winrm

.. END EXTRAS HERE
Expand Down
12 changes: 6 additions & 6 deletions INSTALL
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ pip install . \
# START EXTRAS HERE

all, all_dbs, async, atlas, aws, azure, azure_blob_storage, azure_container_instances, azure_cosmos,
azure_data_lake, cassandra, celery, cgroups, cloudant, crypto, dask, databricks, datadog, devel,
devel_azure, devel_ci, devel_hadoop, doc, docker, druid, elasticsearch, emr, gcp, gcp_api,
github_enterprise, google_auth, grpc, hashicorp, hdfs, hive, jdbc, jira, kerberos, kubernetes, ldap,
mongo, mssql, mysql, oracle, papermill, password, pinot, postgres, presto, qds, rabbitmq, redis, s3,
salesforce, samba, segment, sendgrid, sentry, slack, snowflake, ssh, statsd, vertica, virtualenv,
webhdfs, winrm
azure_data_lake, azure_secrets, cassandra, celery, cgroups, cloudant, crypto, dask, databricks,
datadog, devel, devel_azure, devel_ci, devel_hadoop, doc, docker, druid, elasticsearch, emr, gcp,
gcp_api, github_enterprise, google_auth, grpc, hashicorp, hdfs, hive, jdbc, jira, kerberos,
kubernetes, ldap, mongo, mssql, mysql, oracle, papermill, password, pinot, postgres, presto, qds,
rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, slack, snowflake, ssh, statsd,
vertica, virtualenv, webhdfs, winrm

# END EXTRAS HERE

Expand Down
148 changes: 148 additions & 0 deletions airflow/contrib/secrets/azure_key_vault.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional

from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from cached_property import cached_property

from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin


class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Airflow Connections or Variables from Azure Key Vault secrets.
The Azure Key Vault can be configured as a secrets backend in the ``airflow.cfg``:
.. code-block:: ini
[secrets]
backend = airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend
backend_kwargs = {"connections_prefix": "airflow-connections", "vault_url": "<azure_key_vault_uri>"}
For example, if the secrets prefix is ``airflow-connections-smtp-default``, this would be accessible
if you provide ``{"connections_prefix": "airflow-connections"}`` and request conn_id ``smtp-default``.
And if variables prefix is ``airflow-variables-hello``, this would be accessible
if you provide ``{"variables_prefix": "airflow-variables"}`` and request variable key ``hello``.
:param connections_prefix: Specifies the prefix of the secret to read to get Connections
:type connections_prefix: str
:param variables_prefix: Specifies the prefix of the secret to read to get Variables
:type variables_prefix: str
:param config_prefix: Specifies the prefix of the secret to read to get Variables.
:type config_prefix: str
:param vault_url: The URL of an Azure Key Vault to use
:type vault_url: str
:param sep: separator used to concatenate secret_prefix and secret_id. Default: "-"
:type sep: str
"""

def __init__(
self,
connections_prefix='airflow-connections', # type: str
variables_prefix='airflow-variables', # type: str
config_prefix='airflow-config', # type: str
vault_url='', # type: str
sep='-', # type: str
**kwargs
):
super(AzureKeyVaultBackend, self).__init__()
self.vault_url = vault_url
self.connections_prefix = connections_prefix.rstrip(sep)
self.variables_prefix = variables_prefix.rstrip(sep)
self.config_prefix = config_prefix.rstrip(sep)
self.sep = sep
self.kwargs = kwargs

@cached_property
def client(self):
"""
Create a Azure Key Vault client.
"""
credential = DefaultAzureCredential()
client = SecretClient(vault_url=self.vault_url, credential=credential, **self.kwargs)
return client

def get_conn_uri(self, conn_id):
# type: (str) -> Optional[str]
"""
Get an Airflow Connection URI from an Azure Key Vault secret
:param conn_id: The Airflow connection id to retrieve
:type conn_id: str
"""
return self._get_secret(self.connections_prefix, conn_id)

def get_variable(self, key):
# type: (str) -> Optional[str]
"""
Get an Airflow Variable from an Azure Key Vault secret.
:param key: Variable Key
:type key: str
:return: Variable Value
"""
return self._get_secret(self.variables_prefix, key)

def get_config(self, key):
# type: (str) -> Optional[str]
"""
Get Airflow Configuration
:param key: Configuration Option Key
:return: Configuration Option Value
"""
return self._get_secret(self.config_prefix, key)

@staticmethod
def build_path(path_prefix, secret_id, sep='-'):
# type: (str, str, str) -> str
"""
Given a path_prefix and secret_id, build a valid secret name for the Azure Key Vault Backend.
Also replaces underscore in the path with dashes to support easy switching between
environment variables, so ``connection_default`` becomes ``connection-default``.
:param path_prefix: The path prefix of the secret to retrieve
:type path_prefix: str
:param secret_id: Name of the secret
:type secret_id: str
:param sep: Separator used to concatenate path_prefix and secret_id
:type sep: str
"""
path = '{}{}{}'.format(path_prefix, sep, secret_id)
return path.replace('_', sep)

def _get_secret(self, path_prefix, secret_id):
# type: (str, str) -> Optional[str]
"""
Get an Azure Key Vault secret value
:param path_prefix: Prefix for the Path to get Secret
:type path_prefix: str
:param secret_id: Secret Key
:type secret_id: str
"""
name = self.build_path(path_prefix, secret_id, self.sep)
try:
secret = self.client.get_secret(name=name)
return secret.value
except ResourceNotFoundError as ex:
self.log.debug('Secret %s not found: %s', name, ex)
return None
35 changes: 35 additions & 0 deletions docs/howto/use-alternative-secrets-backend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,41 @@ When ``gcp_key_path`` is not provided, it will use the Application Default Crede
The value of the Secrets Manager secret id must be the :ref:`connection URI representation <generating_connection_uri>`
of the connection object.

Azure Key Vault Backend
^^^^^^^^^^^^^^^^^^^^^^^

To enable the Azure Key Vault as secrets backend, specify
:py:class:`~airflow.contrib.secrets.azure_key_vault.AzureKeyVaultBackend`
as the ``backend`` in ``[secrets]`` section of ``airflow.cfg``.

Here is a sample configuration:

.. code-block:: ini
[secrets]
backend = airflow.contrib.secrets.azure_key_vault.AzureKeyVaultBackend
backend_kwargs = {"connections_prefix": "airflow-connections", "variables_prefix": "airflow-variables", "vault_url": "https://example-akv-resource-name.vault.azure.net/"}
For client authentication, the ``DefaultAzureCredential`` from the Azure Python SDK is used as credential provider,
which supports service principal, managed identity and user credentials.


Storing and Retrieving Connections
""""""""""""""""""""""""""""""""""

If you have set ``connections_prefix`` as ``airflow-connections``, then for a connection id of ``smtp_default``,
you would want to store your connection at ``airflow-connections-smtp-default``.

The value of the secret must be the :ref:`connection URI representation <generating_connection_uri>`
of the connection object.

Storing and Retrieving Variables
""""""""""""""""""""""""""""""""

If you have set ``variables_prefix`` as ``airflow-variables``, then for an Variable key of ``hello``,
you would want to store your Variable at ``airflow-variables-hello``.


.. _roll_your_own_secrets_backend:

Roll your own secrets backend
Expand Down
12 changes: 9 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def write_version(filename=os.path.join(*[my_dir, "airflow", "git_version"])):
'azure-mgmt-datalake-store>=0.5.0',
'azure-mgmt-resource>=2.2.0',
]
azure_secrets = [
'azure-identity>=1.3.1',
'azure-keyvault>=4.1.0',
]
cassandra = [
'cassandra-driver>=3.13.0,<3.21.0',
]
Expand Down Expand Up @@ -450,9 +454,10 @@ def write_version(filename=os.path.join(*[my_dir, "airflow", "git_version"])):

devel_minreq = aws + cgroups + devel + doc + kubernetes + mysql + password
devel_hadoop = devel_minreq + hdfs + hive + kerberos + presto + webhdfs
devel_azure = azure_cosmos + azure_data_lake + devel_minreq

devel_azure = azure_blob_storage + azure_container_instances + azure_cosmos + azure_data_lake + azure_secrets + devel_minreq # noqa
devel_all = (all_dbs + atlas + aws +
azure_blob_storage + azure_container_instances + azure_cosmos + azure_data_lake +
devel_azure +
celery + cgroups + crypto + datadog + devel + doc + docker +
elasticsearch + gcp + grpc + hashicorp + jdbc + jenkins + kerberos + kubernetes + ldap +
oracle + papermill + password +
Expand Down Expand Up @@ -481,11 +486,12 @@ def write_version(filename=os.path.join(*[my_dir, "airflow", "git_version"])):
'async': async_packages,
'atlas': atlas,
'aws': aws,
'azure': azure_blob_storage + azure_container_instances + azure_cosmos + azure_data_lake,
'azure': azure_blob_storage + azure_container_instances + azure_cosmos + azure_data_lake + azure_secrets,
'azure_blob_storage': azure_blob_storage,
'azure_container_instances': azure_container_instances,
'azure_cosmos': azure_cosmos,
'azure_data_lake': azure_data_lake,
'azure_secrets': azure_secrets,
'cassandra': cassandra,
'celery': celery,
'cgroups': cgroups,
Expand Down
105 changes: 105 additions & 0 deletions tests/contrib/secrets/test_azure_key_vault.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import unittest
from tests.compat import mock

from azure.core.exceptions import ResourceNotFoundError

from airflow.contrib.secrets.azure_key_vault import AzureKeyVaultBackend


class TestAzureKeyVaultBackend(unittest.TestCase):
@mock.patch('airflow.contrib.secrets.azure_key_vault.AzureKeyVaultBackend.get_conn_uri')
def test_get_connections(self, mock_get_uri):
mock_get_uri.return_value = 'scheme://user:pass@host:100'
conn_list = AzureKeyVaultBackend().get_connections('fake_conn')
conn = conn_list[0]
self.assertEqual(conn.host, 'host')

@mock.patch('airflow.contrib.secrets.azure_key_vault.DefaultAzureCredential')
@mock.patch('airflow.contrib.secrets.azure_key_vault.SecretClient')
def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
mock_cred = mock.Mock()
mock_sec_client = mock.Mock()
mock_azure_cred.return_value = mock_cred
mock_secret_client.return_value = mock_sec_client

mock_sec_client.get_secret.return_value = mock.Mock(
value='postgresql://airflow:airflow@host:5432/airflow'
)

backend = AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/")
returned_uri = backend.get_conn_uri(conn_id='hi')
mock_secret_client.assert_called_once_with(
credential=mock_cred, vault_url='https://example-akv-resource-name.vault.azure.net/'
)
self.assertEqual(returned_uri, 'postgresql://airflow:airflow@host:5432/airflow')

@mock.patch('airflow.contrib.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_conn_uri_non_existent_key(self, mock_client):
"""
Test that if the key with connection ID is not present,
AzureKeyVaultBackend.get_connections should return None
"""
conn_id = 'test_mysql'
mock_client.get_secret.side_effect = ResourceNotFoundError
backend = AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/")

self.assertIsNone(backend.get_conn_uri(conn_id=conn_id))
self.assertEqual([], backend.get_connections(conn_id=conn_id))

@mock.patch('airflow.contrib.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_variable(self, mock_client):
mock_client.get_secret.return_value = mock.Mock(value='world')
backend = AzureKeyVaultBackend()
returned_uri = backend.get_variable('hello')
mock_client.get_secret.assert_called_with(name='airflow-variables-hello')
self.assertEqual('world', returned_uri)

@mock.patch('airflow.contrib.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_variable_non_existent_key(self, mock_client):
"""
Test that if Variable key is not present,
AzureKeyVaultBackend.get_variables should return None
"""
mock_client.get_secret.side_effect = ResourceNotFoundError
backend = AzureKeyVaultBackend()
self.assertIsNone(backend.get_variable('test_mysql'))

@mock.patch('airflow.contrib.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_secret_value_not_found(self, mock_client):
"""
Test that if a non-existent secret returns None
"""
mock_client.get_secret.side_effect = ResourceNotFoundError
backend = AzureKeyVaultBackend()
self.assertIsNone(
backend._get_secret(path_prefix=backend.connections_prefix, secret_id='test_non_existent')
)

@mock.patch('airflow.contrib.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_secret_value(self, mock_client):
"""
Test that get_secret returns the secret value
"""
mock_client.get_secret.return_value = mock.Mock(value='super-secret')
backend = AzureKeyVaultBackend()
secret_val = backend._get_secret('af-secrets', 'test_mysql_password')
mock_client.get_secret.assert_called_with(name='af-secrets-test-mysql-password')
self.assertEqual(secret_val, 'super-secret')

0 comments on commit ea36166

Please sign in to comment.