In [1]:
import pandas as pd
import itertools
import os
import pandas.io.formats.style

In [2]:
pd.options.display.float_format = "{:,.3f}".format

### Globals, Helper Functions, etc

In [3]:
MEASURE_NAME_DICT = {
    "AlignedCosineSimilarity": "AlignCos",
    "CKA": "CKA",
    "ConcentricityDifference": "ConcDiff",
    "DistanceCorrelation": "DistCorr",
    "EigenspaceOverlapScore": "EOS",
    "GeometryScore": "GS",
    "Gulp": "GULP",
    "HardCorrelationMatch": "HardCorr",
    "IMDScore": "IMD",
    "JaccardSimilarity": "Jaccard",
    "LinearRegression": "LinReg",
    "MagnitudeDifference": "MagDiff",
    "OrthogonalAngularShapeMetricCentered": "ShapeMet",
    "OrthogonalProcrustesCenteredAndNormalized": "OrthProc",
    "PWCCA": "PWCCA",
    "PermutationProcrustes": "PermProc",
    "ProcrustesSizeAndShapeDistance": "ProcDist",
    "RSA": "RSA",
    "RSMNormDifference": "RSMDiff",
    "RankSimilarity": "RankSim",
    "SVCCA": "SVCCA",
    "SecondOrderCosineSimilarity": "2nd-Cos",
    "SoftCorrelationMatch": "SoftCorr",
    "UniformityDifference": "UnifDiff",
}

PIVOT_COL_DICT = {
    "cora": "Cora",
    "flickr": "Flickr",
    "ogbn-arxiv": "OGBN-Arxiv",
    "GraphSAGE": "SAGE",
    "violation_rate": "Violation Rate",
    "correlation": "Spearman Correlation"
}

COLUMN_NAME_DICT = {
    "similarity_measure": "Similarity Measure", 
    "quality_measure": "Measure",
    "functional_similarity_measure": "Measure",
    "architecture": "Model",
    "representation_dataset": "Dataset"
}

LATEX_FORMAT_DICT = {
    "hrules": True,
    "column_format": "l||rrr|rrr|rrr||rrr|rrr|rrr",
    "multicol_align": "c",
}

In [26]:
# EXPERIMENT_RESULTS_PATH = 'C:/Users/Tobias/Eigene Dokumente/Research/similaritybench/experiments/results/res_csvs/'
EXPERIMENT_RESULTS_PATH = os.path.join(os.path.dirname(os.path.dirname(os.getcwd())), "experiments", "results", "res_csvs")
COMPARISON_TYPE_DICT = {
    "label_test": "group_separation",
    "layer_test": "monotonicity",
    "augmentation_test": "group_separation",
    "shortcut_test": "group_separation"
}

EXPERIMENTS = list(COMPARISON_TYPE_DICT.keys())
DATASETS = ["cora", "flickr", "ogbn-arxiv"]
ARCHITECTURES = ["GCN", "GraphSAGE", "GAT"]

def FULL_DF_FILE_NAME(experiment, comparison_type, dataset, groups=5):
    if groups < 5:
        return f"{experiment}_{comparison_type}_{dataset}_{groups}groups_full.csv"
    return f"{experiment}_{comparison_type}_{dataset}_full.csv"

In [5]:
def get_pivot_table(experiment, dataset):
    path = os.path.join(EXPERIMENT_RESULTS_PATH, FULL_DF_FILE_NAME(experiment, COMPARISON_TYPE_DICT[experiment], dataset))
    
    df = pd.read_csv(path)
    data = df.loc[:, ["similarity_measure", "quality_measure", "value", "architecture", "representation_dataset"]]
    return data.pivot(index="similarity_measure", columns=["representation_dataset", "architecture", "quality_measure"], values="value")

In [18]:
# def get_agg_pivot_table(experiment, datasets, groups=5):
#     dfs = []
#     for dataset in datasets:
#         path = os.path.join(EXPERIMENT_RESULTS_PATH, FULL_DF_FILE_NAME(experiment, COMPARISON_TYPE_DICT[experiment], dataset, groups))
#         df = pd.read_csv(path)
#         data = df.loc[:, ["similarity_measure", "quality_measure", "value", "architecture", "representation_dataset"]].dropna()
   
