diff --git a/task-sdk/src/airflow/sdk/bases/notifier.py b/task-sdk/src/airflow/sdk/bases/notifier.py index 7cfed6ae97efb..809dd1442baaf 100644 --- a/task-sdk/src/airflow/sdk/bases/notifier.py +++ b/task-sdk/src/airflow/sdk/bases/notifier.py @@ -87,9 +87,14 @@ def render_template_fields( :param context: Context dict with values to apply on content. :param jinja_env: Jinja environment to use for rendering. """ + if not self.template_fields: + return dag = context.get("dag") if not jinja_env: - jinja_env = self.get_template_env(dag=dag) + if dag is not None and hasattr(dag, "get_template_env"): + jinja_env = self.get_template_env(dag=dag) + else: + jinja_env = self.get_template_env() self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) async def async_notify(self, context: Context) -> None: diff --git a/task-sdk/tests/task_sdk/bases/test_notifier.py b/task-sdk/tests/task_sdk/bases/test_notifier.py index b8cedaa518831..99927d33359fb 100644 --- a/task-sdk/tests/task_sdk/bases/test_notifier.py +++ b/task-sdk/tests/task_sdk/bases/test_notifier.py @@ -24,6 +24,7 @@ import pytest from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.serialization.serialized_objects import SerializedDAG from airflow.sdk.bases.notifier import BaseNotifier from airflow.sdk.definitions.dag import DAG @@ -45,6 +46,13 @@ def notify(self, context: Context) -> None: pass +class NoTemplateNotifier(BaseNotifier): + """Notifier used to verify callbacks without templated fields.""" + + def notify(self, context: Context) -> None: + pass + + class TestBaseNotifier: def test_render_message_with_message(self): with DAG("test_render_message_with_message") as dag: @@ -97,3 +105,28 @@ def test_notifier_call_with_prepared_context(self, caplog): } ) assert notifier.message == "task: some_task" + + def test_notifier_call_with_serialized_dag_and_no_template_fields(self): + with DAG("test_notifier_call_with_serialized_dag_and_no_template_fields") as dag: + EmptyOperator(task_id="test_id") + + notifier = NoTemplateNotifier() + notifier.notify = MagicMock() + serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + context: Context = {"dag": serialized_dag} + + notifier(context) + + notifier.notify.assert_called_once_with({"dag": serialized_dag}) + + def test_notifier_render_template_fields_with_serialized_dag(self): + with DAG("test_notifier_render_template_fields_with_serialized_dag") as dag: + EmptyOperator(task_id="test_id") + + notifier = MockNotifier(message="Hello {{ dag.dag_id }}") + serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + context: Context = {"dag": serialized_dag} + + notifier.render_template_fields(context) + + assert notifier.message == "Hello test_notifier_render_template_fields_with_serialized_dag"