In [45]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys, copy, os, shutil
from tqdm.notebook import tqdm
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import itertools
from scipy.stats import norm

# translations for our datasets (full title + dimension D)
dataset_descs = {"avazu-app_binary_sparse" : ("Avazu (App)", 1000000),
                 "avazu-site_binary_sparse" : ("Avazu (Site)", 1000000),
                 "criteo_binary_sparse" : ("Criteo", 1000000),
                 "dexter_binary_sparse" : ("Dexter", 20000),
                 "dorothea_binary_sparse" : ("Dorothea", 100000),
                 "kdd2010-a_binary_sparse" : ("KDD2010 (Algebra)", 20216830),
                 "mnist8-4+9_binary_sparse" : ("MNIST8 (4+9)", 784),
                 "news20_binary_sparse" : ("News20", 1355191),
                 "newsgroups_binary_sparse" : ("Newsgroups (Binary, CS)", 101631),
                 "pcmac_binary_sparse" : ("PCMAC", 3289),
                 "rcv1_binary_sparse" : ("RCV1", 47236),
                 "real-sim_binary_sparse" : ("Real-Sim", 20958),
                 "sst2_binary_sparse" : ("SST-2", 13757),
                 "url_binary_sparse" : ("URL", 3231961),
                 "w8a_binary_sparse" : ("W8A", 300),
                 "webspam_binary_sparse" : ("Webspam", 254)}

# which datasets are we working with?
ordered_datasets = sorted(list(dataset_descs.keys()))

# Computational Costs Per Iteration

In [47]:
'''
Table of seconds per iteration of {PAC, FSOL} x {WRS, top-K, moving average, exponential average}:
1. Standardize against WRS (dense + K=64) having 1 unit of time.
2. Even if partial results, just note them.
'''

# create a dictionary to store our dataframes
time_dfs = {}

# go thru all of our base models
for base_model in ["PAC", "FSOL"]:
    
    # create our table, where each row corresponds to a dataset
    df = pd.DataFrame(data=None, columns=["Dataset", "D", "WRS", "Moving Avg.", "Expo. Avg."])
    
    # iterate thru our datasets
    for dataset in ordered_datasets:
        
        # start our row with the name + D
        row = list(dataset_descs[dataset])
        
        # WRS
        WRS = np.nanmean(list(icfi([list(pd.read_csv(f"WRS/results/{base_model}/{dataset}"
                                                     f"/model={base_model}_ws=dense_K=64_seed={seed}_metrics.csv")\
                                         .time_elapsed.values) for seed in range(5)])))
        
        # Moving-Avg.
        MAG = np.nanmean(list(icfi([list(pd.read_csv(f"moving_average/results/{base_model}/{dataset}"
                                                     f"/model={base_model}_K=64_seed={seed}_metrics.csv")\
                                         .time_elapsed.values) for seed in range(5)])))
        
        # Expo. Avg.
        EAG = np.nanmean(list(icfi([list(pd.read_csv(f"exponential_average/results/{base_model}/{dataset}"
                                                     f"/model={base_model}_seed={seed}_metrics.csv")\
                                         .time_elapsed.values) for seed in range(5)])))
        
        # add to our row + to our dataframe
        row += [np.round(WRS/WRS, 3), np.round(MAG/WRS, 3), np.round(EAG/WRS, 3)]
        df.loc[len(df.index)] = row
        
    # at the end, add to our dataframe dictionary
    time_dfs[base_model] = df.sort_values(by="D")

In [48]:
# combined dataframe to insert into the .tex file
combined = pd.concat([time_dfs["PAC"][["Dataset", "D"]],
                      time_dfs["PAC"][["Moving Avg.", "Expo. Avg."]], 
                      time_dfs["FSOL"][["Moving Avg.", "Expo. Avg."]]], axis=1)

In [49]:
# show the full time complexity table for LaTeX
print(combined.to_latex(index=False))