#         data = data.rename(COLUMN_NAME_DICT, axis="columns")
#         dfs.append(data.iloc[:])
#     df_cc = pd.concat(dfs, axis=0)
#     df_res = df_cc.pivot_table(index="Similarity Measure", columns=["Measure", "Dataset", "Model"], values="value", aggfunc="mean")
#     df_res = df_res.rename(MEASURE_NAME_DICT, axis="index")
#     return df_res.rename(PIVOT_COL_DICT, axis="columns")

In [58]:
def get_agg_pivot_table(experiment, datasets, conformity=True):
    dfs = []
    full_csvs = os.listdir(EXPERIMENT_RESULTS_PATH)
    
    for dataset in datasets:
        for f in full_csvs:
            if experiment in f and dataset in f:
                df = pd.read_csv(os.path.join(EXPERIMENT_RESULTS_PATH, f), encoding="utf-8")
                data = df.loc[:, ["similarity_measure", "quality_measure", "value", "architecture", "representation_dataset"]].dropna()
                data = data.rename(COLUMN_NAME_DICT, axis="columns")
                dfs.append(data.iloc[:])
    df_cc = pd.concat(dfs, axis=0)
    df_res = df_cc.pivot_table(index="Similarity Measure", columns=["Measure", "Dataset", "Model"], values="value", aggfunc="mean")
    if conformity:
        df_res["violation_rate"] = 1 - df_res["violation_rate"]
        df_res = df_res.rename({"violation_rate": "Conformity Rate"}, axis=1)
    df_res = df_res.rename(MEASURE_NAME_DICT, axis="index")
    return df_res.rename(PIVOT_COL_DICT, axis="columns")

### Results for Group Separation Experiments

In [59]:
layer_cora_df = get_agg_pivot_table("layer_test", ["cora"])
styled = pd.io.formats.style.Styler(
    layer_cora_df,
    precision=2,
)

latex_str = styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("cora_results_layer_test.tex", **LATEX_FORMAT_DICT)
layer_cora_df

Measure,Spearman Correlation,Spearman Correlation,Spearman Correlation,Spearman Correlation,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate
Dataset,Cora,Cora,Cora,Cora,Cora,Cora,Cora,Cora
Model,GAT,GCN,SAGE,PGNN,GAT,GCN,SAGE,PGNN
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
AlignCos,1.0,0.98,1.0,0.313,1.0,0.97,1.0,0.88
CKA,1.0,0.98,1.0,1.0,1.0,0.98,1.0,1.0
ConcDiff,0.403,0.85,0.25,0.117,0.62,0.85,0.52,0.51
DistCorr,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
EOS,1.0,1.0,1.0,0.987,1.0,1.0,1.0,0.98
GULP,1.0,1.0,1.0,0.72,1.0,1.0,1.0,0.85
HardCorr,0.54,0.833,0.983,0.777,0.76,0.91,0.99,0.82
IMD,1.0,1.0,1.0,0.887,1.0,1.0,1.0,0.95
Jaccard,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
LinReg,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


#### Random Label Test Results

In [37]:
get_agg_pivot_table("label_test", ["cora", "flickr", "ogbn-arxiv"])

