In [1]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import itertools
from scipy.stats import pearsonr, spearmanr
from IPython.display import display

In [2]:
EXPERIMENT_DICT = {"augmentation_test":     ["Normal", "Augmentation_25", "Augmentation_50", "Augmentation_75", "Augmentation_100"],
                   "shortcut_test": ["Shortcut_0", "Shortcut_25", "Shortcut_50", "Shortcut_75", "Shortcut_100"],
                   "label_test":     ["Normal", "RandomLabels_25", "RandomLabels_50", "RandomLabels_75", "RandomLabels_100"],
                   "layer_test":     ["Normal", "MultiLayer"],
}

In [3]:
def agg_train_results(exp, dataset, architecture):
    res_path = os.path.join("C:/Users/Tobias/Eigene Dokumente/Research/similaritybench/experiments/models/graphs", dataset, architecture)
    res_list = []
    res_csvs = [f"train_results_s{s}.csv" for s in range(5)]
    for setting in EXPERIMENT_DICT[exp]:
        setting_path = os.path.join(res_path, setting)
        n_runs = 0
        setting_results = np.zeros(3)
        for fname in res_csvs:
            curr_df = pd.read_csv(os.path.join(setting_path,fname)).iloc[:,-3:]
            curr_results = np.array(curr_df.iloc[-1])
            setting_results += curr_results
            n_runs+=1
        setting_results /= n_runs
        res_list.append([setting, architecture] + list(setting_results))
    df_res = pd.DataFrame(data=res_list, columns = ["Setting", "Architecture", "Train", "Val", "Test"])
    return df_res

In [4]:
def get_data_table(experiment, dataset):
    data_list = []
    for arch in ["GCN", "GraphSAGE", "GAT"]:
        data_list.append(agg_train_results(experiment, dataset, arch))
    res_df = pd.concat(data_list)
    res_df = res_df.pivot(index="Setting", columns=["Architecture"])
    return res_df.reindex(EXPERIMENT_DICT[experiment])[["Train", "Val", "Test"]].swaplevel(axis=1).sort_index(axis=1, level=0).reindex(columns=["Train", "Val", "Test"], level=1)

In [5]:
get_data_table("shortcut_test", "cora")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,1.0,0.7492,0.7586,1.0,0.704,0.7336,1.0,0.7284,0.7454
Shortcut_25,1.0,0.7564,0.7596,1.0,0.7216,0.734,1.0,0.7328,0.739
Shortcut_50,1.0,0.7816,0.761,1.0,0.7416,0.7278,1.0,0.7652,0.737
Shortcut_75,1.0,0.826,0.7606,1.0,0.8104,0.7426,1.0,0.8232,0.7042
Shortcut_100,1.0,0.8748,0.76,1.0,0.848,0.7328,1.0,0.9136,0.6526


In [6]:
get_data_table("shortcut_test", "flickr")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,0.605136,0.497105,0.499843,0.504081,0.480056,0.480688,0.981652,0.511536,0.512464
Shortcut_25,0.632964,0.512218,0.514552,0.514725,0.489387,0.482409,0.974866,0.489378,0.454175
Shortcut_50,0.64739,0.512003,0.498042,0.481699,0.469523,0.436669,0.990682,0.568322,0.358858
Shortcut_75,0.622776,0.516583,0.449415,0.50275,0.492533,0.434276,0.994425,0.734421,0.210864
Shortcut_100,0.755895,0.682897,0.357191,0.363827,0.359403,0.307399,0.999982,0.999104,0.001138


In [7]:
get_data_table("shortcut_test", "ogbn-arxiv")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,0.775545,0.725058,0.709043,0.808234,0.720514,0.703442,0.881437,0.72183,0.703689
Shortcut_25,0.784854,0.736897,0.711182,0.820673,0.735716,0.696035,0.916328,0.754079,0.689484
Shortcut_50,0.805025,0.756515,0.701146,0.846465,0.770254,0.693574,0.959481,0.81488,0.663864
Shortcut_75,0.825584,0.788396,0.699451,0.876436,0.806859,0.676621,0.991273,0.891238,0.60537
Shortcut_100,0.877598,0.854519,0.651001,0.921142,0.86266,0.593688,1.0,1.0,0.000416


In [8]:
get_data_table("augmentation_test", "cora")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,1.0,0.7516,0.7602,1.0,0.72,0.7422,1.0,0.7344,0.7448
Augmentation_25,1.0,0.7468,0.7614,1.0,0.6956,0.709,1.0,0.7148,0.721
Augmentation_50,1.0,0.7568,0.7692,1.0,0.6988,0.7158,1.0,0.6808,0.702
Augmentation_75,1.0,0.7528,0.764,1.0,0.72,0.74,1.0,0.684,0.6856
Augmentation_100,0.998571,0.756,0.7706,0.995714,0.7308,0.7568,1.0,0.65,0.6558


