 DAEDALUS – Distributed and Automated Evolutionary Deep Architecture Learning with Unprecedented Scalability

This research code was developed as part of the research programme Open Technology Programme with project number 18373, which was financed by the Dutch Research Council (NWO), Elekta, and Ortec Logiqcare.

Project leaders: Peter A.N. Bosman, Tanja Alderliesten
Researchers: Alex Chebykin, Arthur Guijt, Vangelis Kostoulas
Main code developer: Arthur Guijt

# Stitching Segmentation Models

Quoting from "Rethinking Atrous Convolution for Semantic Image Segmentation":
> We evaluate the proposed models on the PASCAL VOC 2012 semantic segmentation benchmark [ 20] which contains 20 foreground object classes and one background class. The original dataset contains 1, 464 (train), 1, 449 (val), and
1, 456 (test) pixel-level labeled images for training, validation, and testing, respectively. The dataset is augmented by
the extra annotations provided by [ 29 ], resulting in 10, 582 (trainaug) training images. The performance is measured in
terms of pixel intersection-over-union (IOU) averaged across the 21 classes.

So the VOC dataset may be used.

Furthermore, the pretrained models seem to be trained on COCO-using-VOC-labels, we might want to figure that out, too.

Alternative: MONAI for Medical Decathlon?

In [None]:
import recomb.problems as problems
problem = problems.VOCSegmentationProblem(root="<add-dataset-folder>")
# problem.ensure_downloaded()

In [None]:
import torch
import torchvision
import torchvision.datasets as thd
import torchvision.models.segmentation as segmentation_models
import recomb.layers as ly
import igraph as ig
import matplotlib.pyplot as plt
import polars as pl

from pathlib import Path

import recomb.cx as cx
import recomb.layers as ly
import recomb.problems as problems
from recomb.cx import forward_get_all_feature_maps, construct_trained_cx_network_stitching

In [None]:
# thd.CocoDetection(".", "")

In [None]:
# trsf = segmentation_models.DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1.transforms()
from torch.utils.data import DataLoader
# d_train, _, _ = problem.load_problem_dataset(transform_train=trsf, transform_validation=trsf)
d_train = problem.get_dataset_train()
# d_train.transform = trsf
dl_train = DataLoader(d_train)
it_train = iter(dl_train)
for _ in range(5):
    X, Y = next(it_train)

In [None]:
in_shape = X.shape
in_shape

In [None]:
out_shape = Y.shape
out_shape

In [None]:
# Load pretrained with weights COCO_WITH_VOC_LABELS_V1
model_a = segmentation_models.deeplabv3_mobilenet_v3_large(
    weights=torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
)
model_a.eval()

model_b = segmentation_models.deeplabv3_resnet50(
    weights=torchvision.models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT
)
model_b.eval()

In [None]:
gca = ly.trace_network(model_a, in_shape, verbose=True).to_neural_net_graph()

In [None]:
from IPython.display import SVG, Image
with open("g.dot", "w") as f:
    gca.to_dot(f, include_ord_label=True)
! dot g.dot -Tpng -og.png
! dot g.dot -Tsvg -og.svg
# Image("g.png")
SVG("g.svg")

In [None]:
d_train = problem.get_dataset_train()
# d_train.transform = trsf
dl_train = DataLoader(d_train)
it_train = iter(dl_train)
next(it_train)

In [None]:
from torch.utils.data import DataLoader
# d_train, _, _ = problem.load_problem_dataset(transform_train=trsf, transform_validation=trsf)
d_train = problem.get_dataset_train()
# d_train.transform = trsf
dl_train = DataLoader(d_train)
it_train = iter(dl_train)
for _ in range(25):
    X, Y = next(it_train)

In [None]:
origout = model_a(X)

In [None]:
gca.verbose = True
gcaout = gca(X)

In [None]:
Y.shape

In [None]:
torch.sum(origout["out"][0, ...].detach()- gcaout["out"][0, ...].detach())

In [None]:
def subst_255_nan(t):
    return t.where(t != 255, torch.nan)

In [None]:
fig, axs = plt.subplots(1, 4)
axs[0].imshow(X[0, ...].permute(1, 2, 0))
axs[1].imshow(subst_255_nan(Y[0, 0, ...]))
axs[2].imshow(origout["out"][0, ...].detach().argmax(dim=0))
axs[3].imshow(gcaout["out"][0, ...].detach().argmax(dim=0))

In [None]:
import plotly.express as plye
plye.imshow(origout["out"][0, ...].detach().argmax(dim=0).numpy())

In [None]:
origout["out"] - gcaout["out"]

In [None]:
origout["aux"] - gcaout["aux"]

In [None]:
gcb = ly.trace_network(model_b, in_shape).to_neural_net_graph()

In [None]:
from IPython.display import SVG
with open("g.dot", "w") as f:
    gcb.to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")

In [None]:
import recomb.problems as problems
import importlib
importlib.reload(problems)
problem = problems.VOCSegmentationProblem(root="<add-dataset-folder>", batched_validation=True)
# problem.ensure_downloaded()

