In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd

import wandb

import functools

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams["font.family"] = "Times"
plt.rcParams["font.weight"] = "light"

%matplotlib inline

In [None]:
# establish and plot colorblind color pallete
colors = sns.color_palette('colorblind')
sns.palplot(colors)

# Weights and Biases API Work

In [None]:
api = wandb.Api(timeout=30)

@functools.lru_cache(maxsize=2048)
def extract_run_df(run, keys):
    keys_to_extract = keys
    set_keys_to_extract = set(keys)
    
    extracted_information = []
    for row in run.scan_history(list(keys)):
        extracted_row = {}
        if len(set(row.keys()).intersection(set_keys_to_extract)) > 1:
            for key in keys_to_extract:
                if key in row.keys():
                    extracted_row[key] = row[key]

            extracted_information.append(extracted_row)
    
    run_df = pd.DataFrame(extracted_information)
#     run_df = run_df.rename({"_step": "step"}, axis=1)
    
    return run_df

def filter_runs_by_tag(tag, all_runs):
    filtered_runs = [run for run in all_runs if tag in run.tags]
    return filtered_runs

# More Utils

In [None]:
cinic10_name_to_select_method_dict = {
    "uniform": "Uniform Sampling", 
    "_reducible_loss": "Reducible Loss (Ours)",
    "importance_sampling": "Gradient Norm IS",
    "_irreducible_loss": "Irreducible Loss", 
    "_gradnorm_ub": "Gradient Norm", 
    "_loss": "Loss",
}

selection_methods = ["Reducible Loss (Ours)", "Uniform Sampling", 
                     "Irreducible Loss", "Gradient Norm", "Loss", "SVP",  "Gradient Norm IS"]

selection_methods_color_dict = {
    "Reducible Loss (Ours)": colors[0], 
    "Reducible Loss (Ours)\nSmall Irrloss Model": colors[0], 
    "Uniform Sampling": colors[1], 
    "Irreducible Loss": colors[2],
    "Irreducible Loss\nSmall Irrloss Model": colors[2],
    "Gradient Norm": colors[3],
    "Loss": colors[4],
    "SVP": colors[5],
    "Gradient Norm IS": colors[6],
    "BALD": colors[2], 
    "Entropy": colors[3], 
    "Conditional Entropy": colors[4],
    "Loss Minus Conditional Entropy": colors[5]
}

metrics_color_dict = {
    "Reducible Loss": colors[0], 
    "Gradient Norm": colors[3],
    "Loss": colors[4],
}

def str_to_selection_method_from_dict(string, name_dict):
    for k, v in name_dict.items():
        if k in string:
            return v
    return f"{string} not found"
        
def run_to_selection_method(run):
    if run.config["logger/wandb/project"] == "svp_final":
        return "SVP"
    elif run.config["model/_target_"] == "src.models.ImportanceSamplingModel.ImportanceSamplingModel":
        return "Gradient Norm IS"
    else:
        name_dict  = {
            "src.curricula.selection_methods.uniform_selection": "Uniform Sampling", 
            "src.curricula.selection_methods.reducible_loss_selection": "Reducible Loss (Ours)", 
            "src.curricula.selection_methods.irreducible_loss_selection": "Irreducible Loss", 
            "src.curricula.selection_methods.gradnorm_ub_selection": "Gradient Norm", 
            "src.curricula.selection_methods.ce_loss_selection": "Loss",
            "src.curricula.selection_methods.bald_selection": "BALD",
            "src.curricula.selection_methods.entropy_selection": "Entropy",
            "src.curricula.selection_methods.conditional_entropy_selection": "Conditional Entropy",
            "src.curricula.selection_methods.loss_minus_conditional_entropy_selection": "Loss Minus Conditional Entropy"
        }
        return str_to_selection_method_from_dict(run.config["selection_method/_target_"], name_dict)
        
def df_to_xvals(df):
    return df["trainer/global_step"].to_numpy()

