Skip to content

Commit

Permalink
Implements JSON-string connection representation generator (#35723)
Browse files Browse the repository at this point in the history
* Implements JSON-string connection representation generator

* json_repr -> as_json()

* Apply suggestions from code review

Co-authored-by: Vincent <97131062+vincbeck@users.noreply.github.com>

---------

Co-authored-by: Vincent <97131062+vincbeck@users.noreply.github.com>
  • Loading branch information
Taragolis and vincbeck committed Nov 23, 2023
1 parent eb691fc commit b07d799
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 2 deletions.
35 changes: 35 additions & 0 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.secrets.cache import SecretCache
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.secrets_masker import mask_secret
from airflow.utils.module_loading import import_string
Expand Down Expand Up @@ -480,6 +481,34 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection:
def to_dict(self) -> dict[str, Any]:
return {"conn_id": self.conn_id, "description": self.description, "uri": self.get_uri()}

def to_json_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[str, Any]:
"""
Convert Connection to json-serializable dictionary.
:param prune_empty: Whether or not remove empty values.
:param validate: Validate dictionary is JSON-serializable
:meta private:
"""
conn = {
"conn_id": self.conn_id,
"conn_type": self.conn_type,
"description": self.description,
"host": self.host,
"login": self.login,
"password": self.password,
"schema": self.schema,
"port": self.port,
}
if prune_empty:
conn = prune_dict(val=conn, mode="strict")
if (extra := self.extra_dejson) or not prune_empty:
conn["extra"] = extra

if validate:
json.dumps(conn)
return conn

@classmethod
def from_json(cls, value, conn_id=None) -> Connection:
kwargs = json.loads(value)
Expand All @@ -496,3 +525,9 @@ def from_json(cls, value, conn_id=None) -> Connection:
except ValueError:
raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.")
return Connection(conn_id=conn_id, **kwargs)

def as_json(self) -> str:
"""Convert Connection to JSON-string object."""
conn = self.to_json_dict(prune_empty=True, validate=False)
conn.pop("conn_id", None)
return json.dumps(conn)
2 changes: 1 addition & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def serialize(
type_=DAT.SIMPLE_TASK_INSTANCE,
)
elif isinstance(var, Connection):
return cls._encode(var.to_dict(), type_=DAT.CONNECTION)
return cls._encode(var.to_json_dict(validate=True), type_=DAT.CONNECTION)
elif use_pydantic_models and _ENABLE_AIP_44:

def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]:
Expand Down
37 changes: 37 additions & 0 deletions docs/apache-airflow/howto/connection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,43 @@ If serializing with JSON:
}
}'
Generating a JSON connection representation
"""""""""""""""""""""""""""""""""""""""""""

.. versionadded:: 2.8.0


To make connection JSON generation easier, the :py:class:`~airflow.models.connection.Connection` class has a
convenience property :py:meth:`~airflow.models.connection.Connection.as_json`. It can be used like so:

.. code-block:: pycon
>>> from airflow.models.connection import Connection
>>> c = Connection(
... conn_id="some_conn",
... conn_type="mysql",
... description="connection description",
... host="myhost.com",
... login="myname",
... password="mypassword",
... extra={"this_param": "some val", "that_param": "other val*"},
... )
>>> print(f"AIRFLOW_CONN_{c.conn_id.upper()}='{c.as_json()}'")
AIRFLOW_CONN_SOME_CONN='{"conn_type": "mysql", "description": "connection description", "host": "myhost.com", "login": "myname", "password": "mypassword", "extra": {"this_param": "some val", "that_param": "other val*"}}'
In addition, same approach could be used to convert Connection from URI format to JSON format

