In [1]:
import numpy as np
from texttable import Texttable
import latextable

In [2]:
def print_overall_performance_mean_std(title:str, results:np.array, compare_names_all:list,
                               dataset_names:list, print_latex:bool=True, print_std:bool=True):
    r"""Prints performance table (and possibly with latex) with mean and standard deviations.
        The best two performing methods are highlighted in \red and \blue respectively.

    Args:
        dataset: (string) Name of the data set considered.
        results: (np.array) Results with shape (num_trials, num_methods, num_metrics).
        compare_names_all: (list of strings, optional) Methods names to compare.
        metric_names: (list of strings, optional) Metrics to use (deemed better with larger values).
        print_latex: (bool, optional) Whether to print latex table also. Default True.
        print_std: (bool, optinoal) Whether to print standard deviations or just mean. Default False.
    """
    t = Texttable(max_width=120)
    t.set_deco(Texttable.HEADER)
    final_res_show = np.chararray(
        [len(dataset_names)+1, len(compare_names_all)+1], itemsize=100)
    final_res_show[0, 0] = title+'Data/Method'
    final_res_show[0, 1:] = compare_names_all
    final_res_show[1:, 0] = dataset_names
    std = np.chararray(
        [len(dataset_names), len(compare_names_all)], itemsize=20)
    results_std = np.transpose(np.round(np.nanstd(results,0),3))
    results_mean = np.transpose(np.round(np.nanmean(results,0),3))
    for i in range(results_mean.shape[0]):
        for j in range(results_mean.shape[1]):
            final_res_show[1+i, 1+j] = '{:.3f}'.format(results_mean[i, j])
            std[i, j] = '{:.3f}'.format(1.0*results_std[i, j])
    if print_std:
        plus_minus = np.chararray(
            [len(dataset_names), len(compare_names_all)], itemsize=20)
        plus_minus[:] = '$\pm$'
        final_res_show[1:, 1:] = final_res_show[1:, 1:] + plus_minus + std
    if len(compare_names_all)>1:
        red_start = np.chararray([1], itemsize=20)
        blue_start = np.chararray([1], itemsize=20)
        both_end = np.chararray([1], itemsize=20)
        red_start[:] = '\\red{'
        blue_start[:] = '\\blue{'
        both_end[:] = '}'
        for i in range(results_mean.shape[0]):
            best_values = np.sort(results_mean[i])[:2] # the smaller, the better
            final_res_show[i+1, 1:][results_mean[i]==best_values[0]] = red_start + final_res_show[i+1, 1:][results_mean[i]==best_values[0]] + both_end
            if best_values[0] != best_values[1]:
                final_res_show[i+1, 1:][results_mean[i]==best_values[1]] = blue_start + final_res_show[i+1, 1:][results_mean[i]==best_values[1]] + both_end

    t.add_rows(final_res_show)
    print(t.draw())
    if print_latex:
        print(latextable.draw_latex(t, caption=title +
                                    " performance.", label="table:"+title) + "\n")

In [12]:
baselines = ['spectral', 'row_norm_spectral', 'GPM', 'TranSync', 'CEMP_GCW', 'CEMP_MST', 'TAS']
results_to_print = np.zeros((24, 10, len(baselines)+1))
dataset_name_print = []
dataset_ind = 0
for eta in [0, 0.05, 0.1, 0.15, 0.2, 0.25]:
    for outlier_ind, outlier_style in enumerate(['gamma', 'multi_normal0', 'multi_normal1', 'block_normal6']):
        dataset_name_print.append('{}&{}'.format(eta, outlier_ind+1))
        try:
            results_to_print[dataset_ind, :, 0] = np.load('../result_arrays/uscities/100eta'+str(int(100*eta))+outlier_style+'/upset/GNNSync/dropout50upset_coe100cycle_coe0spectral_step_num5alpha100train_alphaFalsehid8lr5userow_norm_spectralSGDtrials2seeds2_9_11_20_40.npy')[0,:,2]
        except FileNotFoundError:
            results_to_print[dataset_ind, :, 0] = np.nan
        for baseline_ind, baseline in enumerate(baselines):
            try:
                results_to_print[dataset_ind, :, baseline_ind+1] = np.load('../result_arrays/uscities/100eta'+str(int(100*eta))+outlier_style+'/upset/'+baseline+'/trials2seeds2_9_11_20_40.npy')[0, :,2]
            except FileNotFoundError:
                results_to_print[dataset_ind, :, baseline_ind+1] = np.nan
        dataset_ind += 1
print_overall_performance_mean_std('ANE', (results_to_print.swapaxes(0, 1)).swapaxes(1, 2),
                                    ['GNNSync']+baselines, dataset_name_print, True)
# MSE
baselines = ['spectral', 'row_norm_spectral', 'GPM', 'TranSync', 'CEMP_GCW', 'CEMP_MST', 'TAS']
results_to_print = np.zeros((24, 10, len(baselines)+1))
dataset_name_print = []
dataset_ind = 0
for eta in [0, 0.05, 0.1, 0.15, 0.2, 0.25]:
    for outlier_ind, outlier_style in enumerate(['gamma', 'multi_normal0', 'multi_normal1', 'block_normal6']):
        dataset_name_print.append('{}&{}'.format(eta, outlier_ind+1))
        try:
            results_to_print[dataset_ind, :, 0] = np.load('../result_arrays/uscities/100eta'+str(int(100*eta))+outlier_style+'/MSE/GNNSync/dropout50upset_coe100cycle_coe0spectral_step_num5alpha100train_alphaFalsehid8lr5userow_norm_spectralSGDtrials2seeds2_9_11_20_40.npy')[0,:,2, 0]
        except FileNotFoundError:
            results_to_print[dataset_ind, :, 0] = np.nan
        for baseline_ind, baseline in enumerate(baselines):
            try:
                results_to_print[dataset_ind, :, baseline_ind+1] = np.load('../result_arrays/uscities/100eta'+str(int(100*eta))+outlier_style+'/MSE/'+baseline+'/trials2seeds2_9_11_20_40.npy')[0, :,2, 0]
            except FileNotFoundError:
                results_to_print[dataset_ind, :, baseline_ind+1] = np.nan
        dataset_ind += 1
print_overall_performance_mean_std('MSE', (results_to_print.swapaxes(0, 1)).swapaxes(1, 2),
                                    ['GNNSync']+baselines, dataset_name_print, True)

ANEData/Met     GNNSync      spectral     row_norm_sp       GPM        TranSync      CEMP_GCW     CEMP_MST       TAS    
    hod                                     ectral                                                                      
0&1           0.075$\pm$0   \red{0.000$   \red{0.000$   \red{0.000$   \red{0.000$   \red{0.000   \red{0.000   0.740$\pm$
              .028          \pm$0.000}    \pm$0.000}    \pm$0.000}    \pm$0.000}    $\pm$0.000   $\pm$0.000   0.021     
                                                                                    }            }                      
0&2           0.047$\pm$0   \red{0.000$   \red{0.000$   \red{0.000$   \red{0.000$   \red{0.000   \red{0.000   0.425$\pm$
              .010          \pm$0.000}    \pm$0.000}    \pm$0.000}    \pm$0.000}    $\pm$0.000   $\pm$0.000   0.015     
                                                                                    }            }                      
0&3           0.011$\pm$0   \red

