Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle json encoding of V1Pod in task callback #27609

Merged
merged 9 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,17 @@ def __init__(
self.is_failure_callback = is_failure_callback

def to_json(self) -> str:
dict_obj = self.__dict__.copy()
dict_obj["simple_task_instance"] = self.simple_task_instance.as_dict()
return json.dumps(dict_obj)
from airflow.serialization.serialized_objects import BaseSerialization

val = BaseSerialization.serialize(self.__dict__, strict=True)
return json.dumps(val)

@classmethod
def from_json(cls, json_str: str):
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.serialization.serialized_objects import BaseSerialization

kwargs = json.loads(json_str)
simple_ti = SimpleTaskInstance.from_dict(obj_dict=kwargs.pop("simple_task_instance"))
return cls(simple_task_instance=simple_ti, **kwargs)
val = json.loads(json_str)
return cls(**BaseSerialization.deserialize(val))


class DagCallbackRequest(CallbackRequest):
Expand Down
2 changes: 1 addition & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __str__(self) -> str:


class SerializationError(AirflowException):
"""A problem occurred when trying to serialize a DAG."""
"""A problem occurred when trying to serialize something."""


class ParamValidationError(AirflowException):
Expand Down
10 changes: 10 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,6 +2571,11 @@ def __eq__(self, other):
return NotImplemented

def as_dict(self):
warnings.warn(
"This method is deprecated. Use BaseSerialization.serialize.",
RemovedInAirflow3Warning,
stacklevel=2,
)
new_dict = dict(self.__dict__)
for key in new_dict:
if key in ["start_date", "end_date"]:
Expand Down Expand Up @@ -2601,6 +2606,11 @@ def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:

@classmethod
def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
warnings.warn(
"This method is deprecated. Use BaseSerialization.deserialize.",
RemovedInAirflow3Warning,
stacklevel=2,
)
ti_key = TaskInstanceKey(*obj_dict.pop("key"))
start_date = None
end_date = None
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ class DagAttributeTypes(str, Enum):
PARAM = "param"
XCOM_REF = "xcomref"
DATASET = "dataset"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
23 changes: 17 additions & 6 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
Expand Down Expand Up @@ -381,7 +382,9 @@ def serialize_to_json(
return serialized_object

@classmethod
def serialize(cls, var: Any) -> Any: # Unfortunately there is no support for recursive types in mypy
def serialize(
cls, var: Any, *, strict: bool = False
) -> Any: # Unfortunately there is no support for recursive types in mypy
"""Helper function of depth first search for serialization.

The serialization protocol is:
Expand All @@ -400,9 +403,11 @@ def serialize(cls, var: Any) -> Any: # Unfortunately there is no support for re
return var.value
return var
elif isinstance(var, dict):
return cls._encode({str(k): cls.serialize(v) for k, v in var.items()}, type_=DAT.DICT)
return cls._encode(
{str(k): cls.serialize(v, strict=strict) for k, v in var.items()}, type_=DAT.DICT
)
elif isinstance(var, list):
return [cls.serialize(v) for v in var]
return [cls.serialize(v, strict=strict) for v in var]
elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
Expand All @@ -427,12 +432,12 @@ def serialize(cls, var: Any) -> Any: # Unfortunately there is no support for re
elif isinstance(var, set):
# FIXME: casts set to list in customized serialization in future.
try:
return cls._encode(sorted(cls.serialize(v) for v in var), type_=DAT.SET)
return cls._encode(sorted(cls.serialize(v, strict=strict) for v in var), type_=DAT.SET)
except TypeError:
return cls._encode([cls.serialize(v) for v in var], type_=DAT.SET)
return cls._encode([cls.serialize(v, strict=strict) for v in var], type_=DAT.SET)
elif isinstance(var, tuple):
# FIXME: casts tuple to list in customized serialization in future.
return cls._encode([cls.serialize(v) for v in var], type_=DAT.TUPLE)
return cls._encode([cls.serialize(v, strict=strict) for v in var], type_=DAT.TUPLE)
elif isinstance(var, TaskGroup):
return TaskGroupSerialization.serialize_task_group(var)
elif isinstance(var, Param):
Expand All @@ -441,8 +446,12 @@ def serialize(cls, var: Any) -> Any: # Unfortunately there is no support for re
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(cls.serialize(var.__dict__, strict=strict), type_=DAT.SIMPLE_TASK_INSTANCE)
else:
log.debug("Cast type %s to str in serialization.", type(var))
if strict:
raise SerializationError("Encountered unexpected type")
return str(var)

@classmethod
Expand Down Expand Up @@ -491,6 +500,8 @@ def deserialize(cls, encoded_var: Any) -> Any:
return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
elif type_ == DAT.DATASET:
return Dataset(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")

Expand Down
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from pathlib import Path

REPO_ROOT = Path(__file__).parent.parent
36 changes: 36 additions & 0 deletions tests/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,39 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create
json_str = input.to_json()
result = TaskCallbackRequest.from_json(json_str)
assert input == result

def test_simple_ti_roundtrip_exec_config_pod(self):
"""A callback request including a TI with an exec config with a V1Pod should safely roundtrip."""
from kubernetes.client import models as k8s

from airflow.callbacks.callback_requests import TaskCallbackRequest
from airflow.models import TaskInstance
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.operators.bash import BashOperator

test_pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="hello", namespace="ns"))
op = BashOperator(task_id="hi", executor_config={"pod_override": test_pod}, bash_command="hi")
ti = TaskInstance(task=op)
s = SimpleTaskInstance.from_ti(ti)
data = TaskCallbackRequest("hi", s).to_json()
actual = TaskCallbackRequest.from_json(data).simple_task_instance.executor_config["pod_override"]
assert actual == test_pod

def test_simple_ti_roundtrip_dates(self):
"""A callback request including a TI with an exec config with a V1Pod should safely roundtrip."""
from unittest.mock import MagicMock

from airflow.callbacks.callback_requests import TaskCallbackRequest
from airflow.models import TaskInstance
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.operators.bash import BashOperator

op = BashOperator(task_id="hi", bash_command="hi")
ti = TaskInstance(task=op)
ti.set_state("SUCCESS", session=MagicMock())
start_date = ti.start_date
end_date = ti.end_date
s = SimpleTaskInstance.from_ti(ti)
data = TaskCallbackRequest("hi", s).to_json()
assert TaskCallbackRequest.from_json(data).simple_task_instance.start_date == start_date
assert TaskCallbackRequest.from_json(data).simple_task_instance.end_date == end_date
78 changes: 78 additions & 0 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pytest

from airflow.exceptions import SerializationError
from tests import REPO_ROOT


def test_recursive_serialize_calls_must_forward_kwargs():
"""Any time we recurse cls.serialize, we must forward all kwargs."""
import ast

valid_recursive_call_count = 0
file = REPO_ROOT / "airflow/serialization/serialized_objects.py"
content = file.read_text()
tree = ast.parse(content)

class_def = None
for stmt in ast.walk(tree):
if not isinstance(stmt, ast.ClassDef):
continue
if stmt.name == "BaseSerialization":
class_def = stmt

method_def = None
for elem in ast.walk(class_def):
if isinstance(elem, ast.FunctionDef):
if elem.name == "serialize":
method_def = elem
break
kwonly_args = [x.arg for x in method_def.args.kwonlyargs]

for elem in ast.walk(method_def):
if isinstance(elem, ast.Call):
if getattr(elem.func, "attr", "") == "serialize":
kwargs = {y.arg: y.value for y in elem.keywords}
for name in kwonly_args:
if name not in kwargs or getattr(kwargs[name], "id", "") != name:
ref = f"{file}:{elem.lineno}"
message = (
f"Error at {ref}; recursive calls to `cls.serialize` "
f"must forward the `{name}` argument"
)
raise Exception(message)
valid_recursive_call_count += 1
print(f"validated calls: {valid_recursive_call_count}")
assert valid_recursive_call_count > 0


def test_strict_mode():
"""If strict=True, serialization should fail when object is not JSON serializable."""

class Test:
a = 1

from airflow.serialization.serialized_objects import BaseSerialization

obj = [[[Test()]]] # nested to verify recursive behavior
BaseSerialization.serialize(obj) # does not raise
with pytest.raises(SerializationError, match="Encountered unexpected type"):
BaseSerialization.serialize(obj, strict=True) # now raises