In [1]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import itertools
from scipy.stats import pearsonr, spearmanr
from IPython.display import display

In [2]:
EXPERIMENT_DICT = {"augmentation_test":     ["Normal", "Augmentation_25", "Augmentation_50", "Augmentation_75", "Augmentation_100"],
                   "shortcut_test": ["Shortcut_0", "Shortcut_25", "Shortcut_50", "Shortcut_75", "Shortcut_100"],
                   "label_test":     ["Normal", "RandomLabels_25", "RandomLabels_50", "RandomLabels_75", "RandomLabels_100"],
}

In [5]:
def agg_train_results(exp, dataset, architecture, code=None):
    res_path = os.path.join("C:/Users/Tobias/Eigene Dokumente/Research/similaritybench/experiments/models/graphs", dataset, architecture)
    res_list = []
    res_csvs = [f"train_results_s{s}.csv" for s in range(5)]
    for setting in EXPERIMENT_DICT[exp]:
        setting_path = os.path.join(res_path, setting)
        if code is not None:
            setting_path = os.path.join(setting_path, code)
        for fname in res_csvs:
            setting_results = np.zeros(3)
            n_runs = 0
            curr_df = pd.read_csv(os.path.join(setting_path,fname)).iloc[:,-3:]
            curr_results = np.array(curr_df.iloc[-1])
            setting_results += curr_results
            n_runs+=1
        setting_results /= n_runs
        res_list.append([setting, architecture] + list(setting_results))
    df_res = pd.DataFrame(data=res_list, columns = ["Setting", "Architecture", "Train", "Val", "Test"])
    return df_res

In [10]:
def get_data_table(experiment, dataset, architectures = None, code=None):
    data_list = []
    if architectures is None:
        architectures = ["GCN", "GraphSAGE", "GAT"]
    for arch in architectures:
        data_list.append(agg_train_results(exp=experiment, dataset=dataset, architecture=arch, code=code))
    res_df = pd.concat(data_list)
    res_df = res_df.pivot(index="Setting", columns=["Architecture"])
    return res_df.reindex(EXPERIMENT_DICT[experiment])[["Train", "Val", "Test"]].swaplevel(axis=1).sort_index(axis=1, level=0).reindex(columns=["Train", "Val", "Test"], level=1)

In [11]:
get_data_table("shortcut_test", "cora", ["GCN"], code="10feats")

Architecture,GCN,GCN,GCN
Unnamed: 0_level_1,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
Shortcut_0,0.8,0.564,0.589
Shortcut_25,0.621429,0.268,0.315
Shortcut_50,0.721429,0.54,0.565
Shortcut_75,0.978571,0.634,0.636
Shortcut_100,0.371429,0.18,0.18


In [12]:
get_data_table("shortcut_test", "cora", ["GCN"], code="5feats")

Architecture,GCN,GCN,GCN
Unnamed: 0_level_1,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
Shortcut_0,0.742857,0.34,0.347
Shortcut_25,0.821429,0.366,0.386
Shortcut_50,1.0,0.744,0.751
Shortcut_75,0.814286,0.472,0.435
Shortcut_100,0.985714,0.592,0.58


In [14]:
get_data_table("shortcut_test", "cora", ["GCN", "GraphSAGE"], code="4feats")

Architecture,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
Shortcut_0,0.992857,0.686,0.702,1.0,0.608,0.636
Shortcut_25,0.892857,0.518,0.553,0.985714,0.54,0.58
Shortcut_50,0.878571,0.652,0.672,1.0,0.694,0.691
Shortcut_75,0.964286,0.602,0.601,0.95,0.59,0.591
Shortcut_100,1.0,0.694,0.671,0.957143,0.702,0.595


In [18]:
get_data_table("shortcut_test", "cora", code="1-hot")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,1.0,0.742,0.771,1.0,0.696,0.718,1.0,0.746,0.749
Shortcut_25,1.0,0.748,0.767,1.0,0.716,0.72,1.0,0.746,0.731
Shortcut_50,1.0,0.772,0.778,1.0,0.728,0.735,1.0,0.758,0.742
Shortcut_75,1.0,0.822,0.783,1.0,0.798,0.748,1.0,0.828,0.737
Shortcut_100,1.0,0.884,0.779,1.0,0.86,0.753,1.0,0.902,0.651


In [7]:
get_data_table("shortcut_test", "cora")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,1.0,0.728,0.738,1.0,0.696,0.727,1.0,0.726,0.729
Shortcut_25,1.0,0.726,0.735,1.0,0.714,0.726,1.0,0.594,0.588
Shortcut_50,1.0,0.698,0.714,1.0,0.738,0.737,1.0,0.69,0.677
Shortcut_75,1.0,0.738,0.742,1.0,0.738,0.726,1.0,0.678,0.632
Shortcut_100,1.0,0.692,0.708,1.0,0.758,0.769,1.0,0.642,0.564


