Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YAML file supports extra json parameters #9549

Merged
merged 9 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion airflow/secrets/local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _parse_yaml_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSyn
return {}, [FileSyntaxError(line_no=1, message="The file is empty.")]
try:
secrets = yaml.safe_load(content)

except yaml.MarkedYAMLError as e:
return {}, [FileSyntaxError(line_no=e.problem_mark.line, message=str(e))]
if not isinstance(secrets, dict):
Expand Down Expand Up @@ -180,7 +181,7 @@ def _create_connection(conn_id: str, value: Any):
if isinstance(value, str):
return Connection(conn_id=conn_id, uri=value)
if isinstance(value, dict):
connection_parameter_names = get_connection_parameter_names()
connection_parameter_names = get_connection_parameter_names() | {"extra_dejson"}
current_keys = set(value.keys())
if not current_keys.issubset(connection_parameter_names):
illegal_keys = current_keys - connection_parameter_names
Expand All @@ -189,6 +190,14 @@ def _create_connection(conn_id: str, value: Any):
f"The object have illegal keys: {illegal_keys_list}. "
f"The dictionary can only contain the following keys: {connection_parameter_names}"
)
if "extra" in value and "extra_dejson" in value:
raise AirflowException(
"The extra and extra_dejson parameters are mutually exclusive. "
"Please provide only one parameter."
)
if "extra_dejson" in value:
value["extra"] = json.dumps(value["extra_dejson"])
del value["extra_dejson"]

if "conn_id" in current_keys and conn_id != value["conn_id"]:
raise AirflowException(
Expand Down
6 changes: 5 additions & 1 deletion docs/howto/use-alternative-secrets-backend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ The following is a sample JSON file.
}

The YAML file structure is similar to that of a JSON. The key-value pair of connection ID and the definitions of one or more connections.
The connection can be defined as a URI (string) or JSON object.
The connection can be defined as a URI (string) or JSON object. Any extra json parameters can be provided with the key extra_dejson (Keys extra and extra_dejson are mutually exclusive).
VinayGb665 marked this conversation as resolved.
Show resolved Hide resolved
For a guide about defining a connection as a URI, see:: :ref:`generating_connection_uri`.
For a description of the connection object parameters see :class:`~airflow.models.connection.Connection`.
VinayGb665 marked this conversation as resolved.
Show resolved Hide resolved
The following is a sample YAML file.
Expand All @@ -137,6 +137,10 @@ The following is a sample YAML file.
login: Login
password: None
port: 1234
extra_dejson:
a: b
nestedblock_dict:
x: y

You can also define connections using a ``.env`` file. Then the key is the connection ID, and
the value should describe the connection using the URI. If the connection ID is repeated, all values will
Expand Down
70 changes: 68 additions & 2 deletions tests/secrets/test_local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,15 @@ def test_missing_file(self, mock_exists):
schema: lschema
login: Login
password: None
port: 1234""",
port: 1234
extra_dejson:
VinayGb665 marked this conversation as resolved.
Show resolved Hide resolved
extra__google_cloud_platform__keyfile_dict:
a: b
extra__google_cloud_platform__keyfile_path: asaa""",
{"conn_a": ["mysql://hosta"], "conn_b": ["mysql://hostb", "mysql://hostc"],
"conn_c": ["scheme://Login:None@host:1234/lschema"]}),
"conn_c": [''.join("""scheme://Login:None@host:1234/lschema?
extra__google_cloud_platform__keyfile_dict=%7B%27a%27%3A+%27b%27%7D
&extra__google_cloud_platform__keyfile_path=asaa""".split())]}),
)
)
def test_yaml_file_should_load_connection(self, file_content, expected_connection_uris):
Expand All @@ -241,6 +247,66 @@ def test_yaml_file_should_load_connection(self, file_content, expected_connectio

self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)

@parameterized.expand(
(
("""conn_c:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra_dejson:
aws_conn_id: bbb
region_name: ccc
""", {"conn_c": [{"aws_conn_id": "bbb", "region_name": "ccc"}]}),
("""conn_d:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra_dejson:
extra__google_cloud_platform__keyfile_dict:
a: b
extra__google_cloud_platform__key_path: xxx
""", {"conn_d": [{"extra__google_cloud_platform__keyfile_dict": {"a": "b"},
"extra__google_cloud_platform__key_path": "xxx"}]}),

)
)
def test_yaml_file_should_load_connection_extras(self, file_content, expected_extras):
with mock_local_file(file_content):
connections_by_conn_id = local_filesystem.load_connections("a.yaml")
connection_uris_by_conn_id = {
conn_id: [connection.extra_dejson for connection in connections]
for conn_id, connections in connections_by_conn_id.items()
}
self.assertEqual(expected_extras, connection_uris_by_conn_id)

@parameterized.expand(
(
("""conn_c:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra:
abc: xyz
extra_dejson:
aws_conn_id: bbb
region_name: ccc
""", "The extra and extra_dejson parameters are mutually exclusive."),
)
)
def test_yaml_invalid_extra(self, file_content, expected_message):
with mock_local_file(file_content):
with self.assertRaisesRegex(AirflowException, re.escape(expected_message)):
local_filesystem.load_connections("a.yaml")


class TestLocalFileBackend(unittest.TestCase):
def test_should_read_variable(self):
Expand Down