In [None]:
NeuralNetIndividual = problems.NeuralNetIndividual
o_a = NeuralNetIndividual(model_a)
neti_b = NeuralNetIndividual(gcb)
o_b = NeuralNetIndividual(model_b)
neti_b = NeuralNetIndividual(gcb)

In [None]:
dev = torch.device('cuda')

In [None]:
import os
os.environ["RECOMB_NUM_DATALOADER_WORKERS"] = "2"

In [None]:
problem.evaluate_network(dev, o_a, batch_size=16, objective="both")

In [None]:
problem.evaluate_network(dev, neti_a, batch_size=16, objective="both")

In [None]:
problem.evaluate_network(dev, o_b, batch_size=16, objective="both")

In [None]:
problem.evaluate_network(dev, neti_b, batch_size=16, objective="both")

Does ensembling provide identical results too?

In [None]:
# # Assuming dictionary output

# class LinearDictAggregate(ly.ModuleT):
#     def __init__(self, ws):
#         super().__init__()
        
#         self.ws = torch.nn.Parameter(torch.tensor(ws).reshape(-1), requires_grad=False)

#     def get_reconstructor(self):
#         ws = self.ws
#         return lambda: LinearDictAggregate(ws)
    
#     def forward_key(self, x, k):
        
#         assert len(x) == len(self.ws)
#         stacked = torch.stack([xv[k] for xv in x], dim=0)
#         return torch.sum(stacked * self.ws.reshape([-1 if x == 0 else 1 for x in range(len(stacked.shape))]), dim=0)

#     def forward(self, *x):
#         # Note - assumes set of keys is always identical.
#         return {
#             k: self.forward_key(x, k)
#             for k in x[0].keys()
#         }

# class LinearDictEnsemble(ly.ModuleT):
#     def __init__(self, submodules, ws):
#         super().__init__()
#         self.submodules = torch.nn.ModuleList(submodules)
#         self.ws = ws

#     def get_reconstructor(self):
#         ws = self.ws
#         submodules_reconstructor = [m.get_reconstructor() for m in self.submodules]
#         return lambda: LinearDictEnsemble([m() for m in submodules_reconstructor], ws)
    
#     def forward(self, x):
#         r = None
#         for sm, w in zip(self.submodules, self.ws):
#             o = sm(x)
#             if r is None:
#                 r = {k: o[k] * w for k in o.keys()}
#             else:
#                 for k in o.keys():
#                     r[k] += o[k] * w
#         return r

#     def to_subgraph(self, gc: ly.GraphConstructor, feature_inputs):
#         agg = LinearDictAggregate(self.ws)
#         agg_inputs = [(i, sm.to_subgraph(gc, feature_inputs)) for i, sm in enumerate(self.submodules)]
#         out = agg.to_subgraph(gc, agg_inputs)
#         return out

In [None]:
net_be = ly.LinearDictEnsemble([model_a, model_b], [0.5, 0.5])
neti_be = NeuralNetIndividual(net_be)
torch.manual_seed(42)
problem.evaluate_network(dev, neti_be, batch_size=16, objective="both")

In [None]:
from IPython.display import SVG
net_be = ly.LinearDictEnsemble([gca, gcb], [0.5, 0.5]).to_graph()
with open("g.dot", "w") as f:
    net_be.to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")

In [None]:
net_be = ly.LinearDictEnsemble([gca, gcb], [0.5, 0.5]).to_graph()
neti_be = NeuralNetIndividual(net_be)
torch.manual_seed(42)
problem.evaluate_network(dev, neti_be, batch_size=16, objective="both")

In [None]:
# import importlib 
import recomb.cx as cx

In [None]:
# We use the training dataset to avoid training on the validation data
# importlib.reload(problems)
dataset = problem.get_dataset_train()

# Grab an item from a dataloader for use in the forward pass
from torch.utils.data import DataLoader
dl = DataLoader(dataset)
dli = iter(dl)
X, _y = next(dli)

for p in gca.parameters():
    p.requires_grad_(False)
for p in gcb.parameters():
    p.requires_grad_(False)

In [None]:
# importlib.reload(cx)

In [None]:
batch_size = 4
stitched = cx.construct_trained_cx_network_stitching(
    dataset=dataset,
    dev=dev,
    net_a=gca,
    net_b=gcb,
    X_in_many=X,
    ensemblers=[ly.LinearDictAggregate([0.5, 0.5])],
    compute_similarity=cx.compute_mock_similarity,
    feature_shape_should_match=False,
    batch_size=batch_size,

    pretrain_cx_network = False
)

In [None]:
len(stitched[1].joiners)

In [None]:
from IPython.display import SVG
with open("g.dot", "w") as f:
    stitched[0].to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")

In [None]:
# importlib.reload(problems)
problem = problems.VOCSegmentationProblem(root="<add-dataset-folder>")
# problem.ensure_downloaded()
dataset = problem.get_dataset_train()

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
gca.cpu()
gcb.cpu()

