In [None]:
import os
import pickle
import yaml

from IPython.display import display, Markdown
import numpy as np
import pandas as pd

In [None]:
def load_metrics(path, baseline=True):
    with open(path, 'rb') as fin:
        return pickle.load(fin)


def load_all_metrics(models):
    all_metrics = {ds: {} for ds in models['datasets']}
    
    # Baselines
    for ds in models['datasets']:
        all_metrics[ds].update({
            f'baseline/{variant}/{bl_cfg["name"]}': load_metrics(
                models['bl']['base_path']
                .replace('${DATASET}', ds)
                .replace('${NAME}', bl_cfg['name'])
                .replace('${VARIANT}', variant)
            )
            for bl_cfg in models['bl']['models']
            for variant in bl_cfg['variants']
        })
        
        all_metrics[ds].update({
            'baseline/line2vec': load_metrics(
                models['line2vec']['base_path']
                .replace('${DATASET}', ds)
            )
        })
        
    # MLP baselines
    for ds in models['datasets']:
        all_metrics[ds].update({
            f'{mlptype}/{name}': load_metrics(
                models[mlptype]['base_path']
                .replace('${DATASET}', ds)
                .replace('${NAME}', name)
            )
            for mlptype in ('mlp2', 'mlp3')
            for name in models[mlptype]['models']
        })
    
    # AttrE2vec models
    for ds in models['datasets']:
        all_metrics[ds].update({
            f'ae/{ae_name}': load_metrics(
                models['ae']['base_path']
                .replace('${DATASET}', ds)
                .replace('${NAME}', ae_name)
            )
            for ae_name in models['ae']['models']
        })
    
    return all_metrics


In [None]:
def summarize_supports(all_metrics):
    display(Markdown('# Supports'))
    
    for ds_name in all_metrics.keys():
        supports = []
        method_name = list(all_metrics[ds_name].keys())[0]
        for dss_metrics in all_metrics[ds_name][method_name]:
            for tt, cm in dss_metrics.items():
                for c, v in cm.items():
                    if c in ('accuracy', 'auc', 'cm') or 'avg' in c:
                        continue
                    supports.append((tt, c, v['support']))   
    
        df = pd.DataFrame.from_records(supports, columns=['train/val/test', 'class', 'support'])
        df = df.groupby(['train/val/test', 'class']).agg(['mean', 'std'])
        df['summary'] = df['support'].apply(lambda r: f'{int(r["mean"])} +/- {int(r["std"])}', axis=1)
        df = df.drop(columns='support').reset_index()

        df = df.pivot(columns='class', values='summary', index='train/val/test')
        df = df.reindex(index=['train', 'val', 'test'])
                                            
        display(Markdown(f'## {ds_name}'))
        display(df)

In [None]:
def highlight_max(s):
    '''
    highlight the maximum in a Series yellow.
    '''
    is_max = s == s.max()
    return ['background-color: yellow' if v else '' for v in is_max]

In [None]:
def print_auc_table(all_metrics):
    display(Markdown('# AUC'))
    for ds_name, ds_mtrs in all_metrics.items():
        display(Markdown(f'## {ds_name}'))
        vals = []

        for model_name, model_mtrs in ds_mtrs.items():
            for tt in ('train', 'val', 'test'):
                aucs = []
                for sample_mtrs in model_mtrs:
                    aucs.append(sample_mtrs[tt]['auc'])

                mean = np.round(np.mean(aucs) * 100.0, 2)
                std = np.round(np.std(aucs) * 100.0, 2)

                vals.append((model_name, tt, f'{mean} +/- {std}'))

        df = pd.DataFrame.from_records(vals, columns=['model', 'tt', 'value'])
        df = df.pivot(index='model', columns='tt')
        df = df.reindex(columns=['train', 'val', 'test'], level='tt')

        for group in ('baseline', 'ae'):
            _df = df.loc[df.index.str.startswith(group)]
            _df = _df.style.apply(highlight_max)
            display(_df)

In [None]:
from collections import defaultdict


