In [None]:
from collections import OrderedDict as OD
import logging
from matplotlib import pyplot as plt
import numpy as np
from pathlib import Path
import sys

import jsonpickle
import jsonpickle.ext.numpy as jsonpickle_numpy

In [None]:
# set data to use
main_folder = '/Users/mixheikk/Documents/git/DP-PVI/pytorch-code-results/mimic_bal_trade_off_plotting_10clients_5seeds_runs/'
runs_to_plot = np.linspace(1,18,18,dtype=int)

In [None]:
# define params used in the runs
all_eps_sigma = np.asarray([(np.inf,np.inf,4.),(0., 0.,34.1849) ])
all_q = np.asarray([5e-3,1e-2,5e-2,.1,.5,1.])
all_steps = np.asarray([50])
all_C = np.asarray([1.,1000.])

nondp_C = 100. # C at least this big with dp_sigma=0 considered to be nonDP

restrictions = OD()
restrictions['dp_sigma'] = [0.]#,34.1849]
restrictions['dp_C'] = [1000.]
restrictions['n_global_updates'] = [5]
restrictions['n_steps'] = None#[100]
restrictions['batch_size'] = None#[5]
restrictions['sampling_frac_q'] = [1.,1e-1,1e-2,5e-3]
restrictions['pseudo_client_q'] = None#[.1]
restrictions['learning_rate'] = None#[5e-3]
restrictions['damping_factor'] = None#[.4]
restrictions['init_var'] = None#[1e-3]
restrictions['dp_mode'] = None#['nondp_batches']
restrictions['pre_clip_sigma'] = None#[50.]

# possible balance settings: (0,0), (.7,-3), (.75,.95)
restrictions['data_bal_rho'] = [.0]
restrictions['data_bal_kappa'] = [.0]

dataset_name = 'mimic3_bal'

In [None]:
# set colors

# default color cycle
import pylab

colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] # note: stasndard color cycler has 10 colors
#cm = plt.get_cmap('viridis')
#colors = (cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS))

#print(colors, len(colors))
#sys.exit()

In [None]:
# read in data and do initial formatting

to_plot = OD()

all_res = OD()
all_res['config'] = OD()
all_res['client_train_res'] = OD()
all_res['train_res'] = OD()
all_res['validation_res'] = OD()

all_baselines = OD()
all_baselines['config'] = OD()
all_baselines['client_train_res'] = OD()
all_baselines['train_res'] = OD()
all_baselines['validation_res'] = OD()

baseline_folders = []
baseline_runs_to_plot = []
baseline_names = []

jsonpickle_numpy.register_handlers()
failed_runs = []

for i_run in runs_to_plot:

    run_id = str(i_run)
    print(f'run {run_id}')
    filename = main_folder + run_id + '/config.json'
    #print(f'trying {filename}')
    tmp = read_config(filename, failed_runs)
    if i_run in failed_runs:
        continue

    all_res['config'][run_id] = tmp[0]

    # try opening sacred records, if missing open manual bck instead
    filename = main_folder + run_id + '/info.json'
    filename_bck = main_folder + run_id + '/info_bck.json'
    apu = read_results(filename, filename_bck)
    
    # format results for plotting
    client_measures = ['elbo','kl','logl']
    format_results(apu, run_id, client_measures, all_res)

if len(failed_runs) > 0:
    print(f'failed runs:\n{failed_runs}')
    runs_to_plot = list(runs_to_plot)
    for i_run in failed_runs:
        runs_to_plot.remove(i_run)
    runs_to_plot = np.array(runs_to_plot)

