diff --git a/rfa_toolbox/encodings/pytorch/ingest_architecture.py b/rfa_toolbox/encodings/pytorch/ingest_architecture.py index e189ed3..0322828 100644 --- a/rfa_toolbox/encodings/pytorch/ingest_architecture.py +++ b/rfa_toolbox/encodings/pytorch/ingest_architecture.py @@ -180,28 +180,28 @@ def is_relevant_type(t): ): # go into subgraph sub_prefix = prefix + submodule_name + "." - # with dot.subgraph(name="cluster_" + name) as sub_dot: - # sub_dot.attr(label=label) - submod = mod - # iterate to the lowest submodule hirarchy - for i, k in enumerate(submodule_name.split(".")): - submod = getattr(submod, k) - # create subgraph for the submodule - make_graph( - submod, - dot=dot, - prefix=sub_prefix, - input_preds=[preds[i] for i in list(n.inputs())[1:]], - parent_dot=dot, - classes_to_visit=classes_to_visit, - classes_found=classes_found, - classes_to_not_visit=classes_to_not_visit, - ) - # creating a mapping from the c-values - # to the output of the respective subgraph - for i, o in enumerate(n.outputs()): - # print(i, sub_prefix + f'out_{i}', type(o)) - preds[o] = {sub_prefix + f"out_{i}"}, set() + with dot.subgraph(name="cluster_" + name) as sub_dot: + # sub_dot.attr(label=label) + submod = mod + # iterate to the lowest submodule hirarchy + for i, k in enumerate(submodule_name.split(".")): + submod = getattr(submod, k) + # create subgraph for the submodule + make_graph( + submod, + dot=sub_dot, + prefix=sub_prefix, + input_preds=[preds[i] for i in list(n.inputs())[1:]], + parent_dot=dot, + classes_to_visit=classes_to_visit, + classes_found=classes_found, + classes_to_not_visit=classes_to_not_visit, + ) + # creating a mapping from the c-values + # to the output of the respective subgraph + for i, o in enumerate(n.outputs()): + # print(i, sub_prefix + f'out_{i}', type(o)) + preds[o] = {sub_prefix + f"out_{i}"}, set() else: # here the basic node (Conv2D, BatchNorm etc.) are created. dot.node(name, label=label, shape="box")