In [None]:
import torch
import torchvision.datasets as thd
import timm
import recomb.layers as ly
import igraph as ig
import matplotlib.pyplot as plt
import polars as pl

from pathlib import Path

import recomb.cx as cx
import recomb.layers as ly
import recomb.problems as problems
from recomb.cx import forward_get_all_feature_maps, construct_trained_cx_network_stitching

In [None]:
# Note: imagenet networks are often image-size invariant.
# We need to account for this, somehow, when stitching.
# (Because this means the image shape does not need to be constrained)
imgnet_in_shape = (1, 3, 256, 256)
# imgnet_train = thd.ImageNet("<add-dataset-folder>")
dataset_path = Path("<add-dataset-folder>")
# Some allowable batch size (just in case)
batch_size = 32

In [None]:
timm.list_models("efficientnet*")

In [None]:
model_a = timm.create_model("resnet152", pretrained=True)
model_a.eval()
gca = ly.trace_network(model_a, imgnet_in_shape).to_neural_net_graph()

In [None]:
from IPython.display import SVG
with open("g.dot", "w") as f:
    gca.to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")

In [None]:
# model_b = timm.create_model("vgg13", pretrained=True)
model_b = timm.create_model("efficientnet_b4", pretrained=True)
model_b.eval()
gcb = ly.trace_network(model_b, imgnet_in_shape).to_neural_net_graph()

In [None]:
from IPython.display import SVG
with open("g.dot", "w") as f:
    gcb.to_dot(f, include_ord_label=True)
! dot g.dot -Tsvg -og.png
SVG("g.png")

In [None]:
from torchvision.transforms import v2 as transforms
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

def get_transform_for_timm_model(model):
    transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
    return transform

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
tf = transforms.Compose([
    transforms.ToImage(),
    transforms.Resize(235, antialias=True),
    transforms.CenterCrop(224),
    transforms.ToDtype(torch.float32, scale=True),
    normalize,
])
# tf = get_transform_for_timm_model(model_a)

imagenet_train = problems.ImageNetKG("<add-dataset-folder>", "train", transform=tf)

In [None]:
from torch.utils.data import DataLoader
dl = DataLoader(imagenet_train)
dli = iter(dl)
for _ in range(5):
    X_o, y_o = next(dli)

In [None]:
plt.imshow(X_o[0, :, :, :].permute(1, 2, 0))


In [None]:
outpt = torch.topk(model_a(X_o), k=5)
list(imagenet_train.class_name(outpt.indices.ravel().numpy())), list(imagenet_train.class_name(y_o.item()))

In [None]:
outpt = torch.topk(gca(X_o), k=5)
list(imagenet_train.class_name(outpt.indices.ravel().numpy())), list(imagenet_train.class_name(y_o.item()))

In [None]:
outpt = torch.topk(model_b(X_o), k=5)
list(imagenet_train.class_name(outpt.indices.ravel().numpy())), list(imagenet_train.class_name(y_o.item()))

In [None]:
outpt = torch.topk(gcb(X_o), k=5)
list(imagenet_train.class_name(outpt.indices.ravel().numpy())), list(imagenet_train.class_name(y_o.item()))

In [None]:
from recomb.problems import NeuralNetIndividual, ImageNetProblem

In [None]:
neti_a = NeuralNetIndividual(gca)
neti_b = NeuralNetIndividual(gcb)

In [None]:
dev = torch.device('cuda')

In [None]:
# Set validation sample limit to 1000.
# There are 23000-ish samples in the full validation set.
# As evaluating 1000 samples takes ~15s per # network, 
# 23000-ish samples should take approximately 345s each.
problem = ImageNetProblem("<add-dataset-folder>", validation_sample_limit=1000)

In [None]:
problem.evaluate_network(dev, neti_a, batch_size=16, objective="both")

In [None]:
problem.evaluate_network(dev, neti_b, batch_size=16, objective="loss")

Does ensembling provide identical results too?

In [None]:
net_be = ly.LinearEnsemble([model_a, model_b], [0.5, 0.5])
neti_be = NeuralNetIndividual(net_be)
torch.manual_seed(42)
problem.evaluate_network(dev, neti_be, batch_size=16, objective="loss")