In [None]:
# torch.autograd.set_detect_anomaly(False)
batch_size = 4 # note - 16 is too large for this network
# would need to use checkpointing.

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="./logs/train-stitchseg-redo")

stitched = cx.construct_trained_cx_network_stitching(
    dataset=dataset,
    dev=dev,
    net_a=gca,
    net_b=gcb,
    X_in_many=X,
    ensemblers=[ly.LinearDictAggregate([0.5, 0.5])],
    compute_similarity=cx.compute_mock_similarity,
    feature_shape_should_match=False,
    num_samples_pretrain = 128, # 16384,
    batch_size=batch_size,

    pretrain_cx_network = True,
    immediately_backprop = False,
)

In [None]:
16384/4

In [None]:
for es in stitched[1].joiners:
    for e in es:
        e.agg = None
stitched[1].output_switch.agg = None

In [None]:
stitched_simpl = (stitched[0], cx.SitchingInfo(stitched[1].joiners, stitched[1].output_switch))

In [None]:
torch.save(stitched_simpl, "stitched-seg.th")

In [None]:
stitched_simpl = torch.load("stitched-seg.th")

In [None]:
stitchnet, stitchinfo = stitched_simpl

In [None]:
from IPython.display import SVG
with open("g.dot", "w") as f:
    stitchnet.to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")

In [None]:
def get_cx_connectivity_graph(cx_net: ly.NeuralNetGraph):
    g = cx_net.graph.copy()
    o = g.topological_sorting()
    idxs_to_remove = []
    edges_to_add = []
    for i in o:
        vi = g.vs[i]

        is_edge_case = False
        if len(vi.in_edges()) == 0:
            # edge case - input node
            vi["cxs"] = set([vi.index])
            is_edge_case = True
        
        if len(vi.out_edges())  == 0 and vi["module"] < 0:
            # edge case - output node
            vi["cxs"] = set([vi.index])
            is_edge_case = True
        
        is_cxn = isinstance(cx_net.submodules[vi["module"]], cx.CXN)
        cxs_in = set()
        # i.e., where does this CXN link to via what node?
        affinity_mappings = {}
        for e in vi.in_edges():
            cxs_in.update(g.vs[e.source]["cxs"])
            if is_cxn:
                affinity_set = affinity_mappings.get(e["socket"], set())
                affinity_set.update(g.vs[e.source]["cxs"])
                affinity_mappings[e["socket"]] = affinity_set
        if is_cxn:
            edges_to_add += [(s, i, socket) for (socket, cxn_set) in affinity_mappings.items() for s in cxn_set]
            vi["cxs"] = set([i])
        elif not is_edge_case:
            vi["cxs"] = cxs_in
            idxs_to_remove.append(i)

    g.add_edges([t[:2] for t in edges_to_add], attributes=dict(socket=[t[2] for t in edges_to_add]))
    g.delete_vertices(idxs_to_remove)
    return g

In [None]:
cxg = get_cx_connectivity_graph(stitchnet)
fig, ax = plt.subplots()
graph_layout = cxg.layout("sugiyama")
graph_layout.rotate(-90)
ig.plot(cxg, target=ax, layout=graph_layout)

In [None]:
def compute_primary_line_membership(cx_graph, verbose=False):
    """
    Recover for each vertex in the crossover point graph to which networks they originally belonged
    assuming that input 0 to each crossover point maintains the original graph.

    (Note that this method may be skipped by tracking the original origin during stitching and assigning
     membership accordingly.)
    """
    # Initialize graph membership to each on their own
    cx_graph.vs["og"] = range(len(cx_graph.vs))
    # For input & output add a placeholder
    cx_graph.vs[0]["og"] = -1
    cx_graph.vs[1]["og"] = -1
    
    # Go over the graph in a topological order
    ordering = cx_graph.topological_sorting()
    for o in ordering:
        v = cx_graph.vs[o]
        # skip over input & output nodes
        if v["og"] == -1: continue
        # Loop over the elements with a similar affinity set and get their corresponding assignment.
        new_og = v["og"]
        same_origin_nodes = [e.source for e in v.in_edges() if e["socket"] == 0]
        if verbose: print(f"same origin: {same_origin_nodes}")
        # Merge identities in union-find like structure.
        for set_elem in same_origin_nodes:
            # print(f"visiting {set_elem}")
            v_other = cx_graph.vs[set_elem]
            if verbose: print(f"incident edge og is {v_other['og']}")
            if v_other["og"] == -1: continue
            new_og = min(new_og, v_other["og"])
        if verbose: print(f"og was {v['og']} should update to {new_og}")
        v["og"] = new_og
        for set_elem in same_origin_nodes:
            # print(f"updating {set_elem}")
            v_other = cx_graph.vs[set_elem]
            if v_other["og"] == -1: continue
            v_id = cx_graph.vs[v_other["og"]]
            v_id["og"] = new_og
        if verbose: print(f"og is now {v['og']} updated to {new_og}")
    # iterate union find for each element in the graph.
    for o in ordering:
        v = cx_graph.vs[o]
        # if special case or identical, skip
        if v["og"] == -1: continue
        if v["og"] == o: continue
        # otherwise track down the first identical element
        og = v["og"]
        while True:
            v_potential_og = cx_graph.vs[og]
            # found it?
            if v_potential_og["og"] == -1: continue
            if v_potential_og["og"] == og: break
            # otherwise continue down the line
            og = v_potential_og["og"]
        # go down the line again, updating the og value accordingly.
        l = v["og"]
        v["og"] = og
        while True:
            v_other = cx_graph.vs[l]
            if v_other["og"] == -1: continue
            l = v_other["og"]
            v_other["og"] = og
            if v_other["og"] == l: break
    return cx_graph


