In [168]:
import re

In [169]:
import pandas as pd
import wandb
api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs("joelavond/FedDecay")

summary_list, config_list, name_list = [], [], []
for run in runs:
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append(
        {k: v for k,v in run.config.items()
          if not k.startswith('_')})

    # .name is the human-readable name of the run.
    name_list.append(run.name)

runs_df = pd.DataFrame({
    "summary": summary_list,
    "config": config_list,
    "name": name_list
    })

print(runs_df.shape)
runs_df.head()

(68, 3)


Unnamed: 0,summary,config,name
0,"{'best_Client #38/val_correct': 173, 'best_Cli...","{'lr': 0.05, 'hpo': {'pbt': {'max_stage': 5, '...",sst2--fedavg--n_epochs6--batch_size32--lr0.05-...
1,"{'best_Unseen Client #4/test_correct': 3, 'bes...","{'lr': 0.1, 'hpo': {'pbt': {'max_stage': 5, 'p...",sst2--fedavg--n_epochs3--batch_size32--lr0.1--...
2,{'best_Unseen Client #28/test_avg_loss': 0.203...,"{'lr': 0.01, 'hpo': {'pbt': {'max_stage': 5, '...",sst2--fedavg--n_epochs6--batch_size16--lr0.01-...
3,"{'best_Client #2/val_correct': 2, 'Results_fai...","{'lr': 0.1, 'hpo': {'pbt': {'max_stage': 5, 'p...",sst2--fedavg--n_epochs3--batch_size16--lr0.1--...
4,"{'best_Client #44/test_total': 45, 'best_Clien...","{'lr': 0.05, 'hpo': {'pbt': {'max_stage': 5, '...",sst2--fedavg--n_epochs6--batch_size64--lr0.05-...


In [192]:
## Recover all metrics of interest from runs_df
# get columns from summary dict
metrics = runs_df['summary'].apply(pd.Series)

# all metrics start with Results_
# only keep test metrics other than validation columns
validation_metric = 'Results_avg/val_acc'
metrics = metrics[[name for name in metrics.columns if re.search('^Results', name)]]
metrics = metrics[[
    name for name in metrics.columns
    if re.search('/test', name) or name == validation_metric
]]

# remove excess information from column names
metrics.columns = [re.sub('_avg', '', re.sub('_fairness', '', name)) for name in metrics.columns]
metrics = metrics[sorted(metrics.columns)]
validation_metric = re.sub('_avg', '', re.sub('_fairness', '', validation_metric))

# subset to metrics of interest
metrics_of_interest = ['acc', 'avg_loss', 'f1']
metrics = metrics[[
    name for name
    in metrics.columns
    if (
        any([bool(re.search(metric, name)) for metric in metrics_of_interest])
        and not(bool(re.search('top', name)))  # not interested in metrics containing "top"
    )
]]

# reorder metrics based on name

## Extract run hyperparameters from name
# convert to numeric
hyperparameters = runs_df.name.str.split('--').apply(
    lambda x: {
        re.sub('[0-9.]*', '', obj):re.sub('[a-z_]*', '', obj)
        for i, obj in enumerate(x)
        if re.search('[0-9]', obj)
           and i > 0
    }
).apply(pd.Series).apply(pd.to_numeric, errors='coerce')

# identify method used
methods = ['exact', 'fedavg', 'pfedme', 'fedbn', 'ditto', 'fedem']
method = runs_df.name.apply(
    lambda x: [method for method in methods if re.search(method, x)].pop(0)
)
dataset = runs_df.name.apply(lambda x: re.sub('-.*', '', x))
finetune = runs_df.name.apply(lambda x: bool(re.search('finetune', x)))

# combine with metrics
df = pd.concat({'dataset':dataset, 'method':method, 'finetune':finetune}, axis=1)
df = df.join(hyperparameters)
df = df.join(metrics)

## Copy fedsgd runs to exact decay with beta as 0
fedsgd_runs = df.loc[(df.method == 'fedavg')].copy()
fedsgd_runs.method = 'exact'
fedsgd_runs.beta = 1.0
fedsgd_runs.loc[(df.n_epochs == 1), 'beta'] = 0.0
fedsgd_runs.head()

# combine with previous data
df = pd.concat([df, fedsgd_runs])
df.head()



Unnamed: 0,dataset,method,finetune,n_epochs,batch_size,lr,beta,lr_,regular_weight_,regular_weight,...,Results_unseen/test_acc,Results_unseen/test_acc_bottom_decile,Results_unseen/test_acc_std,Results_unseen/test_f1,Results_unseen/test_f1_bottom_decile,Results_unseen/test_f1_std,Results_weighted/test_acc,Results_weighted/test_f1,Results_weighted_unseen/test_acc,Results_weighted_unseen/test_f1
0,sst2,exact,False,6,32,0.05,0.8,,,,...,0.795853,0.689655,0.095085,0.645415,0.408163,0.175969,0.760671,0.528876,0.787037,0.590872
1,sst2,exact,False,3,32,0.1,0.8,,,,...,0.725388,0.5,0.172125,0.550466,0.333333,0.193473,0.698171,0.494359,0.708333,0.548012
2,sst2,exact,False,6,16,0.01,0.8,,,,...,0.835741,0.666667,0.11607,0.699438,0.4,0.217982,0.769817,0.530319,0.796296,0.617272
3,sst2,exact,False,3,16,0.1,0.8,,,,...,0.794374,0.689655,0.088426,0.649479,0.408163,0.181342,0.775915,0.547154,0.768519,0.5743
4,sst2,exact,False,6,64,0.05,0.8,,,,...,0.782765,0.666667,0.094767,0.639589,0.414141,0.185805,0.77439,0.538243,0.768519,0.573656


In [198]:
## Get best run for each group
filtered_df = df.loc[(df.n_epochs < 6)]
idx = filtered_df.groupby(['method', 'finetune'])[validation_metric].idxmax()
filtered_df = filtered_df.loc[idx]
print(len(df))
print(len(filtered_df))

descending_metrics = [name for name in filtered_df.columns if re.match('Results', name)]
ascending_metrics = [
    descending_metrics.pop(descending_metrics.index(name))
    for name in descending_metrics
    if re.search('std', name)
]
filtered_runs = filtered_df[[
    name for name in filtered_df.columns
    if name not in descending_metrics + ascending_metrics
]]

ranked_descending = filtered_df[descending_metrics].rank(
    method='first',
    ascending=False
)
ranked_ascending = filtered_df[ascending_metrics].rank(
    method='first',
    ascending=True
)

ranked_metrics = pd.concat([ranked_descending, ranked_ascending], axis=1)
ranked_metrics = ranked_metrics[sorted(ranked_metrics.columns)]
filtered_runs.join(ranked_metrics)


80
15


Unnamed: 0,dataset,method,finetune,n_epochs,batch_size,lr,beta,lr_,regular_weight_,regular_weight,...,Results_unseen/test_acc,Results_unseen/test_acc_bottom_decile,Results_unseen/test_acc_std,Results_unseen/test_f1,Results_unseen/test_f1_bottom_decile,Results_unseen/test_f1_std,Results_weighted/test_acc,Results_weighted/test_f1,Results_weighted_unseen/test_acc,Results_weighted_unseen/test_f1
3,sst2,exact,False,3,16,0.1,0.8,,,,...,2.0,3.0,2.0,1.0,5.0,3.0,2.0,2.0,1.0,6.0
12,sst2,fedavg,False,3,16,,,0.05,,,...,10.0,11.0,10.0,10.0,6.0,9.0,9.0,3.0,7.0,8.0
12,sst2,fedavg,False,3,16,,,0.05,,,...,11.0,12.0,11.0,11.0,7.0,10.0,10.0,4.0,8.0,9.0
12,sst2,exact,False,3,16,,1.0,0.05,,,...,10.0,11.0,10.0,10.0,6.0,9.0,9.0,3.0,7.0,8.0
12,sst2,exact,False,3,16,,1.0,0.05,,,...,11.0,12.0,11.0,11.0,7.0,10.0,10.0,4.0,8.0,9.0
16,sst2,fedem,False,3,32,,,0.1,,,...,3.0,2.0,3.0,3.0,4.0,8.0,8.0,11.0,3.0,7.0
20,sst2,fedbn,False,3,64,,,0.5,,,...,13.0,13.0,12.0,13.0,13.0,12.0,13.0,14.0,13.0,13.0
28,sst2,ditto,False,3,32,0.05,,,0.05,,...,14.0,14.0,14.0,14.0,14.0,14.0,15.0,15.0,14.0,14.0
34,sst2,pfedme,False,3,32,0.05,,,,0.9,...,4.0,1.0,1.0,12.0,3.0,1.0,12.0,5.0,2.0,10.0
42,sst2,fedem,True,3,16,,,,,,...,12.0,10.0,13.0,9.0,12.0,13.0,3.0,12.0,4.0,1.0
