Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion task-sdk/src/airflow/sdk/bases/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ def render_template_fields(
:param context: Context dict with values to apply on content.
:param jinja_env: Jinja environment to use for rendering.
"""
if not self.template_fields:
return
dag = context.get("dag")
if not jinja_env:
jinja_env = self.get_template_env(dag=dag)
if dag is not None and hasattr(dag, "get_template_env"):
jinja_env = self.get_template_env(dag=dag)
else:
jinja_env = self.get_template_env()
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())

async def async_notify(self, context: Context) -> None:
Expand Down
33 changes: 33 additions & 0 deletions task-sdk/tests/task_sdk/bases/test_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest

from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.sdk.bases.notifier import BaseNotifier
from airflow.sdk.definitions.dag import DAG

Expand All @@ -45,6 +46,13 @@ def notify(self, context: Context) -> None:
pass


class NoTemplateNotifier(BaseNotifier):
"""Notifier used to verify callbacks without templated fields."""

def notify(self, context: Context) -> None:
pass


class TestBaseNotifier:
def test_render_message_with_message(self):
with DAG("test_render_message_with_message") as dag:
Expand Down Expand Up @@ -97,3 +105,28 @@ def test_notifier_call_with_prepared_context(self, caplog):
}
)
assert notifier.message == "task: some_task"

def test_notifier_call_with_serialized_dag_and_no_template_fields(self):
with DAG("test_notifier_call_with_serialized_dag_and_no_template_fields") as dag:
EmptyOperator(task_id="test_id")

notifier = NoTemplateNotifier()
notifier.notify = MagicMock()
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
context: Context = {"dag": serialized_dag}

notifier(context)

notifier.notify.assert_called_once_with({"dag": serialized_dag})

def test_notifier_render_template_fields_with_serialized_dag(self):
with DAG("test_notifier_render_template_fields_with_serialized_dag") as dag:
EmptyOperator(task_id="test_id")

notifier = MockNotifier(message="Hello {{ dag.dag_id }}")
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
context: Context = {"dag": serialized_dag}

notifier.render_template_fields(context)

assert notifier.message == "Hello test_notifier_render_template_fields_with_serialized_dag"
Loading