diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py index 25b3b4f429bd..16621f9d947c 100644 --- a/airflow/serialization/serde.py +++ b/airflow/serialization/serde.py @@ -154,6 +154,12 @@ def serialize(o: object, depth: int = 0) -> U | None: dct[DATA] = data return dct + # pydantic models are recursive + if _is_pydantic(cls): + data = o.dict() # type: ignore[attr-defined] + dct[DATA] = serialize(data, depth + 1) + return dct + # dataclasses if dataclasses.is_dataclass(cls): # fixme: unfortunately using asdict with nested dataclasses it looses information @@ -250,8 +256,8 @@ def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: if hasattr(cls, "deserialize"): return getattr(cls, "deserialize")(deserialize(value), version) - # attr or dataclass - if attr.has(cls) or dataclasses.is_dataclass(cls): + # attr or dataclass or pydantic + if attr.has(cls) or dataclasses.is_dataclass(cls) or _is_pydantic(cls): class_version = getattr(cls, "__version__", 0) if int(version) > class_version: raise TypeError( @@ -302,6 +308,15 @@ def _stringify(classname: str, version: int, value: T | None) -> str: return s +def _is_pydantic(cls: Any) -> bool: + """Return True if the class is a pydantic model. + + Checking is done by attributes as it is significantly faster than + using isinstance. + """ + return hasattr(cls, "__validators__") and hasattr(cls, "__fields__") and hasattr(cls, "dict") + + def _register(): """Register builtin serializers and deserializers for types that don't have any themselves.""" _serializers.clear() diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py index c5b972a7952f..a86b9ce8161b 100644 --- a/tests/serialization/test_serde.py +++ b/tests/serialization/test_serde.py @@ -24,6 +24,7 @@ import attr import pytest +from pydantic import BaseModel from airflow.datasets import Dataset from airflow.serialization.serde import ( @@ -92,6 +93,13 @@ class V: c: int +class U(BaseModel): + __version__: ClassVar[int] = 1 + x: int + v: V + u: tuple + + @pytest.mark.usefixtures("recalculate_patterns") class TestSerDe: def test_ser_primitives(self): @@ -317,3 +325,9 @@ def test_deserialize_non_serialized_data(self): i = Z(10) e = deserialize(i) assert i == e + + def test_pydantic(self): + i = U(x=10, v=V(W(10), ["l1", "l2"], (1, 2), 10), u=(1, 2)) + e = serialize(i) + s = deserialize(e) + assert i == s