In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd
import wandb
from pprint import pprint

In [None]:

def wandb2pd(exp_runs):
    df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    summary_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    config_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    name_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)

    summary = [] 
    config = [] 
    name = [] 
    for exp in exp_runs: 
        summary.append(exp.summary._json_dict) 
        config.append({k:v for k,v in exp.config.items() if not k.startswith('_')}) 
        name.append(exp.name)       

    summary_df = pd.DataFrame.from_records(summary) 
    config_df = pd.DataFrame.from_records(config) 
    name_df = pd.DataFrame({'name': name}) 
    df = pd.concat([name_df, config_df, summary_df], axis=1)
    return df

In [None]:
# colorlist = ["#377eb8", "#ff7f00", "#4daf4a", "darkred", "#984ea3", "y", "k", "w"]
colorlist = ["#ff7f00","#984ea3"]
metrics = ['avg_val_acc', 'avg_test_acc']
metric_names = ['avg_val_acc', 'avg_test_acc']
dataset_list = ['ColoredMNIST', 'PACS', 'VLCS', 'OfficeHome', 'TerraIncognita','DomainNet', 'RotatedMNIST']
algorithm_list = ['ERM', 'IRM']
optimizer_list = ['momentum_sgd', 'adam']
num_bin = 10

In [None]:
def plot_reliability_diagram(algorithm, dataset):
    with_weight_decay_data = ['RotatedMNIST']
    if dataset in with_weight_decay_data:
        path_list = [
            f'entity_name/project_momentum_sgd_with_wd_100k',
            f'entity_name/project_adam_with_wd_100k'
        ]
    else:
        path_list = [
            f'entity_name/project_momentum_sgd',
            f'entity_name/project_adam'
        ]

    clr = []
    for i in range(len(path_list)):

        api = wandb.Api()
        exp_runs = api.runs(
            path=path_list[i],
            filters={'state':'finished'}
            )

        exp_runs
        df = wandb2pd(exp_runs)

        clr = [colorlist[i] for a in range(len(df))]

        bin_ood_accs = []
        for j in range(num_bin):
            bottom = (j)/num_bin
            upper = (j+1)/num_bin
            
            df_ood_bin = df[upper > df[metrics[0]]]
            df_ood_bin = df_ood_bin[ df[metrics[0]] > bottom]
            value = df_ood_bin[metrics[1]].mean()
            bin_ood_accs.append(value)

        bins = np.linspace(0.1, 1, num_bin)
        plt.bar(bins, bin_ood_accs, width=1/num_bin, edgecolor='black', color=colorlist[i], alpha=0.5)

    plt.gca().set_aspect('equal', adjustable='box')
    plt.grid(color='gray', linestyle='dashed')
    plt.xlim(0, 1.05)
    plt.ylim(0, 1.)

    plt.legend(optimizer_list)
    plt.plot([0, 1], [0, 1], '--', color='gray', linewidth=2)

    plt.xlabel(f'{metric_names[0]}', fontsize=15, labelpad=2)
    plt.ylabel(f'{metric_names[1]}', fontsize=15, labelpad=5)

    plt.title(f"{dataset}:{algorithm}")
    plt.grid(linewidth=1)
    plt.savefig(f'figs/reliable_diag/bin{num_bin}_plot_{algorithm}_{dataset}.pdf')
    plt.show()

In [None]:

for dataset in dataset_list:
    for algorithm in algorithm_list:
        plot_reliability_diagram(algorithm, dataset)