In [None]:
compute_primary_line_membership(cxg);

In [None]:
cxg.vs["og"]

In [None]:
def compute_parallel_set(cxg, i):
    s = set(range(len(cxg.vs))) 
    s -= set(cxg.subcomponent(i, mode='out'))
    s -= set(cxg.subcomponent(i, mode='in'))
    # s.add(i)
    return s

def compute_all_parallel_set(cxg):
    return [compute_parallel_set(cxg, i) for i in range(len(cxg.vs))]

In [None]:
parallel_cxs = compute_all_parallel_set(cxg)

In [None]:
from copy import copy
from functools import partial
def enumerate_parallel_set(g, set_idx, parallel_sets, i, verbose=False):
    def call_funcs(lfn):
        for fn in lfn:
            fn()

    for (set_list, restore_list) in enumerate_parallel_set_recur(g, set_idx, parallel_sets, i, None, set(), verbose=verbose):
        yield (lambda: call_funcs(set_list)), (lambda: call_funcs(restore_list))

def enumerate_parallel_set_recur(g, set_idx, parallel_sets, i, current_set=None, unpickable: set=set(), ref_og=None, verbose=False):
    if current_set is None:
        # Initial case - current set is the parallel set of the index we are starting with.
        current_set = copy(parallel_sets[i])
        ref_og = g.vs[i]["og"]
        # filter current set based on a match
        current_set = {a for a in current_set if g.vs[a]["og"] == ref_og}
    else:
        # Otherwise, update the set of uncovered elements.
        current_set = current_set.intersection(parallel_sets[i])

    if len(current_set) == 0:
        if verbose: print(f"base case - no other choices necessary after picking {i}")
        # yield setter for configuring and unconfiguring i. No other configurations necessary
        # as there are no other branches.
        yield [partial(set_idx, i, 1)], [partial(set_idx, i, 0)]
        # return - as there are no more elements in the neighborhood.
        return

    # Obtain a fixed ordering of the set of leftover elements to be picked.
    ordering = list(current_set - unpickable)
    
    # Find current reverse cumulative intersection.
    # The intersection of sets picked so far provides knowledge of elements that may need
    # to be picked to cover all branches.
    # If we perform this operation cumulatively from the right the elements left in the
    # set allow us to identify necessary picks.
    # if we have the set with fixed ordering [ 1, 2, 3]
    # and the set corresponding here are 1 -> {0, 2, 3}, 2 -> {0, 1, 3}, 3 -> {0, 1, 2}
    # (note: the index itself is never contained within its own parallel set)
    # In this case the sequence of sets would be
    # [{}, {1}, {1, 2}]
    # as the only set that does not contain {1} is the set corresponding to {1}, 1 must be picked.
    cumulative_sets_rl = [None for _ in range(len(ordering))]
    cumulative_sets_rl[-1] = current_set.intersection(parallel_sets[ordering[-1]])
    required_right = {ordering[-1]}
    for i in range(len(ordering) - 1, 0, -1):
        el = ordering[i - 1]
        cumulative_sets_rl[i - 1] = cumulative_sets_rl[i].intersection(parallel_sets[el])
        # note - if an ordering[i - 1] is in cumulative_sets_rl[i], el needs to be picked if we do not
        # pick any of the preceding elements as there are no further elements to cover this branch.
        if ordering[i - 1] in cumulative_sets_rl[i]:
            required_right.add(el)
    # If we do it the other direction we can do the same thing for any following elements.
    cumulative_sets_lr = [None for _ in range(len(ordering))]
    cumulative_sets_lr[0] = current_set.intersection(parallel_sets[ordering[0]])
    required_left = {ordering[0]}
    for i in range(0, len(ordering) - 1):
        el = ordering[i + 1]
        cumulative_sets_lr[i + 1] = cumulative_sets_lr[i].intersection(parallel_sets[el])
        # similar reasoning - if we pick none of the elements after this one, there would be
        # a uncovered branch
        if ordering[i + 1] in cumulative_sets_lr[i]:
            required_left.add(el)
    # Elements that are in both required sets are always to be taken.
    always_required = required_left.intersection(required_right)

    # For future additions: - if one skips elements that have already been investigated previously (i.e., 
    # elsewhere in the ordering, another indicator is important to keep track of:
    # cumulative_sets_lr[-1] and cumulative_sets_rl[0] should always be empty sets - if they are not
    # there exists an element that is not optional that was excluded.
    # Probably shouldn't happen since we force always required, but just in case, handle this edge case.
    if len(cumulative_sets_lr[-1]) != 0 or len(cumulative_sets_rl[0]) != 0:
        if verbose: print("forbidden case - no choices cover all branches anymore...")
        return

    fixed_set = [partial(set_idx, i, 1)]
    fixed_restore = [partial(set_idx, i, 0)]
    
    if verbose: print(f"in this case to cover all branches {always_required} are required")
    for a in always_required:
        current_set.intersection_update(parallel_sets[a])
        fixed_set.append(partial(set_idx, a, 1))
        fixed_restore.append(partial(set_idx, a, 0))

    if len(current_set) == 0:
        if verbose: print(f"fixed case - no more free choices left to make after picking {i}")
        yield fixed_set, fixed_restore
    else:
        if verbose: print(f"recursive case for {i}")
        for e in current_set:
            # consider the cases where we pick it
            
            print(f"considering picking {e}")
            for (set_list, restore_list) in enumerate_parallel_set_recur(g, parallel_sets, e, current_set, unpickable=unpickable, ref_og=ref_og):
                yield (fixed_set + set_list), (fixed_restore + restore_list)
            print(f"no longer considering picking {e}")
            # now - for the following picks consider the case where we not allow e to be picked anymore.
            unpickable.add(i)
        # to avoid issues with branching allow picking these elements again if another branch investigates them.
        for e in current_set:
            unpickable.remove(i)