In [None]:
net_be = ly.LinearEnsemble([gca, gcb], [0.5, 0.5]).to_graph()
neti_be = NeuralNetIndividual(net_be)
torch.manual_seed(42)
problem.evaluate_network(dev, neti_be, batch_size=16, objective="loss")

In [None]:
net_be = ly.LinearEnsemble([model_a, model_b], [1.0, 1.0])
neti_be = NeuralNetIndividual(net_be)
torch.manual_seed(42)
problem.evaluate_network(dev, neti_be, batch_size=16, objective="loss")

In [None]:
net_be = ly.LinearEnsemble([gca, gcb], [1.0, 1.0]).to_graph()
neti_be = NeuralNetIndividual(net_be)
torch.manual_seed(42)
problem.evaluate_network(dev, neti_be, batch_size=16, objective="loss")

In [None]:
# import importlib 
import recomb.cx as cx

In [None]:
# We use the training dataset to avoid training on the validation data
dataset = problem.get_dataset_train()

# Grab an item from a dataloader for use in the forward pass
from torch.utils.data import DataLoader
dl = DataLoader(dataset)
dli = iter(dl)
X, _y = next(dli)

for p in gca.parameters():
    p.requires_grad_(False)
for p in gcb.parameters():
    p.requires_grad_(False)

In [None]:
import importlib
importlib.reload(cx)


In [None]:
def compute_S(net_a, net_b, stitching_library, compute_similarity):
    net_a.store_state_eval()
    net_b.store_state_eval()
    use_gpu = True 
    if use_gpu:
        net_b.to(dev)
        net_a.to(dev)
        X_in_many = X.to(dev)
    # Set up reference networks.
    fms_a, points_a = forward_get_all_feature_maps(net_a, X_in_many, return_points=True)
    fms_b, points_b = forward_get_all_feature_maps(net_b, X_in_many, return_points=True)

    net_a.train_restore()
    net_b.train_restore()

    points_a_v = [p for (p, fm) in zip(points_a, fms_a) if fm is not None]
    fms_a_v = [fm for fm in fms_a if fm is not None]
    points_b_v = [p for (p, fm) in zip(points_b, fms_b) if fm is not None]
    fms_b_v = [fm for fm in fms_b if fm is not None]

    # Characterize graphs
    stitching_library.characterize_graph(net_a.graph)
    stitching_library.characterize_graph(net_b.graph)

    # Store feature map shapes
    for fmidx, (p, fm) in enumerate(zip(points_a, fms_a)):
        if fm is None:
            continue
        # net_a.graph.vs[p[1]]["sh"] = list(fm.shape)
        stitching_library.characterize_fm(net_a.graph.vs[p[1]], fm, net_a.graph)
        net_a.graph.vs[p[1]]["fmidx"] = fmidx
    for fmidx, (p, fm) in enumerate(zip(points_b, fms_b)):
        if fm is None:
            continue
        # net_b.graph.vs[p[1]]["sh"] = list(fm.shape)
        stitching_library.characterize_fm(net_b.graph.vs[p[1]], fm, net_b.graph)
        net_b.graph.vs[p[1]]["fmidx"] = fmidx


    S = cx.compute_pairwise_similarities(
        net_a.graph,
        fms_a_v,
        points_a_v,
        net_b.graph,
        fms_b_v,
        points_b_v,
        compute_similarity=compute_similarity,
        stitching_library=stitching_library,
    )
    # Input & output override
    S[0, 0] = 1
    S[1, 1] = 1

    return S

S = compute_S(
    gca, gcb,
    cx.CVStitchingLib(True, False),
    # cx.BalancingCVStitchingLib(True, False),
    cx.compute_mock_similarity
)
plt.imshow(S)

In [None]:
S

In [None]:
import numpy as np
np.nanmin(S)

In [None]:
torch.save((gca.graph, gcb.graph, S), "stitching-problem.pickle")

In [None]:
gca.graph.write_graphmlz("stitching-problem-graph_a.graphml.gz")
gcb.graph.write_graphmlz("stitching-problem-graph_b.graphml.gz")
np.savetxt("stitching-problem-similarity.txt.gz", S)

In [None]:
# stitched = cx.construct_trained_cx_network_stitching(
#     dataset=dataset,
#     dev=dev,
#     net_a=gca,
#     net_b=gcb,
#     X_in_many=X,
#     ensemblers=[ly.LinearCombine([0.5, 0.5])],
#     compute_similarity=cx.compute_mock_similarity,
#     feature_shape_should_match=False,
#     batch_size=batch_size,

