Skip to content

Commit

Permalink
fix: color options now work as intended and minimum and maximum recep…
Browse files Browse the repository at this point in the history
…tive field size are now depicted. Receptive field plotting in the graph can be disabled.
  • Loading branch information
MLRichter committed Jan 18, 2022
1 parent 6957dcf commit 26f9ab8
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions rfa_toolbox/vizualize.py
Expand Up @@ -24,6 +24,7 @@ def visualize_node(
input_res: int,
color_border: bool,
color_critical: bool,
include_rf_info: bool = True,
) -> None:
"""Create a node in a graphviz-graph based on an EnrichedNetworkNode instance.
Also creates all edges that lead from predecessor nodes to this node.
Expand All @@ -35,26 +36,33 @@ def visualize_node(
coloring critical and border layers
color_border: The color used for marking border layer
color_critical: The color used for marking critical layers
include_rf_info: If True the receptive field information is
included in the node description
Returns:
Nothing.
"""
color = "white"
if node.is_border(input_resolution=input_res):
if node.is_border(input_resolution=input_res) and color_border:
color = "red"
elif (
np.all(np.asarray(node.receptive_field_min) > np.asarray(input_res))
and color_critical
and not node.is_border(input_resolution=input_res)
):
color = "orange"
elif (
np.any(np.asarray(node.receptive_field_min) > np.asarray(input_res))
and color_critical
and not node.is_border(input_resolution=input_res)
):
color = "yellow"
l_name = node.layer_info.name
rf_info = "\\n" + f"r={node.receptive_field_min}"
rf_info = (
"\\n" + f"r(min)={node.receptive_field_min}, r(max)={node.receptive_field_max}"
)

filters = f"\\n{node.layer_info.filters} filters"
units = f"\\n{node.layer_info.units} units"

Expand All @@ -63,7 +71,8 @@ def visualize_node(
label += filters
elif node.layer_info.units is not None:
label += units
label += rf_info
if include_rf_info:
label += rf_info

dot.node(
f"{node.name}-{id(node)}",
Expand All @@ -81,6 +90,7 @@ def visualize_architecture(
input_res: int = 224,
color_critical: bool = True,
color_border: bool = True,
include_rf_info: bool = True,
) -> graphviz.Digraph:
"""Visualize an architecture using graphviz
and mark critical and border layers in the graph visualization.
Expand All @@ -95,6 +105,8 @@ def visualize_architecture(
critical and border layers)
color_critical: if True the critical layers are colored orange, True by default.
color_border: if True the border layers are colored red, True by default.
include_rf_info: if True the receptive field information is included in the node
description
Returns:
A graphviz.Digraph object that can visualize the network architecture.
Expand All @@ -111,6 +123,7 @@ def visualize_architecture(
input_res=input_res,
color_border=color_border,
color_critical=color_critical,
include_rf_info=include_rf_info,
)
return f

Expand Down

0 comments on commit 26f9ab8

Please sign in to comment.