Measure,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate
Dataset,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv
Model,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
AlignCos,0.433,0.457,0.479,0.286,0.839,0.423,0.931,0.981,0.699,0.601,0.659,0.664,0.579,0.936,0.5,0.973,0.996,0.668
CKA,0.424,0.427,0.424,0.27,0.733,0.657,1.0,1.0,1.0,0.507,0.549,0.512,0.542,0.907,0.861,1.0,1.0,1.0
ConcDiff,0.394,0.203,0.391,0.208,0.366,0.57,0.997,0.964,1.0,0.739,0.545,0.781,0.481,0.723,0.848,0.999,0.992,1.0
DistCorr,0.424,0.431,0.444,0.217,0.86,0.427,1.0,1.0,1.0,0.521,0.59,0.631,0.518,0.948,0.558,1.0,1.0,1.0
EOS,0.284,0.295,0.259,0.219,0.413,0.424,0.34,0.253,0.426,0.658,0.649,0.551,0.581,0.5,0.516,0.506,0.498,0.542
GULP,0.285,0.276,0.26,0.227,0.192,0.424,0.303,0.302,0.426,0.659,0.631,0.556,0.571,0.503,0.512,0.565,0.491,0.551
HardCorr,0.424,0.425,0.424,0.332,0.771,0.457,0.544,0.828,0.828,0.508,0.527,0.511,0.667,0.941,0.679,0.766,0.967,0.967
IMD,0.815,0.737,0.975,0.227,0.234,0.335,0.882,1.0,1.0,0.957,0.929,0.993,0.537,0.467,0.693,0.973,1.0,1.0
Jaccard,0.427,0.425,0.424,0.292,0.564,0.43,0.778,0.827,0.432,0.553,0.527,0.508,0.568,0.771,0.579,0.929,0.967,0.533
LinReg,0.456,0.492,0.431,0.23,0.223,0.449,0.467,0.444,0.677,0.688,0.713,0.579,0.516,0.489,0.663,0.626,0.627,0.809


In [54]:
label_test_df = get_agg_pivot_table("label_test", ["cora", "flickr", "ogbn-arxiv"])
styled = pd.io.formats.style.Styler(
    label_test_df,
    precision=2,
)

latex_str = styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("graphs_results_label_test.tex", **LATEX_FORMAT_DICT)

#### Shortcut Test Results

In [56]:
get_agg_pivot_table("shortcut_test", ["cora", "flickr", "ogbn-arxiv"])

Measure,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate
Dataset,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv
Model,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
AlignCos,0.303,0.319,0.599,0.48,0.895,1.0,1.0,1.0,1.0,0.589,0.695,0.881,0.769,0.981,1.0,1.0,1.0,1.0
CKA,0.463,0.273,0.777,0.33,0.282,1.0,1.0,1.0,0.981,0.636,0.661,0.903,0.652,0.572,1.0,1.0,1.0,0.996
ConcDiff,0.204,0.182,0.169,0.324,0.182,0.176,1.0,0.815,0.962,0.498,0.468,0.427,0.607,0.497,0.463,1.0,0.963,0.991
DistCorr,0.459,0.284,0.82,0.317,0.333,1.0,1.0,1.0,0.994,0.622,0.688,0.916,0.718,0.687,1.0,1.0,1.0,0.999
EOS,0.163,0.217,0.495,0.544,0.461,0.433,0.968,1.0,0.724,0.425,0.574,0.74,0.809,0.872,0.6,0.993,1.0,0.833
GULP,0.172,0.197,0.496,0.451,0.196,0.433,0.965,0.986,0.724,0.439,0.525,0.739,0.811,0.492,0.6,0.991,0.997,0.833
HardCorr,0.219,0.2,0.347,0.519,0.549,1.0,0.835,0.798,0.724,0.581,0.559,0.768,0.808,0.829,1.0,0.969,0.961,0.833
IMD,0.601,0.732,0.793,0.358,0.759,0.967,0.923,0.609,0.92,0.88,0.928,0.957,0.622,0.943,0.993,0.977,0.837,0.979
Jaccard,0.38,0.305,0.782,0.545,1.0,0.828,1.0,1.0,1.0,0.714,0.697,0.929,0.84,1.0,0.963,1.0,1.0,1.0
LinReg,0.334,0.299,0.742,0.357,0.365,0.609,1.0,1.0,1.0,0.624,0.617,0.947,0.723,0.703,0.811,1.0,1.0,1.0


In [57]:
shortcut_test_df = get_agg_pivot_table("shortcut_test", ["cora", "flickr", "ogbn-arxiv"])
styled = pd.io.formats.style.Styler(
    shortcut_test_df,
    precision=2,
)

latex_str = styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("graphs_results_shortcut_test.tex", **LATEX_FORMAT_DICT)