#     pretrain_cx_network = False
# )

In [None]:
# from IPython.display import SVG
# with open("g.dot", "w") as f:
#     stitched[0].to_dot(f, include_ord_label=True)
# ! dot g.dot -Tsvg -og.png
# SVG("g.png")

In [None]:
# Alternative that initializes the layers a bit more intelligently.
class CVInitStitchingLib(cx.StitchingLib):

    def __init__(self, image_shape_should_match, feature_shape_should_match):
        self.image_shape_should_match = image_shape_should_match
        self.feature_shape_should_match = feature_shape_should_match

    def characterize_fm(self, v, fm, graph=None):
        if isinstance(fm, torch.Tensor):
            v["ty"] = "tensor"
            v["sh"] = list(fm.shape)
            with torch.no_grad():
                v["std"], v["mean"] = torch.std_mean(fm)
        else:
            v["ty"] = "unk"

    def can_stitch(self, a, b):
        # If types are unknown - do not allow stitching at these points.
        if a["ty"] == "unk": return False
        if b["ty"] == "unk": return False

        if a["ty"] == "tensor" and b["ty"] == "tensor":
            sh_a = a["sh"]
            sh_b = b["sh"]

            # if fm_a.shape[0] != fm_b.shape[0]:
            #     continue
            if self.image_shape_should_match and not (sh_a[2:] == sh_b[2:]):
                return False
            if self.feature_shape_should_match and not (sh_a[1] == sh_b[1]):
                return False
            return True
        
        # Cascade through
        return False

    def create_stitch(self, a, b):
        if a["ty"] == "tensor" and b["ty"] == "tensor":
            sh_a = a["sh"]
            sh_b = b["sh"]

            # normally we preserve mean & variance with the chosen init
            # This strategy sets the weights & biases accordingly
            # albeit under an uniform assumption.
            offset = b["mean"] - a["mean"]
            scale = b["std"] / a["std"]

            num_features_in = sh_a[1]
            num_features_out = sh_b[1]
            if len(sh_a) == 4 and len(sh_b) == 4 and sh_a[2:] == sh_b[2:]:
                stitch = ly.Conv2d(
                    num_features_in, num_features_out, kernel_size=(1, 1)
                )
                with torch.no_grad():
                    stitch.layer.bias += offset
                    stitch.layer.weight *= scale
                return stitch
            elif len(sh_a) == 2 and len(sh_b) == 2:
                stitch = ly.Linear(num_features_in, num_features_out)
                with torch.no_grad():
                    stitch.layer.bias += offset
                    stitch.layer.weight *= scale
                return stitch
            else:
                raise Exception(
                    f"cannot join items. No merging layer defined for shapes a: {sh_a} b: {sh_b}"
                )
        raise Exception(
                    f"cannot join items. No stitching layer defined between layers from type {a['ty']} to {b['ty']}"
                )



In [None]:
from torch.utils.tensorboard import SummaryWriter
summarywriter = SummaryWriter("./logs/train-stitch-imagenet-2")

In [None]:
# torch.autograd.set_detect_anomaly(False)
stitched = cx.construct_trained_cx_network_stitching(
    dataset=dataset,
    dev=dev,
    net_a=gca,
    net_b=gcb,
    X_in_many=X,
    ensemblers=[ly.LinearCombine([0.5, 0.5])],
    stitching_library=cx.CVStitchingLib(True, False),
    # stitching_library=CVInitStitchingLib(True, False),
    # stitching_library=cx.BalancingCVStitchingLib(True, False),
    # compute_similarity=cx.compute_mock_similarity,
    compute_similarity=cx.compute_mock_similarity,
    feature_shape_should_match=False,
    batch_size=batch_size,

    lr_pretrain=1e-3,
    weight_decay_pretrain=1e-5,

    num_epochs_pretrain = 3,
    pretrain_cx_network = True,
    summarywriter=summarywriter,
    immediately_backprop = False,
)

In [None]:
for es in stitched[1].joiners:
    for e in es:
        e.agg = None
stitched[1].output_switch.agg = None

In [None]:
stitched_simpl = (stitched[0], cx.SitchingInfo(stitched[1].joiners, stitched[1].output_switch))