In [6]:
get_data_table("shortcut_test", "flickr")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,0.580706,0.496728,0.498633,0.43079,0.431472,0.429794,0.976067,0.483776,0.490611
Shortcut_25,0.56549,0.503451,0.505535,0.498599,0.48893,0.488953,0.976381,0.461321,0.449917
Shortcut_50,0.620101,0.493188,0.490297,0.472695,0.469075,0.466589,0.971765,0.529446,0.436562
Shortcut_75,0.620011,0.481221,0.475866,0.427025,0.402384,0.339936,0.992112,0.651578,0.318604
Shortcut_100,0.520067,0.445635,0.372832,0.518678,0.492336,0.398781,0.999866,0.98019,0.033882


In [7]:
get_data_table("shortcut_test", "ogbn-arxiv")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,0.7577,0.720662,0.705245,0.788357,0.723581,0.706212,0.838951,0.71499,0.692735
Shortcut_25,0.760834,0.720125,0.699751,0.789622,0.720192,0.699484,0.843459,0.72942,0.707714
Shortcut_50,0.770653,0.72781,0.711067,0.795681,0.731132,0.699566,0.876755,0.775529,0.670946
Shortcut_75,0.784025,0.746166,0.711787,0.807689,0.743515,0.699237,0.922884,0.845934,0.612761
Shortcut_100,0.820774,0.782845,0.691048,0.823391,0.763851,0.68111,0.996778,0.981677,0.087793


In [8]:
get_data_table("augmentation_test", "cora")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,1.0,0.756,0.768,1.0,0.7,0.729,1.0,0.72,0.728
Augmentation_25,1.0,0.746,0.751,1.0,0.696,0.712,1.0,0.696,0.709
Augmentation_50,1.0,0.764,0.769,1.0,0.69,0.671,1.0,0.674,0.698
Augmentation_75,1.0,0.744,0.764,1.0,0.736,0.748,1.0,0.668,0.681
Augmentation_100,0.992857,0.748,0.764,0.992857,0.722,0.754,1.0,0.658,0.643


In [9]:
get_data_table("augmentation_test", "flickr")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.589961,0.498745,0.504952,0.468751,0.455943,0.453816,0.965199,0.477053,0.476986
Augmentation_25,0.574185,0.514701,0.517546,0.525714,0.514835,0.51535,0.960717,0.483014,0.48474
Augmentation_50,0.59144,0.541368,0.543674,0.517221,0.511787,0.508986,0.930017,0.475887,0.470802
Augmentation_75,0.552605,0.525323,0.529826,0.511978,0.507664,0.50428,0.870387,0.456795,0.45646
Augmentation_100,0.525087,0.513759,0.509882,0.444773,0.443887,0.445032,0.805826,0.450789,0.450903


In [10]:
get_data_table("augmentation_test", "ogbn-arxiv")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.763473,0.717977,0.701808,0.787741,0.714051,0.69755,0.852344,0.723279,0.715388
Augmentation_25,0.744494,0.721199,0.712775,0.773271,0.731233,0.71831,0.815254,0.720729,0.707343
Augmentation_50,0.730804,0.713782,0.698105,0.754137,0.723783,0.716046,0.804038,0.728884,0.721684
Augmentation_75,0.718015,0.712977,0.701479,0.739886,0.714152,0.69255,0.786444,0.72338,0.71368
Augmentation_100,0.693758,0.691768,0.680472,0.716344,0.69546,0.671481,0.758635,0.709588,0.701644


In [11]:
get_data_table("label_test", "cora")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,1.0,0.756,0.768,1.0,0.7,0.729,1.0,0.72,0.728
RandomLabels_25,0.050711,1.0,0.432,0.017239,1.0,0.36,0.004216,1.0,0.434
RandomLabels_50,0.089152,1.0,0.218,0.010516,1.0,0.21,0.00639,1.0,0.224
RandomLabels_75,0.054461,0.992857,0.166,0.017131,0.992857,0.158,0.004783,1.0,0.138
RandomLabels_100,0.076068,0.985714,0.148,0.05556,0.985714,0.146,0.004632,1.0,0.136


In [12]:
get_data_table("label_test", "flickr")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.589961,0.498745,0.504952,0.468751,0.455943,0.453816,0.965199,0.477053,0.476986
RandomLabels_25,1.407595,0.498286,0.372221,1.600528,0.414274,0.385398,0.602903,0.962286,0.31566
RandomLabels_50,1.645359,0.378286,0.256947,1.792863,0.306487,0.263625,0.688117,0.944784,0.204598
RandomLabels_75,1.770817,0.34577,0.173001,1.890668,0.220661,0.177483,0.651707,0.978846,0.150278
RandomLabels_100,1.742481,0.326947,0.146961,1.865442,0.197602,0.153012,0.62279,0.982723,0.152877


In [13]:
get_data_table("label_test", "ogbn-arxiv")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.763473,0.717977,0.701808,0.787741,0.714051,0.69755,0.852344,0.723279,0.715388
RandomLabels_25,2.015594,0.571161,0.539045,1.882507,0.586941,0.523642,1.72386,0.629177,0.534984
RandomLabels_50,2.798009,0.382622,0.35914,2.669756,0.3953,0.348468,2.512188,0.436668,0.349072
RandomLabels_75,3.338989,0.201295,0.177959,3.214353,0.220824,0.168395,3.071998,0.284404,0.168026
RandomLabels_100,3.570866,0.08467,0.02641,3.417229,0.137056,0.025202,3.275284,0.279093,0.025672
