In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
from utils import get_wandb_logs, num_params, pd, clean_model_names

In [None]:
name_list = [
    "PACS",
    "OfficeHome",
    "DomainNet",
    "VLCS",
    
    ]

wandb_path_list = []
for v in (3, 4, 6, 7, 8, 9, 10, 11):
    wandb_path_list += [
        f"Timm_OfficeHome_ERM_momentum_sgd_v{v}",
        f"Timm_OfficeHome_ERM_momentum_sgd_v{v}_wd_1e-3",
        f"Timm_OfficeHome_ERM_momentum_sgd_v{v}_wd_1e-2",

        f"Timm_PACS_ERM_momentum_sgd_v{v}",
        f"Timm_PACS_ERM_momentum_sgd_v{v}_wd_1e-3",
        f"Timm_PACS_ERM_momentum_sgd_v{v}_wd_1e-2",
        
        f"Timm_DomainNet_ERM_momentum_sgd_v{v}",
        f"Timm_DomainNet_ERM_momentum_sgd_v{v}_wd_1e-3",
        f"Timm_DomainNet_ERM_momentum_sgd_v{v}_wd_1e-2",

        f"Timm_VLCS_3_ERM_momentum_sgd_v{v}",
        f"Timm_VLCS_3_ERM_momentum_sgd_v{v}_wd_1e-3",
        f"Timm_VLCS_3_ERM_momentum_sgd_v{v}_wd_1e-2",
    ]

    if v in (3, 4, 5, 8):
        wandb_path_list += [
        f"Timm_OfficeHome_ERM_adam_v{v}",
        f"Timm_PACS_ERM_adam_v{v}",
        f"Timm_DomainNet_ERM_adam_v{v}",
        f"Timm_VLCS_3_ERM_adam_v{v}",
    ]

In [None]:
_dfs = {}
entity = "anonymized"
for wandb_path in wandb_path_list:
    tmp_df  = get_wandb_logs(f"{entity}/{wandb_path}")
    _dfs[f'{wandb_path}'] = tmp_df

dfs = {}
for name in name_list:
    dfs[name] = pd.concat(_dfs[w] for w in wandb_path_list if name in w)

wandb_path_list = name_list

# Plot

In [None]:
def scatter_plot(
    ax: plt.Axes,
    title: str,
    df: pd.DataFrame,
    key_1: str,
    key_2: str,
    key_1_title: str,
    key_2_title: str,
    num_of_trial: int = 1,
    fill_std: bool = False,
    save: bool = False,
    different_markers: bool = False,
    show_num_params: bool = False,
    marker_size_base: int = 50,
    criterion_key: str = "avg_val_acc",
):
    assert num_of_trial == 1, "num_of_trial>1 is not implemented"

    model_list = [
        
        "vgg11.tv_in1k",
        "vgg13.tv_in1k", 
        "vgg16.tv_in1k", 
        "vgg19.tv_in1k", 

        "resnet50.tv_in1k",
        "resnet101.tv_in1k",
        "resnet152.tv_in1k",

        "regnety_002.pycls_in1k",
        "regnety_004.pycls_in1k",
        "regnety_006.pycls_in1k",
        "regnety_008.pycls_in1k",
        "regnety_016.pycls_in1k",
        "regnety_032.pycls_in1k",
        "regnety_040.pycls_in1k",
        "regnety_064.pycls_in1k",
        "regnety_080.pycls_in1k",
        "regnety_120.pycls_in1k",
        "regnety_160.pycls_in1k",
        "regnety_320.pycls_in1k",
        
        
        "tf_efficientnet_b0.in1k",
        "tf_efficientnet_b1.in1k",
        "tf_efficientnet_b2.in1k",
        "tf_efficientnet_b3.in1k",
        "tf_efficientnet_b4.in1k",
        "tf_efficientnet_b5.in1k",

        "convnext_tiny.fb_in1k",
        "convnext_small.fb_in1k",
        "convnext_base.fb_in1k",
        "convnext_large.fb_in1k",
        "vit_small_patch16_224.augreg_in1k",
        "vit_base_patch16_224.augreg_in1k",

        "vit_tiny_patch16_224.augreg_in21k_ft_in1k",
        "vit_small_patch16_224.augreg_in21k_ft_in1k",
        "vit_base_patch16_224.augreg_in21k_ft_in1k",
        "vit_large_patch16_224.augreg_in21k_ft_in1k",

        "mixer_b16_224.goog_in21k_ft_in1k",
        "mixer_l16_224.goog_in21k_ft_in1k",
    ]

    cmap = plt.get_cmap("hsv")

    df = df.dropna(subset=('avg_val_acc', 'avg_test_acc', 'avg_test_ece'))

    for i, model in enumerate(model_list):
        row = df[df['model'] == model]
        if 'resnet' in model:
            row = row[row['weight_decay'] == 1e-4]
        row = row.nlargest(1, criterion_key)

        key_1_list = row[key_1].iloc[0]
        key_2_list = row[key_2].iloc[0]

        if (key_1_list, key_2_list) == (0., 0.):
            print(f"{model}'s results may be broken!")
            continue

        marker_size = (
            marker_size_base * (num_params[model] / num_params["resnet50.tv_in1k"])
            if show_num_params
            else marker_size_base
        )

        marker = "o"
        ind = 0
        if 'resnet' in model:
            ind = 1
        if 'convnext' in model:
            ind = 5
        if 'efficient' in model:
            ind = 3
        if 'regnet' in model:
            ind = 2
        if 'vit' in model:
            ind = 6
        if 'mixer' in model:
            ind = 7
        
        if 'ft' in model:
            marker = "8"

        ax.scatter(
            key_1_list,
            key_2_list,
            marker=marker,
            color=cmap(ind / 10),
            s=marker_size,
            label=clean_model_names[model],
            alpha=0.7,
        )

    pdf_name = title
    ax.set_title(pdf_name, fontsize=16)
    ax.set_xlabel(key_1_title, fontsize=16, labelpad=2)
    ax.set_ylabel(key_2_title, fontsize=16, labelpad=2)
    ax.grid(alpha=0.5)