In [None]:
torch.save(stitched_simpl, "stitched-imagenet-a-resnet152-b-efficientnet-b4.th")
# torch.save(stitched_simpl, "stitched-imagenet-a-resnet152-b-efficientnet-b4--r.th")

In [None]:
stitched = torch.load("stitched-imagenet-a-resnet152-b-efficientnet-b4.th")

In [None]:
stitchnet, stitchinfo = stitched

In [None]:
# Train supernetwork by randomly sampling layers
for m in stitchnet.submodules:
    if not isinstance(m, cx.CXN): continue
    m.active = [0, 1]
    m.p = [0.9, 0.1]
    m.determine_p()
    m.randomize_per_sample = True
stitchinfo.output_switch.active = [0, 1, 2]
stitchinfo.output_switch.p = None
stitchinfo.output_switch.determine_p()

In [None]:
import numpy as np
for m in stitchnet.submodules:
    if not isinstance(m, cx.CXN): continue
    if m.p is None: continue
    if m.p[-1] is None:
        m.p = None
        continue
    m.p = np.cumsum(m.p)
    m.p /= m.p[-1]

In [None]:
from torch.utils.tensorboard import SummaryWriter
summarywriter = SummaryWriter("./logs/refine-imagenet")
stitchneti = NeuralNetIndividual(stitchnet)
problem.train_network(dev, stitchneti, lr=1e-4, weight_decay=1e-5, num_epochs=5, minout_nan=True, batch_size=batch_size, raise_on_nan_loss=False, summarywriter=summarywriter)


In [None]:
stitchneti.net.cpu()

In [None]:
torch.save((stitchnet, stitchinfo), "stitch-train-test.th")

In [None]:
# Embed computational cost info
import torchinfo
import recomb.eval_costs as ec
cost_summary = torchinfo.summary(stitchnet, input_data=[X_o])
ec.embed_cost_stats_in_model(cost_summary)

In [None]:
print(("The stitched neural network has "
       f"{len(stitchinfo.joiners)} matches"
       ))

## Evaluate predetermined neighborhood

In [None]:
# Evaluate reference networks
reference_q = []
stitchinfo.output_switch.active = 2
stitchinfo.output_switch.simplify = True

for j in stitchinfo.joiners:
        j[0].active = 0
        j[1].active = 0
        j[0].simplify = True
        j[1].simplify = True
    
stitchnet_pruned = stitchnet.to_graph()
stitchnet_pruned.prune_unused()
total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

neti_os = NeuralNetIndividual(stitchnet_pruned)
accuracy, loss = problem.evaluate_network(dev, neti_os, objective="both")
reference_q.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))

stitchinfo.output_switch.active = 1
stitchnet_pruned = stitchnet.to_graph()
stitchnet_pruned.prune_unused()
total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

neti_os = NeuralNetIndividual(stitchnet_pruned)
accuracy, loss = problem.evaluate_network(dev, neti_os, objective="both")
reference_q.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))

stitchinfo.output_switch.active = 0
stitchnet_pruned = stitchnet.to_graph()
stitchnet_pruned.prune_unused()
total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

neti_os = NeuralNetIndividual(stitchnet_pruned)
accuracy, loss = problem.evaluate_network(dev, neti_os, objective="both")
reference_q.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))

In [None]:
ensembles_a = []
ensembles_b = []
start_a_end_b = []
start_b_end_a = []

# note - usually 1, but due to the large amount of matches, this has been
# increased so that we can evaluate blocks of solutions instead.
step = 10

In [None]:
# Evaluate neighborhood of networks
offset = 9
stitchinfo.output_switch.active = 2
stitchinfo.output_switch.simplify = True

for j in stitchinfo.joiners:
        j[0].active = 0
        j[1].active = 0
        j[0].simplify = True
        j[1].simplify = True

for i in range(offset, len(stitchinfo.joiners), step):
    j = stitchinfo.joiners[i]
    j[0].active = 0
    j[1].active = 1

    stitchnet_pruned = stitchnet.to_graph()
    stitchnet_pruned.prune_unused()

    # Get compute & memory requirements
    # s = torchinfo.summary(stitchnet_pruned, input_data=[X])
    total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

    neti_os = NeuralNetIndividual(stitchnet_pruned)
    accuracy, loss = problem.evaluate_network(dev, neti_os, objective="both")
    ensembles_a.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))
    
    j[0].active = 0
    j[1].active = 0