def compute_speedup(runs_all_info, baseline_name, ours_name):
    baseline_dfs = [d for _, c, d in runs_all_info if c == baseline_name]
    ours_dfs = [d for _, c, d in runs_all_info if c == ours_name]
            
    x_vals_baseline = df_to_xvals(baseline_dfs[0])
    y_vals_baseline = np.zeros(shape=(x_vals_baseline.size, len(baseline_dfs)))
    for sm_df_i, sm_df in enumerate(baseline_dfs):
        acc_baseline = 100*sm_df["val_acc_epoch"].to_numpy()
        y_vals_baseline[:acc_baseline.size, sm_df_i] = acc_baseline
    
    x_vals_ours = df_to_xvals(ours_dfs[0])
    y_vals_ours = np.zeros(shape=(x_vals_ours.size, len(ours_dfs)))
    for sm_df_i, sm_df in enumerate(ours_dfs):
        acc_ours = 100*sm_df["val_acc_epoch"].to_numpy()
        y_vals_ours[:acc_ours.size, sm_df_i] = acc_ours
    
    baseline_max = np.max(np.mean(y_vals_baseline, axis=-1))
    baseline_max_step = x_vals_baseline[np.argmax(np.mean(y_vals_baseline, axis=-1))]
    
    steps_outperformed = np.zeros(shape=len(ours_dfs))
    for run_i in range(len(ours_dfs)):
        nzs = np.nonzero(y_vals_ours[:, run_i] > baseline_max)[0]
        if len(nzs) > 0:
            indx = nzs[0]
            steps_outperformed[run_i] = x_vals_ours[indx]
        else:
            steps_outperformed[run_i] = np.inf
            
    return(np.mean(baseline_max_step)/np.mean(steps_outperformed))

In [None]:
def run_to_dataset(run):
    if "CINIC10" in run.config["datamodule"]:
        return "CINIC10"
    elif "CIFAR100" in run.config["datamodule"]:
        return "CIFAR100"
    elif "CIFAR10" in run.config["datamodule"]:
        return "CIFAR10"

# Active Learning Baselines

# MNIST

In [None]:
selection_methods = ["Reducible Loss (Ours)", "Uniform Sampling", 
                     "BALD", "Entropy", "Conditional Entropy", "Loss Minus Conditional Entropy"]

In [None]:
cifar_runs

In [None]:
keys = ["trainer/global_step", "val_acc_epoch"]

mnist_runs = [*filter_runs_by_tag("mnist_active_learning", api.runs("goldiprox/mnist_active_learning"))]
mnist_runs_all_info = list(zip(mnist_runs, [run_to_selection_method(r) for r in mnist_runs], [extract_run_df(r, tuple(keys)) for r in mnist_runs]))# convert keys to tuple to allow LRU cache to be used
mnist_runs_all_info = [x for x in mnist_runs_all_info if not x[2].empty]

cifar_runs = [*filter_runs_by_tag("cifar10_active_learning", api.runs("goldiprox/cifar10_active_learning_updatedv2"))]
cifar_runs_all_info = list(zip(cifar_runs, [run_to_selection_method(r) for r in cifar_runs], [extract_run_df(r, tuple(keys)) for r in cifar_runs]))# convert keys to tuple to allow LRU cache to be used
cifar_runs_all_info = [x for x in cifar_runs_all_info if not x[2].empty]

In [None]:
def figure2a_subplot(selection_methods, runs_all_info, ylim):
    for sm_i, sm in enumerate(selection_methods):
        sm_dfs = [d for _, c, d in runs_all_info if c == sm]
                
        if len(sm_dfs) == 0:
            print(f"Could not find any dfs corresponding to {sm}")
            continue
            
        x_vals = df_to_xvals(sm_dfs[0])
        y_vals = np.zeros(shape=(x_vals.size, len(sm_dfs)))
        
        for sm_df_i, sm_df in enumerate(sm_dfs):
            acc = 100*sm_df["val_acc_epoch"].to_numpy()
            y_vals[:acc.size, sm_df_i] = acc
            
        plt.plot(x_vals, np.mean(y_vals, axis=-1), color=selection_methods_color_dict[sm], linewidth=1, label=sm)
        plt.fill_between(x_vals, np.min(y_vals, axis=-1), np.max(y_vals, axis=-1), color=selection_methods_color_dict[sm], alpha=0.15, linewidth=0)
        plt.xlabel("Steps", fontsize=10)
        
        plt.ylabel("Test Accuracy (%)", fontsize=10)
        plt.xticks(fontsize=8)
        plt.yticks(fontsize=8)
        plt.ylim(ylim)
        plt.xlim([np.min(x_vals), np.max(x_vals)])
        
