Skip to content

Fix Task Mapping with XCOM arguments from other Tasks#47141

Merged
jedcunningham merged 3 commits intoapache:mainfrom
astronomer:AIP72-fix-mapping-from-task
Feb 27, 2025
Merged

Fix Task Mapping with XCOM arguments from other Tasks#47141
jedcunningham merged 3 commits intoapache:mainfrom
astronomer:AIP72-fix-mapping-from-task

Conversation

@amoghrajesh
Copy link
Contributor

@amoghrajesh amoghrajesh commented Feb 27, 2025

closes: #46580
closes: #46976

This PR fixes a bug that doesn't allow dynamic task mapping to be used with XCOM values returned from another task.

Two issues:

  • The issue was that we were not calling the "SetXcom" utility with mapped_length argument at all.
  • Second issue was that on unmapping, we were not calling with the right map_index due to which it was falling back to the "current task" map_index leading to wrong behaviour and crashes.

Testing

1. DAG mentioned in the issue:


from airflow import DAG
from airflow.decorators import task
from datetime import datetime


@task
def a_list():
    return [3, 6, 9]


@task
def ref_context(num, **context):
    return num + context["ti"].map_index * 5


@task
def assert_sum(nums, expect):
    print(nums)
    print(f"expecting sum: {' + '.join(map(str,nums))} == {expect}")
    print(sum(nums))
    assert sum(nums) == expect


with DAG(
    dag_id="context_ref",
    start_date=datetime(1970, 1, 1),
    schedule=None,
    tags=["taskmap"]
) as dag:
    assert_sum(ref_context.expand(num=a_list()), 33)

Success:
image

Task 1 xcom:
image

Task 2 mapped xcoms:
map_index = 0

image

map_index = 1
image

map_index = 2
image

Assert task

{"timestamp":"2025-02-27T10:15:54.946777Z","level":"info","event":"expecting sum: 3 + 11 + 19 == 33","chan":"stdout","logger":"task"}
image

DAG2: example_dynamic_task_mapping

from datetime import datetime

from airflow.decorators import task
from airflow.models.dag import DAG

with DAG(dag_id="example_dynamic_task_mapping", schedule=None, start_date=datetime(2022, 3, 4)) as dag:

    @task
    def add_one(x: int):
        return x + 1

    @task
    def sum_it(values):
        total = sum(values)
        print(f"Total was {total}")

    added_values = add_one.expand(x=[1, 2, 3])
    sum_it(added_values)

Success:
image

DAG3: cross product

DAG:

from datetime import datetime, timedelta

from airflow import DAG
from airflow.decorators import task
from airflow.models.taskinstance import TaskInstance

delays = [30, 60, 90]
l = [5, 10]


@task
def get_delays():
    return delays

@task
def get_delays2():
    return l

@task
def get_wakes(delay, delay2, **context):
    "Wake {delay} seconds after the task starts"
    ti: TaskInstance = context["ti"]
    return [delay, delay2]


with DAG(
    dag_id="mapping_with_xcom",
    start_date=datetime(1970, 1, 1),
    schedule=None,
    tags=["taskmap"]
) as dag:
    wake_times = get_wakes.expand(delay=get_delays(), delay2=get_delays2())

Task1:
image

Task2:
image

map_index = 0 for Task3:
image

Cross product value from the table:

5,get_wakes,2,return_value,mapping_with_xcom,manual__2025-02-27T10:08:45.638526+00:00_9SZchT5K,"""[60, 5]"""
5,get_wakes,1,return_value,mapping_with_xcom,manual__2025-02-27T10:08:45.638526+00:00_9SZchT5K,"""[30, 10]"""
5,get_wakes,3,return_value,mapping_with_xcom,manual__2025-02-27T10:08:45.638526+00:00_9SZchT5K,"""[60, 10]"""
5,get_wakes,5,return_value,mapping_with_xcom,manual__2025-02-27T10:08:45.638526+00:00_9SZchT5K,"""[90, 10]"""
5,get_wakes,0,return_value,mapping_with_xcom,manual__2025-02-27T10:08:45.638526+00:00_9SZchT5K,"""[30, 5]"""
5,get_wakes,4,return_value,mapping_with_xcom,manual__2025-02-27T10:08:45.638526+00:00_9SZchT5K,"""[90, 5]"""