for i in range(offset, len(stitchinfo.joiners), step):
    j = stitchinfo.joiners[i]
    j[0].active = 1
    j[1].active = 0

    stitchnet_pruned = stitchnet.to_graph()
    stitchnet_pruned.prune_unused()

    # Get compute & memory requirements
    # s = torchinfo.summary(stitchnet_pruned, input_data=[X])
    total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

    neti_os = NeuralNetIndividual(stitchnet_pruned)
    accuracy, loss = problem.evaluate_network(dev, neti_os, objective="both")
    ensembles_b.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))
    
    j[0].active = 0
    j[1].active = 0

stitchinfo.output_switch.active = 1
for i in range(offset, len(stitchinfo.joiners), step):
    j = stitchinfo.joiners[i]
    j[0].active = 0
    j[1].active = 1

    stitchnet_pruned = stitchnet.to_graph()
    stitchnet_pruned.prune_unused()

    # Get compute & memory requirements
    # s = torchinfo.summary(stitchnet_pruned, input_data=[X])
    total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

    neti_os = NeuralNetIndividual(stitchnet_pruned)
    accuracy, loss = problem.evaluate_network(dev, neti_os, objective="both")
    start_a_end_b.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))
    
    j[0].active = 0
    j[1].active = 0

stitchinfo.output_switch.active = 0
for i in range(offset, len(stitchinfo.joiners), step):
    j = stitchinfo.joiners[i]
    j[0].active = 1
    j[1].active = 0

    stitchnet_pruned = stitchnet.to_graph()
    stitchnet_pruned.prune_unused()

    # Get compute & memory requirements
    # s = torchinfo.summary(stitchnet_pruned, input_data=[X])
    total_mult_adds, total_bytes = ec.evaluate_compute_cost(stitchnet_pruned)

    neti_os = NeuralNetIndividual(stitchnet_pruned)
    accuracy, loss = problem.evaluate_network(dev, neti_os, objective="both")
    start_b_end_a.append((accuracy, loss, total_bytes, total_mult_adds, cx.convert_stitcher_to_genotype(stitchinfo, stringify=False)))
    
    j[0].active = 0
    j[1].active = 0

In [None]:
import polars as pl

In [None]:
df_schema = ["accuracy", "loss", "total bytes", "multiply-adds", "genotype"]
samples_reference = pl.DataFrame(reference_q, schema=df_schema).\
    with_columns([
        pl.lit(pl.Series(["ensemble", "b", "a"])).alias("set"),
        pl.lit(False).alias("contains stitch"),
    ])
samples_ensemble_a = pl.DataFrame(ensembles_a, schema=df_schema).\
    with_columns([
        pl.lit("ensemble-major-a").alias("set"),
        pl.lit(True).alias("contains stitch"),
    ])
samples_ensemble_b = pl.DataFrame(ensembles_b, schema=df_schema).\
    with_columns([
        pl.lit("ensemble-major-b").alias("set"),
        pl.lit(True).alias("contains stitch"),
    ])
samples_ab = pl.DataFrame(start_a_end_b, schema=df_schema).\
    with_columns([
        pl.lit("stitch-a-to-b").alias("set"),
        pl.lit(True).alias("contains stitch"),
    ])
samples_ba = pl.DataFrame(start_b_end_a, schema=df_schema).\
    with_columns([
        pl.lit("stitch-b-to-a").alias("set"),
        pl.lit(True).alias("contains stitch"),
    ])

samples = pl.concat([
    samples_reference,
    samples_ensemble_a,
    samples_ensemble_b,
    samples_ab,
    samples_ba,
]).with_columns(
    pl.col("loss").clip(0.0, 4.0).alias("loss-clip")
)
samples.write_ipc("resnet-efficientnet-stitch-samples.arrow")


## Plot approximation front

In [None]:
samples = pl.read_ipc("resnet-efficientnet-stitch-samples.arrow")

# Extract some rows of reference interest
dfcna = samples[2]
dfcnb = samples[1]
dfcnens = samples[0]

# 
improvement_direction = {
    "accuracy": 1,
    "loss": -1,
    "loss-clip": -1,
    "total bytes": -1,
    "multiply-adds": -1,
    # "genotype": 0, # -- not a criterion
}

