In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import wandb
import seaborn as sns

In [None]:

def wandb2pd(exp_runs):
    df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    summary_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    config_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    name_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)

    summary = [] 
    config = [] 
    name = [] 
    for exp in exp_runs: 
        summary.append(exp.summary._json_dict) 
        config.append({k:v for k,v in exp.config.items() if not k.startswith('_')}) 
        name.append(exp.name)       

    summary_df = pd.DataFrame.from_records(summary) 
    config_df = pd.DataFrame.from_records(config) 
    name_df = pd.DataFrame({'name': name}) 
    df = pd.concat([name_df, config_df, summary_df], axis=1)
    return df

In [None]:
optimizer_name_list = ['SGD', 'Momentum SGD', 'Nesterov Momentum SGD', 'RMSProp', 'Adam']

In [None]:
def create_boxplot_domainbed(dataset, acc_envs, threshold=0):

    algorithm_list = ['ERM', 'IRM']

    with_weight_decay_data = ['ColoredMNIST', 'RotatedMNIST']


    for algorithm in algorithm_list:
        
        if algorithm == 'IRM' and dataset == 'RotatedMNIST': 
            return
            
        if dataset in with_weight_decay_data:
            if dataset == 'ColoredMNIST':
                path_list = [
                    f'entity_name/project_vanilla_sgd_with_wd',
                    f'entity_name/project_momentum_sgd_with_wd',
                    f'entity_name/project_nesterov_momentum_sgd_with_wd',
                    f'entity_name/project_rmsprop_with_wd',
                    f'entity_name/project_adam_with_wd'
                ]
            elif dataset == 'RotatedMNIST':
                path_list = [
                    f'entity_name/project_vanilla_sgd_with_wd_100k',
                    f'entity_name/project_momentum_sgd_with_wd_100k',
                    f'entity_name/project_nesterov_momentum_sgd_with_wd_100k',
                    f'entity_name/project_rmsprop_with_wd_100k',
                    f'entity_name/project_adam_with_wd_100k'
                ]
        else:
            path_list = [
                f'entity_name/project_vanilla_sgd',
                f'entity_name/project_momentum_sgd',
                f'entity_name/project_nesterov_momentum_sgd',
                f'entity_name/project_rmsprop',
                f'entity_name/project_adam'
            ]

        
        df_list = []
        
        for i in range(len(path_list)):

            p = path_list[i]

            api = wandb.Api()
            exp_runs = api.runs(
                path=p,
                filters={'state':'finished'}
                )

            exp_runs
            df = wandb2pd(exp_runs)

            print(p, len(df))

            df_list.append(df)

        box_plot_list = []
        for idx_optim, df_item in enumerate(df_list):
        # For calculating Training Acc
            for idx, acc_env in enumerate(acc_envs):
                if idx == 0:
                    df_item['train_acc'] = df_item[acc_env]
                else: 
                    df_item['train_acc'] = df_item['train_acc'] + df_item[acc_env]

            df_item['train_acc'] = df_item['train_acc']/len(acc_envs) 
            df_item = df_item[df_item['avg_val_acc']>threshold]
            
            df_item['gap'] = df_item['avg_val_acc']-df_item['avg_test_acc']
            
            box_plot_list.append(df_item)

            mean_train_acc = df_item['train_acc'].mean()
            mean_test_acc = df_item['avg_test_acc'].mean() 
            mean_gap = mean_train_acc - mean_test_acc
            print(f'{optimizer_name_list[idx_optim]}: {mean_train_acc} / {mean_test_acc} / {mean_gap}')


        df_boxplot = pd.concat(box_plot_list)

        ax = sns.boxplot(x='optimizer_name', y='train_acc', data=df_boxplot)
        ax.set_xticklabels(['SGD', 'Momentum', 'Nesterov', 'RMSProp', 'Adam'])
        plt.ylim(0, 1.)
        plt.title(f"{dataset}: {algorithm}")
        plt.grid(linewidth=1)
        plt.savefig(f'figs/box/non-filtered/boxplot_train_acc_{algorithm}_{dataset}.pdf')
        plt.show()

        ax = sns.boxplot(x='optimizer_name', y='avg_val_acc', data=df_boxplot)
        ax.set_xticklabels(['SGD', 'Momentum', 'Nesterov', 'RMSProp', 'Adam'])
        plt.ylim(0, 1.)
        plt.title(f"{dataset}: {algorithm}")
        plt.grid(linewidth=1)
        plt.savefig(f'figs/box/non-filtered/boxplot_avg_val_acc_{algorithm}_{dataset}.pdf')
        plt.show()

        ax = sns.boxplot(x='optimizer_name', y='avg_test_acc', data=df_boxplot)
        ax.set_xticklabels(['SGD', 'Momentum', 'Nesterov', 'RMSProp', 'Adam'])
        plt.ylim(0, 1.)
        plt.title(f"{dataset}: {algorithm}")
        plt.grid(linewidth=1)
        plt.savefig(f'figs/box/non-filtered/boxplot_avg_test_acc_{algorithm}_{dataset}.pdf')
        plt.show()
    
        ax = sns.boxplot(x='optimizer_name', y='gap', data=df_boxplot)
        ax.set_xticklabels(['SGD', 'Momentum', 'Nesterov', 'RMSProp', 'Adam'])
        plt.ylim(0, 1.)
        if dataset == 'VLCS':
            plt.ylim(-0.3, 1.)
        plt.title(f"{dataset}: {algorithm}")
        plt.grid(linewidth=1)
        plt.savefig(f'figs/box/non-filtered/boxplot_gap_{algorithm}_{dataset}.pdf')
        plt.show()

In [None]:
dataset = 'VLCS'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_boxplot_domainbed(dataset, acc_envs)

In [None]:
dataset = 'PACS'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_boxplot_domainbed(dataset, acc_envs)

dataset = 'VLCS'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_boxplot_domainbed(dataset, acc_envs)

dataset = 'OfficeHome'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_boxplot_domainbed(dataset, acc_envs)

dataset = 'TerraIncognita'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_boxplot_domainbed(dataset, acc_envs)

dataset = 'DomainNet'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc', 'env4_in_acc', 'env5_in_acc']
create_boxplot_domainbed(dataset, acc_envs)

dataset = 'RotatedMNIST'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc', 'env4_in_acc', 'env5_in_acc']
create_boxplot_domainbed(dataset, acc_envs)

dataset = 'ColoredMNIST'
acc_envs = ['env1_in_acc', 'env0_in_acc']
create_boxplot_domainbed(dataset, acc_envs)