In [1]:
import sys; sys.path.append('../..')

In [2]:
import os
import pickle

from IPython.display import display, Markdown
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from attre2vec.evaluation import metrics as ae_mtr

In [3]:
def evaluate_edge_clustering(model_names, dataset_names, base_path):
    res = {}
    
    for dataset_name in tqdm(dataset_names, desc='Dataset'):
        dataset_path = os.path.join(base_path, f'data/datasets/{dataset_name}.pkl')

        with open(dataset_path, 'rb') as fin:
            ds = pickle.load(fin)


        for mn in tqdm(model_names, desc='Models', leave=False):
            accs = []
            for ds_idx in tqdm(range(ds['num_datasets']), desc='Dataset samples', leave=False):
                y = np.array([
                    *ds['Xy'][ds_idx]['train']['y'],
                    *ds['Xy'][ds_idx]['val']['y'],
                    *ds['Xy'][ds_idx]['test']['y'],
                ])

                vp = os.path.join(
                    base_path, 'data/vectors', dataset_name, mn, f'{ds_idx}.pkl'
                )
                with open(vp, 'rb') as fin:
                    embs = np.array(pickle.load(fin))

                accs.append(ae_mtr.acc(embs=embs, y_true=y))

            mean = np.round(np.mean(accs) * 100.0, 2)
            std = np.round(np.std(accs) * 100.0, 2)

            res[(dataset_name, mn)] = {'mean': mean, 'std': std}
    
    return res


In [4]:
datasets = ['cora', 'citeseer', 'pubmed']
    
model_names = [
    'BL_dw/nf/full', 'BL_dw/nfef/full',
    'BL_n2v/nf/full', 'BL_n2v/nfef/full',
    'BL_sdne/nf/full', 'BL_sdne/nfef/full',
    'BL_struc2vec/nf/full', 'BL_struc2vec/nfef/full',
    'BL_graphsage/nf/full', 'BL_graphsage/nfef/full',
    
    'BL_simple/full',
    
    'BL_line2vec',
    
    'AttrE2vec_Avg', 'AttrE2vec_Exp', 'AttrE2vec_GRU', 'AttrE2vec_ConcatGRU',
    
    'MLP_dw/MLP2', 'MLP_n2v/MLP2', 'MLP_sdne/MLP2', 'MLP_struc2vec/MLP2', 'MLP_graphsage/MLP2', 
    'MLP_dw/MLP3', 'MLP_n2v/MLP3', 'MLP_sdne/MLP3', 'MLP_struc2vec/MLP3', 'MLP_graphsage/MLP3', 
]

In [5]:
res = evaluate_edge_clustering(
    model_names=model_names, 
    dataset_names=datasets,
    base_path='../../'
)

HBox(children=(FloatProgress(value=0.0, description='Dataset', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Models', max=26.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Models', max=26.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Models', max=26.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='Dataset samples', max=10.0, style=ProgressStyle(descripti…




In [6]:
df = pd.DataFrame.from_records(
    [(d, m, f'{v["mean"]} +/- {v["std"]}') for (d, m), v in res.items()],
    columns=['dataset', 'model', 'accuracy'],
)
     
     
df = df.pivot(index='model', columns='dataset')
df

Unnamed: 0_level_0,accuracy,accuracy,accuracy
dataset,citeseer,cora,pubmed
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
AttrE2vec_Avg,59.82 +/- 3.3,65.42 +/- 1.71,48.86 +/- 2.46
AttrE2vec_ConcatGRU,60.71 +/- 2.75,66.0 +/- 2.21,50.27 +/- 3.75
AttrE2vec_Exp,59.07 +/- 4.65,66.36 +/- 3.62,48.02 +/- 2.55
AttrE2vec_GRU,60.16 +/- 2.25,66.15 +/- 3.71,49.41 +/- 1.49
BL_dw/nf/full,28.89 +/- 1.06,21.93 +/- 0.86,27.24 +/- 0.5
BL_dw/nfef/full,54.13 +/- 2.73,54.7 +/- 5.85,46.33 +/- 1.53
BL_graphsage/nf/full,18.79 +/- 0.62,17.7 +/- 1.05,27.04 +/- 0.71
BL_graphsage/nfef/full,54.06 +/- 2.54,54.82 +/- 6.86,46.49 +/- 1.64
BL_line2vec,54.73 +/- 2.56,63.5 +/- 1.92,55.26 +/- 1.36
BL_n2v/nf/full,26.82 +/- 0.67,21.32 +/- 0.62,27.17 +/- 0.74


In [7]:
with open('../../data/paper/table-edge-clustering-acc.tex', 'w') as fout:
    fout.write(df.to_latex())