def print_metric_table(all_metrics, metric_name):
    display(Markdown(f'# {metric_name}'))
    
    for ds_name, ds_mtrs in all_metrics.items():
        display(Markdown(f'# {ds_name}'))
        vals = []

        for model_name, model_mtrs in ds_mtrs.items():
            for tt in ('train', 'val', 'test'):
                mtrs = defaultdict(list)
                for sample_mtrs in model_mtrs:
                    for c in sample_mtrs[tt]:
                        if c in ('accuracy', 'auc', 'cm'):
                            continue

                        mtrs[c].append(sample_mtrs[tt][c][metric_name])

                for c, mtrs_vals in mtrs.items():
                    mean = np.round(np.mean(mtrs_vals) * 100.0, 2)
                    std = np.round(np.std(mtrs_vals) * 100.0, 2)

                    vals.append((model_name, tt, c, f'{mean} +/- {std}'))

        df = pd.DataFrame.from_records(vals, columns=['model', 'tt', 'class', 'value'])
        df = pd.pivot_table(df, index='model', columns=['tt', 'class'], values='value', aggfunc='first')
        df = df.reindex(columns=['train', 'val', 'test'], level='tt')

        for group in ('baseline', 'ae'):
            _df = df.loc[df.index.str.startswith(group)]
            _df = _df.style.apply(highlight_max)
            display(_df)

In [None]:
def print_article_table(all_metrics):
    def _agg_test_metric(sm, name):
        vals = [
            m['test']['macro avg']['f1-score'] if name == 'f1' else m['test'][name]
            for m in sm
        ]
        mean = np.round(np.mean(vals) * 100.0, 2)
        std = np.round(np.std(vals) * 100.0, 2)
        
        return f'{mean} +/- {std}'
    
    records = []
    for ds_name, ds_mtrs in all_metrics.items():
        for model_name, model_mtrs in ds_mtrs.items():
            for metric_name in ('auc',):#('auc', 'accuracy', 'f1'):
                metric_val = _agg_test_metric(model_mtrs, metric_name)
                records.append((ds_name, model_name, metric_name, metric_val))
                
    df = pd.DataFrame.from_records(records, columns=['dataset', 'model', 'metric', 'value'])
    df = pd.pivot_table(df, index='model', columns=['dataset', 'metric'], values='value', aggfunc='first')
#     with open('../../data/table.tex', 'w') as fout:
#         fout.write(df.to_latex())
    
    df = df.style.apply(highlight_max)
    display(Markdown('# Test metrics summary'))
    display(df)
    
    

In [None]:
models = {
    'datasets': ['cora', 'citeseer', 'pubmed'],
    'bl': {
        'base_path': '../../data/metrics/bl/${DATASET}/${NAME}/${VARIANT}.pkl',
        'models': [
            {'name': 'simple', 'variants': ['full',]},
            {'name': 'dw/nf', 'variants': ['full',]},
            {'name': 'dw/nfef', 'variants': ['full',]},
            {'name': 'n2v/nf', 'variants': ['full',]},
            {'name': 'n2v/nfef', 'variants': ['full',]},
            {'name': 'sdne/nf', 'variants': ['full',]},
            {'name': 'sdne/nfef', 'variants': ['full',]},
            {'name': 'struc2vec/nf', 'variants': ['full',]},
            {'name': 'struc2vec/nfef', 'variants': ['full',]},
            
            {'name': 'graphsage/nf', 'variants': ['full',]},
            {'name': 'graphsage/nfef', 'variants': ['full',]},
        ]
    },
    'mlp2': {
        'base_path': '../../data/metrics/mlp/${DATASET}/${NAME}/MLP2.pkl',
        'models': ['dw', 'graphsage', 'n2v', 'sdne', 'struc2vec'],
    },
    'mlp3': {
        'base_path': '../../data/metrics/mlp/${DATASET}/${NAME}/MLP3.pkl',
        'models': ['dw', 'graphsage', 'n2v', 'sdne', 'struc2vec'],
    },
    'line2vec': {
        'base_path': '../../data/metrics/bl/${DATASET}/line2vec.pkl',
    }, 
    'ae': {
        'base_path': '../../data/metrics/ae/${DATASET}/${NAME}.pkl',
        'models': [
            'AttrE2vec_Avg',
            'AttrE2vec_Exp',
            'AttrE2vec_ConcatGRU',
            'AttrE2vec_GRU',
        ],
    },
}

In [None]:
all_metrics = load_all_metrics(models)

print_article_table(all_metrics)

summarize_supports(all_metrics)
print_auc_table(all_metrics)
for metric in ('precision', 'recall', 'f1-score'):
    print_metric_table(all_metrics, metric)