In [None]:
def set_idx(i, to):
    module_idx = cxg.vs[i]['module']
    switch = stitchnet.submodules[module_idx]
    switch.active = to
    print(f"submodule {module_idx} ({type(switch)}) active set to {to}")


for set_list, restore_list in enumerate_parallel_set(cxg, set_idx, parallel_cxs, 16):
    set_list()
    print("----------------------")
    restore_list()

In [None]:
import plotly.graph_objects as go
import numpy as np

# Transform graph layout to coords
layout_coords = np.array(graph_layout)
vertex_coords = np.array([layout_coords[v.index] for v in cxg.vs])
# edge_coords = [[layout_coords[e.source, :], layout_coords[e.target, :], [np.nan, np.nan]] for e in cxg.es]
edge_coords = [[
    layout_coords[e.source, :],
    layout_coords[e.target, :],
    [np.nan, np.nan],
    ]
    for e in cxg.es]
edge_coords = np.array(edge_coords).reshape(-1, 2)

# edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=lineWidth, color=lineColor), hoverinfo='none', mode='lines')
edge_trace = go.Scatter(
    x=edge_coords[:, 0],
    y=edge_coords[:, 1],
    hoverinfo='none',
    mode='lines+markers',
    marker=dict(
        size=10,
        symbol="arrow",
        angleref="previous"
        )
)

# node_trace = go.Scatter(x=node_x, y=node_y, mode='markers', hoverinfo='text', marker=dict(showscale=False, color = nodeColor, size=nodeSize))
node_trace = go.Scatter(
    x=vertex_coords[:, 0],
    y=vertex_coords[:, 1],
    mode='markers',
    hoverinfo='text',
    marker=dict(showscale=False)
)

