Skip to content

Commit

Permalink
treeview - deterministic and new getter (apache#37162)
Browse files Browse the repository at this point in the history
* treeview - determinist and new getter

* review 1

---------

Co-authored-by: raphaelauv <raphaelauv@users.noreply.github.com>
  • Loading branch information
2 people authored and abhishekbhakat committed Mar 5, 2024
1 parent 475b0e3 commit 806a914
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
25 changes: 18 additions & 7 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Callable,
Collection,
Container,
Generator,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -2627,15 +2628,25 @@ def pickle(self, session=NEW_SESSION) -> DagPickle:

def tree_view(self) -> None:
"""Print an ASCII tree representation of the DAG."""
for tmp in self._generate_tree_view():
print(tmp)

def get_downstream(task, level=0):
print((" " * level * 4) + str(task))
def _generate_tree_view(self) -> Generator[str, None, None]:
def get_downstream(task, level=0) -> Generator[str, None, None]:
yield (" " * level * 4) + str(task)
level += 1
for t in task.downstream_list:
get_downstream(t, level)

for t in self.roots:
get_downstream(t)
for tmp_task in sorted(task.downstream_list, key=lambda x: x.task_id):
yield from get_downstream(tmp_task, level)

for t in sorted(self.roots, key=lambda x: x.task_id):
yield from get_downstream(t)

def get_tree_view(self) -> str:
"""Return an ASCII tree representation of the DAG."""
rst = ""
for tmp in self._generate_tree_view():
rst += tmp + "\n"
return rst

@property
def task(self) -> TaskDecoratorCollection:
Expand Down
17 changes: 14 additions & 3 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,19 +1422,30 @@ def test_leaves(self):
def test_tree_view(self):
"""Verify correctness of dag.tree_view()."""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = EmptyOperator(task_id="t1")
op1_a = EmptyOperator(task_id="t1_a")
op1_b = EmptyOperator(task_id="t1_b")
op2 = EmptyOperator(task_id="t2")
op3 = EmptyOperator(task_id="t3")
op1 >> op2 >> op3
op1_b >> op2
op1_a >> op2 >> op3

with redirect_stdout(StringIO()) as stdout:
dag.tree_view()
stdout = stdout.getvalue()

stdout_lines = stdout.splitlines()
assert "t1" in stdout_lines[0]
assert "t1_a" in stdout_lines[0]
assert "t2" in stdout_lines[1]
assert "t3" in stdout_lines[2]
assert "t1_b" in stdout_lines[3]
assert dag.get_tree_view() == (
"<Task(EmptyOperator): t1_a>\n"
" <Task(EmptyOperator): t2>\n"
" <Task(EmptyOperator): t3>\n"
"<Task(EmptyOperator): t1_b>\n"
" <Task(EmptyOperator): t2>\n"
" <Task(EmptyOperator): t3>\n"
)

def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
Expand Down

0 comments on commit 806a914

Please sign in to comment.