Skip to content

Commit

Permalink
Add SharedVariable.default_update graphs to debugprint
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 7, 2023
1 parent 5f21ff0 commit a5008d1
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 16 deletions.
103 changes: 90 additions & 13 deletions aesara/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def debugprint(
print_destroy_map: bool = False,
print_view_map: bool = False,
print_fgraph_inputs: bool = False,
print_default_updates: bool = False,
ids: Optional[IDTypesType] = None,
) -> Union[str, TextIO]:
r"""Print a graph as text.
Expand Down Expand Up @@ -177,6 +178,8 @@ def debugprint(
Whether to print the `view_map`\s of printed objects
print_fgraph_inputs
Print the inputs of `FunctionGraph`\s.
print_default_updates
Print the `SharedVariable.default_update` values.
Returns
-------
Expand Down Expand Up @@ -263,6 +266,7 @@ def debugprint(
raise TypeError(f"debugprint cannot print an object type {type(obj)}")

inner_graph_vars: List[Variable] = []
default_updates: List[Variable] = []

if any(p for p in profile_list if p is not None and p.fct_callcount > 0):
print(
Expand Down Expand Up @@ -297,14 +301,16 @@ def debugprint(
print_type=print_type,
file=_file,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
stop_on_name=stop_on_name,
used_ids=used_ids,
op_information=op_information,
parent_node=var.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

for var, profile, storage_map, topo_order in zip(
Expand All @@ -325,7 +331,7 @@ def debugprint(
file=_file,
topo_order=topo_order,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
stop_on_name=stop_on_name,
profile=profile,
storage_map=storage_map,
Expand All @@ -335,6 +341,8 @@ def debugprint(
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

if len(inner_graph_vars) > 0:
Expand Down Expand Up @@ -384,7 +392,7 @@ def debugprint(
print_type=print_type,
file=_file,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
stop_on_name=stop_on_name,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
Expand All @@ -393,6 +401,8 @@ def debugprint(
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

if print_fgraph_inputs:
Expand All @@ -406,7 +416,7 @@ def debugprint(
file=_file,
id_type=id_type,
stop_on_name=stop_on_name,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
Expand All @@ -415,6 +425,8 @@ def debugprint(
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=ig_var.owner,
print_default_updates=print_default_updates,
default_updates=default_updates,
)
inner_to_outer_inputs = None

Expand All @@ -436,7 +448,7 @@ def debugprint(
id_type=id_type,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
Expand All @@ -445,8 +457,43 @@ def debugprint(
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=ig_var.owner,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

if len(default_updates) > 0:
print("", file=_file)
print("Default updates:", file=_file)

inner_to_outer_inputs = {}

for var in default_updates:

print("", file=_file)

update_var = var.default_update
inner_to_outer_inputs[update_var] = var

_debugprint(
update_var,
depth=depth,
done=done,
print_type=print_type,
file=_file,
id_type=id_type,
inner_graph_vars=inner_graph_vars,
stop_on_name=stop_on_name,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
parent_node=None,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

if file is _file:
return file
elif file == "str":
Expand All @@ -470,7 +517,7 @@ def _debugprint(
id_type: IDTypesType = "CHAR",
stop_on_name: bool = False,
prefix_child: Optional[str] = None,
inner_graph_ops: Optional[List[Variable]] = None,
inner_graph_vars: Optional[List[Variable]] = None,
profile: Optional[ProfileStats] = None,
inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None,
storage_map: Optional[StorageMapType] = None,
Expand All @@ -479,6 +526,8 @@ def _debugprint(
parent_node: Optional[Apply] = None,
print_op_info: bool = False,
inner_graph_node: Optional[Apply] = None,
print_default_updates: bool = False,
default_updates: Optional[List[Variable]] = None,
) -> TextIO:
r"""Print the graph represented by `var`.
Expand Down Expand Up @@ -506,8 +555,8 @@ def _debugprint(
See `debugprint`.
stop_on_name
Whether to print `Op` ``view_map``\s.
inner_graph_ops
A list of `Op`\s with inner graphs.
inner_graph_vars
A list of `Variables`\s with inner graphs.
inner_to_outer_inputs
A dictionary mapping an `Op`'s inner-inputs to its outer-inputs.
storage_map
Expand All @@ -522,6 +571,10 @@ def _debugprint(
See `debugprint`.
inner_graph_node
The inner-graph node in which `var` is contained.
print_default_updates
Print the `SharedVariable.default_update` values.
default_updates
A list of `Variables`\s with default updates.
"""
if depth == 0:
return file
Expand All @@ -534,8 +587,11 @@ def _debugprint(
else:
_done = done

if inner_graph_ops is None:
inner_graph_ops = []
if inner_graph_vars is None:
inner_graph_vars = []

if default_updates is None:
default_updates = []

if print_type:
type_str = f" <{var.type}>"
Expand Down Expand Up @@ -664,9 +720,9 @@ def get_id_str(
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if (
isinstance(in_var.owner.op, HasInnerGraph)
and in_var not in inner_graph_ops
and in_var not in inner_graph_vars
):
inner_graph_ops.append(in_var)
inner_graph_vars.append(in_var)

_debugprint(
in_var,
Expand All @@ -679,7 +735,7 @@ def get_id_str(
id_type=id_type,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
inner_graph_ops=inner_graph_ops,
inner_graph_vars=inner_graph_vars,
profile=profile,
inner_to_outer_inputs=inner_to_outer_inputs,
storage_map=storage_map,
Expand All @@ -690,6 +746,8 @@ def get_id_str(
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=inner_graph_node,
print_default_updates=print_default_updates,
default_updates=default_updates,
)
else:

Expand All @@ -705,6 +763,25 @@ def get_id_str(

var_output = f"{prefix}{var}{id_str}{type_str}{data}"

# `SharedVariable`s with default updates are considered "inner-graph" variables
if (
print_default_updates
and isinstance(var, SharedVariable)
and var.default_update is not None
):
update_obj = (
var.default_update
if var.default_update.owner is None
else var.default_update.owner
)
update_obj_id = get_id_str(update_obj)
var_output = f"{var_output} <- {update_obj_id}"

# We still want to print the graph later
if var not in default_updates:
default_updates.append(var)
del _done[update_obj]

if print_op_info and var.owner and var.owner not in op_information:
op_information.update(op_debug_information(var.owner.op, var.owner))

Expand Down
4 changes: 2 additions & 2 deletions tests/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def is_variable(x):


class MyType(Type):
def filter(self, data):
def filter(self, data, **kwargs):
return data

def __eq__(self, other):
Expand All @@ -27,7 +27,7 @@ def __repr__(self):


class MyType2(Type):
def filter(self, data):
def filter(self, data, **kwargs):
return data

def __eq__(self, other):
Expand Down
93 changes: 92 additions & 1 deletion tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aesara
from aesara.compile.mode import get_mode
from aesara.compile.ops import deep_copy_op
from aesara.compile.sharedvalue import SharedVariable
from aesara.printing import (
PatternPrinter,
PPrinter,
Expand All @@ -25,7 +26,7 @@
)
from aesara.tensor import as_tensor_variable
from aesara.tensor.type import dmatrix, dvector, matrix
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
from tests.graph.utils import MyInnerGraphOp, MyOp, MyType, MyVariable


@pytest.mark.skipif(not pydot_imported, reason="pydot not available")
Expand Down Expand Up @@ -450,3 +451,93 @@ def test_Print(capsys):

stdout, stderr = capsys.readouterr()
assert "hello" in stdout


def test_debugprint_default_updates():

op1 = MyOp("op1")
op2 = MyOp("op2")

r1 = MyVariable("1")
s1 = SharedVariable(MyType(), None, None, name="s1")
s2 = SharedVariable(MyType(), None, None, name="s2")

s1.default_update = op1(r1, s2)
s2.default_update = op1(r1, s1)

out = op2(r1, s1)
out.name = "o1"

s = StringIO()
debugprint(out, file=s, print_default_updates=True)
s = s.getvalue()

reference = dedent(
r"""
op2 [id A] 'o1'
|1 [id B]
|s1 [id C] <- [id D]
Default updates:
op1 [id D]
|1 [id B]
|s2 [id E] <- [id F]
op1 [id F]
|1 [id B]
|s1 [id C] <- [id D]
"""
).lstrip()

assert s == reference


def test_debugprint_inner_graph_default_updates():
"""Test for updates on shared variables in an `OpFromGraph`."""

r1 = MyVariable("1")
r2 = MyVariable("2")
o1 = MyOp("op1")(r1, r2)
o1.name = "o1"

# Inner graph
igo_in_1 = MyVariable("4")
igo_in_s = SharedVariable(MyType(), None, None, name="s")
igo_in_s.default_update = o1
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_s)
igo_out_1.name = "igo1"

from aesara.compile.builders import OpFromGraph

igo = OpFromGraph([igo_in_1], [igo_out_1])

r3 = MyVariable("3")
out = igo(r3)

s = StringIO()
debugprint(out, file=s, print_default_updates=True)
s = s.getvalue()

reference = dedent(
r"""
OpFromGraph{inline=False} [id A]
|3 [id B]
|s [id C] <- [id D]
Inner graphs:
OpFromGraph{inline=False} [id A]
>op2 [id E] 'igo1'
> |*0-<MyType()> [id F]
> |*1-<MyType()> [id G]
Default updates:
op1 [id D] 'o1'
|1 [id H]
|2 [id I]
"""
).lstrip()

assert s == reference

0 comments on commit a5008d1

Please sign in to comment.