# read baselines
if len(baseline_folders) > 0:
    running_id = 0
    for folder, baseline_name, baseline_ids in zip(baseline_folders, baseline_names, baseline_runs_to_plot):
        for i_run in baseline_ids:
            run_id = str(running_id)
            print(f'baseline run {run_id}')
            filename = folder + str(i_run) + '/config.json'
            #print(f'trying {filename}')
            tmp = read_config(filename, failed_runs)
            all_baselines['config'][run_id] = tmp[0]

            # try opening sacred records, if missing open manual bck instead
            filename = folder + str(i_run) + '/info.json'
            filename_bck = folder + str(i_run) + '/info_bck.json'
            apu = read_results(filename, filename_bck)

            #print(apu)
            #sys.exit()

            #for k in apu:
            #    print(k)
            #sys.exit()
            
            # format results for plotting
            client_measures = ['elbo','kl','logl']
            format_results(apu, run_id, client_measures, all_baselines)

            running_id += 1

# check restrictions
list_to_print = []
for i_run in runs_to_plot:
    print_this = True
    for k in restrictions:
        try:
            if restrictions[k] is not None and all_res['config'][str(i_run)][k] not in restrictions[k]:
                print_this = False
        except:
            continue
    # check baselines
    '''
    if include_baselines and not print_this:
        for tmp in baselines:
            print_this = True
            for k in tmp:
                if tmp[k] is not None and all_res['config'][str(i_run)][k] != tmp[k]:
                    print_this = False
                    break
            if print_this:
                break
    '''
    if print_this:
        list_to_print.append(i_run)
if len(list_to_print) == 0:
    sys.exit('No runs satisfying restrictions found!')
else:
    print(f'Found {len(list_to_print)} runs to plot')


In [None]:
#plot_type = 'q_trade_off_rebuttal'
plot_type = 'q_trade_off_rebuttal2'

if plot_type == 'q_trade_off_rebuttal':
    # plot mean (over seeds) acc/logl against q values, with fixed eps, C; use max performance on any global
    to_plot['best_mean_acc'] = np.zeros((2,len(all_eps_sigma[0]), len(all_q)))
    to_plot['best_mean_logl'] = np.zeros((2,len(all_eps_sigma[0]), len(all_q)))
    to_plot['best_mean_avg_prec_score'] = np.zeros((2,len(all_eps_sigma[0]), len(all_q)))
    to_plot['mean_ROC_at_best_mean_logl'] = OD()
    to_plot['dp_C'] = np.zeros(len(all_eps_sigma[0]))
    
elif plot_type == 'q_trade_off_rebuttal2':
    # plot q as different lines; acc/logl against global update to show convergence speed
    # these are now not best means, but just means over seeds
    # take number of globals from any run; should be same for all to make any sense
    to_plot['mean_acc'] = np.zeros((2,len(restrictions['sampling_frac_q']), all_res['config']['1']['n_global_updates']))
    to_plot['mean_logl'] = np.zeros((2,len(restrictions['sampling_frac_q']), all_res['config']['1']['n_global_updates']))
    to_plot['mean_avg_prec_score'] = np.zeros((2, len(restrictions['sampling_frac_q']),all_res['config']['1']['n_global_updates']))
    to_plot['mean_ROC_at_best_mean_logl'] = OD()
    #to_plot['dp_C'] = np.zeros(len(list_to_print))
    to_plot['dp_C'] = np.zeros(len(restrictions['sampling_frac_q']))

else:
    sys.exit(f'Unknown plot type: {plot_type}')


In [None]:
# format data for plotting