In [None]:
num_of_trial = 1

key_tuple_list = [
    ('num_params', 'avg_test_ece', ), 
                  ]

axis_titlte_tuple_list = [
    ('#parameters ($\\times10^8$)', 'Expected Calibration Error', ),
                          ]

fig, axes = plt.subplots(1, 4, figsize=(12, 3.4))

for key_element, title_element in zip(key_tuple_list, axis_titlte_tuple_list):
    key_1, key_2 = key_element
    key_1_title, key_2_title = title_element

    for wandb_path, name, ax in zip(wandb_path_list, name_list, axes):
        df = dfs[f'{wandb_path}']
        df['num_params'] = df['model'].apply(lambda x: num_params[x])
        scatter_plot(ax, name, df, key_1, key_2, key_1_title, key_2_title, num_of_trial, fill_std=True, save=False, show_num_params=True, different_markers=True)
        ax.xaxis.get_offset_text().set_visible(False)

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='center', bbox_to_anchor=(0.5, -0.4), ncol=6, borderpad=0.3, labelspacing=1, fontsize=12)
fig.tight_layout()

In [None]:
from utils import imagenet_results
num_of_trial = 1

key_tuple_list = [
    ('avg_test_acc', 'avg_test_ece'), 
                  ]

axis_titlte_tuple_list = [
    ('OOD Test ACC', 'Expected Calibration Error'),
                          ]

fig, axes = plt.subplots(1, 4, figsize=(12, 3.4))

for key_element, title_element in zip(key_tuple_list, axis_titlte_tuple_list):
    key_1, key_2 = key_element
    key_1_title, key_2_title = title_element

    for wandb_path, name, ax in zip(wandb_path_list, name_list, axes):
        df = dfs[f'{wandb_path}']
        df['num_params'] = df['model'].apply(lambda x: num_params[x])
        df['in1k_acc'] = df['model'].apply(lambda x: imagenet_results[x])
        scatter_plot(ax, name, df, key_1, key_2, key_1_title, key_2_title, num_of_trial, fill_std=True, save=False, show_num_params=True, different_markers=True)
        ax.xaxis.get_offset_text().set_visible(False)

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='center', bbox_to_anchor=(0.5, -0.4), ncol=6, borderpad=0.3, labelspacing=1, fontsize=12)
fig.tight_layout()

In [None]:
from utils import imagenet_results
num_of_trial = 1

key_tuple_list = [
    ('in1k_acc', 'avg_test_acc', ), 
                  ]

axis_titlte_tuple_list = [
    ('ImageNet-1k ACC', 'OOD Test ACC', ),
                          ]

fig, axes = plt.subplots(1, 4, figsize=(12, 3.4))

for key_element, title_element in zip(key_tuple_list, axis_titlte_tuple_list):
    key_1, key_2 = key_element
    key_1_title, key_2_title = title_element

    for wandb_path, name, ax in zip(wandb_path_list, name_list, axes):
        df = dfs[f'{wandb_path}']
        df['num_params'] = df['model'].apply(lambda x: num_params[x])
        df['in1k_acc'] = df['model'].apply(lambda x: imagenet_results[x])
        scatter_plot(ax, name, df, key_1, key_2, key_1_title, key_2_title, num_of_trial, fill_std=True, save=False, show_num_params=True, different_markers=True)
        ax.xaxis.get_offset_text().set_visible(False)

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='center', bbox_to_anchor=(0.5, -0.4), ncol=6, borderpad=0.3, labelspacing=1, fontsize=12)
fig.tight_layout()

In [None]:
num_of_trial = 1

key_tuple_list = [
    ('num_params', "avg_test_acc", ), 
                  ]

axis_titlte_tuple_list = [
    ('#parameters ($\\times10^8$)', "OOD Test Accuracy", ),
                          ]

fig, axes = plt.subplots(1, 4, figsize=(12, 3.4))

for key_element, title_element in zip(key_tuple_list, axis_titlte_tuple_list):
    key_1, key_2 = key_element
    key_1_title, key_2_title = title_element

    for wandb_path, name, ax in zip(wandb_path_list, name_list, axes):
        df = dfs[f'{wandb_path}']
        df['num_params'] = df['model'].apply(lambda x: num_params[x])
        scatter_plot(ax, name, df, key_1, key_2, key_1_title, key_2_title, num_of_trial, fill_std=True, save=False, show_num_params=True, different_markers=True)
        ax.xaxis.get_offset_text().set_visible(False)

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='center', bbox_to_anchor=(0.5, -0.4), ncol=6, borderpad=0.3, labelspacing=1, fontsize=12)
fig.tight_layout()