In [None]:
%matplotlib inline


from collections import defaultdict
import os
import warnings

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from IPython.display import display

from utils import get_wandb_logs, num_params, pd

warnings.simplefilter('ignore')
plt.rcParams['image.cmap'] = 'viridis'

In [None]:


name_list = [
    "PACS",
    "VLCS", 
    "VLCS_1", 
    "OfficeHome", 
    "DomainNet"
    ]

wandb_path_list = [
    "Timm_PACS_ERM_momentum_sgd_v3",
    'Timm_VLCS_ERM_momentum_sgd_v3',
    'Timm_VLCS_1_ERM_momentum_sgd_v3',
    'Timm_OfficeHome_ERM_momentum_sgd_v3', 
    'Timm_DomainNet_ERM_momentum_sgd_v3' 
]


In [None]:
entity = "anonymized"
dfs = {}
history_dfs = {}
for wandb_path in wandb_path_list:
    path = f'{entity}/{wandb_path}'
    tmp_df  = get_wandb_logs(path)
    dfs[f'{path}'] = tmp_df

# Plot

In [None]:
def scatter_plot(
    ax: plt.Axes,
    title: str,
    df: pd.DataFrame,
    key_1: str,
    key_2: str,
    key_1_title: str,
    key_2_title: str,
    num_of_trial: int,
    fill_std: bool = False,
    save: bool = False,
    different_markers: bool = False,
    show_num_params: bool = False,
):
    # ax.tick_params("both", labelsize=32)
    ax.grid(alpha=0.5)

    model_list = [
        "resnet50.tv_in1k",
        "resnet101.tv_in1k",
        "resnet152.tv_in1k",
        "convnext_tiny.fb_in1k",
        "convnext_small.fb_in1k",
        "convnext_base.fb_in1k",
        "convnext_large.fb_in1k",
        "vit_small_patch16_224.augreg_in1k",
        "vit_base_patch16_224.augreg_in1k",
    ]

    num_params = {
        "vit_tiny_patch16_224": 5.717416e06,
        "vit_small_patch16_224": 2.205066e07,
        "vit_base_patch16_224": 8.656766e07,
        "vit_large_patch16_224": 3.043266e08,
        "resnet50": 2.555703e07,
        "resnet101": 4.454916e07,
        "resnet152": 6.019281e07,
        "convnext_tiny": 2.858913e07,
        "convnext_small": 5.022369e07,
        "convnext_base": 8.859146e07,
        "convnext_large": 1.977673e08,
    }

    cmap = cm.get_cmap("hsv")

    # Initialize Dict
    key_1_list_dict = {}
    key_2_list_dict = {}
    for model in model_list:
        key_1_list_dict[model] = []
        key_2_list_dict[model] = []

    max_key_1 = 0
    max_key_2 = 0

    for column_name, df_element in df.iterrows():
        key_1_element = df_element[f"{key_1}"]
        key_2_element = df_element[f"{key_2}"]
        model = df_element["model"]

        if key_1_element == 'NaN' or key_2_element == 'NaN':
            continue

        if max_key_1 < key_1_element:
                max_key_1 = key_1_element
        try:
            if max_key_1 < key_1_element:
                max_key_1 = key_1_element
        except TypeError as e:
            print(f"{max_key_1=} {key_1_element=}")
            raise e

        try:
            if max_key_2 < key_2_element:
                max_key_2 = key_2_element
        except TypeError as e:
            print(f"{max_key_2=}, {key_2_element=}")
            raise e

        # model ごとに 上位の trial を num_of_trial 個抽出
        if len(key_1_list_dict[f"{model}"]) < num_of_trial:
            key_1_list_dict[f"{model}"].append(key_1_element)
            key_2_list_dict[f"{model}"].append(key_2_element)

    # model 数だけ loop し scatter plot
    for i, model in enumerate(model_list):
        ind = i / len(model_list)

        key_1_list = key_1_list_dict[f"{model}"]
        key_2_list = key_2_list_dict[f"{model}"]

        marker_size = (
            200 * (num_params[model.split('.')[0]] / num_params["resnet50"])
            if show_num_params
            else 200
        )
        marker = "o"

        if different_markers:
            if model.startswith("resnet"):
                marker = "o"
            elif model.startswith("convnext"):
                marker = "h"
            if model.startswith("vit"):
                marker = "^"

        ax.scatter(
            key_1_list,
            key_2_list,
            marker=marker,
            color=cmap(ind),
            s=marker_size,
            label=model,
            alpha=0.7,
        )
        # plt.scatter(key_1_list, key_2_list, s=200, label=model, alpha=0.7)

    # plt.xlim(0, max_key_1+1)
    # plt.ylim(0, max_key_2+1)

    pdf_name = f"DomainBed / {title}"
    ax.set_title(pdf_name, fontsize=32)
    ax.legend(fontsize=16, ncol=2, loc="lower center", bbox_to_anchor=(0.5, -0.5))

    ax.set_xlabel(f"{key_1_title}", fontsize=32, labelpad=2)
    ax.set_ylabel(f"{key_2_title}", fontsize=32, labelpad=2)

    # if not os.path.exists(f"figs/scatter-{title}"):
    #     os.makedirs(f"figs/scatter-{title}")
    # fig.tight_layout()
    # if save:
    #     fig.savefig(f"figs/scatter-{title}/{key_1}-vs-{key_2}.pdf", bbox_inches="tight")

