In [5]:
import pandas as pd
import itertools
import os

In [73]:
pd.options.display.float_format = "{:,.3f}".format

In [6]:
EXPERIMENT_RESULTS_PATH = 'C:/Users/Tobias/Eigene Dokumente/Research/similaritybench/experiments/paper_results'
COMPARISON_TYPE_DICT = {
    "label_test": "group_separation",
    "layer_test": "monotonicity",
    "augmentation_test": "group_separation",
    "shortcut_test": "group_separation"
}

EXPERIMENTS = list(COMPARISON_TYPE_DICT.keys())
DATASETS = ["cora", "flickr", "ogbn-arxiv"]
ARCHITECTURES = ["GCN", "GraphSAGE", "GAT"]

def FULL_DF_FILE_NAME(experiment, comparison_type, dataset, groups=5):
    if groups < 5:
        return f"{experiment}_{comparison_type}_{dataset}_{groups}groups_full.csv"
    return f"{experiment}_{comparison_type}_{dataset}_full.csv"

In [7]:
MEASURE_NAME_DICT = {
    "AlignedCosineSimilarity": "AlignCos",
    "CKA": "CKA",
    "ConcentricityDifference": "ConcDiff",
    "DistanceCorrelation": "DistCorr",
    "EigenspaceOverlapScore": "EOS",
    "GeometryScore": "GS",
    "Gulp": "GULP",
    "HardCorrelationMatch": "HardCorr",
    "IMDScore": "IMD",
    "JaccardSimilarity": "Jaccard",
    "LinearRegression": "LinReg",
    "MagnitudeDifference": "MagDiff",
    "OrthogonalAngularShapeMetricCentered": "ShapeMet",
    "OrthogonalProcrustesCenteredAndNormalized": "OrthProc",
    "PWCCA": "PWCCA",
    "PermutationProcrustes": "PermProc",
    "ProcrustesSizeAndShapeDistance": "ProcDist",
    "RSA": "RSA",
    "RSMNormDifference": "RSMDiff",
    "RankSimilarity": "RankSim",
    "SVCCA": "SVCCA",
    "SecondOrderCosineSimilarity": "2nd-Cos",
    "SoftCorrelationMatch": "SoftCorr",
    "UniformityDifference": "UnifDiff",
}

PIVOT_COL_DICT = {
    "cora": "Cora",
    "flickr": "Flickr",
    "ogbn-arxiv": "OGBN-Arxiv",
    "GraphSAGE": "SAGE",
    "violation_rate": "Violation Rate",
    "correlation": "Spearman Correlation"
}

COLUMN_NAME_DICT = {
    "similarity_measure": "Similarity Measure", 
    "quality_measure": "Measure",
    "functional_similarity_measure": "Measure",
    "architecture": "Model",
    "representation_dataset": "Dataset"
}

LATEX_FORMAT_DICT = {
    "float_format": "%.2f", 
    "column_format": "l||rrr|rrr|rrr||rrr|rrr|rrr",
    "multicolumn_format": "c",
    "index_names": False,
}

In [8]:
def get_agg_pivot_table(experiment, datasets, groups=5):
    dfs = []
    for dataset in datasets:
        path = os.path.join(EXPERIMENT_RESULTS_PATH, FULL_DF_FILE_NAME(experiment, COMPARISON_TYPE_DICT[experiment], dataset, groups))
        df = pd.read_csv(path)
        data = df.loc[:, ["similarity_measure", "quality_measure", "value", "architecture", "representation_dataset"]].dropna()
   
        data = data.rename(COLUMN_NAME_DICT, axis="columns")
        dfs.append(data.iloc[:])
    df_cc = pd.concat(dfs, axis=0)
    df_res = df_cc.pivot_table(index="Similarity Measure", columns=["Measure", "Dataset", "Model"], values="value", aggfunc="mean")
    df_res = df_res.rename(MEASURE_NAME_DICT, axis="index")
    return df_res.rename(PIVOT_COL_DICT, axis="columns")

