In [None]:
# switch to main dir to fix local imports
import os
if os.getcwd().endswith("notebooks"):
    os.chdir("..")

# package imports
from datetime import datetime
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import torch
import numpy as np
import math

# local imports
from src.evaluator import Evaluator
from src.args import Args
import src.graphs as graphs

# start autoreload
%load_ext autoreload
%autoreload 2

# select device for machine learning
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
print(f"using device: {device}")

# settings for plots (seaborn/matplotlib)
sns.set_context("paper")
sns.set_style("darkgrid", {"grid.color": ".8"})
palette = "Dark2"

# RGG threshold reconstruction

This experiment tries to visualize what the framework learned about RGGs. For RGGs two nodes are adjacent if their distance is smaller than some radius/threshold. We expect a framework that was trained on an RGG (and yields a good score) to "understand" this rule.

As such we plot what the framework predicts for different node pairs (distances). We expect that the resulting plot to indicate low numbers of false negatives and false positives.

In [None]:
# experiment setup
args = Args(
    graph_type="rgg",
    graph_size=1000,
    subgraph_alg="rjs",
    subgraph_size=100,
    subgraph_alpha=0.15,
)
radius = math.sqrt(args.rgg_avg_degree / ((args.graph_size - 1) * math.pi))
experiment_key = f"rgg-threshold--{datetime.now().strftime('%d-%m--%H-%M')}"

In [None]:
# generate graph and train evaluator
graph = graphs.gen_graph(args)
evaluator = Evaluator(
    graph=graph,
    args=args,
    writer_log_dir=f"runs/{experiment_key}--{args.__hash__()}",
    device=device
)
# train the model
evaluator.train(
    optimizer=torch.optim.Adam(evaluator.net.parameters(), lr=1e-3),
    pbar=True
)
# test the model
test_loss, test_ap, test_f1, test_threshold = evaluator.test(
    epoch=args.epochs
)
# print train results
print(f"model performance: "
      f"\n- loss: {test_loss}"
      f"\n- f1: {test_f1}"
      f"\n- ap: {test_ap}"
      f"\n- th: {test_threshold}")

In [None]:
# generate distance data
_, preds = evaluator.score(evaluator.whole_dataset.dataloader)

precision, recall, thresholds = evaluator.pr_curve_fn(preds, evaluator.whole_dataset.ds_labels)
f1_scores = (2 * precision * recall) / (precision + recall)
threshold = thresholds[np.argmax(f1_scores.cpu())].item()

positions = graph.nodes(data="pos")
df_edges = pd.DataFrame([
    {
        "dist": math.dist(positions[u], positions[v]),
        "dist_rounded": round(math.dist(positions[u], positions[v]), 2),
        "label": 1 if math.dist(positions[u], positions[v]) <= radius else 0,
        "u": u,
        "v": v,
        "uf": uf,
        "vf": vf,
        "up": positions[u],
        "vp": positions[v],
        "i": i,
        "pred": preds[i].item(),
        "pred_label": 1 if preds[i].item() >= threshold else 0,
    }
    for i, ((u, uf), (v, vf)) in enumerate(evaluator.whole_dataset.node_feature_pairs)
])

In [None]:
# plot distance data
g = sns.relplot(
    kind="line",
    errorbar="sd",
    data=df_edges,
    x="dist_rounded",
    y="pred",
    palette=palette,
    aspect=2,
)
g.set_axis_labels("Edge Distance", "Prediction")

plt.axvline(x=radius, color="r")
plt.axhline(y=threshold, color="g")
plt.ylim(bottom=0)
plt.xlim(left=-0.01, right=0.21)

plt.tight_layout()
plt.savefig('./out/original_edge_pred.pdf')

# Graph reconstructions

This experiment plots the reconstruction of a graph (GIRG) by the framework. We expect the reconstructed graph to be similar to the original graph.

In [None]:
# experiment setup
args = Args(
    graph_type="girg",
    graph_size=1000,
    subgraph_alg="rjs",
    subgraph_size=250,
    subgraph_alpha=1.0,
)
radius = math.sqrt(args.rgg_avg_degree / ((args.graph_size - 1) * math.pi))
experiment_key = f"rgg-reconstruction--{datetime.now().strftime('%d-%m--%H-%M')}"

In [None]:
# generate graph and train evaluator
graph = graphs.gen_graph(args)
evaluator = Evaluator(
    graph=graph,
    args=args,
    writer_log_dir=f"runs/{experiment_key}--{args.__hash__()}",
    device=device
)
# train the model
evaluator.train(
    optimizer=torch.optim.Adam(evaluator.net.parameters(), lr=1e-3),
    pbar=True
)
# test the model
test_loss, test_ap, test_f1, test_threshold = evaluator.test(
    epoch=args.epochs
)

print(f"model performance:"
      f"\n- loss: {test_loss}"
      f"\n- f1: {test_f1}"
      f"\n- ap: {test_ap}"
      f"\n- th: {test_threshold}")

In [None]:
# plot the results
evaluator.eval(toroid=True)

plt.tight_layout(pad=2)
plt.savefig('./out/reconstruction_girg.pdf')

# Graph reconstructions - false negatives and false positives

Like the previous experiment, this experiment plots the reconstruction of a graph (GIRG) by the framework. But instead of showing the prediction for every edge, it highlights the false negatives (edges that were not predicted but should be) and false positives (edges that were predicted but shouldn't be).

In [None]:
# experiment setup
args = Args(
    graph_type="rgg",
    graph_size=1000,
    subgraph_alg="rjs",
    subgraph_size=100,
    subgraph_alpha=0.15,
)
radius = math.sqrt(args.rgg_avg_degree / ((args.graph_size - 1) * math.pi))
experiment_key = f"rgg-reconstruction-fn-fp--{datetime.now().strftime('%d-%m--%H-%M')}"

In [None]:
# generate graph and train evaluator
graph = graphs.gen_graph(args)
evaluator = Evaluator(
    graph=graph,
    args=args,
    writer_log_dir=f"runs/{experiment_key}--{args.__hash__()}",
    device=device
)
# train the model
evaluator.train(
    optimizer=torch.optim.Adam(evaluator.net.parameters(), lr=1e-3),
    pbar=True
)
# test the model
test_loss, test_ap, test_f1, test_threshold = evaluator.test(
    epoch=args.epochs
)

print(f"model performance:"
      f"\n- loss: {test_loss}"
      f"\n- f1: {test_f1}"
      f"\n- ap: {test_ap}"
      f"\n- th: {test_threshold}")

In [None]:
# plot the results
fig = evaluator.classify(toroid=False)

plt.tight_layout(pad=2)
plt.savefig('./out/classification_all_reconstruction_girg.pdf')