Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ def _get_current_databricks_task(self) -> dict[str, Any]:

def _convert_to_databricks_workflow_task(
self,
relevant_upstreams: list[BaseOperator],
relevant_upstreams: list[str],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why this change is needed?

Copy link
Copy Markdown
Contributor Author

@Vamsi-klu Vamsi-klu May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This is correcting the annotation to match the runtime value rather than changing behavior.

relevant_upstreams is populated from task.task_id in DatabricksWorkflowTaskGroup.__exit__, so it is a list of task-id strings. _convert_to_databricks_workflow_task() then compares those strings with self.upstream_task_ids:

for task_id in self.upstream_task_ids
if task_id in relevant_upstreams

The old list[BaseOperator] annotation was misleading: passing operators there would make that membership check fail because it compares str task IDs to operator objects. The actual BaseOperator instances are still carried separately in task_dict, which is used when resolving the parent task’s Databricks task key.

So this change is mainly a type-hint cleanup that also makes the tests reflect the real call shape.

task_dict: dict[str, BaseOperator],
context: Context | None = None,
) -> dict[str, object]:
Expand Down Expand Up @@ -1621,7 +1621,7 @@ def _extend_workflow_notebook_packages(

def _convert_to_databricks_workflow_task(
self,
relevant_upstreams: list[BaseOperator],
relevant_upstreams: list[str],
task_dict: dict[str, BaseOperator],
context: Context | None = None,
) -> dict[str, object]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(
self.python_params = python_params or []
self.spark_submit_params = spark_submit_params or []
self.tasks_to_convert = tasks_to_convert or {}
self.relevant_upstreams = [task_id]
self.relevant_upstreams: list[str] = []
self.workflow_run_metadata: WorkflowRunMetadata | None = None
super().__init__(task_id=task_id, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2576,15 +2576,22 @@ def test_convert_to_databricks_workflow_task(self):
operator.task_group = databricks_workflow_task_group
operator.task_id = "test_task"
operator.upstream_task_ids = ["upstream_task"]
relevant_upstreams = [MagicMock(task_id="upstream_task")]
task_dict = {"upstream_task": MagicMock(task_id="upstream_task")}
upstream_task = DatabricksNotebookOperator(
notebook_path="/path/to/upstream",
source="WORKSPACE",
task_id="upstream_task",
dag=dag,
)
relevant_upstreams = ["upstream_task"]
task_dict = {"upstream_task": upstream_task}

task_json = operator._convert_to_databricks_workflow_task(relevant_upstreams, task_dict)

task_key = hashlib.md5(b"example_dag__test_task").hexdigest()
upstream_task_key = hashlib.md5(b"example_dag__upstream_task").hexdigest()
expected_json = {
"task_key": task_key,
"depends_on": [],
"depends_on": [{"task_key": upstream_task_key}],
"timeout_seconds": 0,
"email_notifications": {},
"notebook_task": {
Expand Down Expand Up @@ -2755,6 +2762,19 @@ def test_generate_databricks_task_key(self):
expected_task_key = hashlib.md5(task_key).hexdigest()
assert expected_task_key == operator.databricks_task_key

def test_generate_databricks_task_key_requires_task_dict_when_task_id_passed(self):
"""Looking up a parent task's key without a ``task_dict`` is a programmer error."""
operator = DatabricksTaskOperator(
task_id="test_task",
databricks_conn_id="test_conn_id",
task_config={},
)
with pytest.raises(
ValueError,
match="Must pass task_dict if task_id is provided in _generate_databricks_task_key.",
):
operator._generate_databricks_task_key(task_id="upstream_task")

def test_user_databricks_task_key(self):
task_config = {}
operator = DatabricksTaskOperator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import hashlib
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -30,6 +31,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.databricks.hooks.databricks import RunLifeCycleState
from airflow.providers.databricks.operators.databricks import DatabricksNotebookOperator
from airflow.providers.databricks.operators.databricks_workflow import (
DatabricksWorkflowTaskGroup,
WorkflowRunMetadata,
Expand Down Expand Up @@ -333,3 +335,240 @@ def test_on_kill(mock_databricks_hook, context, mock_workflow_run_metadata):
operator.on_kill()

operator._hook.cancel_run.assert_called_once_with(RUN_ID)


class TestWorkflowDependsOn:
"""End-to-end coverage that ``depends_on`` references the *parent's* ``task_key``.

Regression coverage for issue apache/airflow#47614 (root cause fixed by #48492).
Each test builds a real ``DAG`` + ``DatabricksWorkflowTaskGroup`` populated with
real ``DatabricksNotebookOperator`` tasks (no operator mocks), then drives
``_CreateDatabricksWorkflowOperator.create_workflow_json`` and asserts the
resulting ``tasks[*]['depends_on']`` payload.
"""

DAG_ID = "test_depends_on_dag"
GROUP_ID = "wf_group"
CONN_ID = "databricks_conn"

@staticmethod
def _build_notebook(task_id: str, **kwargs) -> DatabricksNotebookOperator:
return DatabricksNotebookOperator(
task_id=task_id,
notebook_path=f"/path/{task_id}",
source="WORKSPACE",
**kwargs,
)

def _expected_default_key(self, group_task_id: str) -> str:
full_task_id = f"{self.GROUP_ID}.{group_task_id}"
return hashlib.md5(f"{self.DAG_ID}__{full_task_id}".encode()).hexdigest()

def _launch_task(self, dag: DAG) -> _CreateDatabricksWorkflowOperator:
launch = dag.task_dict[f"{self.GROUP_ID}.launch"]
assert isinstance(launch, _CreateDatabricksWorkflowOperator)
return launch

@staticmethod
def _tasks_by_key(workflow_json: dict) -> dict:
return {t["task_key"]: t for t in workflow_json["tasks"]}

def test_depends_on_uses_parent_key_default_keys(self):
"""``task_A >> task_B`` — ``task_B.depends_on`` references ``task_A``'s key."""
with DAG(dag_id=self.DAG_ID, schedule=None, start_date=DEFAULT_DATE) as dag:
with DatabricksWorkflowTaskGroup(group_id=self.GROUP_ID, databricks_conn_id=self.CONN_ID):
task_a = self._build_notebook("task_a")
task_b = self._build_notebook("task_b")
task_a >> task_b

tasks_by_key = self._tasks_by_key(self._launch_task(dag).create_workflow_json())
a_key = self._expected_default_key("task_a")
b_key = self._expected_default_key("task_b")

assert set(tasks_by_key) == {a_key, b_key}
assert tasks_by_key[a_key]["depends_on"] == []
assert tasks_by_key[b_key]["depends_on"] == [{"task_key": a_key}]

def test_depends_on_uses_parent_key_custom_parent_key(self):
"""An explicit ``databricks_task_key`` on the parent flows into ``depends_on``."""
with DAG(dag_id=self.DAG_ID, schedule=None, start_date=DEFAULT_DATE) as dag:
with DatabricksWorkflowTaskGroup(group_id=self.GROUP_ID, databricks_conn_id=self.CONN_ID):
task_a = self._build_notebook("task_a", databricks_task_key="custom_a")
task_b = self._build_notebook("task_b")
task_a >> task_b

tasks_by_key = self._tasks_by_key(self._launch_task(dag).create_workflow_json())
b_key = self._expected_default_key("task_b")

assert "custom_a" in tasks_by_key
assert tasks_by_key[b_key]["depends_on"] == [{"task_key": "custom_a"}]

def test_depends_on_falls_back_to_hash_when_parent_key_too_long(self):
"""A >100-char explicit key is rejected; both task and ``depends_on`` use the hash."""
too_long_key = "x" * 101
with DAG(dag_id=self.DAG_ID, schedule=None, start_date=DEFAULT_DATE) as dag:
with DatabricksWorkflowTaskGroup(group_id=self.GROUP_ID, databricks_conn_id=self.CONN_ID):
task_a = self._build_notebook("task_a", databricks_task_key=too_long_key)
task_b = self._build_notebook("task_b")
task_a >> task_b

tasks_by_key = self._tasks_by_key(self._launch_task(dag).create_workflow_json())
a_key = self._expected_default_key("task_a")
b_key = self._expected_default_key("task_b")

assert too_long_key not in tasks_by_key
assert a_key in tasks_by_key
assert tasks_by_key[b_key]["depends_on"] == [{"task_key": a_key}]

def test_depends_on_diamond_dependency(self):
"""``A >> [B, C] >> D`` — D depends on both B and C; B and C each depend only on A."""
with DAG(dag_id=self.DAG_ID, schedule=None, start_date=DEFAULT_DATE) as dag:
with DatabricksWorkflowTaskGroup(group_id=self.GROUP_ID, databricks_conn_id=self.CONN_ID):
task_a = self._build_notebook("task_a")
task_b = self._build_notebook("task_b")
task_c = self._build_notebook("task_c")
task_d = self._build_notebook("task_d")
task_a >> [task_b, task_c] >> task_d

tasks_by_key = self._tasks_by_key(self._launch_task(dag).create_workflow_json())
a_key = self._expected_default_key("task_a")
b_key = self._expected_default_key("task_b")
c_key = self._expected_default_key("task_c")
d_key = self._expected_default_key("task_d")

assert tasks_by_key[a_key]["depends_on"] == []
assert tasks_by_key[b_key]["depends_on"] == [{"task_key": a_key}]
assert tasks_by_key[c_key]["depends_on"] == [{"task_key": a_key}]
d_parent_keys = {entry["task_key"] for entry in tasks_by_key[d_key]["depends_on"]}
assert d_parent_keys == {b_key, c_key}

def test_depends_on_fan_out_dependency(self):
"""``A >> [B, C]`` — both downstreams reference A's key only."""
with DAG(dag_id=self.DAG_ID, schedule=None, start_date=DEFAULT_DATE) as dag:
with DatabricksWorkflowTaskGroup(group_id=self.GROUP_ID, databricks_conn_id=self.CONN_ID):
task_a = self._build_notebook("task_a")
task_b = self._build_notebook("task_b")
task_c = self._build_notebook("task_c")
task_a >> [task_b, task_c]

tasks_by_key = self._tasks_by_key(self._launch_task(dag).create_workflow_json())
a_key = self._expected_default_key("task_a")
b_key = self._expected_default_key("task_b")
c_key = self._expected_default_key("task_c")

assert tasks_by_key[a_key]["depends_on"] == []
assert tasks_by_key[b_key]["depends_on"] == [{"task_key": a_key}]
assert tasks_by_key[c_key]["depends_on"] == [{"task_key": a_key}]

def test_root_tasks_have_empty_depends_on(self):
"""Root tasks' Airflow upstream is the launch task; that must never appear in ``depends_on``."""
with DAG(dag_id=self.DAG_ID, schedule=None, start_date=DEFAULT_DATE) as dag:
with DatabricksWorkflowTaskGroup(group_id=self.GROUP_ID, databricks_conn_id=self.CONN_ID):
root_a = self._build_notebook("root_a")
root_b = self._build_notebook("root_b")
self._build_notebook("downstream").set_upstream([root_a, root_b])

launch_task = self._launch_task(dag)
# Sanity: both roots actually have the launch task as an Airflow upstream.
for root_task_id in (f"{self.GROUP_ID}.root_a", f"{self.GROUP_ID}.root_b"):
assert launch_task.task_id in dag.task_dict[root_task_id].upstream_task_ids

tasks_by_key = self._tasks_by_key(launch_task.create_workflow_json())
root_a_key = self._expected_default_key("root_a")
root_b_key = self._expected_default_key("root_b")

assert tasks_by_key[root_a_key]["depends_on"] == []
assert tasks_by_key[root_b_key]["depends_on"] == []

def test_depends_on_filters_out_external_upstream(self):
"""An Airflow upstream outside the workflow group must not appear in ``depends_on``."""
with DAG(dag_id=self.DAG_ID, schedule=None, start_date=DEFAULT_DATE) as dag:
external_op = EmptyOperator(task_id="external_op")
with DatabricksWorkflowTaskGroup(group_id=self.GROUP_ID, databricks_conn_id=self.CONN_ID):
dbx_task = self._build_notebook("dbx_task")
external_op >> dbx_task

tasks_by_key = self._tasks_by_key(self._launch_task(dag).create_workflow_json())
dbx_key = self._expected_default_key("dbx_task")

assert tasks_by_key[dbx_key]["depends_on"] == []


class TestWorkflowDependsOnWirePayload:
"""Wire-boundary coverage: the spec sent to the Databricks Jobs API carries ``depends_on``.

:class:`TestWorkflowDependsOn` asserts the in-process ``create_workflow_json`` payload.
These tests assert the *wire* payload — what ``_create_or_reset_job`` actually hands to
``DatabricksHook.create_job`` (new job) or ``DatabricksHook.reset_job`` (existing job),
which is what the Databricks REST API receives.
"""

DAG_ID = "test_depends_on_wire_dag"
GROUP_ID = "wf_group"
CONN_ID = "databricks_conn"

@staticmethod
def _build_notebook(task_id: str, **kwargs) -> DatabricksNotebookOperator:
return DatabricksNotebookOperator(
task_id=task_id,
notebook_path=f"/path/{task_id}",
source="WORKSPACE",
**kwargs,
)

def _expected_default_key(self, group_task_id: str) -> str:
full_task_id = f"{self.GROUP_ID}.{group_task_id}"
return hashlib.md5(f"{self.DAG_ID}__{full_task_id}".encode()).hexdigest()

def _launch_task(self, dag: DAG) -> _CreateDatabricksWorkflowOperator:
launch = dag.task_dict[f"{self.GROUP_ID}.launch"]
assert isinstance(launch, _CreateDatabricksWorkflowOperator)
return launch

@staticmethod
def _tasks_by_key(workflow_json: dict) -> dict:
return {t["task_key"]: t for t in workflow_json["tasks"]}

def _build_two_task_dag(self) -> DAG:
with DAG(dag_id=self.DAG_ID, schedule=None, start_date=DEFAULT_DATE) as dag:
with DatabricksWorkflowTaskGroup(group_id=self.GROUP_ID, databricks_conn_id=self.CONN_ID):
task_a = self._build_notebook("task_a")
task_b = self._build_notebook("task_b")
task_a >> task_b
return dag

def _assert_parent_depends_on(self, job_spec: dict) -> None:
tasks_by_key = self._tasks_by_key(job_spec)
a_key = self._expected_default_key("task_a")
b_key = self._expected_default_key("task_b")

assert len(job_spec["tasks"]) == 2
assert set(tasks_by_key) == {a_key, b_key}
assert tasks_by_key[a_key]["depends_on"] == []
assert tasks_by_key[b_key]["depends_on"] == [{"task_key": a_key}]

def test_create_job_payload_carries_parent_depends_on(self, mock_databricks_hook):
"""No existing job → ``create_job`` receives a spec whose ``depends_on`` references the parent key."""
launch_task = self._launch_task(self._build_two_task_dag())
launch_task._hook.list_jobs.return_value = []
launch_task._hook.create_job.return_value = 999

launch_task._create_or_reset_job(context=MagicMock())

launch_task._hook.create_job.assert_called_once()
launch_task._hook.reset_job.assert_not_called()
(job_spec,) = launch_task._hook.create_job.call_args.args
self._assert_parent_depends_on(job_spec)

def test_reset_job_payload_carries_parent_depends_on(self, mock_databricks_hook):
"""Existing job → ``reset_job`` receives a spec whose ``depends_on`` references the parent key."""
launch_task = self._launch_task(self._build_two_task_dag())
launch_task._hook.list_jobs.return_value = [{"job_id": 42}]

launch_task._create_or_reset_job(context=MagicMock())

launch_task._hook.reset_job.assert_called_once()
launch_task._hook.create_job.assert_not_called()
job_id, job_spec = launch_task._hook.reset_job.call_args.args
assert job_id == 42
self._assert_parent_depends_on(job_spec)
Loading