Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualization Fix #1782

Merged
merged 3 commits into from Nov 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -18,7 +18,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

### Fixes

- None
- Fix issue with `flow.visualize()` for mapped tasks which are skipped - [#1765](https://github.com/PrefectHQ/prefect/issues/1765)

### Deprecations

Expand Down
41 changes: 28 additions & 13 deletions src/prefect/core/flow.py
Expand Up @@ -1086,13 +1086,21 @@ def get_color(task: Task, map_index: int = None) -> str:
name = "{} <map>".format(t.name) if is_mapped else t.name
if is_mapped and flow_state:
assert isinstance(flow_state.result, dict)
for map_index, _ in enumerate(flow_state.result[t].map_states):
if flow_state.result[t].is_mapped():
for map_index, _ in enumerate(flow_state.result[t].map_states):
kwargs = dict(
color=get_color(t, map_index=map_index),
style="filled",
colorscheme="svg",
)
graph.node(
str(id(t)) + str(map_index), name, shape=shape, **kwargs
)
else:
kwargs = dict(
color=get_color(t, map_index=map_index),
style="filled",
colorscheme="svg",
color=get_color(t), style="filled", colorscheme="svg",
)
graph.node(str(id(t)) + str(map_index), name, shape=shape, **kwargs)
graph.node(str(id(t)), name, shape=shape, **kwargs)
else:
kwargs = (
{}
Expand All @@ -1108,15 +1116,22 @@ def get_color(task: Task, map_index: int = None) -> str:
or any(edge.mapped for edge in self.edges_to(e.downstream_task))
) and flow_state:
assert isinstance(flow_state.result, dict)
for map_index, _ in enumerate(
flow_state.result[e.downstream_task].map_states
):
upstream_id = str(id(e.upstream_task))
if any(edge.mapped for edge in self.edges_to(e.upstream_task)):
upstream_id += str(map_index)
down_state = flow_state.result[e.downstream_task]
if down_state.is_mapped():
for map_index, _ in enumerate(down_state.map_states):
upstream_id = str(id(e.upstream_task))
if any(edge.mapped for edge in self.edges_to(e.upstream_task)):
upstream_id += str(map_index)
graph.edge(
upstream_id,
str(id(e.downstream_task)) + str(map_index),
e.key,
style=style,
)
else:
graph.edge(
upstream_id,
str(id(e.downstream_task)) + str(map_index),
str(id(e.upstream_task)),
str(id(e.downstream_task)),
e.key,
style=style,
)
Expand Down
19 changes: 19 additions & 0 deletions tests/core/test_flow.py
Expand Up @@ -1041,6 +1041,25 @@ def test_viz_reflects_mapping(self):
"label=y style=dashed" not in graph.source
) # constants are no longer represented

def test_viz_can_handle_skipped_mapped_tasks(self):
ipython = MagicMock(
get_ipython=lambda: MagicMock(config=dict(IPKernelApp=True))
)
with patch.dict("sys.modules", IPython=ipython):
with Flow(name="test") as f:
t = Task(name="a_list_task")
res = AddTask(name="a_nice_task").map(x=t, y=8)

graph = f.visualize(
flow_state=Success(result={t: Success(), res: Skipped()})
)
assert 'label="a_nice_task <map>" color="#62757f80"' in graph.source
assert 'label=a_list_task color="#28a74580"' in graph.source
assert "label=x style=dashed" in graph.source
assert (
"label=y style=dashed" not in graph.source
) # constants are no longer represented

@pytest.mark.parametrize("state", [Success(), Failed(), Skipped()])
def test_viz_if_flow_state_provided(self, state):
import graphviz
Expand Down