Skip to content

Commit

Permalink
Merge pull request #1782 from PrefectHQ/viz-fix
Browse files Browse the repository at this point in the history
Visualization Fix
  • Loading branch information
cicdw committed Nov 30, 2019
2 parents d91c5eb + 402638f commit e9a4331
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -18,7 +18,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

### Fixes

- None
- Fix issue with `flow.visualize()` for mapped tasks which are skipped - [#1765](https://github.com/PrefectHQ/prefect/issues/1765)

### Deprecations

Expand Down
41 changes: 28 additions & 13 deletions src/prefect/core/flow.py
Expand Up @@ -1086,13 +1086,21 @@ def get_color(task: Task, map_index: int = None) -> str:
name = "{} <map>".format(t.name) if is_mapped else t.name
if is_mapped and flow_state:
assert isinstance(flow_state.result, dict)
for map_index, _ in enumerate(flow_state.result[t].map_states):
if flow_state.result[t].is_mapped():
for map_index, _ in enumerate(flow_state.result[t].map_states):
kwargs = dict(
color=get_color(t, map_index=map_index),
style="filled",
colorscheme="svg",
)
graph.node(
str(id(t)) + str(map_index), name, shape=shape, **kwargs
)
else:
kwargs = dict(
color=get_color(t, map_index=map_index),
style="filled",
colorscheme="svg",
color=get_color(t), style="filled", colorscheme="svg",
)
graph.node(str(id(t)) + str(map_index), name, shape=shape, **kwargs)
graph.node(str(id(t)), name, shape=shape, **kwargs)
else:
kwargs = (
{}
Expand All @@ -1108,15 +1116,22 @@ def get_color(task: Task, map_index: int = None) -> str:
or any(edge.mapped for edge in self.edges_to(e.downstream_task))
) and flow_state:
assert isinstance(flow_state.result, dict)
for map_index, _ in enumerate(
flow_state.result[e.downstream_task].map_states
):
upstream_id = str(id(e.upstream_task))
if any(edge.mapped for edge in self.edges_to(e.upstream_task)):
upstream_id += str(map_index)
down_state = flow_state.result[e.downstream_task]
if down_state.is_mapped():
for map_index, _ in enumerate(down_state.map_states):
upstream_id = str(id(e.upstream_task))
if any(edge.mapped for edge in self.edges_to(e.upstream_task)):
upstream_id += str(map_index)
graph.edge(
upstream_id,
str(id(e.downstream_task)) + str(map_index),
e.key,
style=style,
)
else:
graph.edge(
upstream_id,
str(id(e.downstream_task)) + str(map_index),
str(id(e.upstream_task)),
str(id(e.downstream_task)),
e.key,
style=style,
)
Expand Down
19 changes: 19 additions & 0 deletions tests/core/test_flow.py
Expand Up @@ -1041,6 +1041,25 @@ def test_viz_reflects_mapping(self):
"label=y style=dashed" not in graph.source
) # constants are no longer represented

def test_viz_can_handle_skipped_mapped_tasks(self):
ipython = MagicMock(
get_ipython=lambda: MagicMock(config=dict(IPKernelApp=True))
)
with patch.dict("sys.modules", IPython=ipython):
with Flow(name="test") as f:
t = Task(name="a_list_task")
res = AddTask(name="a_nice_task").map(x=t, y=8)

graph = f.visualize(
flow_state=Success(result={t: Success(), res: Skipped()})
)
assert 'label="a_nice_task <map>" color="#62757f80"' in graph.source
assert 'label=a_list_task color="#28a74580"' in graph.source
assert "label=x style=dashed" in graph.source
assert (
"label=y style=dashed" not in graph.source
) # constants are no longer represented

@pytest.mark.parametrize("state", [Success(), Failed(), Skipped()])
def test_viz_if_flow_state_provided(self, state):
import graphviz
Expand Down

0 comments on commit e9a4331

Please sign in to comment.