def figure2a_subplot_alt(selection_methods, runs_all_info, xlim):
    for sm_i, sm in enumerate(selection_methods):
        sm_dfs = [d for _, c, d in runs_all_info if c == sm]
        if len(sm_dfs) == 0:
            print(f"Could not find any dfs corresponding to {sm}")
            continue
            
        x_vals = df_to_xvals(sm_dfs[0])
        acc_vals = np.zeros(shape=(x_vals.size, len(sm_dfs)))

        for sm_df_i, sm_df in enumerate(sm_dfs):
            acc_vals[:, sm_df_i] = 100*sm_df["val_acc_epoch"].to_numpy()
            
        xrange = np.linspace(xlim[0], xlim[1], 50)
        steps_needed = np.zeros((xrange.size, len(sm_dfs)))
                
        for i, acc in enumerate(xrange):
            for j in range(len(sm_dfs)):
                exceeded_acc = acc_vals[:, j] > acc
                if np.sum(exceeded_acc) > 0: # i.e., we exceeded the accuracy
                    steps_needed[i, j] = x_vals[np.nonzero(exceeded_acc)[0][0]]
                else:
                    steps_needed[i, j] = np.nan
                    
        plt.plot(xrange, np.mean(steps_needed, axis=-1), color=selection_methods_color_dict[sm], linewidth=1.5, label=sm)
        plt.fill_between(xrange, np.min(steps_needed, axis=-1), np.max(steps_needed, axis=-1), color=selection_methods_color_dict[sm], alpha=0.15, linewidth=0)
        plt.ylabel("Steps Required\nLower is Better", fontsize=10)
        
        plt.xlabel("Target Accuracy (%)", fontsize=10)
        plt.xticks(fontsize=8)
        plt.yticks(fontsize=8)
        plt.xlim(xlim)

In [None]:
plt.figure(figsize=(5.75, 2.45), dpi=300)
plt.subplot(121)
figure2a_subplot(selection_methods, mnist_runs_all_info, [0, 100])
plt.title("MNIST Active Learning", fontsize=10)
plt.xlim([0, 2000])
plt.ylim([75, 98.5])
plt.subplot(122)
figure2a_subplot(selection_methods, cifar_runs_all_info, [0, 100])
plt.title("CIFAR Active Learning", fontsize=10)
plt.ylim([15, 75])
plt.tight_layout()

plt.legend(fancybox=True, shadow=True, fontsize=8, loc="upper center", bbox_to_anchor=(-0.15, -0.25), ncol=3)
plt.savefig("figure_outputs/figure_al_baselines.pdf", bbox_inches="tight")

# Figure 2a—alt

