Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce operator_name dupe in serialised JSON #25819

Merged
merged 2 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,13 @@ def serialize_to_json(
if cls._is_excluded(value, key, object_to_serialize):
continue

if key in decorated_fields:
if key == '_operator_name':
# when operator_name matches task_type, we can remove
# it to reduce the JSON payload
task_type = getattr(object_to_serialize, '_task_type', None)
if value != task_type:
serialized_object[key] = cls._serialize(value)
elif key in decorated_fields:
serialized_object[key] = cls._serialize(value)
elif key == "timetable" and value is not None:
serialized_object[key] = _encode_timetable(value)
Expand Down Expand Up @@ -684,7 +690,8 @@ def _serialize_node(cls, op: Union[BaseOperator, MappedOperator], include_deps:
serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
serialize_op['_task_type'] = getattr(op, "_task_type", type(op).__name__)
serialize_op['_task_module'] = getattr(op, "_task_module", type(op).__module__)
serialize_op['_operator_name'] = op.operator_name
if op.operator_name != serialize_op['_task_type']:
serialize_op['_operator_name'] = op.operator_name

# Used to determine if an Operator is inherited from EmptyOperator
serialize_op['_is_empty'] = op.inherits_from_empty_operator
Expand Down Expand Up @@ -745,6 +752,9 @@ def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None:
# Extra Operator Links defined in Plugins
op_extra_links_from_plugin = {}

if "_operator_name" not in encoded_op:
encoded_op["_operator_name"] = encoded_op["_task_type"]

# We don't want to load Extra Operator links in Scheduler
if cls._load_operator_extra_links:
from airflow import plugins_manager
Expand Down Expand Up @@ -845,6 +855,10 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Operator:
if encoded_op.get("_is_mapped", False):
# Most of these will be loaded later, these are just some stand-ins.
op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()}
try:
operator_name = encoded_op["_operator_name"]
except KeyError:
operator_name = encoded_op["_task_type"]
op = MappedOperator(
operator_class=op_data,
expand_input=EXPAND_INPUT_EMPTY,
Expand All @@ -861,7 +875,7 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Operator:
is_empty=False,
task_module=encoded_op["_task_module"],
task_type=encoded_op["_task_type"],
operator_name=encoded_op["_operator_name"],
operator_name=operator_name,
dag=None,
task_group=None,
start_date=None,
Expand Down
6 changes: 0 additions & 6 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def detect_task_dependencies(task: Operator) -> Optional[DagDependency]: # type
"template_fields_renderers": {'bash_command': 'bash', 'env': 'json'},
"bash_command": "echo {{ task.task_id }}",
"_task_type": "BashOperator",
"_operator_name": "BashOperator",
"_task_module": "airflow.operators.bash",
"pool": "default_pool",
"executor_config": {
Expand Down Expand Up @@ -1854,7 +1853,6 @@ def test_operator_expand_serde():
'_is_mapped': True,
'_task_module': 'airflow.operators.bash',
'_task_type': 'BashOperator',
'_operator_name': 'BashOperator',
'downstream_task_ids': [],
'expand_input': {
"type": "dict-of-lists",
Expand Down Expand Up @@ -1886,7 +1884,6 @@ def test_operator_expand_serde():

assert op.operator_class == {
'_task_type': 'BashOperator',
'_operator_name': 'BashOperator',
'downstream_task_ids': [],
'task_id': 'a',
'template_ext': ['.sh', '.bash'],
Expand All @@ -1913,7 +1910,6 @@ def test_operator_expand_xcomarg_serde():
'_is_mapped': True,
'_task_module': 'tests.test_utils.mock_operators',
'_task_type': 'MockOperator',
'_operator_name': 'MockOperator',
'downstream_task_ids': [],
'expand_input': {
"type": "dict-of-lists",
Expand Down Expand Up @@ -1963,7 +1959,6 @@ def test_operator_expand_kwargs_serde(strict):
'_is_mapped': True,
'_task_module': 'tests.test_utils.mock_operators',
'_task_type': 'MockOperator',
'_operator_name': 'MockOperator',
'downstream_task_ids': [],
'expand_input': {
"type": "list-of-dicts",
Expand Down Expand Up @@ -2236,7 +2231,6 @@ class MyDummyOperator(DummyOperator):
'_is_empty': is_inherit,
'_task_module': 'tests.serialization.test_dag_serialization',
'_task_type': 'MyDummyOperator',
'_operator_name': 'MyDummyOperator',
'downstream_task_ids': [],
"pool": "default_pool",
'task_id': 'my_task',
Expand Down