\begin{tabular}{lrrrrr}
\toprule
Dataset & D & Moving Avg. & Expo. Avg. & Moving Avg. & Expo. Avg. \\
\midrule
Webspam & 254 & 0.953000 & 0.717000 & 0.738000 & 0.613000 \\
W8A & 300 & 0.420000 & 0.353000 & 0.368000 & 0.336000 \\
MNIST8 (4+9) & 784 & 1.125000 & 0.885000 & 0.968000 & 0.887000 \\
PCMAC & 3289 & 0.483000 & 0.312000 & 0.377000 & 0.235000 \\
SST-2 & 13757 & 0.903000 & 0.431000 & 0.748000 & 0.411000 \\
Dexter & 20000 & 0.577000 & 0.143000 & 0.285000 & 0.066000 \\
Real-Sim & 20958 & 0.949000 & 0.624000 & 0.707000 & 0.564000 \\
RCV1 & 47236 & 1.825000 & 0.892000 & 1.366000 & 1.011000 \\
Dorothea & 100000 & 1.194000 & 0.190000 & 0.580000 & 0.254000 \\
Newsgroups (Binary, CS) & 101631 & 3.587000 & 1.042000 & 1.376000 & 0.543000 \\
Avazu (App) & 1000000 & 8.357000 & 2.507000 & 8.078000 & 2.544000 \\
Avazu (Site) & 1000000 & 18.704000 & 3.207000 & 11.612000 & 3.683000 \\
Criteo & 1000000 & 14.267000 & 2.876000 & 7.009000 & 2.381000 \\
News20 & 1355191 & 4.288000 & 1.356000 & 3.3760

# WRS vs. Moving Average vs. Exponential Average

In [29]:
'''
Test accuracies of {PAC} x {WRS, moving average, exponential average}:
- 4x4 grid of subplots to cover 16 datasets.
'''
# take out the datasets that we don't have runs for
datasets = [dataset for dataset in ordered_datasets]