In [9]:
get_data_table("augmentation_test", "flickr")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.603621,0.504823,0.509362,0.453203,0.438679,0.437288,0.964114,0.472589,0.470336
Augmentation_25,0.592215,0.516619,0.518361,0.504982,0.498422,0.498113,0.94863,0.482861,0.482454
Augmentation_50,0.559278,0.522329,0.523775,0.523792,0.516762,0.515377,0.908343,0.478899,0.476171
Augmentation_75,0.547626,0.524399,0.525694,0.472623,0.470581,0.469995,0.844114,0.462863,0.461301
Augmentation_100,0.522483,0.511178,0.508484,0.443299,0.443609,0.443956,0.775964,0.447831,0.448779


In [10]:
get_data_table("augmentation_test", "ogbn-arxiv")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.766088,0.722601,0.710368,0.797682,0.724346,0.705384,0.854081,0.725098,0.708076
Augmentation_25,0.743915,0.720789,0.713252,0.769884,0.726065,0.715108,0.819899,0.723098,0.707969
Augmentation_50,0.731188,0.718346,0.706989,0.754711,0.726575,0.71533,0.804462,0.724937,0.710145
Augmentation_75,0.718527,0.714829,0.705623,0.740216,0.717178,0.699977,0.786332,0.723273,0.710565
Augmentation_100,0.693441,0.692694,0.683715,0.719431,0.697768,0.673181,0.761124,0.709286,0.698274


In [11]:
get_data_table("label_test", "cora")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,1.0,0.7516,0.7602,1.0,0.72,0.7422,1.0,0.7344,0.7448
RandomLabels_25,0.997143,0.4444,0.4524,0.992857,0.3972,0.398,1.0,0.4292,0.4256
RandomLabels_50,0.994286,0.2176,0.2368,0.988571,0.2104,0.2098,1.0,0.2268,0.2326
RandomLabels_75,0.991429,0.152,0.1472,0.991429,0.1472,0.1514,1.0,0.1436,0.1462
RandomLabels_100,0.988571,0.15,0.1636,0.984286,0.1504,0.1616,1.0,0.1476,0.1552


In [12]:
get_data_table("label_test", "flickr")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.603621,0.504823,0.509362,0.453203,0.438679,0.437288,0.964114,0.472589,0.470336
RandomLabels_25,0.47101,0.364109,0.366083,0.366207,0.34438,0.344588,0.957185,0.332951,0.331511
RandomLabels_50,0.395514,0.25943,0.258746,0.290075,0.265194,0.262869,0.950216,0.203917,0.202268
RandomLabels_75,0.323671,0.158776,0.157684,0.216887,0.167255,0.167857,0.97522,0.148162,0.146874
RandomLabels_100,0.33107,0.152456,0.153104,0.208757,0.151237,0.155022,0.982889,0.15147,0.151992


In [13]:
get_data_table("label_test", "ogbn-arxiv")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.766088,0.722601,0.710368,0.797682,0.724346,0.705384,0.854081,0.725098,0.708076
RandomLabels_25,0.570711,0.540428,0.530358,0.588729,0.532018,0.520799,0.628674,0.535427,0.520593
RandomLabels_50,0.383352,0.357884,0.349941,0.398597,0.349757,0.343329,0.43718,0.350683,0.34546
RandomLabels_75,0.201533,0.177382,0.173652,0.221972,0.165878,0.163759,0.28208,0.165086,0.162496
RandomLabels_100,0.086482,0.025806,0.025821,0.137538,0.025323,0.025287,0.270413,0.026209,0.025435


In [14]:
get_data_table("layer_test", "cora")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,1.0,0.7516,0.7602,1.0,0.72,0.7422,1.0,0.7344,0.7448
MultiLayer,0.142857,0.0788,0.081,0.994286,0.6548,0.683,0.99,0.6212,0.6356


In [15]:
get_data_table("layer_test", "flickr")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.603621,0.504823,0.509362,0.453203,0.438679,0.437288,0.964114,0.472589,0.470336
MultiLayer,0.345089,0.346791,0.346605,0.340101,0.337549,0.336199,0.461042,0.412917,0.412719


In [16]:
get_data_table("layer_test", "ogbn-arxiv")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Normal,0.766088,0.722601,0.710368,0.797682,0.724346,0.705384,0.854081,0.725098,0.708076
MultiLayer,0.044095,0.021799,0.019986,0.750236,0.706809,0.700545,0.762536,0.711453,0.700434
