In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd

import wandb

import functools
import pickle

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams["font.family"] = "Times"
plt.rcParams["font.weight"] = "light"

%matplotlib inline

import re

In [None]:
# establish and plot colorblind color pallete
colors = sns.color_palette('colorblind')
sns.set_palette(colors)
sns.palplot(colors)

In [None]:
all_speedups_dict = pickle.load(open("all_speedups_dict.pkl", "rb"))

def all_speedups_dict_key_to_dataset(key):
    datasets = ["CIFAR100", "CIFAR10", "CINIC10"]
    for d in datasets:
        if d in key:
            return d

def all_speedups_dict_key_to_exp_type(key, exp_key):
    exp_types = {
        "hypers": "Hyperparameter transfer",
        "archs": "Architecture transfer",
        "holdout_set": "Holdout set size",
        "double_IrLoMo": "No holdout set",
        "default": "Default",
        "_small_CNN": "Small irreducible loss model",
    }
    for k, v in exp_types.items():
        if k in key:
            if k != "holdout_set":
                return v
            else:
                return f"Holdout set {100*exp_key_to_label('Holdout set size', exp_key)/2:.1f}% of available data"

def exp_key_to_label(exp_type, exp_key):
    if exp_type in ["Hyperparameter transfer", "Architecture transfer", "Default", "Small irreducible loss model", "No holdout set"]: 
        return ""
    if exp_type == "Holdout set size":
        return float(re.findall("\d.\d*", exp_key)[0])
    

In [None]:
def exp_filter_out(exp_group_key, exp_key):
    if exp_key in ["small CNN, 0.75", "small CNN, 0.33", "small CNN, 1", "src.models.modules.resnet_cifar.ResNet18"]:
        return True
    return False

In [None]:
all_speedups_df_list = []

for exp_group_key, exp_group_exps in all_speedups_dict.items():
    dataset = all_speedups_dict_key_to_dataset(exp_group_key)
    
    for exp_key, exp_vals in exp_group_exps.items():
        
        for val in exp_vals:
            if exp_filter_out(exp_group_key, exp_key):
                continue
            
            exp_type = all_speedups_dict_key_to_exp_type(exp_group_key, exp_key)
            exp_dict = {
                "Experiment Group": exp_type,
                "Dataset": dataset,
                "Speedup": val,
                "Label": exp_key_to_label(exp_type, exp_key)
            }
            all_speedups_df_list.append(exp_dict)

all_speedups_df = pd.DataFrame(all_speedups_df_list)

# Easy Plot Version

In [None]:
order = ["Default", "Small irreducible loss model", "Holdout set 25.0% of available data", "Holdout set 12.5% of available data", 
         "No holdout set", "Architecture transfer", "Hyperparameter transfer"]
hue_order = ["CIFAR10", "CIFAR100", "CINIC10"]
plt.figure(figsize=(5.75, 2), dpi=300)
for i in range(0, len(order)+1, 2):
    plt.fill_between([-1, 10], [i-0.5, i-0.5], [i+0.5, i+0.5], color="gainsboro", alpha=0.3, linewidth=0)

sns.stripplot(x="Speedup", y="Experiment Group", hue="Dataset",
              data=all_speedups_df, dodge=True, alpha=.6, zorder=1, order=order, hue_order=hue_order)
plt.ylabel(None)
plt.xlabel("RHOLS speedup over uniform training", fontsize=8)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.legend(fontsize=8, loc="upper right", shadow=True, fancybox=True, bbox_to_anchor=(1.1, 1.1), title="Dataset")
plt.plot([1, 1], [-5, 10], 'k--', linewidth=0.5)
plt.ylim([6.5, -0.5])
plt.xlim([-0.1, 7])
plt.xticks([0, 1, 3, 6], ["0", "No speedup", "3x", "6x"], fontsize=7)
plt.tight_layout()
plt.savefig("figure_outputs/figure5.pdf", bbox_inches="tight")