In [None]:
plt.figure(figsize=(5.75, 2), dpi=300)
plt.subplot(131)
figure2a_subplot_alt(selection_methods, cifar10_runs_all_info, [0, 110])
plt.title("(a) CIFAR10", fontsize=10)
print(f"CIFAR10 speedup: {compute_speedup(cifar10_runs_all_info, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.subplot(132)
figure2a_subplot_alt(selection_methods, cifar100_runs_all_info, [0, 70])
plt.title("(b) CIFAR100", fontsize=10)
plt.ylabel(None)
print(f"CIFAR100 speedup: {compute_speedup(cifar100_runs_all_info, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.subplot(133)
figure2a_subplot_alt(selection_methods, cinic10_runs_all_info, [0, 90])
plt.title("(c) CINIC10", fontsize=10)
plt.ylabel(None)
plt.ylim([0, 20000])
print(f"CINIC10 speedup: {compute_speedup(cinic10_runs_all_info, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")

plt.tight_layout()
plt.legend(fancybox=True, shadow=True, fontsize=8, ncol=3, bbox_to_anchor=(-1, -0.35), loc="upper center")
plt.savefig("figure_outputs/figure_2a.pdf", bbox_inches='tight')

# Figure 2b

In [None]:
keys = ["trainer/global_step", "val_acc_epoch"]

cifar10_runs_2b = [*filter_runs_by_tag("cifar10_labelnoise", api.runs("goldiprox/jb_cifar10")), *filter_runs_by_tag("cifar10_labelnoise", api.runs("goldiprox/svp_final"))]
cifar10_runs_all_info_2b = list(zip(cifar10_runs_2b, [run_to_selection_method(r) for r in cifar10_runs_2b], [extract_run_df(r, tuple(keys)) for r in cifar10_runs_2b]))# convert keys to tuple to allow LRU cache to be used

cinic10_runs_2b = [*filter_runs_by_tag("cinic10_labelnoise", api.runs("goldiprox/goldiprox")), *filter_runs_by_tag("cinic10_labelnoise", api.runs("goldiprox/svp_final"))]
cinic10_runs_all_info_2b = list(zip(cinic10_runs_2b, [run_to_selection_method(r) for r in cinic10_runs_2b], [extract_run_df(r, tuple(keys)) for r in cinic10_runs_2b]))# convert keys to tuple to allow LRU cache to be used

cifar100_runs_2b = [*filter_runs_by_tag("cifar100_labelnoise", api.runs("goldiprox/cifar100")), *filter_runs_by_tag("cifar100_labelnoise", api.runs("goldiprox/svp_final"))]
cifar100_runs_all_info_2b = list(zip(cinic10_runs_2b, [run_to_selection_method(r) for r in cifar100_runs_2a], [extract_run_df(r, tuple(keys)) for r in cifar100_runs_2b]))

# Figure 2b – alt

In [None]:
plt.figure(figsize=(5.75, 2), dpi=300)
plt.subplot(131)
figure2a_subplot_alt(selection_methods, cinic10_runs_all_info_2b, [0, 110])
plt.title("CIFAR10", fontsize=10)
print(f"CIFAR10 speedup: {compute_speedup(cifar10_runs_all_info_2b, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.subplot(132)
figure2a_subplot_alt(selection_methods, cifar100_runs_all_info_2b, [0, 70])
plt.title("CIFAR100", fontsize=10)
plt.ylabel(None)
print(f"CIFAR100 speedup: {compute_speedup(cifar100_runs_all_info_2b, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.subplot(133)
figure2a_subplot_alt(selection_methods, cinic10_runs_all_info_2b, [0, 90])
print(f"CINIC10 speedup: {compute_speedup(cinic10_runs_all_info_2b, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.title("CINIC10", fontsize=10)
plt.ylabel(None)
plt.tight_layout()
# plt.legend(fancybox=True, shadow=True, fontsize=8, ncol=3, bbox_to_anchor=(-1, -0.35), loc="upper center")
plt.savefig("figure_outputs/figure_2b.pdf", bbox_inches='tight')

# Figure 2 Combined

In [None]:
plt.figure(figsize=(5.75, 3.75), dpi=300)
plt.subplot(231)
figure2a_subplot_alt(selection_methods, cifar10_runs_all_info, [0, 110])
plt.title("Half of CIFAR10", fontsize=10)
plt.xlabel(None)
print(f"CIFAR10 speedup: {compute_speedup(cifar10_runs_all_info, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.subplot(232)
figure2a_subplot_alt(selection_methods, cifar100_runs_all_info, [0, 70])
plt.title("Half of CIFAR100", fontsize=10)
plt.ylabel(None)
plt.xlabel(None)
print(f"CIFAR100 speedup: {compute_speedup(cifar100_runs_all_info, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
splt = plt.subplot(233)
figure2a_subplot_alt(selection_methods, cinic10_runs_all_info, [0, 90])
plt.title("CINIC10", fontsize=10)
plt.ylabel(None)
plt.xlabel(None)
plt.ylim([0, 20000])
print(f"CINIC10 speedup: {compute_speedup(cinic10_runs_all_info, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")

plt.subplot(234)
figure2a_subplot_alt(selection_methods, cifar10_runs_all_info_2b, [0, 100])
plt.title("Half of CIFAR10\n(Label Noise)", fontsize=10)
print(f"CIFAR10 speedup: {compute_speedup(cifar10_runs_all_info_2b, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.subplot(235)
figure2a_subplot_alt(selection_methods, cifar100_runs_all_info_2b, [0, 70])
plt.title("Half of CIFAR100\n(Label Noise)", fontsize=10)
plt.ylabel(None)
print(f"CIFAR100 speedup: {compute_speedup(cifar100_runs_all_info_2b, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.subplot(236)
figure2a_subplot_alt(selection_methods, cinic10_runs_all_info_2b, [0, 90])
print(f"CINIC10 speedup: {compute_speedup(cinic10_runs_all_info_2b, 'Uniform Sampling', 'Reducible Loss (Ours)'):.2f}")
plt.title("CINIC10\n(Label Noise)", fontsize=10)
plt.ylabel(None)
plt.ylim([0, 20000])
plt.tight_layout()

splt.legend(fancybox=True, shadow=True, fontsize=8, ncol=3, bbox_to_anchor=(-1, -2.05), loc="upper center")
plt.savefig("figure_outputs/figure_2_combined.pdf", bbox_inches='tight')