Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,23 @@ def get(
return XComResponse.model_validate_json(resp.read())

def set(
self, dag_id: str, run_id: str, task_id: str, key: str, value, map_index: int | None = None
self,
dag_id: str,
run_id: str,
task_id: str,
key: str,
value,
map_index: int | None = None,
mapped_length: int | None = None,
) -> dict[str, bool]:
"""Set a XCom value via the API server."""
# TODO: check if we need to use map_index as params in the uri
# ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
params = {}
if map_index is not None and map_index >= 0:
params = {"map_index": map_index}
if mapped_length is not None and mapped_length >= 0:
params["mapped_length"] = mapped_length
self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value)
# Any error from the server will anyway be propagated down to the supervisor,
# so we choose to send a generic response to the supervisor over the server response to
Expand Down
23 changes: 17 additions & 6 deletions task_sdk/src/airflow/sdk/definitions/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,23 @@ def resolve(self, context: Mapping[str, Any]) -> Any:

if self.operator.is_mapped:
return LazyXComSequence[Any](xcom_arg=self, ti=ti)

result = ti.xcom_pull(
task_ids=task_id,
key=self.key,
default=NOTSET,
)
tg = ti.task.get_closest_mapped_task_group()
result = None
if tg is None:
# regular task
result = ti.xcom_pull(
task_ids=task_id,
key=self.key,
default=NOTSET,
map_indexes=None,
)
else:
# task from a task group
result = ti.xcom_pull(
task_ids=task_id,
key=self.key,
default=NOTSET,
)
if not isinstance(result, ArgNotSet):
return result
if self.key == XCOM_RETURN_KEY:
Expand Down
4 changes: 3 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
self._terminal_state = IntermediateTIState.UP_FOR_RESCHEDULE
self.client.task_instances.reschedule(self.id, msg)
elif isinstance(msg, SetXCom):
self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index)
self.client.xcoms.set(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length
)
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
elif isinstance(msg, SetRenderedFields):
Expand Down
27 changes: 27 additions & 0 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,33 @@ def handle_request(request: httpx.Request) -> httpx.Response:
)
assert result == {"ok": True}

def test_xcom_set_with_mapped_length(self):
# Simulate a successful response from the server when setting an xcom with mapped_length
def handle_request(request: httpx.Request) -> httpx.Response:
if (
request.url.path == "/xcoms/dag_id/run_id/task_id/key"
and request.url.params.get("map_index") == "2"
and request.url.params.get("mapped_length") == "3"
):
assert json.loads(request.read()) == "value1"
return httpx.Response(
status_code=201,
json={"message": "XCom successfully set"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
result = client.xcoms.set(
dag_id="dag_id",
run_id="run_id",
task_id="task_id",
key="key",
value="value1",
map_index=2,
mapped_length=3,
)
assert result == {"ok": True}


class TestConnectionOperations:
"""
Expand Down
10 changes: 8 additions & 2 deletions task_sdk/tests/definitions/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,16 @@ def tg(va):
return t2

# The group is mapped by 3.
t2 = tg.expand(va=[["a", "b"], [4], ["z"]])
tg1 = tg.expand(
va=[
["a", "b"],
[4],
["z"],
]
)

# Aggregates results from task group.
t.override(task_id="t3")(t2)
t.override(task_id="t3")(tg1)

def xcom_get():
# TODO: Tidy this after #45927 is reopened and fixed properly
Expand Down
27 changes: 27 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,7 @@ def watched_subprocess(self, mocker):
"test_key",
'{"key": "test_key", "value": {"key2": "value2"}}',
None,
None,
),
{},
{"ok": True},
Expand All @@ -1065,11 +1066,37 @@ def watched_subprocess(self, mocker):
"test_key",
'{"key": "test_key", "value": {"key2": "value2"}}',
2,
None,
),
{},
{"ok": True},
id="set_xcom_with_map_index",
),
pytest.param(
SetXCom(
dag_id="test_dag",
run_id="test_run",
task_id="test_task",
key="test_key",
value='{"key": "test_key", "value": {"key2": "value2"}}',
map_index=2,
mapped_length=3,
),
b"",
"xcoms.set",
(
"test_dag",
"test_run",
"test_task",
"test_key",
'{"key": "test_key", "value": {"key2": "value2"}}',
2,
3,
),
{},
{"ok": True},
id="set_xcom_with_map_index_and_mapped_length",
),
# we aren't adding all states under TerminalTIState here, because this test's scope is only to check
# if it can handle TaskState message
pytest.param(
Expand Down
27 changes: 27 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,33 @@ def test_xcom_set(self, client, create_task_instance, session, value, expected_v
task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none()
assert task_map is None, "Should not be mapped"

def test_xcom_set_mapped(self, client, create_task_instance, session):
ti = create_task_instance()
session.commit()

response = client.post(
f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1",
params={"map_index": -1, "mapped_length": 3},
json="value1",
)

assert response.status_code == 201
assert response.json() == {"message": "XCom successfully set"}

xcom = (
session.query(XCom)
.filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1", map_index=-1)
.first()
)
assert xcom.value == "value1"
task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none()
assert task_map is not None, "Should be mapped"
assert task_map.dag_id == "dag"
assert task_map.run_id == "test"
assert task_map.task_id == "op1"
assert task_map.map_index == -1
assert task_map.length == 3

@pytest.mark.parametrize(
("length", "err_context"),
[
Expand Down