In [1]:
from itertools import product
from display import *
from util import load_experiments

save_svg, save_png, save_pdf = False, False, False

  self[key] = other[key]


## Loading data

In [2]:
dynamics = ["sis", "plancksis", "sissis"]
networks = ["gnp", "ba"]

path = "../../data/case-study/summaries"
exp_names = {d: f"exp-{d}-{networks}" for d in dynamics}
stocont_exp = load_experiments(path, exp_names)

dynamics = ["dsir"]
networks = ["w_gnp", "w_ba"]

path = "../../data/case-study/summaries"
exp_names = {d: f"exp-{d}-{networks}" for d in dynamics}
metapop_exp = load_experiments(path, exp_names)

Did not find file `exp-sis-ba.zip`, kept proceding.
Did not find file `exp-plancksis-ba.zip`, kept proceding.


## Additional functions


In [3]:
colors = {
    "gnn": color_pale["blue"],
    "mle": color_dark["red"],
    "uni": color_dark["grey"],
}
linestyles = {
    "gnn": "-",
    "mle": "-",
    "uni": "--",
}
markers = {
    "gnn": "o",
    "mle": "^",
    "uni": "None",
}


def score(x, y):
    return pearsonr(x, y)[0]
#     return 1 - np.sum((x - y)**2) / np.sum(x**2 + y**2)

def accuracy_plot(x, y, ax, marker, color, with_r=True):
    ax.plot(
        x, y, marker=marker, markersize=2, alpha=0.3, color=color, linestyle="None"
    )
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])
    if with_r:
        r = score(x, y)
        label_plot(ax, fr"$r = {np.round(r, 5)}$", "lower right")
    ax.set_xlabel(r"Target [$y_i(t)$]", fontsize=large_fontsize)
    ax.set_ylabel(r"Prediction [$\hat{y}_i(t)$]", fontsize=large_fontsize)
    return ax

def acc_deg_plot(true, pred, k, ax, m, ls, c):
    _k = np.arange(k.min(), k.max() + 1)
    r = np.zeros(_k.shape)
    for i, kk in enumerate(_k):
        index = k == kk
        
        if np.sum(index) > 10:
            _true = true[index].flatten()
            _pred = pred[index].flatten()
            r[i] = score(_true, _pred)
        else:
            r[i] = np.nan
    index = np.where(~np.isnan(r))[0]
    _k = _k[~np.isnan(r)]
    r = r[~np.isnan(r)]
    if m == "x":
        ms=6
    elif m == "+":
        ms=8
    else:
        ms=4
    ax.plot(_k, 1 - r, marker=m, linestyle=ls, color=c, markersize=ms, lw=1)
    ax.set_xlabel(r"Number of neighbors [$k$]", fontsize=large_fontsize)
    ax.set_ylabel(r"Error [$1 - r$]", fontsize=large_fontsize)
    ax.set_xscale("log")
    ax.set_yscale("log")
    
    r_min = np.min(1 - r) / 5
    r_max = np.max(1 - r) * 2
    
    
    if k.min() == 0:
        ax.set_xlim([1, 100])
        ax.set_xticks([1, 10, 100])
        r_min = np.min(1 - r[1:]) / 5
        r_max = np.max(1 - r[1:]) * 2
    else:
        ax.set_xlim([_k.min(), 100])
        ax.set_xticks([10, 100])
        r_min = np.min(1 - r) / 5
        r_max = np.max(1 - r) * 2
    return ax

def make_acc_plots(experiment, ax_acc, ax_deg, markers, linestyles, colors):
    x = experiment.metrics["TrueLTPMetrics"].data["ltp"]
    k = experiment.metrics["TrueLTPMetrics"].data["summaries"][:,1:].sum(-1)
    gnn_y = experiment.metrics["GNNLTPMetrics"].data["ltp"]
    mle_y = experiment.metrics["MLELTPMetrics"].data["ltp"]
    accuracy_plot(x.flatten(), gnn_y.flatten(), ax_acc, "^", colors["gnn"])
    acc_deg_plot(x, mle_y, k, ax_deg, markers["mle"], linestyles["mle"], colors["mle"])
    acc_deg_plot(x, gnn_y, k, ax_deg, markers["gnn"], linestyles["gnn"], colors["gnn"])
    
    ax_acc.tick_params(axis='both', which='both', labelsize=small_fontsize)
    ax_deg.tick_params(axis='both', which='both', labelsize=small_fontsize)


## Making the plot

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(4 * 4, 3 * 3))