#### Augmentation Test Results

In [58]:
get_agg_pivot_table("augmentation_test", ["cora", "flickr", "ogbn-arxiv"])

Measure,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate
Dataset,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv
Model,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
AlignCos,0.908,0.51,0.635,0.544,0.686,0.698,1.0,0.736,0.427,0.978,0.783,0.881,0.808,0.87,0.901,1.0,0.871,0.536
CKA,0.906,0.697,0.731,0.508,0.572,0.747,0.987,0.997,1.0,0.965,0.897,0.919,0.736,0.827,0.912,0.997,0.999,1.0
ConcDiff,0.623,0.174,0.504,0.534,0.412,0.348,0.508,0.426,0.533,0.867,0.492,0.781,0.742,0.779,0.681,0.84,0.76,0.804
DistCorr,0.891,0.785,0.732,0.609,0.602,0.794,0.989,1.0,1.0,0.965,0.929,0.912,0.812,0.838,0.941,0.997,1.0,1.0
EOS,0.342,0.535,0.618,0.459,0.679,0.529,1.0,0.819,0.488,0.752,0.857,0.896,0.703,0.875,0.797,1.0,0.965,0.75
GULP,0.349,0.448,0.612,0.555,0.223,0.542,1.0,0.557,0.484,0.763,0.79,0.891,0.839,0.523,0.809,1.0,0.805,0.745
HardCorr,0.513,0.512,0.633,0.571,0.712,0.716,0.536,0.466,0.506,0.805,0.849,0.895,0.851,0.894,0.916,0.792,0.587,0.768
IMD,0.238,0.711,0.347,0.281,0.59,0.586,1.0,1.0,1.0,0.47,0.929,0.64,0.531,0.873,0.866,1.0,1.0,1.0
Jaccard,0.953,0.781,0.807,0.58,0.743,0.88,1.0,1.0,0.991,0.984,0.904,0.946,0.841,0.885,0.973,1.0,1.0,0.998
LinReg,0.872,0.943,0.85,0.338,0.302,0.805,0.809,0.724,0.724,0.952,0.983,0.963,0.691,0.527,0.933,0.93,0.835,0.834


In [60]:
augment_test_df = get_agg_pivot_table("augmentation_test", ["cora", "flickr", "ogbn-arxiv"])
styled = pd.io.formats.style.Styler(
    augment_test_df,
    precision=2,
)

latex_str = styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("graphs_results_augmentation_test.tex", **LATEX_FORMAT_DICT)

#### Layer Monotonicity Test Results

In [61]:
layer_test_df = get_agg_pivot_table("layer_test", ["cora", "flickr", "ogbn-arxiv"])
styled = pd.io.formats.style.Styler(
    layer_test_df,
    precision=2,
)

latex_str = styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("graphs_results_layer_test.tex", **LATEX_FORMAT_DICT)

In [62]:
layer_test_df

Measure,Spearman Correlation,Spearman Correlation,Spearman Correlation,Spearman Correlation,Spearman Correlation,Spearman Correlation,Spearman Correlation,Spearman Correlation,Spearman Correlation,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate,Conformity Rate
Dataset,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv
Model,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
AlignCos,1.0,0.98,1.0,0.04,0.327,0.68,0.837,0.843,0.933,1.0,0.97,1.0,0.46,0.59,0.77,0.89,0.83,0.93
CKA,1.0,0.98,1.0,0.387,0.603,0.89,0.96,0.85,0.457,1.0,0.98,1.0,0.59,0.79,0.92,0.95,0.87,0.64
ConcDiff,0.403,0.85,0.25,0.243,0.383,-0.273,0.397,0.733,0.39,0.62,0.85,0.52,0.58,0.57,0.36,0.67,0.7,0.74
DistCorr,1.0,1.0,1.0,0.467,0.63,0.993,0.933,0.807,0.523,1.0,1.0,1.0,0.63,0.81,0.99,0.92,0.88,0.65
EOS,1.0,1.0,1.0,0.917,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.92,1.0,1.0,1.0,1.0,1.0
GULP,1.0,1.0,1.0,0.33,0.54,1.0,1.0,0.733,1.0,1.0,1.0,1.0,0.72,0.83,1.0,1.0,0.82,1.0
HardCorr,0.54,0.833,0.983,0.213,0.627,0.803,1.0,0.833,0.62,0.76,0.91,0.99,0.68,0.81,0.84,1.0,0.86,0.72
IMD,1.0,1.0,1.0,0.97,1.0,0.82,1.0,0.55,1.0,1.0,1.0,1.0,0.97,1.0,0.94,1.0,0.85,1.0
Jaccard,1.0,1.0,1.0,0.953,0.95,0.973,1.0,0.973,0.987,1.0,1.0,1.0,0.96,0.94,0.96,1.0,0.96,0.98
LinReg,1.0,1.0,1.0,0.03,0.45,1.0,1.0,0.993,1.0,1.0,1.0,1.0,0.59,0.67,1.0,1.0,0.99,1.0


