In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import wandb

from typing import Union

In [None]:
def wandb2pd(exp_runs):
    df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    summary_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    config_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    name_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)

    summary = [] 
    config = [] 
    name = [] 
    for exp in exp_runs: 
        summary.append(exp.summary._json_dict) 
        config.append({k:v for k,v in exp.config.items() if not k.startswith('_')}) 
        name.append(exp.name)       

    summary_df = pd.DataFrame.from_records(summary) 
    config_df = pd.DataFrame.from_records(config) 
    name_df = pd.DataFrame({'name': name}) 
    df = pd.concat([name_df, config_df, summary_df], axis=1)
    return df


In [None]:
def plot_status_by_trial_budget(
        df             :pd.DataFrame,
        plot_column    :str,
        top_k          :Union[int, str],
        optimizer_name :str,
        in_or_out      :str,
        max_xlim       :int,
        max_ylim       :float,
        algorithm      :str,
        dataset        :str,
        is_log_xscale  :bool=False,
    ):
 
    plt.figure(figsize=(6,4))
    plt.get_cmap('Blues') 
    
    ns = [n*10 for n in range(max_xlim)]
    
    acc_maxes = [df.head(n)[plot_column].max() for n in ns]
    plt.plot(ns, acc_maxes, label="top-1")
    
    acc_averages = [
        df.head(n)\
          .sort_values(by=plot_column, axis=0, ascending=False)[plot_column]\
          .head(top_k)\
          .mean()\
        for n in ns
    ]
    plt.plot(ns, acc_averages, label=f"top-{top_k} average")
    
    acc_averages = [
        df.head(n)[plot_column]\
          .mean()\
        for n in ns
    ]
    plt.plot(ns, acc_averages, label=f"all average")

    plt.title(f"{dataset} {algorithm} {optimizer_name}", fontsize=12)
    plt.xlabel(f"trials", fontsize=14)
    plt.ylabel(f"{in_or_out}-distribution acc", fontsize=14)
    plt.ylim(top=max_ylim+max_ylim*0.01)
    
    if is_log_xscale:
        plt.xscale("log")
        
    plt.grid(linewidth=1, alpha=0.5)
    plt.legend(loc='lower right', framealpha=0.5)
    
    os.makedirs(f"figs/accplot_by_trial_budget/{algorithm}/", exist_ok=True)
    if is_log_xscale:
        plt.savefig(f"figs/accplot_by_trial_budget/{algorithm}/{dataset}_{in_or_out}dist_acc_{optimizer_name}_logscale.pdf")
    else:
        plt.savefig(f"figs/accplot_by_trial_budget/{algorithm}/{dataset}_{in_or_out}dist_acc_{optimizer_name}.pdf")
    
    return 

In [None]:
exp_df_dict = {}
optimizer_name_list = [
    "Momentum",
    "Adam"
]

dataset_list = ["WILDS_civilcomments", "WILDS_Amazon"]
metrics = ["val_eval/acc_avg", "test_eval/acc_wg"]
algorithm_list = ["ERM", "IRM"]


In [None]:
for dataset in dataset_list:
    for algorithm in algorithm_list:
        
        path_list = [
            f"entity_name/ICML2022_{dataset}_{algorithm}_momentum_sgd",
            f"entity_name/ICML2022_{dataset}_{algorithm}_adam"
        ]
        api = wandb.Api()

        for i in range(len(path_list)):
            path = path_list[i]
            optimizer_name = optimizer_name_list[i]

            exp_runs = api.runs(
                path=path,
                filters={"state":"finished"}
            )

            df = wandb2pd(exp_runs)
            df.sort_values(by="_timestamp", axis=0, ascending=True, inplace=True)

            exp_df_dict[optimizer_name] = df

            min_exp_length = 10**5
            max_in_acc = 0
            max_out_acc = 0

            for exp_df in exp_df_dict.values():
                if len(exp_df) < min_exp_length:
                    min_exp_length = len(exp_df)
                    
                if exp_df[metrics[0]].max() > max_in_acc:
                    max_in_acc = exp_df.head(min_exp_length)[metrics[0]].max()
                    
                if exp_df[metrics[1]].max() > max_out_acc:
                    max_out_acc = exp_df.head(min_exp_length)[metrics[1]].max()

            print(min_exp_length, max_in_acc, max_out_acc)

            plot_column_list = metrics
            in_or_out_list = ["in", "out"]
            max_ylim_list = [max_in_acc, max_out_acc]

            for optimizer_name, exp_df in exp_df_dict.items():
                for plot_column, in_or_out, max_ylim in zip(plot_column_list, in_or_out_list, max_ylim_list):

                    plot_status_by_trial_budget(
                        df=exp_df,
                        plot_column=plot_column,
                        top_k=10,
                        optimizer_name=optimizer_name,
                        in_or_out=in_or_out,
                        max_xlim=min_exp_length//10,
                        max_ylim=max_ylim,
                        algorithm=algorithm,
                        dataset=dataset,
                        is_log_xscale=False,
                    )

                    # log scale
                    plot_status_by_trial_budget(
                        df=exp_df,
                        plot_column=plot_column,
                        top_k=10,
                        optimizer_name=optimizer_name,
                        in_or_out=in_or_out,
                        max_xlim=min_exp_length//10,
                        max_ylim=max_ylim,
                        algorithm=algorithm,
                        dataset=dataset,
                        is_log_xscale=True,
                    )