In [None]:
num_of_trial = 1

key_tuple_list = [
    ('avg_test_acc', 'avg_val_acc'), 
                  ]

axis_titlte_tuple_list = [
    ('OOD Test Accuracy', 'ID Validation Accuracy'),
                          ]

fig, axes = plt.subplots(3, 2, figsize=(18, 27))
axes = axes.reshape(-1)

for key_element, title_element in zip(key_tuple_list, axis_titlte_tuple_list):
    key_1, key_2 = key_element
    key_1_title, key_2_title = title_element

    for wandb_path, name, ax in zip(wandb_path_list, name_list, axes):
        df = dfs[f'{wandb_path}']
        scatter_plot(ax, name, df, key_1, key_2, key_1_title, key_2_title, num_of_trial, fill_std=True, save=True, show_num_params=True, different_markers=True)

fig.tight_layout()

In [None]:
num_of_trial = 1

key_tuple_list = [
    ('num_params', 'avg_test_acc', ), 
                  ]

axis_titlte_tuple_list = [
    ('Number of parameters', 'OOD Test Accuracy', ),
                          ]

fig, axes = plt.subplots(3, 2, figsize=(18, 27))
axes = axes.reshape(-1)

for key_element, title_element in zip(key_tuple_list, axis_titlte_tuple_list):
    key_1, key_2 = key_element
    key_1_title, key_2_title = title_element

    for wandb_path, name, ax in zip(wandb_path_list, name_list, axes):
        df = dfs[f'{wandb_path}']
        df['num_params'] = df['model'].apply(lambda x: num_params[x])
        scatter_plot(ax, name, df, key_1, key_2, key_1_title, key_2_title, num_of_trial, fill_std=True, save=False, show_num_params=True, different_markers=True)
        ax.set_xscale('log')

fig.tight_layout()

In [None]:
num_of_trial = 1

key_tuple_list = [
    ('num_params', 'avg_test_ece', ), 
                  ]

axis_titlte_tuple_list = [
    ('Number of parameters', 'Expected Calibration Error', ),
                          ]

fig, axes = plt.subplots(3, 2, figsize=(18, 27))
axes = axes.reshape(-1)

for key_element, title_element in zip(key_tuple_list, axis_titlte_tuple_list):
    key_1, key_2 = key_element
    key_1_title, key_2_title = title_element

    for wandb_path, name, ax in zip(wandb_path_list, name_list, axes):
        df = dfs[f'{wandb_path}']
        df['num_params'] = df['model'].apply(lambda x: num_params[x])
        scatter_plot(ax, name, df, key_1, key_2, key_1_title, key_2_title, num_of_trial, fill_std=True, save=False, show_num_params=True, different_markers=True)
        ax.set_xscale('log')

fig.tight_layout()