### Results of Output Correlations

In [39]:
def get_output_correlation_table(datasets, corr_func="spearmanr", acc_test = False):
    
    full_csvs = os.listdir(EXPERIMENT_RESULTS_PATH)

    dfs = []
    for dataset in datasets:
        for f in full_csvs:
            if "output" in f and dataset in f:
                df = pd.read_csv(os.path.join(EXPERIMENT_RESULTS_PATH, f))
                df = df.loc[df.loc[:,"quality_measure"] == corr_func]
                data = df.loc[:, ["similarity_measure", "functional_similarity_measure", "corr", "architecture"]]
                data = data.rename(COLUMN_NAME_DICT, axis="columns")
                data["Dataset"] = dataset
                dfs.append(data.iloc[:])

    df_cc = pd.concat(dfs, axis=0)
    df_res = df_cc.pivot_table(index="Similarity Measure", columns=["Measure", "Dataset", "Model"], values="corr", aggfunc="mean")
    df_res = df_res.rename(MEASURE_NAME_DICT, axis="index")
    df_res = df_res.rename(PIVOT_COL_DICT, axis="columns")
    if acc_test: 
        return df_res.loc[:,"AbsoluteAccDiff"]
        
    return df_res.loc[:,["Disagreement", "JSD"]]

In [43]:
output_corr_df = get_output_correlation_table(["cora"])
styled = pd.io.formats.style.Styler(
    output_corr_df,
    precision=2,
)
styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("cora_results_output_correlation_pgnn.tex", **LATEX_FORMAT_DICT)
output_corr_df

Measure,Disagreement,Disagreement,Disagreement,Disagreement,JSD,JSD,JSD,JSD
Dataset,Cora,Cora,Cora,Cora,Cora,Cora,Cora,Cora
Model,GAT,GCN,SAGE,PGNN,GAT,GCN,SAGE,PGNN
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
AlignCos,-0.051,-0.084,0.271,0.206,0.174,0.017,0.384,-0.143
CKA,0.445,0.363,0.001,0.137,0.535,0.435,-0.031,-0.327
ConcDiff,0.03,-0.038,-0.097,-0.066,0.014,0.032,0.041,0.205
DistCorr,0.456,0.529,-0.078,0.284,0.601,0.597,0.053,-0.362
EOS,-0.088,0.127,0.023,0.294,-0.036,0.168,0.224,-0.225
GULP,-0.129,-0.165,-0.09,0.334,-0.042,-0.096,0.182,0.1
HardCorr,0.261,0.454,0.104,0.45,0.516,0.479,0.162,-0.174
IMD,-0.15,0.18,-0.103,-0.135,-0.086,0.213,-0.404,0.104
Jaccard,0.12,0.436,0.333,0.296,0.379,0.441,0.453,0.051
LinReg,-0.062,0.574,0.09,0.228,0.19,0.58,0.331,-0.267