In [22]:
def get_dataset(filename, domain):
    if domain == "nlp":
        return "MNLI" if "mnli" in filename else "SST2"
    else:
        for dataset in DATASETS:
            if dataset in filename:
                return dataset

    return "FUCKDISSHIT"

def get_measure(df):
    

In [69]:
def get_autorank_df(domain="graphs", quality_measure = "violation_rate"):
    dfs = []

    res_path = os.path.join(EXPERIMENT_RESULTS_PATH, domain)
    for fname in os.listdir(res_path):
        fpath = os.path.join(res_path, fname)
        df = pd.read_csv(fpath)
        if "correlation" in fname:
            data = df.loc[:, ["similarity_measure", "functional_similarity_measure", "quality_measure", "corr", "architecture"]].dropna()
            data = data.loc[data.quality_measure=="spearmanr"]
            data = data.drop(columns=["quality_measure"])
            data = data.loc[data.functional_similarity_measure.isin(["JSD","AbsoluteAccDiff"])]
            
            data = data.rename(COLUMN_NAME_DICT, axis="columns")
            data["Dataset"] = get_dataset(fname, domain)
            # print(data)
            df_piv = data.pivot_table(index="Similarity Measure", columns=["Measure", "Dataset", "Model"], values="corr", aggfunc="mean")
            df_piv = df_piv.rename(MEASURE_NAME_DICT, axis="index")
        else:
            data = df.loc[:, ["similarity_measure", "quality_measure", "value", "architecture", "representation_dataset"]].dropna()
            data = data.loc[data.quality_measure.isin(["correlation","AUPRC"])]
            data = data.rename(COLUMN_NAME_DICT, axis="columns").dropna()
            data["Dataset"] = get_dataset(fname, domain)
            # print(data)
            df_piv = data.pivot_table(index="Similarity Measure", columns=["Measure", "Dataset", "Model"], values="value", aggfunc="mean")
            df_piv = df_piv.rename(MEASURE_NAME_DICT, axis="index")
            
        dfs.append(df_piv.iloc[:])
   
    return pd.concat(dfs, axis=1).dropna()


In [108]:
get_autorank_df(domain="graph").rank(axis=1)


Measure,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC,AUPRC
Dataset,cora,cora,cora,flickr,flickr,flickr,cora,cora,cora,flickr,...,cora,flickr,flickr,flickr,ogbn-arxiv,ogbn-arxiv,ogbn-arxiv,ogbn-arxiv,ogbn-arxiv,ogbn-arxiv
Model,GAT,GCN,GraphSAGE,GAT,GCN,GraphSAGE,GAT,GCN,GraphSAGE,GAT,...,GraphSAGE,GAT,GCN,GraphSAGE,GAT,GCN,GraphSAGE,GAT,GCN,GraphSAGE
Similarity Measure,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3


In [114]:
get_autorank_df(domain="nlp").T.reset_index(drop=True).rank(axis=1, method='max')