# for each base model ...
for model in ["PAC"]:
    
    # custom zoom-in for each plot
    ylims = {"avazu-app_binary_sparse" : 0.6,
             "avazu-site_binary_sparse" : 0.75,
             "criteo_binary_sparse" : 0.75,
             "dexter_binary_sparse" : 0.6,
             "dorothea_binary_sparse" : 0.85,
             "kdd2010-a_binary_sparse" : 0.85,
             "mnist8-4+9_binary_sparse" : 0.8,
             "news20_binary_sparse": 0.6,
             "newsgroups_binary_sparse" : 0.6,
             "pcmac_binary_sparse" : 0.7,
             "rcv1_binary_sparse" : 0.85,
             "real-sim_binary_sparse" : 0.75,
             "sst2_binary_sparse" : 0.6,
             "url_binary_sparse" : 0.9,
             "w8a_binary_sparse" : 0.9,
             "webspam_binary_sparse" : 0.8}

    # create a figure of 4x4 grid of subplots
    fig, ax = plt.subplots(4, 4, dpi=200, figsize=(16, 12))

    # go thru each dataset + plot the instantaneous test accuracy for base algorithm seed=0
    for i, dataset in enumerate(datasets):
        
        ##### INSTANTANEOUS METRICS
        
        # load in the instantaneous metrics for seed=0
        if model == "PAC":

            # immediately load in the best hyperparameters for this dataset + model
            log10Cerr = pd.read_csv("WRS/base_variants/PAC_hparams.csv")\
            .query(f"dataset == '{dataset}'")[["log10Cerr"]].values[0,0]
            log10Cerr = int(log10Cerr)

            # load in the file
            logs_inst = pd.read_csv(f"hparam_tuning/results/{model}/{dataset}/model={model}_log10Cerr={log10Cerr}_seed=0_metrics.csv")

        elif model == "FSOL":

            # immediately load in the best hyperparameters for this dataset + model
            log2eta, log10lmbda = pd.read_csv("WRS/base_variants/FSOL_hparams.csv")\
            .query(f"dataset == '{dataset}'")[["log2eta", "log10lmbda"]].values[0]
            log2eta, log10lmbda = log2eta, log10lmbda

            # load in the file
            logs_inst = pd.read_csv(f"hparam_tuning/results/{model}/{dataset}/model={model}_log2eta={log2eta}_log10lmbda={log10lmbda}_seed=0_metrics.csv")
        
        ##### ENSEMBLE METRICS
        
        # 1. WRS-augmented training (use K=64)
        WRS = pd.read_csv(f"WRS/results/{model}/{dataset}/model={model}_ws=dense_K=64_seed=0_metrics.csv")
        
        # 2. SIMPLE-(MOVING) AVERAGE (use K=64)
        SA = pd.read_csv(f"moving_average/results/{model}/{dataset}/model={model}_K=64_seed=0_metrics.csv")
        
        # 3. EXPONENTIAL-AVERAGE (just one setting, gamma=0.9)
        EA = pd.read_csv(f"exponential_average/results/{model}/{dataset}/model={model}_seed=0_metrics.csv")
        
        #### PLOTTING
        
        ax[i // 4, i % 4].grid()
        ax[i // 4, i % 4].plot(logs_inst["timestep"], logs_inst["inst_test-set-acc"], color="grey", alpha=0.4, label=model)
        ax[i // 4, i % 4].plot(WRS["timestep"], WRS["WRS_test-set-acc_SA"], color="black", linewidth=1.0, alpha=1.0, label="WRS")
        ax[i // 4, i % 4].plot(WRS["timestep"], WRS["WRS_test-set-acc_SA"], color="yellow", linewidth=5.0, alpha=0.5, label="WRS")
        ax[i // 4, i % 4].plot(SA["timestep"], SA["SW_test-set-acc"], color="blue", linewidth=1.0, label="Moving Avg.")
        ax[i // 4, i % 4].plot(EA["timestep"], EA["EW_test-set-acc"], color="red", linewidth=1.0, label="Expo. Avg.")
        
        # show what the dataset size is too, reformat if necessary
        ax[i // 4, i % 4].set_title(dataset_descs[dataset][0], fontsize=16)
        ax[i // 4, i % 4].tick_params("both", labelsize=13)
        ax[i // 4, i % 4].set_ylim(bottom=ylims[dataset])
    
    # create our custom legend
    custom_lines = [Line2D([0], [0], color="black", linewidth=1.0, alpha=1.0, label=f"WRS-Augmented Training"),
                    Line2D([0], [0], color="yellow", linewidth=5.0, alpha=0.5, label=f"WRS-Augmented Training (For Emphasis)"),
                    Line2D([0], [0], color="blue", linewidth=1.0, label="Moving Avg."),
                    Line2D([0], [0], color="red", linewidth=1.0, label="Expo. Avg."),
                    Line2D([0], [0], color="grey", linestyle=None, 
                           label=model)]
    fig.legend(handles=custom_lines, loc="lower center", ncol=3, fontsize=16, bbox_to_anchor=(0.5, -0.075))
    
    # beautify
    plt.tight_layout()
    plt.savefig("figures/WRS+competitor-wrappers_PAC_final.png", facecolor="white", bbox_inches="tight")
    plt.close()

In [31]:
'''
Test accuracies of {FSOL} x {WRS, top-K, moving average, exponential average}:
- 4x4 grid of subplots to cover 16 datasets.
'''
# take out the datasets that we don't have runs for
datasets = [dataset for dataset in ordered_datasets]

# for each base model ...
for model in ["FSOL"]:
    
    # custom zoom-in for each plot
    ylims = {"avazu-app_binary_sparse" : 0.6,
             "avazu-site_binary_sparse" : 0.75,
             "criteo_binary_sparse" : 0.75,
             "dexter_binary_sparse" : 0.6,
             "dorothea_binary_sparse" : 0.85,
             "kdd2010-a_binary_sparse" : 0.75,
             "mnist8-4+9_binary_sparse" : 0.8,
             "news20_binary_sparse": 0.6,
             "newsgroups_binary_sparse" : 0.6,
             "pcmac_binary_sparse" : 0.7,
             "rcv1_binary_sparse" : 0.85,
             "real-sim_binary_sparse" : 0.75,
             "sst2_binary_sparse" : 0.6,
             "url_binary_sparse" : 0.9,
             "w8a_binary_sparse" : 0.6,
             "webspam_binary_sparse" : 0.8}

    # create a figure of 4x4 grid of subplots
    fig, ax = plt.subplots(4, 4, dpi=200, figsize=(16, 12))

    # go thru each dataset + plot the instantaneous test accuracy for base algorithm seed=0
    for i, dataset in enumerate(datasets):
        
        ##### INSTANTANEOUS METRICS
        
        # load in the instantaneous metrics for seed=0
        if model == "PAC":

            # immediately load in the best hyperparameters for this dataset + model
            log10Cerr = pd.read_csv("WRS/base_variants/PAC_hparams.csv")\
            .query(f"dataset == '{dataset}'")[["log10Cerr"]].values[0,0]
            log10Cerr = int(log10Cerr)

            # load in the file
            logs_inst = pd.read_csv(f"hparam_tuning/results/{model}/{dataset}/model={model}_log10Cerr={log10Cerr}_seed=0_metrics.csv")

        elif model == "FSOL":

            # immediately load in the best hyperparameters for this dataset + model
            log2eta, log10lmbda = pd.read_csv("WRS/base_variants/FSOL_hparams.csv")\
            .query(f"dataset == '{dataset}'")[["log2eta", "log10lmbda"]].values[0]
            log2eta, log10lmbda = log2eta, log10lmbda

            # load in the file
            logs_inst = pd.read_csv(f"hparam_tuning/results/{model}/{dataset}/model={model}_log2eta={log2eta}_log10lmbda={log10lmbda}_seed=0_metrics.csv")
        
        ##### ENSEMBLE METRICS
        
        # 1. WRS-augmented training (use K=64)
        WRS = pd.read_csv(f"WRS/results/{model}/{dataset}/model={model}_ws=dense_K=64_seed=0_metrics.csv")
        
        # 2. SIMPLE-AVERAGE (use K=64)
        SA = pd.read_csv(f"moving_average/results/{model}/{dataset}/model={model}_K=64_seed=0_metrics.csv")
        
        # 3. EXPONENTIAL-AVERAGE (just one setting, gamma=0.9)
        EA = pd.read_csv(f"exponential_average/results/{model}/{dataset}/model={model}_seed=0_metrics.csv")
        
        #### PLOTTING
        
        ax[i // 4, i % 4].grid()
        ax[i // 4, i % 4].plot(logs_inst["timestep"], logs_inst["inst_test-set-acc"], color="grey", alpha=0.4, label=model)
        ax[i // 4, i % 4].plot(WRS["timestep"], WRS["WRS_test-set-acc_SA"], color="black", linewidth=1.0, alpha=1.0, label="WRS")
        ax[i // 4, i % 4].plot(WRS["timestep"], WRS["WRS_test-set-acc_SA"], color="yellow", linewidth=5.0, alpha=0.5, label="WRS")
        ax[i // 4, i % 4].plot(SA["timestep"], SA["SW_test-set-acc"], color="blue", linewidth=1.0, label="Moving Avg.")
        ax[i // 4, i % 4].plot(EA["timestep"], EA["EW_test-set-acc"], color="red", linewidth=1.0, label="Expo. Avg.")
        
        # show what the dataset size is too, reformat if necessary
        ax[i // 4, i % 4].set_title(dataset_descs[dataset][0], fontsize=16)
        ax[i // 4, i % 4].tick_params("both", labelsize=13)
        ax[i // 4, i % 4].set_ylim(bottom=ylims[dataset])
    
    # create our custom legend
    custom_lines = [Line2D([0], [0], color="black", linewidth=1.0, alpha=1.0, label=f"WRS-Augmented Training"),
                    Line2D([0], [0], color="yellow", linewidth=5.0, alpha=0.5, label=f"WRS-Augmented Training (For Emphasis)"),
                    Line2D([0], [0], color="green", linestyle="--", linewidth=1.0, label=f"Top-K"),
                    Line2D([0], [0], color="blue", linewidth=1.0, label="Moving Avg."),
                    Line2D([0], [0], color="red", linewidth=1.0, label="Expo. Avg."),
                    Line2D([0], [0], color="grey", linestyle=None, 
                           label=model)]
    fig.legend(handles=custom_lines, loc="lower center", ncol=3, fontsize=16, bbox_to_anchor=(0.5, -0.075))
    
    # beautify
    plt.tight_layout()
    plt.savefig("figures/WRS+competitor-wrappers_FSOL_final.png", facecolor="white", bbox_inches="tight")
    plt.close()

# WRS on ADAGRAD, SGD+Momentum, TGD

In [40]:
'''
Test accuracies of {SGD+momentum} x {base, WRS}:
- 4x4 grid of subplots to cover 16 datasets.
'''
# take out the datasets that we don't have runs for
datasets = ordered_datasets

# for each base model ...
for model in ["sgd+momentum"]:
    
    # custom zoom-in for each plot
    ylims = {"avazu-app_binary_sparse" : 0.8,
             "avazu-site_binary_sparse" : 0.6,
             "criteo_binary_sparse" : 0.6,
             "dexter_binary_sparse" : 0.6,
             "dorothea_binary_sparse" : 0.75,
             "kdd2010-a_binary_sparse" : 0.75,
             "mnist8-4+9_binary_sparse" : 0.82,
             "news20_binary_sparse": 0.6,
             "newsgroups_binary_sparse" : 0.6,
             "pcmac_binary_sparse" : 0.6,
             "rcv1_binary_sparse" : 0.93,
             "real-sim_binary_sparse" : 0.85,
             "sst2_binary_sparse" : 0.6,
             "url_binary_sparse" : 0.9,
             "w8a_binary_sparse" : 0.9,
             "webspam_binary_sparse" : 0.8}

    # create a figure of 4x4 grid of subplots
    fig, ax = plt.subplots(4, 4, dpi=200, figsize=(16, 12))

    # go thru each dataset + plot the instantaneous test accuracy for base algorithm seed=0
    for i, dataset in enumerate(datasets):
        
        ##### ENSEMBLE METRICS
        
        # load in our results for SEED=0
        logs = pd.read_csv(f"sgd_variants/results/{model}/{dataset}/model={model}_K=64_seed=0_metrics.csv")
        
        #### PLOTTING
        
        ax[i // 4, i % 4].grid()
        ax[i // 4, i % 4].plot(logs["timestep"], logs["inst_test-set-acc"], color="grey", alpha=0.4, label=model)
        ax[i // 4, i % 4].plot(logs["timestep"], logs["WRS_test-set-acc"], color="black", linewidth=1.0, alpha=1.0, label="WRS")
        ax[i // 4, i % 4].plot(logs["timestep"], logs["WRS_test-set-acc"], color="yellow", linewidth=5.0, alpha=0.5, label="WRS")
        
        # show what the dataset size is too, reformat if necessary
        ax[i // 4, i % 4].set_title(dataset_descs[dataset][0], fontsize=16)
        ax[i // 4, i % 4].tick_params("both", labelsize=13)
        ax[i // 4, i % 4].set_ylim(bottom=ylims[dataset])
    
    # create our custom legend
    custom_lines = [Line2D([0], [0], color="black", linewidth=1.0, alpha=1.0, label=f"WRS-Augmented Training"),
                    Line2D([0], [0], color="yellow", linewidth=5.0, alpha=0.5, label=f"WRS-Augmented Training (For Emphasis)"),
                    Line2D([0], [0], color="grey", linestyle=None, label="SGD+Momentum")]
    fig.legend(handles=custom_lines, loc="lower center", ncol=3, fontsize=16, bbox_to_anchor=(0.5, -0.075))
    
    # beautify
    plt.tight_layout()
    plt.savefig("figures/SGDM+WRS_final.png", facecolor="white", bbox_inches="tight")
    plt.close()

In [44]:
'''
Test accuracies of {AdaGrad} x {base, WRS}:
- 4x4 grid of subplots to cover 16 datasets.
'''
# take out the datasets that we don't have runs for
datasets = ordered_datasets

# for each base model ...
for model in ["adagrad"]:
    
    # custom zoom-in for each plot
    ylims = {"avazu-app_binary_sparse" : 0.88,
             "avazu-site_binary_sparse" : 0.75,
             "criteo_binary_sparse" : 0.75,
             "dexter_binary_sparse" : 0.75,
             "dorothea_binary_sparse" : 0.75,
             "kdd2010-a_binary_sparse" : 0.85,
             "mnist8-4+9_binary_sparse" : 0.93,
             "news20_binary_sparse": 0.6,
             "newsgroups_binary_sparse" : 0.6,
             "pcmac_binary_sparse" : 0.6,
             "rcv1_binary_sparse" : 0.93,
             "real-sim_binary_sparse" : 0.85,
             "sst2_binary_sparse" : 0.6,
             "url_binary_sparse" : 0.96,
             "w8a_binary_sparse" : 0.97,
             "webspam_binary_sparse" : 0.8}

    # create a figure of 4x4 grid of subplots
    fig, ax = plt.subplots(4, 4, dpi=200, figsize=(16, 12))

    # go thru each dataset + plot the instantaneous test accuracy for base algorithm seed=0
    for i, dataset in enumerate(datasets):
        
        ##### ENSEMBLE METRICS
        
        # load in our results for SEED=0
        logs = pd.read_csv(f"sgd_variants/results/{model}/{dataset}/model={model}_K=64_seed=0_metrics.csv")
        
        #### PLOTTING
        
        ax[i // 4, i % 4].grid()
        ax[i // 4, i % 4].plot(logs["timestep"], logs["inst_test-set-acc"], color="grey", alpha=0.4, label=model)
        ax[i // 4, i % 4].plot(logs["timestep"], logs["WRS_test-set-acc"], color="black", linewidth=1.0, alpha=1.0, label="WRS")
        ax[i // 4, i % 4].plot(logs["timestep"], logs["WRS_test-set-acc"], color="yellow", linewidth=5.0, alpha=0.5, label="WRS")
        
        # show what the dataset size is too, reformat if necessary
        ax[i // 4, i % 4].set_title(dataset_descs[dataset][0], fontsize=16)
        ax[i // 4, i % 4].tick_params("both", labelsize=13)
        ax[i // 4, i % 4].set_ylim(bottom=ylims[dataset])
    
    # create our custom legend
    custom_lines = [Line2D([0], [0], color="black", linewidth=1.0, alpha=1.0, label=f"WRS-Augmented Training"),
                    Line2D([0], [0], color="yellow", linewidth=5.0, alpha=0.5, label=f"WRS-Augmented Training (For Emphasis)"),
                    Line2D([0], [0], color="grey", linestyle=None, label="AdaGrad")]
    fig.legend(handles=custom_lines, loc="lower center", ncol=3, fontsize=16, bbox_to_anchor=(0.5, -0.075))
    
    # beautify
    plt.tight_layout()
    plt.savefig("figures/ADAGRAD+WRS_final.png", facecolor="white", bbox_inches="tight")
    plt.close()

In [43]:
'''
Test accuracies of {Truncated Gradient} x {base, WRS}:
- 4x4 grid of subplots to cover 16 datasets.
'''
# take out the datasets that we don't have runs for
datasets = ordered_datasets

# for each base model ...
for model in ["tgd"]:
    
    # custom zoom-in for each plot
    ylims = {"avazu-app_binary_sparse" : 0.85,
             "avazu-site_binary_sparse" : 0.75,
             "criteo_binary_sparse" : 0.75,
             "dexter_binary_sparse" : 0.6,
             "dorothea_binary_sparse" : 0.6,
             "kdd2010-a_binary_sparse" : 0.8,
             "mnist8-4+9_binary_sparse" : 0.6,
             "news20_binary_sparse": 0.6,
             "newsgroups_binary_sparse" : 0.6,
             "pcmac_binary_sparse" : 0.6,
             "rcv1_binary_sparse" : 0.85,
             "real-sim_binary_sparse" : 0.77,
             "sst2_binary_sparse" : 0.6,
             "url_binary_sparse" : 0.6,
             "w8a_binary_sparse" : 0.95,
             "webspam_binary_sparse" : 0.5}

    # create a figure of 4x4 grid of subplots
    fig, ax = plt.subplots(4, 4, dpi=200, figsize=(16, 12))

    # go thru each dataset + plot the instantaneous test accuracy for base algorithm seed=0
    for i, dataset in enumerate(datasets):
        
        ##### ENSEMBLE METRICS
        
        # load in our results for SEED=0
        logs = pd.read_csv(f"sgd_variants/results/{model}/{dataset}/model={model}_K=64_seed=0_metrics.csv")
        
        #### PLOTTING
        
        ax[i // 4, i % 4].grid()
        ax[i // 4, i % 4].plot(logs["timestep"], logs["inst_test-set-acc"], color="grey", alpha=0.4, label=model)
        ax[i // 4, i % 4].plot(logs["timestep"], logs["WRS_test-set-acc"], color="black", linewidth=1.0, alpha=1.0, label="WRS")
        ax[i // 4, i % 4].plot(logs["timestep"], logs["WRS_test-set-acc"], color="yellow", linewidth=5.0, alpha=0.5, label="WRS")
        
        # show what the dataset size is too, reformat if necessary
        ax[i // 4, i % 4].set_title(dataset_descs[dataset][0], fontsize=16)
        ax[i // 4, i % 4].tick_params("both", labelsize=13)
        ax[i // 4, i % 4].set_ylim(bottom=ylims[dataset])
    
    # create our custom legend
    custom_lines = [Line2D([0], [0], color="black", linewidth=1.0, alpha=1.0, label=f"WRS-Augmented Training"),
                    Line2D([0], [0], color="yellow", linewidth=5.0, alpha=0.5, label=f"WRS-Augmented Training (For Emphasis)"),
                    Line2D([0], [0], color="grey", linestyle=None, label="Truncated Gradient Descent (TGD)")]
    fig.legend(handles=custom_lines, loc="lower center", ncol=3, fontsize=16, bbox_to_anchor=(0.5, -0.075))
    
    # beautify
    plt.tight_layout()
    plt.savefig("figures/TGD+WRS_final.png", facecolor="white", bbox_inches="tight")
    plt.close()