In [None]:
len(stitchinfo.joiners) * 4 * 15 / 60 / 60

In [None]:
# How many seconds per evaluated sample?
number_of_minutes = 16 * 8
number_of_seconds = 0
number_of_samples = len(samples) - 3
seconds_total = number_of_minutes * 60 + number_of_seconds
seconds_per_sample = seconds_total / number_of_samples

print(f"spent {number_of_minutes}m{number_of_seconds}s "
      f"to evaluate {number_of_samples} samples.\n"
      f"Resulting in a cost of {seconds_per_sample}s per sample.")

In [None]:
# Compute pareto set from these points, with respect to these
# two criteria / objectives
c0 = "accuracy"
c1 = "multiply-adds"

samples_pareto = (samples.lazy()
    .sort(c0, descending=improvement_direction[c0] > 0)
    .with_columns((pl.col(c1) * -improvement_direction[c1]).alias("c1-min"))
    .with_columns((pl.col("c1-min")).cummin().alias("mv"))
    .with_columns((pl.col("c1-min") < pl.col("mv").shift(1)).alias("is pareto")).fill_null(True)
    .filter(pl.col("is pareto"))
).collect()

samples_pareto_stitch_only = (samples.lazy()
    .filter(pl.col("contains stitch"))
    .sort(c0, descending=improvement_direction[c0] > 0)
    .with_columns((pl.col(c1) * -improvement_direction[c1]).alias("c1-min"))
    .with_columns((pl.col("c1-min")).cummin().alias("mv"))
    .with_columns((pl.col("c1-min") < pl.col("mv").shift(1)).alias("is pareto")).fill_null(True)
    .filter(pl.col("is pareto"))
).collect()

In [None]:
for sn, df in samples.filter(pl.col("contains stitch")).group_by("set", maintain_order=True):
    plt.scatter(df[c0], df[c1], label=sn, s=1.0)

plt.scatter(samples_pareto[c0], samples_pareto[c1], alpha=0.4, marker="s", color="grey")
plt.scatter(samples_pareto_stitch_only[c0], samples_pareto_stitch_only[c1], s=20.0, alpha=0.5, color="grey")

plt.scatter(dfcna[c0], dfcna[c1], label="a", marker='x')
plt.scatter(dfcnb[c0], dfcnb[c1], label="b", marker='x')
plt.scatter(dfcnens[c0], dfcnens[c1], label="ensemble", marker='x')

def get_direction_arrow(c):
    return '->' if improvement_direction[c] > 0 else '<-'

plt.xlabel(f"{c0} ({get_direction_arrow(c0)})")
plt.ylabel(f"{c1} ({get_direction_arrow(c1)})")
plt.legend(loc='upper left',
           bbox_to_anchor=(1.0, 1.0),
           fancybox=False,
           shadow=True)

**Potential Points of Improvement?**
1. Pretrain for longer? (e.g. specific stopping condition?)
2. Train using actual loss function.

In [None]:
plt.hist2d(samples[c0], samples[c1])
plt.colorbar()

### Some test evaluations

In [None]:
dev2 = torch.device("cuda:1")

In [None]:
# Evaluate original networks
neti_a = NeuralNetIndividual(gca)
neti_b = NeuralNetIndividual(gcb)
problem.evaluate_network(dev2, neti_a, objective="both"),\
    problem.evaluate_network(dev, neti_b, objective="both")

In [None]:
stitchneti = NeuralNetIndividual(stitchnet)
for j in stitchinfo.joiners:
    j[0].active = 0
    j[1].active = 0

stitchinfo.output_switch.active = 0
roa = problem.evaluate_network(dev, stitchneti, objective="both")
stitchinfo.output_switch.active = 1
rob = problem.evaluate_network(dev, stitchneti, objective="both")
roa, rob

In [None]:
stitchneti = NeuralNetIndividual(stitchnet)
for j in stitchinfo.joiners:
    j[0].active = 0
    j[1].active = 0
stitchinfo.output_switch.active = 2
j = stitchinfo.joiners[18]
# j[0].active = 0
# j[1].active = 1
j[0].active = 1
j[1].active = 0

problem.evaluate_network(dev, stitchneti, objective="both")

In [None]:
import gc


gc.collect()
torch.cuda.empty_cache()