for i_line,i_run in enumerate(list_to_print):

    config = all_res['config'][str(i_run)]
    res = all_res['validation_res'][str(i_run)]
    
    # best mean over all global updates
    tmp = ['acc', 'logl', 'avg_prec_score']
    for i_tmp,tmp_name in enumerate(tmp):
        # take argmax logl as the best model
        #print(all_res['validation_res'][str(i_run)]['logl'].shape)
        i_max = np.argmax(all_res['validation_res'][str(i_run)]['logl'],0)
        #print(i_max)
        #sys.exit()
        
        if plot_type in ['eps_trade_off']:
            raise NotImplementedError('fix argmax')
            #logging.warning('check argmax here!')
            if config['dp_sigma'] != 0 and config['dp_sigma'] is not None:
                i_max = np.argmax(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
                to_plot[f'best_mean_{tmp_name}'][0,all_eps_sigma[1] == config['dp_sigma'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].mean(-1)[i_max]
                to_plot[f'best_mean_{tmp_name}'][1,all_eps_sigma[1] == config['dp_sigma'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].std(-1)[i_max]

            else:
                if config['dp_C'] < nondp_C:
                    # only clipping
                    i_eps = 1
                else:
                    # nonDP
                    if i_tmp == 0:
                        try:
                            print(f"nondp run: {i_line}: dp_C={config['dp_C']}, dp_sigma={config['dp_sigma']}, sampling q={config['sampling_frac_q']}, pseudo q={config['pseudo_client_q']}")
                        except:
                            print("nondp doesn't have pseudo client conf?")
                    i_eps = 0

                i_max = np.argmax(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
                to_plot[f'best_mean_{tmp_name}'][0,i_eps, all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].mean(-1)[i_max]
                to_plot[f'best_mean_{tmp_name}'][1,i_eps, all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].std(-1)[i_max]

        elif plot_type == 'q_trade_off':
            raise NotImplementedError('fix argmax')
            #print(config['n_steps'], all_steps, all_steps == config['n_steps'])
            i_max = np.argmax(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
            to_plot[f'best_mean_{tmp_name}'][0, all_steps == config['n_steps'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].mean(-1)[i_max]
            to_plot[f'best_mean_{tmp_name}'][1,all_steps == config['n_steps'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].std(-1)[i_max]

        elif plot_type == 'q_trade_off_with_C':
            # NOTE: check that works with several seeds if used
            #print(all_res['validation_res'][str(i_run)][tmp_name][i_max].shape)
            #print(all_res['validation_res'][str(i_run)][tmp_name][i_max].mean(-1))

            #print(config['n_steps'], all_steps, all_steps == config['n_steps'])
            #i_max = np.argmax(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
            to_plot[f'best_mean_{tmp_name}'][0, all_C == config['dp_C'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name][i_max].mean(-1)
            to_plot[f'best_mean_{tmp_name}'][1,all_C == config['dp_C'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name][i_max].std(-1)
            #i_max = np.argmax(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
            #to_plot[f'best_mean_{tmp_name}'][0, all_C == config['dp_C'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].mean(-1)[i_max]
            #to_plot[f'best_mean_{tmp_name}'][1,all_C == config['dp_C'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].std(-1)[i_max]


        elif plot_type in ['q_trade_off_rebuttal']:
            # want separate line for each eps and (possibly) dp_C
            #print(i_run,tmp_name)
            #print(all_res['validation_res'][str(i_run)][tmp_name].shape)
            #print(all_res['validation_res'][str(i_run)][tmp_name])
            #print(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
            #print(all_res['validation_res'][str(i_run)][tmp_name][i_max].mean(-1))
            
            if config['dp_sigma'] != 0 and config['dp_sigma'] is not None:
                to_plot['dp_C'][all_eps_sigma[1] == config['dp_sigma']] = config['dp_C']
                
                i_max = np.argmax(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
                to_plot[f'best_mean_{tmp_name}'][0,all_eps_sigma[1] == config['dp_sigma'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].mean(-1)[i_max]
                to_plot[f'best_mean_{tmp_name}'][1,all_eps_sigma[1] == config['dp_sigma'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].std(-1)[i_max]
            else:
                if i_tmp == 0:
                    print(f"nondp run: {i_line}: dp_C={config['dp_C']}, dp_sigma={config['dp_sigma']}, sampling q={config['sampling_frac_q']}")

                if config['dp_C'] < nondp_C:
                    # only clipping
                    i_eps = 1
                else:
                    # nonDP
                    i_eps = 0
                i_max = np.argmax(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
                to_plot[f'best_mean_{tmp_name}'][0,i_eps, all_q == config['sampling_frac_q']] = all_res['validation_res'][str(i_run)][tmp_name].mean(-1)[i_max]
                to_plot[f'best_mean_{tmp_name}'][1,i_eps, all_q == config['sampling_frac_q']] = all_res['validation_res'][str(i_run)][tmp_name].std(-1)[i_max]
                to_plot['dp_C'][i_eps] = config['dp_C']


        elif plot_type in ['q_trade_off_rebuttal2']:
            # separate line for each q, eps and C
            
            #print(i_run,tmp_name)
            #print(all_res['validation_res'][str(i_run)][tmp_name].shape)
            #print(all_res['validation_res'][str(i_run)][tmp_name])
            #print(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
            #print(all_res['validation_res'][str(i_run)][tmp_name][i_max].mean(-1))

            # this needs array as well!
            to_plot['dp_C'][restrictions['sampling_frac_q'] == config['sampling_frac_q']] = config['dp_C']

            """
            if config['dp_sigma'] != 0 and config['dp_sigma'] is not None:
                #i_max = np.argmax(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
                to_plot[f'mean_{tmp_name}'][0, all_eps_sigma[1] == config['dp_sigma'], all_q == config['sampling_frac_q'],:]  = all_res['validation_res'][str(i_run)][tmp_name].mean(-1)
                to_plot[f'mean_{tmp_name}'][1, all_eps_sigma[1] == config['dp_sigma'], all_q == config['sampling_frac_q'] ]  = all_res['validation_res'][str(i_run)][tmp_name].std(-1)
            else:
            
            if config['dp_C'] < nondp_C:
                # only clipping
                i_eps = 1
            else:
                # nonDP
                if i_tmp == 0:
                    try:
                        print(f"nondp run: {i_line}: dp_C={config['dp_C']}, dp_sigma={config['dp_sigma']}, sampling q={config['sampling_frac_q']}, pseudo q={config['pseudo_client_q']}")
                    except:
                        print("nondp doesn't have pseudo client conf?")
                i_eps = 0
            """
            to_plot[f'mean_{tmp_name}'][0, np.asarray(restrictions['sampling_frac_q']) == config['sampling_frac_q'],:]  = all_res['validation_res'][str(i_run)][tmp_name].mean(-1)
            to_plot[f'mean_{tmp_name}'][1, np.asarray(restrictions['sampling_frac_q']) == config['sampling_frac_q'],:]  = all_res['validation_res'][str(i_run)][tmp_name].std(-1)
            #print(restrictions['sampling_frac_q'],config['sampling_frac_q'], np.asarray(restrictions['sampling_frac_q']) == config['sampling_frac_q'])
            #print(all_res['validation_res'][str(i_run)][tmp_name].mean(-1))
#to_plot['mean_acc'] = np.zeros((2,len(all_eps_sigma[0]),len(all_q), all_res['config']['1']['n_global_updates']))

In [None]:
print(list_to_print)

In [None]:
#print(to_plot)

In [None]:
# plot acc/logl vs q (number of local splits)
# separate line for each eps or dp_C

fig, axs = plt.subplots(2,2)
#plt.suptitle(f"Included clipping C: {restrictions['dp_C']}")
for i_line, eps in enumerate(all_eps_sigma[0]):
    C = to_plot['dp_C'][i_line]
    
    axs[0,0].errorbar(np.log10(all_q), to_plot['best_mean_acc'][0,i_line,:], 
                    yerr= 2*to_plot['best_mean_acc'][1,i_line,:]/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
                    label=f'eps={eps}, C={C}', 
                    color=colors[i_line%len(colors)]
                    )
    axs[1,0].errorbar(np.log10(all_q), to_plot['best_mean_logl'][0,i_line,:], 
                    yerr= 2*to_plot['best_mean_logl'][1,i_line,:]/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
                    label=f'eps={eps}, C={C}', 
                    color=colors[i_line%len(colors)]
                    )
    #axs[0,1].errorbar(np.log10(all_q), to_plot['best_mean_avg_prec_score'][0,i_line,:], 
                    yerr= 2*to_plot['best_mean_avg_prec_score'][1,i_line,:]/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
                    label=f'eps={eps}, C={C}', 
                    color=colors[i_line%len(colors)]
                    )
    axs[1,1].plot(0,0, label=f"eps={eps}, C={C}") # this is currently just used for labels
axs[1,1].tick_params(axis='both',which='both',bottom=False,left=False,labelbottom=False, labelleft=False)
#
# invert axes when plotting number of local splits
#axs[0,0].set_xlim(100,1)
axs[1,1].legend()
axs[0,0].grid()
axs[1,0].grid()
axs[0,1].grid()
axs[0,0].set_ylabel('Acc')
axs[1,0].set_ylabel('Logl')
axs[0,1].set_ylabel('Avg prec score')
axs[0,0].set_xlabel('log2 fraction of local samples per split')
axs[1,0].set_xlabel('log2 fraction of local samples per split')
axs[0,1].set_xlabel('log2 fraction of local samples per split')
plt.tight_layout()
#if to_disk:
#    plt.savefig(fig_folder + plot_filename)
#else:
plt.show()

In [None]:
all_eps_sigma[0][all_res['config'][str(list_to_print[0])]['dp_sigma'] == all_eps_sigma[1]]

In [None]:
# plot acc/logl vs global update
# 

legend_font_size = 14

fig, axs = plt.subplots(2,1, figsize=(8,10))
plt.suptitle(f"Log. regr. with balanced MIMIC-III, 10 clients, balanced split", fontsize=legend_font_size+2)
#for i_line, eps in enumerate(all_eps_sigma[0]):
#for i_line, i_run in enumerate(list_to_print):
for i_q, q in enumerate(restrictions['sampling_frac_q']):
    
    #C = all_res['config'][str(i_run)]['dp_C']
    C = to_plot['dp_C'][i_q]
    #eps = all_eps_sigma[0][all_res['config'][str(list_to_print[0])]['dp_sigma'] == all_eps_sigma[1]]
    #try:
    #    eps = eps[0]
    #except:
    #    pass
    
    #n_globals = all_res['config'][str(i_run)]['n_global_updates']
    #q = all_res['config'][str(i_run)]['sampling_frac_q']
    n_globals = 5
    
    #axs[0,0].errorbar(np.linspace(1,n_globals,n_globals), all_res['validation_res'][str(i_run)]['acc'].mean(-1), 
        #yerr= 2*all_res['validation_res'][str(i_run)]['acc'].std(-1)/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
    axs[0].errorbar(np.linspace(1,n_globals,n_globals, dtype='int'), to_plot['mean_acc'][0,i_q,:], 
                    yerr= 2*to_plot['mean_acc'][1,i_q,:]/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
                    #label=f'q={q},eps={eps}, C={C}', 
                    label=f'Number of local splits: {int(1/q)}', 
                    color=colors[i_q%len(colors)], markersize=12, 
                    )
    #axs[1,0].errorbar(np.linspace(1,n_globals,n_globals), all_res['validation_res'][str(i_run)]['logl'].mean(-1), 
    #                yerr= 2*all_res['validation_res'][str(i_run)]['logl'].std(-1)/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
    axs[1].errorbar(np.linspace(1,n_globals,n_globals, dtype='int'), to_plot['mean_logl'][0,i_q,:], 
                    yerr= 2*to_plot['mean_logl'][1,i_q,:]/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
                    #label=f'q={q},eps={eps}, C={C}', 
                    label=f'Number of local splits: {int(1/q)}', 
                    color=colors[i_q%len(colors)], markersize=12, 
                    )

# invert axes when plotting number of local splits
#axs[0,0].set_xlim(100,1)
axs[0].set_ylim((.5,.78))
axs[1].set_ylim((-1.,-.45))
#axs[1,1].legend()
axs[0].grid()
axs[1].grid()
#axs[0,1].grid()
axs[0].set_ylabel('Accuracy', fontsize=legend_font_size+1)
axs[1].set_ylabel('Logl', fontsize=legend_font_size+1)
#axs[0,1].set_ylabel('Avg prec score')
axs[0].set_xlabel('Communications', fontsize=legend_font_size+1)#axs[1,0].set_xlabel('Communications')
axs[1].set_xlabel('Communications', fontsize=legend_font_size+1)
axs[0].legend(loc=4, fontsize=legend_font_size)
axs[1].legend(loc=4, fontsize=legend_font_size)
#axs[1].legend()
plt.tight_layout()
#if to_disk:
#plt.savefig('figs/' + 'mimic_bal_trade-off_non-DP.pdf')
#else:
plt.show()

In [None]:
# plot acc/logl vs q (number of local splits)
# separate line for each eps or dp_C
legend_font_size = 14

fig, axs = plt.subplots(2, figsize=(8,10))
plt.suptitle(f"Log. regr. with balanced MIMIC-III, 10 clients, balanced split", fontsize=legend_font_size+2)
for i_line, eps in enumerate(all_eps_sigma[0]):
    C = to_plot['dp_C'][i_line]
    
    axs[0].errorbar(1/all_q, to_plot['best_mean_acc'][0,i_line,:], 
                    yerr= 2*to_plot['best_mean_acc'][1,i_line,:]/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
                    label=f'eps={eps}, C={int(C)}', 
                    color=colors[i_line%len(colors)], markersize=12, 
                    )
    axs[1].errorbar(1/all_q, to_plot['best_mean_logl'][0,i_line,:], 
                    yerr= 2*to_plot['best_mean_logl'][1,i_line,:]/np.sqrt(config['n_rng_seeds']), # 2*SEM errorbar over seeds
                    label=f'eps={eps}, C={int(C)}', 
                    color=colors[i_line%len(colors)], markersize=12, 
                    )

# invert axes when plotting number of local splits
#axs[0].set_xlim(200,1)
axs[0].legend(loc=4, fontsize=legend_font_size)
axs[1].legend(loc=4, fontsize=legend_font_size)
axs[0].set_ylim((.5,.78))
axs[1].set_ylim((-1.,-.45))
axs[0].grid()
axs[1].grid()
axs[0].set_ylabel('Accuracy', fontsize=legend_font_size+1)
axs[1].set_ylabel('Logl', fontsize=legend_font_size+1)
axs[0].set_xlabel('Number of local splits', fontsize=legend_font_size+1)
axs[1].set_xlabel('Number of local splits', fontsize=legend_font_size+1)
plt.tight_layout()
#if to_disk:
#plt.savefig('figs/' + 'mimic_bal_trade-off.pdf')
#else:
plt.show()

In [None]:
# plot legend; NOT FIXED
fig, axs = plt.subplots(1,1, figsize=(8,10))
    for i_method, params in enumerate(plot_group):
        x = 1
        y = 1
        yerr = None    
        if 'BCM' in params['name']:
            axs.errorbar(None,None,
                        yerr=yerr,
                        marker='*', markersize=markersize,
                        linewidth=0,
                        linestyle = None,
                        label=params['name'],
                        color=params['colour'])
        else:
            if 'global' in params['name'] or 'trusted' in params['name']:
                ls = '--'
            else:
                ls = '-'
            axs.plot(x, y, label=params['name'],linestyle=ls, linewidth=linewidth, color=params['colour'])
    axs.legend(loc=10, fontsize=legend_font_size, framealpha=1., mode='expand')
    axs.set_axis_off()
    plt.tight_layout()
    if filename is None:
        plt.show()
    else:
        plt.savefig(filename)

In [None]:
# total samples with balanced mimic, 10 clients uniform split: 447
print(1/config['sampling_frac_q'], np.floor(config['sampling_frac_q']*447) * (1/config['sampling_frac_q']))
assert np.floor(config['sampling_frac_q'])*447 * (1/config['sampling_frac_q']) <= 447

In [None]:
print( all_q, 1/all_q)

In [None]:
# define some funs for reading data

def read_config(filename, failed_runs):
    try:
        with open(filename, 'r') as f:
            apu = f.read()
    except FileNotFoundError as err:
        print(f"Can't open file {filename}! Skipping")
        failed_runs.append(i_run)
        return None
    apu = jsonpickle.unpickler.decode(apu)
    #print(apu)
    return apu, failed_runs


def read_results(filename, filename_bck):
    try:
        with open(filename, 'r') as f:
            apu = f.read()
        try:
            apu = jsonpickle.unpickler.decode(apu)
        except:
            print(f'error in JSON decoding in run {filename}')
            with open(filename, 'r') as f:
                apu = f.read()
            print('results from file: {}\n{}'.format(filename,apu))
            sys.exit()
    except FileNotFoundError as err:
        import json
        with open(filename_bck, 'r') as f:
            apu = f.read()
            apu = json.loads(apu)
        try:
            #apu = jsonpickle.unpickler.decode(apu)
            #print(apu)
            apu = jsonpickle.decode(apu, keys=True)
            #print(apu)
            #print('at bck decode')
        except:
            print(f'error in JSON decoding in {filename_bck}')
            #with open(filename, 'r') as f:
            #    apu = f.read()
            #print('results from file: {}\n{}'.format(filename,apu))
            sys.exit()
    return apu

def format_results(apu, run_id, client_measures, all_res):

    all_res['client_train_res'][run_id] = {}

    for k in client_measures:
        all_res['client_train_res'][run_id][k] = np.zeros((
            all_res['config'][run_id]['clients'],  
            all_res['config'][run_id]['n_global_updates'],  
            all_res['config'][run_id]['n_steps'],  
            all_res['config'][run_id]['n_rng_seeds']
            ))
    measures = ['acc','logl']
    posneg_measures = ['avg_prec_score','balanced_acc','f1_score']
    
    all_res['train_res'][run_id] = {}
    all_res['validation_res'][run_id] = {}
    for k in (measures+posneg_measures):
        all_res['train_res'][run_id][k] = np.zeros((
            all_res['config'][run_id]['n_global_updates'],  
            all_res['config'][run_id]['n_rng_seeds']
            ))
        all_res['validation_res'][run_id][k] = np.zeros((
            all_res['config'][run_id]['n_global_updates'],  
            all_res['config'][run_id]['n_rng_seeds']
            ))
        all_res['train_res'][run_id]['best_'+k] = np.zeros((
            all_res['config'][run_id]['n_rng_seeds']
            ))
        all_res['validation_res'][run_id]['best_'+k] = np.zeros((
            all_res['config'][run_id]['n_rng_seeds']
            ))

    # does this work with sampling=seq?
    if dataset_name != 'mnist':
        # for plotting ROC curve for max logl global update, one for each seed
        try:
            all_res['validation_res'][run_id]['TPR'] = np.zeros((
                all_res['config'][run_id]['n_rng_seeds'],
                apu['validation_res_seed0']['posneg'][0]['n_points']
                ))
            all_res['validation_res'][run_id]['TNR'] = np.zeros((
                all_res['config'][run_id]['n_rng_seeds'],
                apu['validation_res_seed0']['posneg'][0]['n_points']
                ))
            all_res['validation_res'][run_id]['ROC_thresholds'] = np.linspace(0,1,apu['validation_res_seed0']['posneg'][0]['n_points'])
        except:
            print('error in posneg results')

    for i_seed in range(all_res['config'][run_id]['n_rng_seeds']):
        # logl, elbo, kl
        '''
        if i_seed == 0:
            for k in apu[f'validation_res_seed{i_seed}']:
                print(k)
            print(apu[f'validation_res_seed{i_seed}']['acc'].shape )
            print(apu[f'client_train_res_seed{i_seed}']['logl'].shape )
        #'''
        #sys.exit()
        for k in client_measures:
            try:
                all_res['client_train_res'][run_id][k][:,:,:,i_seed] = apu[f'client_train_res_seed{i_seed}'][k]
            except KeyError as err:
                print(f'KeyError in run {i_run} (=folder)')
                print(f'got\n{apu}')
                print("config: batch_size={}, jobid={}".format(all_res['config'][run_id]['batch_size'], all_res['config'][run_id]['job_id'] ))
                raise err

        for k in measures:
            all_res['train_res'][run_id][k][:,i_seed] = apu[f'train_res_seed{i_seed}'][k]
            all_res['validation_res'][run_id][k][:,i_seed] = apu[f'validation_res_seed{i_seed}'][k]

            all_res['train_res'][run_id]['best_'+k][i_seed] = np.amax(all_res['train_res'][run_id][k][:,i_seed])
            all_res['validation_res'][run_id]['best_'+k][i_seed] = np.amax(all_res['validation_res'][run_id][k][:,i_seed])

        # calculate true positive and true negative rates at global update with best logl
        #print( all_res['validation_res'][run_id]['logl'][:,i_seed] )
        best_global_logl = np.argmax( all_res['validation_res'][run_id]['logl'][:,i_seed] )
        #print(all_res['validation_res'][run_id]['logl'][:,i_seed][best_global_logl])
        #print(all_res['validation_res'][run_id]['best_logl'][i_seed])
        # all_res['validation_res'][run_id] = {}
        #sys.exit()

        if dataset_name != 'mnist':
            try:
                all_res['validation_res'][run_id]['TPR'][i_seed,:] = apu[f"validation_res_seed{i_seed}"]['posneg'][best_global_logl]['TP']/( apu[f"validation_res_seed{i_seed}"]['posneg'][best_global_logl]['TP'] + apu[f"validation_res_seed{i_seed}"]['posneg'][best_global_logl]['FN'])

                all_res['validation_res'][run_id]['TNR'][i_seed,:] = apu[f"validation_res_seed{i_seed}"]['posneg'][best_global_logl]['TN']/( apu[f"validation_res_seed{i_seed}"]['posneg'][best_global_logl]['TN'] + apu[f"validation_res_seed{i_seed}"]['posneg'][best_global_logl]['FP'])

                # posneg = list of posneg dicts with len=n_global_updates
                # NOTE: need to check format when n_seeds > 1
                for i_global in range(all_res['config'][run_id]['n_global_updates']):
                    for k in posneg_measures:
                        #print(apu[f"validation_res_seed0"]['posneg'][i_global][k])
                        all_res['train_res'][run_id][k][i_global,i_seed] = apu[f"train_res_seed{i_seed}"]['posneg'][i_global][k]
                        all_res['validation_res'][run_id][k][i_global,i_seed] = apu[f"validation_res_seed{i_seed}"]['posneg'][i_global][k]
                        #print(len(apu[f"validation_res_seed0"]['posneg']),len(apu[f"train_res_seed0"]['posneg'] ))
                        #sys.exit()

                for k in posneg_measures:
                    all_res['train_res'][run_id]['best_'+k][i_seed] = np.amax(all_res['train_res'][run_id][k][:,i_seed])
                    all_res['validation_res'][run_id]['best_'+k][i_seed] = np.amax(all_res['validation_res'][run_id][k][:,i_seed])

            except:
                print('error in AUCROC')