fig = go.Figure(data=[edge_trace, node_trace],
             layout=go.Layout(
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig

In [None]:
import plotly.graph_objects as go
import numpy as np

graph_layout, cxge = cxg.layout_sugiyama(return_extended_graph=True)
# Transform graph layout to coords
layout_coords = np.array(graph_layout)
vertex_coords = np.array([layout_coords[v.index] for v in cxg.vs])
# edge_coords = [[layout_coords[e.source, :], layout_coords[e.target, :], [np.nan, np.nan]] for e in cxg.es]
edge_coords = [[
    layout_coords[e.source, :],
    layout_coords[e.target, :],
    [np.nan, np.nan],
    ]
    for e in cxge.es]
edge_coords = np.array(edge_coords).reshape(-1, 2)
is_arrow_end = np.array([[0, 0, 0] if edge.target >= num_orig_nodes else [0, 10, 0] for edge in cxge.es]).ravel()

num_orig_nodes = len(cxg.vs)
# edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=lineWidth, color=lineColor), hoverinfo='none', mode='lines')
edge_trace = go.Scatter(
    # note: coords are transposed.
    x=edge_coords[:, 1],
    y=edge_coords[:, 0],
    hoverinfo='none',
    mode='lines+markers',
    marker=dict(
        size=10,
        angleref="previous",
        symbol="arrow",
        ),
    marker_size = is_arrow_end,
)

def color_table(ogv):
    if ogv == 2: return 'red'
    if ogv == 3: return 'blue'
    return 'white'

# node_trace = go.Scatter(x=node_x, y=node_y, mode='markers', hoverinfo='text', marker=dict(showscale=False, color = nodeColor, size=nodeSize))
node_trace = go.Scatter(
    # note: coords are transposed.
    x=vertex_coords[:, 1],
    y=vertex_coords[:, 0],
    marker_color = [color_table(og) for og in cxg.vs["og"]],
    mode='markers',
    # hoverinfo='text',
    marker=dict(showscale=False)
)

fig = go.Figure(data=[edge_trace, node_trace],
             layout=go.Layout(
                showlegend=False,
                # hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig

In [None]:
is_arrow_end

In [None]:
stitchnet.eval()
stitchnet(X)

In [None]:
# Embed computational cost info
import torchinfo
import recomb.eval_costs as ec
cost_summary = torchinfo.summary(stitchnet, input_data=[X])
ec.embed_cost_stats_in_model(cost_summary)

In [None]:
print(("The stitched neural network has "
       f"{len(stitchinfo.joiners)} matches"
       ))

## Evaluate predetermined neighborhood

In [None]:
import gc
stitchnet.cpu()
gc.collect()
torch.cuda.empty_cache()

In [None]:
# this network will cost  70200072306 mul-adds and approx 3032042904
# this network will cost 180146119266 mul-adds and approx 3761352936 of memory


In [None]:
from IPython.display import SVG
with open("g.dot", "w") as f:
    stitchnet.to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")


total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet)
print(f"this network will cost {total_mult_adds} mul-adds and approx {total_bytes} of memory")


In [None]:
stitchinfo.output_switch.active = 2
stitchinfo.output_switch.simplify = True

for j in stitchinfo.joiners:
        j[0].active = 0
        j[1].active = 0
        j[0].simplify = True
        j[1].simplify = True

dev = torch.device("cuda:1")

stitchnet_pruned = stitchnet.to_graph()
stitchnet_pruned.prune_unused()

total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)
print(f"this network will cost {total_mult_adds} mul-adds and approx {total_bytes} of memory")


from IPython.display import SVG
with open("g.dot", "w") as f:
    stitchnet_pruned.to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")

In [None]:
problem = problems.VOCSegmentationProblem("<add-dataset-folder>", batched_validation=True, validation_sample_limit=1000)

In [None]:
dev = torch.device("cuda:1")
# Evaluate reference networks
reference_q = []
stitchinfo.output_switch.active = 2
stitchinfo.output_switch.simplify = True

for j in stitchinfo.joiners:
        j[0].active = 0
        j[1].active = 0
        j[0].simplify = True
        j[1].simplify = True

batch_size = 16
stitchnet_pruned = stitchnet.to_graph()
stitchnet_pruned.prune_unused()
total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)
print(f"full ensemble will cost {total_mult_adds} mul-adds and approx {total_bytes} of memory")
neti_os = NeuralNetIndividual(stitchnet_pruned)
accuracy, loss = problem.evaluate_network(dev, neti_os, batch_size=batch_size, objective="both")
reference_q.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))

gc.collect()
torch.cuda.empty_cache()

stitchinfo.output_switch.active = 1
stitchnet_pruned = stitchnet.to_graph()
stitchnet_pruned.prune_unused()
total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

neti_os = NeuralNetIndividual(stitchnet_pruned)
accuracy, loss = problem.evaluate_network(dev, neti_os, batch_size=batch_size, objective="both")
reference_q.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))

gc.collect()
torch.cuda.empty_cache()

stitchinfo.output_switch.active = 0
stitchnet_pruned = stitchnet.to_graph()
stitchnet_pruned.prune_unused()
total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

neti_os = NeuralNetIndividual(stitchnet_pruned)
accuracy, loss = problem.evaluate_network(dev, neti_os, batch_size=batch_size, objective="both")
reference_q.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))

gc.collect()
torch.cuda.empty_cache()

In [None]:
ensembles_a = []
ensembles_b = []
start_a_end_b = []
start_b_end_a = []

# note - usually 1, but due to the large amount of matches, this has been
# increased so that we can evaluate blocks of solutions instead.
step = 10

In [None]:
stitchinfo.output_switch.active = 2
stitchinfo.output_switch.simplify = True

dev = torch.device("cuda:0")

for j in stitchinfo.joiners:
        j[0].active = 0
        j[1].active = 0
        j[0].simplify = True
        j[1].simplify = True

i = 0
j = stitchinfo.joiners[i]
j[0].active = 0
j[1].active = 1

stitchnet_pruned = stitchnet.to_graph()
stitchnet_pruned.prune_unused()

# Get compute & memory requirements
# s = torchinfo.summary(stitchnet_pruned, input_data=[X])
total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

neti_os = NeuralNetIndividual(stitchnet_pruned)


gc.collect()
torch.cuda.empty_cache()

j[0].active = 0
j[1].active = 0

from IPython.display import SVG
with open("g.dot", "w") as f:
    stitchnet_pruned.to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")

# accuracy, loss = problem.evaluate_network(dev, neti_os, batch_size=batch_size, objective="both")
# ensembles_a.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))