Measure,Disagreement,Disagreement,Disagreement,Disagreement,JSD,JSD,JSD,JSD
Dataset,Cora,Cora,Cora,Cora,Cora,Cora,Cora,Cora
Model,GAT,GCN,SAGE,PGNN,GAT,GCN,SAGE,PGNN
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3
AlignCos,-0.051,-0.084,0.271,0.206,0.174,0.017,0.384,-0.143
CKA,0.445,0.363,0.001,0.137,0.535,0.435,-0.031,-0.327
ConcDiff,0.03,-0.038,-0.097,-0.066,0.014,0.032,0.041,0.205
DistCorr,0.456,0.529,-0.078,0.284,0.601,0.597,0.053,-0.362
EOS,-0.088,0.127,0.023,0.294,-0.036,0.168,0.224,-0.225
GULP,-0.129,-0.165,-0.09,0.334,-0.042,-0.096,0.182,0.1
HardCorr,0.261,0.454,0.104,0.45,0.516,0.479,0.162,-0.174
IMD,-0.15,0.18,-0.103,-0.135,-0.086,0.213,-0.404,0.104
Jaccard,0.12,0.436,0.333,0.296,0.379,0.441,0.453,0.051
LinReg,-0.062,0.574,0.09,0.228,0.19,0.58,0.331,-0.267


In [70]:
output_corr_df = get_output_correlation_table(["cora", "flickr", "ogbn-arxiv"])
styled = pd.io.formats.style.Styler(
    output_corr_df,
    precision=2,
)
styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("graphs_results_output_correlation.tex", **LATEX_FORMAT_DICT)
output_corr_df

Measure,Disagreement,Disagreement,Disagreement,Disagreement,Disagreement,Disagreement,Disagreement,Disagreement,Disagreement,JSD,JSD,JSD,JSD,JSD,JSD,JSD,JSD,JSD
Dataset,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv
Model,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3
AlignCos,-0.051,-0.084,0.271,-0.083,0.147,0.374,0.166,-0.098,0.001,0.174,0.017,0.384,-0.013,0.306,0.443,0.275,-0.026,0.05
CKA,0.445,0.363,0.001,0.061,-0.211,0.534,0.227,0.029,-0.041,0.535,0.435,-0.031,0.166,0.029,0.576,0.381,0.119,-0.021
ConcDiff,0.03,-0.038,-0.097,0.029,-0.214,-0.043,-0.16,-0.254,0.072,0.014,0.032,0.041,0.028,-0.172,-0.034,-0.215,-0.13,0.024
DistCorr,0.456,0.529,-0.078,-0.031,0.172,0.395,0.204,0.11,0.081,0.601,0.597,0.053,0.031,0.456,0.431,0.361,0.163,0.122
EOS,-0.088,0.127,0.023,0.231,0.025,0.334,0.119,-0.255,-0.023,-0.036,0.168,0.224,0.114,-0.027,0.385,0.369,-0.166,0.115
GULP,-0.129,-0.165,-0.09,-0.006,0.029,0.334,0.099,0.154,-0.037,-0.042,-0.096,0.182,0.128,0.149,0.383,0.353,0.098,0.106
HardCorr,0.261,0.454,0.104,-0.093,0.395,0.458,0.239,0.02,-0.241,0.516,0.479,0.162,0.095,0.534,0.505,0.46,-0.046,-0.276
IMD,-0.165,0.2,-0.107,-0.029,-0.103,0.344,0.002,-0.109,-0.02,-0.109,0.24,-0.396,0.033,-0.084,0.307,-0.043,-0.085,-0.046
Jaccard,0.12,0.436,0.333,-0.004,0.037,0.419,0.219,0.142,0.012,0.379,0.441,0.453,0.107,0.331,0.421,0.367,0.2,0.09
LinReg,-0.062,0.574,0.09,0.063,0.119,0.457,0.221,-0.099,-0.188,0.19,0.58,0.331,0.176,0.152,0.477,0.465,-0.089,-0.166


In [44]:
acc_corr_df = get_output_correlation_table(["cora"], acc_test=True)
styled = pd.io.formats.style.Styler(
    acc_corr_df,
    precision=2,
)
styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("cora_results_accuracy_correlation.tex", **LATEX_FORMAT_DICT)
acc_corr_df

