In [17]:
import torch
import numpy as np
import igraph as ig
from recomb.cx import CXN, compute_cx_module_ordering_and_out_switch
from pathlib import Path

In [9]:
base_path = Path("<base directory to pre-stitched models here>")

In [16]:
def get_annotated_cx_connectivity_graph(cx_net):
    g = cx_net.graph.copy()
    o = g.topological_sorting()
    idxs_to_remove = []
    edges_to_add = []
    edge_labels_to_add = []
    for i in o:
        vi = g.vs[i]

        is_edge_case = False
        if len(vi.in_edges()) == 0:
            # edge case - input node
            vi["cxs"] = set([vi.index])
            is_edge_case = True
        
        if len(vi.out_edges())  == 0 and vi["module"] < 0:
            # edge case - output node
            vi["cxs"] = set([vi.index])
            is_edge_case = True
        
        cxs_in = set()
        for e in vi.in_edges():
            cxs_in.update(g.vs[e.source]["cxs"])

        if isinstance(cx_net.submodules[vi["module"]], CXN):
            # Determine by socket, so that we can associate connections with sockets ('annotated')
            cxs_in_by_socket = {}
            for e in vi.in_edges():
                cx_e = cxs_in_by_socket.get(e["socket"])
                if cx_e is None:
                    cx_e = set()
                cx_e.update(g.vs[e.source]["cxs"])
                cxs_in_by_socket[e["socket"]] = cx_e
            
            edges_to_add += [(s, i) for (sk, ss) in cxs_in_by_socket.items() for s in ss]
            edge_labels_to_add += [sk for (sk, ss) in cxs_in_by_socket.items() for s in ss]
            vi["cxs"] = set([i])
        elif not is_edge_case:
            vi["cxs"] = cxs_in
            idxs_to_remove.append(i)

    g.add_edges(edges_to_add, attributes={"socket": edge_labels_to_add})
    g.delete_vertices(idxs_to_remove)
    return g

def embed_genotype_indices(stitchnet, mo):
    # Embed variable indices into graph (the vertex data should be preserved!)
    module_position_mapping = {m: i for i, m in enumerate(mo)}
    for v in stitchnet.graph.vs:
        v["genotype_index"] = module_position_mapping.get(v["module"])

def davuag_process_edge(graph, genotype, e):
    """
    Determine whether an edge e is (potentially) used according to the annotated graph, using
    the given genotype.

    Used in determine_active_variables_using_annot_graph
    """
    target_vertex = graph.vs[e.target]
    related_genotype_pos = target_vertex["genotype_index"]
    # Keep edges for edge cases (they are not impacted)
    if related_genotype_pos is None or not np.isfinite(related_genotype_pos): return True
    assert (related_genotype_pos is not None) or target_vertex["module"] < 0, "Output should not have a genotype position"
    related_genotype_pos = int(related_genotype_pos)
    # Keep edges where socket matches (remove otherwise)
    return genotype[related_genotype_pos] == e["socket"]

def determine_active_variables_using_annot_graph(graph, genotype):
    """
    Given an annotated cx graph (e.g. generated by get_annotated_cx_connectivity_graph), determine the set of
    variables which are active - i.e. those which upon a change without changes to other variables will affect
    the output of the network.
    """
    graph_b = graph.copy()
    graph_b.delete_edges((~np.array([davuag_process_edge(graph, genotype, e) for e in graph.es])).nonzero()[0])
    active_vertices = graph_b.subcomponent(np.array([a == -1 for a in graph_b.vs["module"]]).nonzero()[0][0], mode='in')
    active_variables = np.unique([int(a) for a in graph_b.vs[active_vertices]["genotype_index"] if a is not None and np.isfinite(a)])
    return active_variables

In [18]:
cx_net, cx_net_info = torch.load(base_path / "stitched-imagenet-a-resnet152-b-efficientnet-b4.th")
mo, out_switch = compute_cx_module_ordering_and_out_switch(cx_net, cx_net_info)
embed_genotype_indices(cx_net, mo)
cx_graph = get_annotated_cx_connectivity_graph(cx_net)
cx_graph.write_graphmlz("stitched-imagenet-a-resnet152-b-efficientnet-b4-stitchgraph.graphmlz")

In [19]:
cx_net, cx_net_info = torch.load(base_path / "stitched-imagenet-b-a-resnet50-b-resnext50_32x4d.th")
mo, out_switch = compute_cx_module_ordering_and_out_switch(cx_net, cx_net_info)
embed_genotype_indices(cx_net, mo)
cx_graph = get_annotated_cx_connectivity_graph(cx_net)
cx_graph.write_graphmlz("stitched-imagenet-b-a-resnet50-b-resnext50_32x4d-stitchgraph.graphmlz")

In [20]:
cx_net, cx_net_info = torch.load(base_path / "stitched-voc-a-deeplab-mobilenetv3-b-deeplab-resnet50.th")
mo, out_switch = compute_cx_module_ordering_and_out_switch(cx_net, cx_net_info)
embed_genotype_indices(cx_net, mo)
cx_graph = get_annotated_cx_connectivity_graph(cx_net)
cx_graph.write_graphmlz("stitched-voc-a-deeplab-mobilenetv3-b-deeplab-resnet50-stitchgraph.graphmlz")