Similarity Measure,AlignCos,CKA,ConcDiff,DistCorr,EOS,GULP,HardCorr,IMD,Jaccard,LinReg,...,OrthProc,PWCCA,PermProc,ProcDist,RSA,RankSim,SVCCA,2nd-Cos,SoftCorr,UnifDiff
0,13.0,20.0,9.0,19.0,7.0,6.0,5.0,16.0,12.0,4.0,...,15.0,3.0,2.0,10.0,21.0,11.0,17.0,18.0,8.0,22.0
1,16.0,20.0,5.0,22.0,2.0,7.0,8.0,1.0,11.0,12.0,...,18.0,9.0,4.0,15.0,21.0,10.0,19.0,13.0,14.0,6.0
2,18.0,15.0,6.0,14.0,9.0,5.0,10.0,1.0,3.0,16.0,...,20.0,8.0,11.0,21.0,7.0,4.0,22.0,2.0,12.0,13.0
3,20.0,17.0,7.0,21.0,19.0,18.0,2.0,1.0,10.0,16.0,...,14.0,12.0,4.0,8.0,15.0,9.0,22.0,11.0,3.0,5.0
4,16.0,14.0,18.0,11.0,2.0,3.0,4.0,15.0,20.0,1.0,...,9.0,5.0,12.0,17.0,7.0,19.0,13.0,21.0,6.0,22.0
5,18.0,16.0,9.0,20.0,3.0,1.0,4.0,10.0,21.0,6.0,...,12.0,2.0,7.0,17.0,13.0,22.0,15.0,19.0,5.0,14.0
6,22.0,15.0,22.0,14.0,7.0,6.0,13.0,12.0,10.0,5.0,...,19.0,9.0,3.0,20.0,4.0,8.0,11.0,2.0,16.0,17.0
7,22.0,16.0,21.0,17.0,5.0,2.0,9.0,11.0,4.0,8.0,...,15.0,1.0,10.0,20.0,12.0,3.0,19.0,6.0,13.0,18.0
8,19.0,17.0,14.0,12.0,15.0,2.0,20.0,1.0,11.0,4.0,...,13.0,22.0,5.0,9.0,21.0,7.0,8.0,10.0,16.0,6.0
9,19.0,13.0,15.0,17.0,9.0,1.0,20.0,3.0,12.0,8.0,...,18.0,22.0,5.0,7.0,14.0,11.0,4.0,16.0,10.0,6.0


In [83]:
# autorank package needed for creation of CD plots 
from autorank import autorank, plot_stats, create_report, latex_table
from autorank._util import *
from matplotlib import pyplot as plt
from scipy import stats

In [91]:
def get_figsize(columnwidth, wf=0.5, hf=(5. ** 0.5 - 1.0) / 2.0):
    """ Credit: https://stackoverflow.com/a/31527287
    Parameters:
      - wf [float]:  width fraction in columnwidth units
      - hf [float]:  height fraction in columnwidth units.
                     Set by default to golden ratio.
      - columnwidth [float]: width of the column in latex. Get this from LaTeX
                             using \showthe\columnwidth
    Returns:  [fig_width,fig_height]: that should be given to matplotlib
    """
    fig_width_pt = columnwidth * wf
    inches_per_pt = 1.0 / 72.27  # Convert pt to inch
    fig_width = fig_width_pt * inches_per_pt  # width in inches
    fig_height = fig_width * hf  # height in inches
    return fig_width, fig_height

PLOTS_BASE_WIDTH = 433.62  # pt
PLOTS_CD_WIDTH, PLOTS_CD_HEIGHT = get_figsize(PLOTS_BASE_WIDTH, wf=1.5)

In [104]:
def build_cd_plot(domain):
    res_df = get_autorank_df(domain).T.reset_index(drop=True)
    result = autorank(res_df, alpha=0.05, verbose=False, force_mode="nonparametric")
    cd_diagram(result, False, None, PLOTS_CD_WIDTH)
    fig = plt.gcf()
    fig.set_size_inches(PLOTS_CD_WIDTH, PLOTS_CD_HEIGHT)
    plt.savefig(f"cd_plot_{domain}.png", bbox_inches="tight")
    plt.close()

In [106]:
build_cd_plot("graph")

Tests for normality and homoscedacity are ignored for test selection, forcing nonparametric tests


  if abs(sorted_ranks[i] - sorted_ranks[j]) <= critical_difference:
  plot_line([(rankpos(sorted_ranks[i]), cline),
  (rankpos(sorted_ranks[i]), chei),
  plot_line([(rankpos(sorted_ranks[i]), cline),
  (rankpos(sorted_ranks[i]), chei),
  plot_line([(rankpos(sorted_ranks[l]) - side, start),
  (rankpos(sorted_ranks[r]) + side, start)],