^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named {pr_number}.significant.rst or {issue_number}.significant.rst, in newsfragments.

@boring-cyborg boring-cyborg bot added area:API Airflow's REST/HTTP API area:task-sdk labels Feb 27, 2025
@amoghrajesh
Copy link
Contributor Author

Also tested with the DAG from #46976.

DAG:

from __future__ import annotations

import datetime
from pathlib import Path

from airflow.decorators import task
from airflow.models.dag import DAG
from airflow.sdk import Param
from airflow.utils.trigger_rule import TriggerRule

# [START params_trigger]
with DAG(
    dag_id=Path(__file__).stem,
    dag_display_name="Params Trigger UI",
    description=__doc__.partition(".")[0],
    doc_md=__doc__,
    schedule=None,
    start_date=datetime.datetime(2022, 3, 4),
    catchup=False,
    tags=["example", "params"],
    params={
        "names": Param(
            ["Linda", "Martha", "Thomas"],
            type="array",
            description="Define the list of names for which greetings should be generated in the logs."
            " Please have one name per line.",
            title="Names to greet",
        ),
        "english": Param(True, type="boolean", title="English"),
        "german": Param(True, type="boolean", title="German (Formal)"),
        "french": Param(True, type="boolean", title="French"),
    },
) as dag:

    @task(task_id="get_names", task_display_name="Get names")
    def get_names(**kwargs) -> list[str]:
        params = kwargs["params"]
        if "names" not in params:
            print("Uuups, no names given, was no UI used to trigger?")
            return []
        return params["names"]

    @task.branch(task_id="select_languages", task_display_name="Select languages")
    def select_languages(**kwargs) -> list[str]:
        params = kwargs["params"]
        selected_languages = []
        for lang in ["english", "german", "french"]:
            if params[lang]:
                selected_languages.append(f"generate_{lang}_greeting")
        return selected_languages

    @task(task_id="generate_english_greeting", task_display_name="Generate English greeting")
    def generate_english_greeting(name: str) -> str:
        return f"Hello {name}!"

    @task(task_id="generate_german_greeting", task_display_name="Erzeuge Deutsche Begrüßung")
    def generate_german_greeting(name: str) -> str:
        return f"Sehr geehrter Herr/Frau {name}."

    @task(task_id="generate_french_greeting", task_display_name="Produire un message d'accueil en français")
    def generate_french_greeting(name: str) -> str:
        return f"Bonjour {name}!"

    @task(task_id="print_greetings", task_display_name="Print greetings", trigger_rule=TriggerRule.ALL_DONE)
    def print_greetings(greetings1, greetings2, greetings3) -> None:
        for g in greetings1 or []:
            print(g)
        for g in greetings2 or []:
            print(g)
        for g in greetings3 or []:
            print(g)
        if not (greetings1 or greetings2 or greetings3):
            print("sad, nobody to greet :-(")

    lang_select = select_languages()
    names = get_names()
    english_greetings = generate_english_greeting.expand(name=names)
    german_greetings = generate_german_greeting.expand(name=names)
    french_greetings = generate_french_greeting.expand(name=names)
    lang_select >> [english_greetings, german_greetings, french_greetings]
    results_print = print_greetings(english_greetings, german_greetings, french_greetings)
# [END params_trigger]

Works as expected:
image

XCOM table:

9,get_names,-1,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""[\""Linda\"", \""Martha\"", \""Thomas\""]"""
9,select_languages,-1,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""[\""generate_english_greeting\"", \""generate_german_greeting\"", \""generate_french_greeting\""]"""
9,generate_german_greeting,0,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Sehr geehrter Herr/Frau Linda.\"""""
9,generate_german_greeting,2,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Sehr geehrter Herr/Frau Thomas.\"""""
9,generate_french_greeting,2,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Bonjour Thomas!\"""""
9,generate_german_greeting,1,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Sehr geehrter Herr/Frau Martha.\"""""
9,generate_french_greeting,0,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Bonjour Linda!\"""""
9,generate_french_greeting,1,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Bonjour Martha!\"""""
9,generate_english_greeting,1,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Hello Martha!\"""""
9,generate_english_greeting,2,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Hello Thomas!\"""""
9,generate_english_greeting,0,return_value,example_params_trigger_ui,manual__2025-02-27T10:49:51.967849+00:00_hZl3Kekr,"""\""Hello Linda!\"""""

@ashb ashb added the area:task-execution-interface-aip72 AIP-72: Task Execution Interface (TEI) aka Task SDK label Feb 27, 2025
@amoghrajesh
Copy link
Contributor Author

I think I managed to handle task groups as well.

Example DAG:

from airflow import DAG
from airflow.decorators import task, task_group
from datetime import datetime

# Track the results for verification (only for testing purposes)
results = {}

# Expected values for reference
expected_values = {
    ("tg.t1", 0): ["a", "b"],
    ("tg.t1", 1): [4],
    ("tg.t1", 2): ["z"],
    ("tg.t2", 0): ["a", "b"],
    ("tg.t2", 1): [4],
    ("tg.t2", 2): ["z"],
    ("t3", None): [["a", "b"], [4], ["z"]],
}

# Define the DAG
with DAG(
    dag_id="tg_dag",
    start_date=datetime(2025, 1, 1),
    schedule=None,
    catchup=False,
) as dag:
    @task
    def t(value, ti=None):
        # Store results for verification
        global results
        results[(ti.task_id, ti.map_index)] = value

        print("Value is", value)

        return value

    @task
    def t(value, ti=None):
        # Store results for verification
        global results
        results[(ti.task_id, ti.map_index)] = value

        print("Value is", value)

        return value

    @task_group
    def tg(va):
        # Each expanded group has one t1 and t2 each.
        t1 = t.override(task_id="t1")(va)
        t2 = t.override(task_id="t2")(t1)
        return t2

    t2 = tg.expand(va=[["a", "b"], [4], ["z"]])

    t3 = t.override(task_id="t3")(t2)

Graph vieW:
image

Run:
image

XCOM table:

8,tg.t1,0,return_value,tg_dag,manual__2025-02-27T15:00:50.382954+00:00_BGTtj5oT,"""[\""a\"", \""b\""]"""
8,tg.t2,0,return_value,tg_dag,manual__2025-02-27T15:00:50.382954+00:00_BGTtj5oT,"""[\""a\"", \""b\""]"""
8,tg.t1,1,return_value,tg_dag,manual__2025-02-27T15:00:50.382954+00:00_BGTtj5oT,"""[4]"""
8,tg.t2,1,return_value,tg_dag,manual__2025-02-27T15:00:50.382954+00:00_BGTtj5oT,"""[4]"""
8,tg.t2,2,return_value,tg_dag,manual__2025-02-27T15:00:50.382954+00:00_BGTtj5oT,"""[\""z\""]"""
8,tg.t1,2,return_value,tg_dag,manual__2025-02-27T15:00:50.382954+00:00_BGTtj5oT,"""[\""z\""]"""

@jedcunningham jedcunningham merged commit 578ad78 into apache:main Feb 27, 2025
62 checks passed
@jedcunningham jedcunningham deleted the AIP72-fix-mapping-from-task branch February 27, 2025 16:05
@jscheffl
Copy link
Contributor

Cool! Thanks for fixing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:API Airflow's REST/HTTP API area:task-execution-interface-aip72 AIP-72: Task Execution Interface (TEI) aka Task SDK area:task-sdk

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mapped Tasks make DAG fail to display GRID in new UI Dynamic task mapping- Expand not working when provided value from a task

6 participants