In [77]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import itertools
from scipy.stats import pearsonr, spearmanr
from IPython.display import display

In [10]:
EXPERIMENT_DICT = {"augmentation_test":     ["Normal", "Augmentation_25", "Augmentation_50", "Augmentation_75", "Augmentation_100"],
                   "shortcut_test": ["Shortcut_0", "Shortcut_25", "Shortcut_50", "Shortcut_75", "Shortcut_100"]
}

In [78]:
def agg_train_results(exp, dataset, architecture):
    res_path = os.path.join("C:/Users/Tobias/Eigene Dokumente/Research/similaritybench/experiments/models/graphs", dataset, architecture)
    res_list = []
    for setting in EXPERIMENT_DICT[exp]:
        setting_path = os.path.join(res_path, setting)
        for fname in os.listdir(setting_path):
            setting_results = np.zeros(3)
            n_runs = 0
            if fname.startswith("train_results"):
                curr_df = pd.read_csv(os.path.join(setting_path,fname)).iloc[:,-3:]
                curr_results = np.array(curr_df.iloc[-1])
                setting_results += curr_results
                n_runs+=1
        setting_results /= n_runs
        res_list.append([setting, architecture] + list(setting_results))
    df_res = pd.DataFrame(data=res_list, columns = ["Setting", "Architecture", "Train", "Val", "Test"])
    return df_res

In [87]:
def get_data_table(experiment, dataset):
    data_list = []
    for arch in ["GCN", "GraphSAGE", "GAT"]:
        data_list.append(agg_train_results("shortcut_test", dataset, arch))
    res_df = pd.concat(data_list)
    res_df = res_df.pivot(index="Setting", columns=["Architecture"])
    return res_df.reindex(EXPERIMENT_DICT["shortcut_test"])[["Train", "Val", "Test"]].swaplevel(axis=1).sort_index(axis=1, level=0).reindex(columns=["Train", "Val", "Test"], level=1)

In [88]:
get_data_table("shortcut_test", "cora")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,1.0,0.73,0.749,1.0,0.704,0.722,0.914286,0.552,0.563
Shortcut_25,1.0,0.748,0.767,1.0,0.732,0.765,0.935714,0.494,0.524
Shortcut_50,1.0,0.722,0.74,1.0,0.716,0.738,1.0,0.678,0.669
Shortcut_75,1.0,0.744,0.747,1.0,0.716,0.728,1.0,0.636,0.631
Shortcut_100,1.0,0.754,0.761,1.0,0.706,0.703,1.0,0.774,0.662


In [89]:
get_data_table("shortcut_test", "flickr")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,0.526454,0.485344,0.485367,0.431933,0.433802,0.4332,0.968493,0.466879,0.470981
Shortcut_25,0.577098,0.514208,0.515977,0.457793,0.453343,0.452113,0.961748,0.479742,0.465648
Shortcut_50,0.589221,0.468268,0.466365,0.508885,0.484269,0.469547,0.957916,0.530611,0.427105
Shortcut_75,0.553972,0.431965,0.428943,0.502521,0.498566,0.492269,0.987877,0.651757,0.339309
Shortcut_100,0.62707,0.536931,0.449917,0.509154,0.488795,0.423027,0.999574,0.969702,0.050733


In [90]:
get_data_table("shortcut_test", "ogbn-arxiv")

Architecture,GAT,GAT,GAT,GCN,GCN,GCN,GraphSAGE,GraphSAGE,GraphSAGE
Unnamed: 0_level_1,Train,Val,Test,Train,Val,Test,Train,Val,Test
Setting,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Shortcut_0,0.757128,0.718145,0.713289,0.790216,0.718816,0.699134,0.837829,0.723078,0.712302
Shortcut_25,0.763231,0.722373,0.710121,0.789446,0.72687,0.715038,0.846989,0.730159,0.708763
Shortcut_50,0.769246,0.716568,0.682201,0.796253,0.721568,0.686275,0.875678,0.774556,0.65749
Shortcut_75,0.784531,0.744052,0.707467,0.801509,0.733112,0.686233,0.922884,0.85325,0.614057
Shortcut_100,0.823523,0.785127,0.691501,0.822269,0.757072,0.658972,0.997009,0.981878,0.094562
