In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import wandb
import seaborn as sns

In [None]:

def wandb2pd(exp_runs):
    df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    summary_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    config_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    name_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)

    summary = [] 
    config = [] 
    name = [] 
    for exp in exp_runs: 
        summary.append(exp.summary._json_dict) 
        config.append({k:v for k,v in exp.config.items() if not k.startswith('_')}) 
        name.append(exp.name)       

    summary_df = pd.DataFrame.from_records(summary) 
    config_df = pd.DataFrame.from_records(config) 
    name_df = pd.DataFrame({'name': name}) 
    df = pd.concat([name_df, config_df, summary_df], axis=1)
    return df

In [None]:
def create_scatter_plot_domainbed(dataset, acc_envs):
    optimizer_name_list = ['SGD', 'Momentum SGD', 'Nesterov Momentum SGD', 'RMSProp', 'Adam']
    algorithm_list = ['ERM', 'IRM']
    with_weight_decay_data = ['ColoredMNIST', 'RotatedMNIST']
    colorlist = ["#377eb8", "#ff7f00", "#4daf4a", "darkred", "#984ea3", "y", "k", "w"]

    for algorithm in algorithm_list:
        
        text_list = [
            f'{algorithm} SGD',
            f'{algorithm} Momentum SGD',
            f'{algorithm} Nesterov Momentum SGD',
            f'{algorithm} RMSProp',
            f'{algorithm} Adam'
        ]
        
        if algorithm == 'IRM' and dataset == 'RotatedMNIST': 
            return
            
        if dataset in with_weight_decay_data:
            if dataset == 'ColoredMNIST':
                path_list = [
                    f'entity_name/project_vanilla_sgd_with_wd',
                    f'entity_name/project_momentum_sgd_with_wd',
                    f'entity_name/project_nesterov_momentum_sgd_with_wd',
                    f'entity_name/project_rmsprop_with_wd',
                    f'entity_name/project_adam_with_wd'
                ]
            elif dataset == 'RotatedMNIST':
                path_list = [
                    f'entity_name/project_vanilla_sgd_with_wd_100k',
                    f'entity_name/project_momentum_sgd_with_wd_100k',
                    f'entity_name/project_nesterov_momentum_sgd_with_wd_100k',
                    f'entity_name/project_rmsprop_with_wd_100k',
                    f'entity_name/project_adam_with_wd_100k'
                ]
        else:
            path_list = [
                f'entity_name/project_vanilla_sgd',
                f'entity_name/project_momentum_sgd',
                f'entity_name/project_nesterov_momentum_sgd',
                f'entity_name/project_rmsprop',
                f'entity_name/project_adam'
            ]
        
        
        for i in range(len(path_list)):

            p = path_list[i]

            api = wandb.Api()
            exp_runs = api.runs(
                path=p,
                filters={'state':'finished'}
                )

            exp_runs
            df = wandb2pd(exp_runs)

            print(p, len(df))
            
            accs = df['avg_val_acc'].tolist()
            ood_accs = df['avg_test_acc'].tolist()

            clr = [colorlist[i] for a in range(len(df))]

            plt.scatter(accs, ood_accs, color=clr, s=50, label=text_list[i], alpha=0.5)
        
        plt.xlim(0, 1.)
        plt.ylim(0, 1.)

        plt.title(f"{dataset}:{algorithm}")

        plt.xlabel(f'avg_val_acc', fontsize=15, labelpad=2)
        plt.ylabel(f'avg_test_acc', fontsize=15, labelpad=5)

        plt.legend(bbox_to_anchor=(1.05, 0), borderaxespad=0, loc='lower left', fontsize=10).get_frame().set_linewidth(0)
        plt.grid(linewidth=1)
        plt.savefig(f'figs/scatter/val-test/{algorithm}_{dataset}.pdf')
        plt.show()

In [None]:
dataset = 'PACS'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_scatter_plot_domainbed(dataset, acc_envs)

dataset = 'VLCS'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_scatter_plot_domainbed(dataset, acc_envs)

dataset = 'OfficeHome'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_scatter_plot_domainbed(dataset, acc_envs)

dataset = 'TerraIncognita'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc']
create_scatter_plot_domainbed(dataset, acc_envs)

dataset = 'DomainNet'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc', 'env4_in_acc', 'env5_in_acc']
create_scatter_plot_domainbed(dataset, acc_envs)

dataset = 'RotatedMNIST'
acc_envs = ['env1_in_acc', 'env2_in_acc', 'env3_in_acc', 'env4_in_acc', 'env5_in_acc']
create_scatter_plot_domainbed(dataset, acc_envs)

dataset = 'ColoredMNIST'
acc_envs = ['env1_in_acc', 'env0_in_acc']
create_scatter_plot_domainbed(dataset, acc_envs)