.. code-block:: pycon
>>> from airflow.models.connection import Connection
>>> c = Connection(
... conn_id="awesome_conn",
... description="Example Connection",
... uri="aws://AKIAIOSFODNN7EXAMPLE:wJalrXUtnFEMI%2FK7MDENG%2FbPxRfiCYEXAMPLEKEY@/?__extra__=%7B%22region_name%22%3A+%22eu-central-1%22%2C+%22config_kwargs%22%3A+%7B%22retries%22%3A+%7B%22mode%22%3A+%22standard%22%2C+%22max_attempts%22%3A+10%7D%7D%7D",
... )
>>> print(f"AIRFLOW_CONN_{c.conn_id.upper()}='{c.as_json()}'")
AIRFLOW_CONN_AWESOME_CONN='{"conn_type": "aws", "description": "Example Connection", "host": "", "login": "AKIAIOSFODNN7EXAMPLE", "password": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "schema": "", "extra": {"region_name": "eu-central-1", "config_kwargs": {"retries": {"mode": "standard", "max_attempts": 10}}}}'
URI format example
^^^^^^^^^^^^^^^^^^
Expand Down
56 changes: 56 additions & 0 deletions tests/always/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,3 +790,59 @@ def test_get_uri_no_conn_type(self):
assert Connection(host="abc").get_uri() == "//abc"
# parsing back as conn still works
assert Connection(uri="//abc").host == "abc"

@pytest.mark.parametrize(
"conn, expected_json",
[
pytest.param(Connection(), "{}", id="empty"),
pytest.param(Connection(host="apache.org", extra={}), '{"host": "apache.org"}', id="empty-extra"),
pytest.param(
Connection(conn_type="foo", login="", password="p@$$"),
'{"conn_type": "foo", "login": "", "password": "p@$$"}',
id="some-fields",
),
pytest.param(
Connection(
conn_type="bar",
description="Sample Description",
host="example.org",
login="user",
password="p@$$",
schema="schema",
port=777,
extra={"foo": "bar", "answer": 42},
),
json.dumps(
{
"conn_type": "bar",
"description": "Sample Description",
"host": "example.org",
"login": "user",
"password": "p@$$",
"schema": "schema",
"port": 777,
"extra": {"foo": "bar", "answer": 42},
}
),
id="all-fields",
),
pytest.param(
Connection(uri="aws://"),
# During parsing URI some of the fields evaluated as an empty strings
'{"conn_type": "aws", "host": "", "schema": ""}',
id="uri",
),
],
)
def test_as_json_from_connection(self, conn: Connection, expected_json):
result = conn.as_json()
assert result == expected_json
restored_conn = Connection.from_json(result)

assert restored_conn.conn_type == conn.conn_type
assert restored_conn.description == conn.description
assert restored_conn.host == conn.host
assert restored_conn.password == conn.password
assert restored_conn.schema == conn.schema
assert restored_conn.port == conn.port
assert restored_conn.extra_dejson == conn.extra_dejson
23 changes: 22 additions & 1 deletion tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.serialization.enums import DagAttributeTypes as DAT
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.job import JobPydantic
Expand Down Expand Up @@ -213,6 +213,27 @@ def test_serialize_deserialize(input, encoded_type, cmp_func):
json.dumps(serialized) # does not raise


@pytest.mark.parametrize(
"conn_uri",
[
pytest.param("aws://", id="only-conn-type"),
pytest.param("postgres://username:password@ec2.compute.com:5432/the_database", id="all-non-extra"),
pytest.param(
"///?__extra__=%7B%22foo%22%3A+%22bar%22%2C+%22answer%22%3A+42%2C+%22"
"nullable%22%3A+null%2C+%22empty%22%3A+%22%22%2C+%22zero%22%3A+0%7D",
id="extra",
),
],
)
def test_backcompat_deserialize_connection(conn_uri):
"""Test deserialize connection which serialised by previous serializer implementation."""
from airflow.serialization.serialized_objects import BaseSerialization

conn_obj = {Encoding.TYPE: DAT.CONNECTION, Encoding.VAR: {"conn_id": "TEST_ID", "uri": conn_uri}}
deserialized = BaseSerialization.deserialize(conn_obj)
assert deserialized.get_uri() == conn_uri


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
@pytest.mark.parametrize(
"input, pydantic_class, encoded_type, cmp_func",
Expand Down

0 comments on commit b07d799

Please sign in to comment.