In [None]:
# Evaluate neighborhood of networks
def evaluate_neighborhood_given_offset(offset, batch_size=16):
    stitchinfo.output_switch.active = 2
    stitchinfo.output_switch.simplify = True

    dev = torch.device("cuda:0")

    for j in stitchinfo.joiners:
            j[0].active = 0
            j[1].active = 0
            j[0].simplify = True
            j[1].simplify = True

    for i in range(offset, len(stitchinfo.joiners), step):
        j = stitchinfo.joiners[i]
        j[0].active = 0
        j[1].active = 1

        stitchnet_pruned = stitchnet.to_graph()
        stitchnet_pruned.prune_unused()

        # Get compute & memory requirements
        # s = torchinfo.summary(stitchnet_pruned, input_data=[X])
        total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

        neti_os = NeuralNetIndividual(stitchnet_pruned)
        accuracy, loss = problem.evaluate_network(dev, neti_os, batch_size=batch_size, objective="both")
        ensembles_a.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))

        gc.collect()
        torch.cuda.empty_cache()
        
        j[0].active = 0
        j[1].active = 0

    for i in range(offset, len(stitchinfo.joiners), step):
        j = stitchinfo.joiners[i]
        j[0].active = 1
        j[1].active = 0

        stitchnet_pruned = stitchnet.to_graph()
        stitchnet_pruned.prune_unused()

        # Get compute & memory requirements
        # s = torchinfo.summary(stitchnet_pruned, input_data=[X])
        total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

        neti_os = NeuralNetIndividual(stitchnet_pruned)
        accuracy, loss = problem.evaluate_network(dev, neti_os, batch_size=batch_size, objective="both")
        ensembles_b.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))
        
        gc.collect()
        torch.cuda.empty_cache()

        j[0].active = 0
        j[1].active = 0

    stitchinfo.output_switch.active = 1
    for i in range(offset, len(stitchinfo.joiners), step):
        j = stitchinfo.joiners[i]
        j[0].active = 0
        j[1].active = 1

        stitchnet_pruned = stitchnet.to_graph()
        stitchnet_pruned.prune_unused()

        # Get compute & memory requirements
        # s = torchinfo.summary(stitchnet_pruned, input_data=[X])
        total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

        neti_os = NeuralNetIndividual(stitchnet_pruned)
        accuracy, loss = problem.evaluate_network(dev, neti_os, batch_size=batch_size, objective="both")
        start_a_end_b.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))

        gc.collect()
        torch.cuda.empty_cache()

        j[0].active = 0
        j[1].active = 0

    stitchinfo.output_switch.active = 0
    for i in range(offset, len(stitchinfo.joiners), step):
        j = stitchinfo.joiners[i]
        j[0].active = 1
        j[1].active = 0

        stitchnet_pruned = stitchnet.to_graph()
        stitchnet_pruned.prune_unused()

        # Get compute & memory requirements
        # s = torchinfo.summary(stitchnet_pruned, input_data=[X])
        total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

        neti_os = NeuralNetIndividual(stitchnet_pruned)
        accuracy, loss = problem.evaluate_network(dev, neti_os, batch_size=batch_size, objective="both")
        start_b_end_a.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))
        
        gc.collect()
        torch.cuda.empty_cache()
        
        j[0].active = 0
        j[1].active = 0

In [None]:
# Note - very little training performed above initially
evaluate_neighborhood_given_offset(0)

In [None]:
evaluate_neighborhood_given_offset(1)

In [None]:
evaluate_neighborhood_given_offset(2)

In [None]:
evaluate_neighborhood_given_offset(3)

In [None]:
evaluate_neighborhood_given_offset(4)

In [None]:
evaluate_neighborhood_given_offset(5)

In [None]:
evaluate_neighborhood_given_offset(6)

In [None]:
evaluate_neighborhood_given_offset(7)

In [None]:
evaluate_neighborhood_given_offset(8)

In [None]:
evaluate_neighborhood_given_offset(9)

In [None]:
import polars as pl

In [None]:
df_schema = ["accuracy", "loss", "total bytes", "multiply-adds", "genotype"]
samples_reference = pl.DataFrame(reference_q, schema=df_schema).\
    with_columns([
        pl.lit(pl.Series(["ensemble", "b", "a"])).alias("set"),
        pl.lit(False).alias("contains stitch"),
    ])
samples_ensemble_a = pl.DataFrame(ensembles_a, schema=df_schema).\
    with_columns([
        pl.lit("ensemble-major-a").alias("set"),
        pl.lit(True).alias("contains stitch"),
    ])
samples_ensemble_b = pl.DataFrame(ensembles_b, schema=df_schema).\
    with_columns([
        pl.lit("ensemble-major-b").alias("set"),
        pl.lit(True).alias("contains stitch"),
    ])
samples_ab = pl.DataFrame(start_a_end_b, schema=df_schema).\
    with_columns([
        pl.lit("stitch-a-to-b").alias("set"),
        pl.lit(True).alias("contains stitch"),
    ])
samples_ba = pl.DataFrame(start_b_end_a, schema=df_schema).\
    with_columns([
        pl.lit("stitch-b-to-a").alias("set"),
        pl.lit(True).alias("contains stitch"),
    ])