Dataset,Cora,Cora,Cora,Cora
Model,GAT,GCN,SAGE,PGNN
Similarity Measure,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
AlignCos,-0.289,-0.332,0.132,0.05
CKA,0.023,0.163,-0.172,-0.009
ConcDiff,-0.197,0.154,-0.252,0.354
DistCorr,0.026,0.008,-0.175,0.153
EOS,-0.049,-0.236,0.08,0.272
GULP,-0.121,-0.435,0.081,0.276
HardCorr,-0.143,0.106,-0.109,0.322
IMD,-0.023,-0.114,-0.005,-0.157
Jaccard,-0.16,0.023,-0.111,0.079
LinReg,-0.109,0.058,-0.207,0.061


In [71]:
acc_corr_df = get_output_correlation_table(["cora", "flickr", "ogbn-arxiv"], acc_test=True)
styled = pd.io.formats.style.Styler(
    acc_corr_df,
    precision=2,
)
styled.highlight_max(axis=0, props="textbf:--rwrap;").to_latex("graphs_results_accuracy_correlation.tex", **LATEX_FORMAT_DICT)
acc_corr_df

Dataset,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv
Model,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE
Similarity Measure,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
AlignCos,-0.369,-0.188,0.154,0.088,0.173,0.388,-0.185,-0.069,0.201
CKA,-0.08,-0.31,-0.134,-0.095,-0.15,0.372,-0.129,0.04,0.056
ConcDiff,-0.098,-0.021,-0.213,0.078,-0.11,-0.06,-0.191,-0.247,-0.004
DistCorr,-0.112,-0.299,-0.126,-0.041,0.2,0.36,-0.127,0.138,0.01
EOS,-0.002,0.29,0.085,0.399,0.227,0.224,0.048,-0.05,0.01
GULP,-0.037,0.245,0.085,-0.087,0.041,0.215,0.052,-0.185,-0.013
HardCorr,-0.261,0.03,-0.116,-0.145,0.177,0.323,0.064,0.132,-0.04
IMD,-0.056,-0.118,0.038,-0.091,-0.096,0.37,-0.148,-0.223,-0.076
Jaccard,-0.226,-0.042,-0.1,-0.057,0.079,0.074,-0.145,-0.275,-0.132
LinReg,-0.205,0.009,-0.206,-0.115,-0.228,0.277,0.049,-0.229,-0.021


### CD Plots

In [21]:
# autorank package needed for creation of CD plots 
from autorank import autorank, plot_stats, create_report, latex_table
from autorank._util import *
from matplotlib import pyplot as plt

In [22]:
def get_figsize(columnwidth, wf=0.5, hf=(5. ** 0.5 - 1.0) / 2.0):
    """ Credit: https://stackoverflow.com/a/31527287
    Parameters:
      - wf [float]:  width fraction in columnwidth units
      - hf [float]:  height fraction in columnwidth units.
                     Set by default to golden ratio.
      - columnwidth [float]: width of the column in latex. Get this from LaTeX
                             using \showthe\columnwidth
    Returns:  [fig_width,fig_height]: that should be given to matplotlib
    """
    fig_width_pt = columnwidth * wf
    inches_per_pt = 1.0 / 72.27  # Convert pt to inch
    fig_width = fig_width_pt * inches_per_pt  # width in inches
    fig_height = fig_width * hf  # height in inches
    return fig_width, fig_height

PLOTS_BASE_WIDTH = 433.62  # pt
PLOTS_CD_WIDTH, PLOTS_CD_HEIGHT = get_figsize(PLOTS_BASE_WIDTH, wf=1.5)

In [23]:
def get_autorank_df(experiments=EXPERIMENTS, datasets=DATASETS, quality_measure = "violation_rate"):
    dfs = []
    for experiment in experiments:
        df = get_agg_pivot_table(experiment, datasets)        
        if experiment == "layer_test":
            dfs.append(df.loc[:,"Spearman Correlation"].dropna())
        else:
            dfs.append(df.loc[:,"AUPRC"].dropna())
    df_out = get_output_correlation_table(datasets).dropna()
    dfs.append(df_out.loc[:,"JSD"])
    df_acc = get_output_correlation_table(datasets, acc_test=True).dropna()
    dfs.append(df_acc)
    return pd.concat(dfs, axis=1)


