-
Notifications
You must be signed in to change notification settings - Fork 17k
fix serialize_template_field handling callable value in dict #63871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,87 +19,88 @@ | |
| from __future__ import annotations | ||
|
|
||
| import contextlib | ||
| import inspect | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from airflow._shared.module_loading import qualname | ||
| from airflow._shared.secrets_masker import redact | ||
| from airflow._shared.template_rendering import truncate_rendered_value | ||
| from airflow.configuration import conf | ||
| from airflow.settings import json | ||
|
|
||
| if TYPE_CHECKING: | ||
| from airflow.partition_mappers.base import PartitionMapper | ||
| from airflow.timetables.base import Timetable as CoreTimetable | ||
|
|
||
|
|
||
| def serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: | ||
| def serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float | bool | None: | ||
| """ | ||
| Return a serializable representation of the templated field. | ||
|
|
||
| If ``templated_field`` is provided via a callable then | ||
| return the following serialized value: ``<callable full_qualified_name>`` | ||
| The walk has two responsibilities: | ||
|
|
||
| If ``templated_field`` contains a class or instance that requires recursive | ||
| templating, store them as strings. Otherwise simply return the field as-is. | ||
| 1. **Make the template_field JSON-encodable** — every container is rebuilt | ||
| with primitive leaves (str/int/float/bool/None), tuples and sets are | ||
| flattened to lists, and unsupported objects fall through to ``str()`` | ||
| so ``json.dumps`` never raises on the result. | ||
| 2. **Keep the output deterministic across parses** — callables are replaced | ||
| with their qualified name (never the default ``<function ... at 0x...>`` | ||
| repr), dicts are key-sorted, and (frozen)sets are sorted by element so | ||
| the same input always produces the same string. | ||
| """ | ||
|
|
||
| def is_jsonable(x): | ||
| try: | ||
| json.dumps(x) | ||
| except (TypeError, OverflowError): | ||
| return False | ||
| else: | ||
| return True | ||
|
|
||
| def translate_tuples_to_lists(obj: Any): | ||
| """Recursively convert tuples to lists.""" | ||
| if isinstance(obj, tuple): | ||
| return [translate_tuples_to_lists(item) for item in obj] | ||
| if isinstance(obj, list): | ||
| return [translate_tuples_to_lists(item) for item in obj] | ||
| if isinstance(obj, dict): | ||
| return {key: translate_tuples_to_lists(value) for key, value in obj.items()} | ||
| return obj | ||
| def normalize_dict_key(key) -> str: | ||
| """Normalize a dict key to a serialized string type.""" | ||
| # Serialized template_field keys must all be strings, not a mix of types, so that | ||
| # downstream json.dumps(..., sort_keys=True) does not raise on mixed-type keys. | ||
| return str(serialize_object(key)) | ||
|
|
||
| def serialize_object(obj): | ||
| """Recursively rewrite ``obj`` into a JSON-encodable, hash-stable structure.""" | ||
| if obj is None or isinstance(obj, (str, int, float, bool)): | ||
| return obj | ||
|
|
||
| def sort_dict_recursively(obj: Any) -> Any: | ||
| """Recursively sort dictionaries to ensure consistent ordering.""" | ||
| if isinstance(obj, dict): | ||
| return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} | ||
| if isinstance(obj, list): | ||
| return [sort_dict_recursively(item) for item in obj] | ||
| if isinstance(obj, tuple): | ||
| return tuple(sort_dict_recursively(item) for item in obj) | ||
| return obj | ||
| # Serialize keys/values first so each key is a string and the output is hash-stable, | ||
| # then sort by the serialized key to prevents hash inconsistencies when dict ordering varies. | ||
| serialized_pairs = [(normalize_dict_key(k), serialize_object(v)) for k, v in obj.items()] | ||
| return dict(sorted(serialized_pairs, key=lambda kv: kv[0])) | ||
|
|
||
| if isinstance(obj, (list, tuple)): | ||
| return [serialize_object(item) for item in obj] | ||
|
|
||
| if isinstance(obj, (set, frozenset)): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously |
||
| # JSON has no set type → flatten to a list with deterministic ordering | ||
| # so hash randomization on element types cannot shift cross-process iteration order. | ||
| return sorted( | ||
| (serialize_object(item) for item in obj), | ||
| key=lambda x: (type(x).__name__, str(x)), | ||
| ) | ||
|
|
||
| # Use inspect.getattr_static to bypass any custom __getattr__ / metaclass magic | ||
| if callable(inspect.getattr_static(obj, "serialize", None)): | ||
| return serialize_object(obj.serialize()) | ||
|
|
||
| # Kubernetes client objects (V1Pod, V1Container, ...) expose their content via to_dict() | ||
| if callable(inspect.getattr_static(obj, "to_dict", None)): | ||
| return serialize_object(obj.to_dict()) | ||
|
|
||
| if callable(obj): | ||
| # Use qualified name; default repr embeds memory addresses, which would change the DAG hash on every parse | ||
| return f"<callable {qualname(obj, True)}>" | ||
|
|
||
| # Non-primitive objects without a serialize attribute are converted to str | ||
| # So they don't break json.dumps downstream | ||
| return str(obj) | ||
|
|
||
| max_length = conf.getint("core", "max_templated_field_length") | ||
|
|
||
| if not is_jsonable(template_field): | ||
| try: | ||
| serialized = template_field.serialize() | ||
| except AttributeError: | ||
| if callable(template_field): | ||
| full_qualified_name = qualname(template_field, True) | ||
| serialized = f"<callable {full_qualified_name}>" | ||
| else: | ||
| serialized = str(template_field) | ||
| if len(serialized) > max_length: | ||
| rendered = redact(serialized, name) | ||
| return truncate_rendered_value(str(rendered), max_length) | ||
| return serialized | ||
| if not template_field and not isinstance(template_field, tuple): | ||
| # Avoid unnecessary serialization steps for empty fields unless they are tuples | ||
| # and need to be converted to lists | ||
| return template_field | ||
| template_field = translate_tuples_to_lists(template_field) | ||
| # Sort dictionaries recursively to ensure consistent string representation | ||
| # This prevents hash inconsistencies when dict ordering varies | ||
| if isinstance(template_field, dict): | ||
| template_field = sort_dict_recursively(template_field) | ||
| serialized = str(template_field) | ||
| if len(serialized) > max_length: | ||
| rendered = redact(serialized, name) | ||
| serialized = serialize_object(template_field) | ||
|
|
||
| if len(str(serialized)) > max_length: | ||
| rendered = redact(str(serialized), name) | ||
| return truncate_rendered_value(str(rendered), max_length) | ||
| return template_field | ||
|
|
||
| return serialized | ||
|
|
||
|
|
||
| class TimetableNotRegistered(ValueError): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| from __future__ import annotations | ||
|
wjddn279 marked this conversation as resolved.
|
||
|
|
||
| from datetime import datetime | ||
|
|
||
| from airflow.sdk import dag, task, task_group | ||
|
|
||
|
|
||
| @dag( | ||
| dag_id="TEST_DTM", | ||
| dag_display_name="TEST DTM", | ||
| schedule=None, | ||
| default_args={"owner": "airflow", "email": ""}, | ||
| start_date=datetime(2024, 1, 25), | ||
| ) | ||
| def dtm_test( | ||
| exponent: int = 2, | ||
| ): | ||
|
|
||
| @task | ||
| def get_data(): | ||
| return [20, 100, 200, 222, 242, 272] | ||
|
|
||
| @task | ||
| def to_exp(number: int, exponent: int) -> float: | ||
| return number**exponent | ||
|
|
||
| @task | ||
| def trunc(number: float, digits: int) -> float: | ||
| return round(number / 22, digits) | ||
|
|
||
| @task | ||
| def save(number: list[float]): | ||
| for n in number: | ||
| print(f"Got number: {n}") | ||
|
|
||
| @task_group # type: ignore[type-var] | ||
| def transform(number: int, exponent: int) -> float: | ||
| a = to_exp(number, exponent) | ||
| b = trunc(a, 2) | ||
| return b | ||
|
|
||
| data = get_data() | ||
| result = transform.partial(exponent=exponent).expand(number=data) | ||
| save(result) # type: ignore[arg-type] | ||
|
|
||
|
|
||
| instance = dtm_test() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -116,11 +116,11 @@ def teardown_method(self): | |
| pytest.param([], [], id="list"), | ||
| pytest.param({}, {}, id="empty_dict"), | ||
| pytest.param((), [], id="empty_tuple"), | ||
| pytest.param(set(), "set()", id="empty_set"), | ||
| pytest.param(set(), [], id="empty_set"), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update tests to reflect logic change: related https://github.com/apache/airflow/pull/63871/changes#r3212530026 |
||
| pytest.param("test-string", "test-string", id="string"), | ||
| pytest.param({"foo": "bar"}, {"foo": "bar"}, id="dict"), | ||
| pytest.param(("foo", "bar"), ["foo", "bar"], id="tuple"), | ||
| pytest.param({"foo"}, "{'foo'}", id="set"), | ||
| pytest.param({"foo"}, ["foo"], id="set"), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| (date(2018, 12, 6), "2018-12-06"), | ||
| pytest.param(datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00", id="datetime"), | ||
| pytest.param( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since keys are now fixed as strings by
normalize_dict_key, the logic that sorted by key type alongside the key value has been removed. Sorting is now performed solely by the key value.