samples = pl.concat([
    samples_reference,
    samples_ensemble_a,
    samples_ensemble_b,
    samples_ab,
    samples_ba,
], how="vertical_relaxed").with_columns(
    pl.col("loss").clip(0.0, 4.0).alias("loss-clip")
)
samples.write_ipc("segmentation-stitch-samples.arrow")


## Plot approximation front

In [None]:
samples = pl.read_ipc("resnet-efficientnet-stitch-samples.arrow")

# Extract some rows of reference interest
dfcna = samples[2]
dfcnb = samples[1]
dfcnens = samples[0]

# 
improvement_direction = {
    "accuracy": 1,
    "loss": -1,
    "loss-clip": -1,
    "total bytes": -1,
    "multiply-adds": -1,
    # "genotype": 0, # -- not a criterion
}

In [None]:
len(stitchinfo.joiners) * 4 * 15 / 60 / 60

In [None]:
# # How many seconds per evaluated sample?
# number_of_minutes = 16 * 8
# number_of_seconds = 0
# number_of_samples = len(samples) - 3
# seconds_total = number_of_minutes * 60 + number_of_seconds
# seconds_per_sample = seconds_total / number_of_samples

# print(f"spent {number_of_minutes}m{number_of_seconds}s "
#       f"to evaluate {number_of_samples} samples.\n"
#       f"Resulting in a cost of {seconds_per_sample}s per sample.")

In [None]:
# Compute pareto set from these points, with respect to these
# two criteria / objectives
c0 = "accuracy"
c1 = "multiply-adds"

samples_pareto = (samples.lazy()
    .sort(c0, descending=improvement_direction[c0] > 0)
    .with_columns((pl.col(c1) * -improvement_direction[c1]).alias("c1-min"))
    .with_columns((pl.col("c1-min")).cummin().alias("mv"))
    .with_columns((pl.col("c1-min") < pl.col("mv").shift(1)).alias("is pareto")).fill_null(True)
    .filter(pl.col("is pareto"))
).collect()

samples_pareto_stitch_only = (samples.lazy()
    .filter(pl.col("contains stitch"))
    .sort(c0, descending=improvement_direction[c0] > 0)
    .with_columns((pl.col(c1) * -improvement_direction[c1]).alias("c1-min"))
    .with_columns((pl.col("c1-min")).cummin().alias("mv"))
    .with_columns((pl.col("c1-min") < pl.col("mv").shift(1)).alias("is pareto")).fill_null(True)
    .filter(pl.col("is pareto"))
).collect()

In [None]:
for sn, df in samples.filter(pl.col("contains stitch")).group_by("set", maintain_order=True):
    plt.scatter(df[c0], df[c1], label=sn, s=1.0)

plt.scatter(samples_pareto[c0], samples_pareto[c1], alpha=0.4, marker="s", color="grey")
plt.scatter(samples_pareto_stitch_only[c0], samples_pareto_stitch_only[c1], s=20.0, alpha=0.5, color="grey")

plt.scatter(dfcna[c0], dfcna[c1], label="a", marker='x')
plt.scatter(dfcnb[c0], dfcnb[c1], label="b", marker='x')
plt.scatter(dfcnens[c0], dfcnens[c1], label="ensemble", marker='x')

def get_direction_arrow(c):
    return '->' if improvement_direction[c] > 0 else '<-'

plt.xlabel(f"{c0} ({get_direction_arrow(c0)})")
plt.ylabel(f"{c1} ({get_direction_arrow(c1)})")
plt.legend(loc='upper left',
           bbox_to_anchor=(1.0, 1.0),
           fancybox=False,
           shadow=True)

**Potential Points of Improvement?**
1. Pretrain for longer? (e.g. specific stopping condition?)
2. Train using actual loss function.

In [None]:
plt.hist2d(samples[c0], samples[c1])
plt.colorbar()

### Some test evaluations

In [None]:
dev2 = torch.device("cuda:1")

In [None]:
# Evaluate original networks
neti_a = NeuralNetIndividual(gca)
neti_b = NeuralNetIndividual(gcb)
problem.evaluate_network(dev2, neti_a, objective="both"),\
    problem.evaluate_network(dev, neti_b, objective="both")

In [None]:
stitchneti = NeuralNetIndividual(stitchnet)
for j in stitchinfo.joiners:
    j[0].active = 0
    j[1].active = 0

stitchinfo.output_switch.active = 0
roa = problem.evaluate_network(dev, stitchneti, objective="both")
stitchinfo.output_switch.active = 1
rob = problem.evaluate_network(dev, stitchneti, objective="both")
roa, rob

In [None]:
stitchneti = NeuralNetIndividual(stitchnet)
for j in stitchinfo.joiners:
    j[0].active = 0
    j[1].active = 0
stitchinfo.output_switch.active = 2
j = stitchinfo.joiners[18]
# j[0].active = 0
# j[1].active = 1
j[0].active = 1
j[1].active = 0

problem.evaluate_network(dev, stitchneti, objective="both")

In [None]:
import gc


gc.collect()
torch.cuda.empty_cache()