for i, d in enumerate(["sis", "plancksis", "sissis"]):
    colors = {"gnn": color_dark["green"], "mle": color_dark["red"]}
    markers = {"gnn": "^", "mle": "+"}
    linestyles = {"gnn": "-", "mle": "-"}
    make_acc_plots(exp[d, "gnp"], ax[0,i], ax[2,i], markers, linestyles, colors)
    
    colors = {"gnn": color_dark["blue"], "mle": color_dark["orange"]}
    markers = {"gnn": "s", "mle": "x"}
    linestyles = {"gnn": "-", "mle": "-"}
    make_acc_plots(exp[d, "ba"], ax[1,i], ax[2,i], markers, linestyles, colors)

x = stocont_exp["dsir", "w_gnp"].metrics["PredictionMetrics"].data["true"]
k = stocont_exp["dsir", "w_gnp"].metrics["PredictionMetrics"].data["degree"]
gnn_y = stocont_exp["dsir", "w_gnp"].metrics["PredictionMetrics"].data["pred"]
accuracy_plot(x.flatten(), gnn_y.flatten(), ax[0, -1], "v", color_dark["green"])
acc_deg_plot(x, gnn_y, k, ax[2, -1], "s", "-", color_dark["green"])

x = stocont_exp["dsir", "w_ba"].metrics["PredictionMetrics"].data["true"]
k = stocont_exp["dsir", "w_ba"].metrics["PredictionMetrics"].data["degree"]
gnn_y = stocont_exp["dsir", "w_ba"].metrics["PredictionMetrics"].data["pred"]
accuracy_plot(x.flatten(), gnn_y.flatten(), ax[1, -1], "^", color_dark["blue"])
acc_deg_plot(x, gnn_y, k, ax[2,-1], "^", "-", color_dark["blue"])
ax[2,-1].set_ylim([1e-5, 1e-1])

ax[0,-1].tick_params(axis='both', which='major', labelsize=small_fontsize)
ax[0,-1].tick_params(axis='both', which='minor', labelsize=small_fontsize)
ax[1,-1].tick_params(axis='both', which='major', labelsize=small_fontsize)
ax[1,-1].tick_params(axis='both', which='minor', labelsize=small_fontsize)
ax[2,-1].tick_params(axis='both', which='major', labelsize=small_fontsize)
ax[2,-1].tick_params(axis='both', which='minor', labelsize=small_fontsize)
    
ax[0,0].set_title(r"\textbf{Simple}", fontsize=large_fontsize)
ax[0,1].set_title(r"\textbf{Complex}", fontsize=large_fontsize)
ax[0,2].set_title(r"\textbf{Interacting}", fontsize=large_fontsize)
ax[0,3].set_title(r"\textbf{Metapopulation}", fontsize=large_fontsize)

label_plot(ax[0,0], r"\textbf{(a)}", loc="upper left")
label_plot(ax[0,1], r"\textbf{(b)}", loc="upper left")
label_plot(ax[0,2], r"\textbf{(c)}", loc="upper left")
label_plot(ax[0,3], r"\textbf{(d)}", loc="upper left")

label_plot(ax[1,0], r"\textbf{(e)}", loc="upper left")
label_plot(ax[1,1], r"\textbf{(f)}", loc="upper left")
label_plot(ax[1,2], r"\textbf{(g)}", loc="upper left")
label_plot(ax[1,3], r"\textbf{(h)}", loc="upper left")

label_plot(ax[2,0], r"\textbf{(i)}", loc="upper left")
label_plot(ax[2,1], r"\textbf{(j)}", loc="upper left")
label_plot(ax[2,2], r"\textbf{(k)}", loc="upper left")
label_plot(ax[2,3], r"\textbf{(l)}", loc="upper left")

handles = []

handles.append(
    Line2D(
            [-1], [-1], 
            linestyle="None", 
            marker="^", 
            linewidth=3,
            markersize=8,
            color=color_dark["green"],
            label=r"GNN-ER"
    )
)
handles.append(
    Line2D(
            [-1], [-1], 
            linestyle="None", 
            marker="s", 
            linewidth=3,
            markersize=8,
            color=color_dark["blue"], 
            label=r"GNN-BA"
    )
)
handles.append(
    Line2D(
            [-1], [-1], 
            linestyle="None", 
            marker="x", 
            linewidth=3,
            markersize=8,
            color=color_dark["orange"], 
            label=r"MLE-ER"
    )
)
handles.append(
    Line2D(
            [-1], [-1], 
            linestyle="None", 
            marker="+", 
            linewidth=3,
            markersize=10,
            color=color_dark["red"], 
            label=r"MLE-BA"
    )
)
ax[-1,-1].legend(
    handles=handles, loc="lower right", fancybox=False, fontsize=14, framealpha=0., ncol=2, handletextpad=0.1
)

plt.tight_layout(0.1, w_pad=1)

figname = "manuscript-figure1"
if save_png:
    fig.savefig(os.path.join("png", f"{figname}.png"))
# if save_pdf:
#     fig.savefig(os.path.join("pdf", f"{figname}.pdf"))
# if save_svg:
#     fig.savefig(os.path.join("svg", f"{figname}.svg"))
