Skip to content
12 changes: 11 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,17 @@ class PutVariable(BaseModel):
type: Literal["PutVariable"] = "PutVariable"


class SetRenderedFields(BaseModel):
"""Payload for setting RTIF for a task instance."""

# We are using a BaseModel here compared to server using RootModel because we
# have a discriminator running with "type", and RootModel doesn't support type

rendered_fields: dict[str, str | None]
type: Literal["SetRenderedFields"] = "SetRenderedFields"


ToSupervisor = Annotated[
Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask, PutVariable, SetXCom],
Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask, PutVariable, SetXCom, SetRenderedFields],
Field(discriminator="type"),
]
27 changes: 25 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@

from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState, ToSupervisor, ToTask
from airflow.sdk.execution_time.comms import (
DeferTask,
SetRenderedFields,
StartupDetails,
TaskState,
ToSupervisor,
ToTask,
)

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand Down Expand Up @@ -136,11 +143,27 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
# TODO: set the "magic loop" context vars for parsing
ti = parse(msg)
log.debug("DAG file parsed", file=msg.file)
return ti, log
else:
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")

# TODO: Render fields here
# 1. Implementing the part where we pull in the logic to render fields and add that here
# for all operators, we should do setattr(task, templated_field, rendered_templated_field)
# task.templated_fields should give all the templated_fields and each of those fields should
# give the rendered values.

# 2. Once rendered, we call the `set_rtif` API to store the rtif in the metadata DB
templated_fields = ti.task.template_fields
payload = {}

for field in templated_fields:
if field not in payload:
payload[field] = getattr(ti.task, field)

# so that we do not call the API unnecessarily
if payload:
SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=payload))
return ti, log


def run(ti: RuntimeTaskInstance, log: Logger):
Expand Down
37 changes: 37 additions & 0 deletions task_sdk/tests/dags/basic_templated_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.standard.operators.bash import BashOperator
from airflow.sdk.definitions.dag import dag

default_args = {
"owner": "airflow",
"depends_on_past": False,
"retries": 1,
}


@dag()
def basic_templated_dag():
BashOperator(
task_id="task1",
bash_command="echo 'Logical date is {{ logical_date }}'",
)


basic_templated_dag()
31 changes: 29 additions & 2 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run, startup
from airflow.utils import timezone


Expand Down Expand Up @@ -159,3 +159,30 @@ def test_run_basic_skipped(test_dags_dir: Path, time_machine):
mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant), log=mock.ANY
)


def test_startup_basic_templated_dag(test_dags_dir: Path):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="task1", dag_id="basic_templated_dag", run_id="c", try_number=1),
file=str(test_dags_dir / "basic_templated_dag.py"),
requests_fd=0,
)
parse(what)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = what
startup()

mock_supervisor_comms.send_request.assert_called_once_with(
msg=SetRenderedFields(
rendered_fields={
"bash_command": "echo 'Logical date is {{ logical_date }}'",
"cwd": None,
"env": None,
}
),
log=mock.ANY,
)