In [24]:
get_autorank_df(datasets=["cora", "flickr", "ogbn-arxiv"])

Dataset,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv,Cora,...,OGBN-Arxiv,Cora,Cora,Cora,Flickr,Flickr,Flickr,OGBN-Arxiv,OGBN-Arxiv,OGBN-Arxiv
Model,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,...,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE,GAT,GCN,SAGE
Similarity Measure,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
AlignCos,0.278,0.289,0.287,0.183,0.513,0.275,0.459,0.467,0.464,1.0,...,0.05,-0.369,-0.188,0.154,0.088,0.173,0.388,-0.185,-0.069,0.201
CKA,0.274,0.273,0.274,0.141,0.445,0.415,0.728,0.846,0.8,1.0,...,-0.021,-0.08,-0.31,-0.134,-0.095,-0.15,0.372,-0.129,0.04,0.056
ConcDiff,0.212,0.131,0.198,0.156,0.218,0.235,0.505,0.518,0.543,0.403,...,0.024,-0.098,-0.021,-0.213,0.078,-0.11,-0.06,-0.191,-0.247,-0.004
DistCorr,0.274,0.28,0.287,0.116,0.491,0.308,0.722,0.804,0.822,1.0,...,0.122,-0.112,-0.299,-0.126,-0.041,0.2,0.36,-0.127,0.138,0.01
EOS,0.132,0.16,0.168,0.123,0.268,0.284,0.222,0.17,0.276,1.0,...,0.115,-0.002,0.29,0.085,0.399,0.227,0.224,0.048,-0.05,0.01
GULP,0.133,0.16,0.168,0.118,0.106,0.285,0.197,0.194,0.277,1.0,...,0.106,-0.037,0.245,0.085,-0.087,0.041,0.215,0.052,-0.185,-0.013
HardCorr,0.27,0.272,0.273,0.191,0.389,0.332,0.334,0.52,0.515,0.54,...,-0.276,-0.261,0.03,-0.116,-0.145,0.177,0.323,0.064,0.132,-0.04
Jaccard,0.274,0.273,0.273,0.131,0.326,0.296,0.385,0.565,0.283,1.0,...,0.09,-0.226,-0.042,-0.1,-0.057,0.079,0.074,-0.145,-0.275,-0.132
LinReg,0.286,0.282,0.284,0.125,0.13,0.326,0.297,0.279,0.367,1.0,...,-0.166,-0.205,0.009,-0.206,-0.115,-0.228,0.277,0.049,-0.229,-0.021
MagDiff,0.185,0.119,0.134,0.11,0.41,0.273,0.172,0.27,0.199,0.893,...,-0.127,-0.073,0.051,-0.106,0.101,0.178,-0.006,-0.101,-0.214,0.033


In [32]:
res_df = get_autorank_df().T.reset_index(drop=True)
result = autorank(res_df, alpha=0.05, verbose=False)
cd_diagram(result, False, None, PLOTS_CD_WIDTH)
fig = plt.gcf()
fig.set_size_inches(PLOTS_CD_WIDTH, PLOTS_CD_HEIGHT)
plt.savefig(f"cd_plot_graphs.png", bbox_inches="tight")
plt.close()

  if abs(sorted_ranks[i] - sorted_ranks[j]) <= critical_difference:
  plot_line([(rankpos(sorted_ranks[i]), cline),
  (rankpos(sorted_ranks[i]), chei),
  plot_line([(rankpos(sorted_ranks[i]), cline),
  (rankpos(sorted_ranks[i]), chei),
  plot_line([(rankpos(sorted_ranks[l]) - side, start),
  (rankpos(sorted_ranks[r]) + side, start)],


In [29]:
get_autorank_df().T.reset_index(drop=True).shape

(54, 20)