In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os
from tqdm.notebook import tqdm

from p_tqdm import p_umap
from time import time
import numpy as np
from tools import *

In [None]:

from collections import OrderedDict
all_linestyles = OrderedDict(
    [('solid',               (0, ())),
     ('loosely dotted',      (0, (1, 2))),
     ('dotted',              (0, (1, 3))),
     ('densely dotted',      (0, (1, 1))),

     ('loosely dashed',      (0, (3, 3))),
     ('dashed',              (0, (3, 2))),
     ('densely dashed',      (0, (3, 1))),

     ('loosely dashdotted',  (0, (2, 2, 1, 2))),
     ('dashdotted',          (0, (2, 3, 1, 3))),
     ('densely dashdotted',  (0, (2, 1, 1, 1))),

     ('loosely dashdotdotted', (0, (2, 3, 1, 3, 1, 3))),
     ('dashdotdotted',         (0, (2, 3, 1, 2, 1, 2))),
     ('densely dashdotdotted', (0, (2, 1, 1, 1, 1, 1)))])

corval_key = "$\\bf{CORVAL}$"
markers = {"CORRAL": "s", corval_key: ".",
           'MCB[µ=0]': '>', 'MCB[µ=0.005]': '<', 'MCB[µ=0.05]': '^', 'MCB[µ=0.5]': 'v',
           'EXP4.S[µ=0]': 's', 'EXP4.S[µ=0.005]': 'p', 'EXP4.S[µ=0.05]': 'h', 'EXP4.S[µ=0.5]': 'D',
           'D-TS[µ=0]': 'P', 'D-TS[µ=0.005]': 'X', 'D-TS[µ=0.05]': 's', 'D-TS[µ=0.5]': 'd',
           'Best Expert': '*'}


cc = [  # colorblind friendly colors
    "#000000",  # Black
        "#E69F00",  # Orange
        "#56B4E9",  # Sky blue
        "#009E73",  # Bluish green
        "#CC79A7",  # Reddish purple
        "#0072B2",  # Blue
        "#D55E00",  # Vermilion
        "#F0E442",  # Yellow
]

color_per_decay_palette = {corval_key: cc[0], "CORRAL": cc[2],
                           'MCB[µ=0]': cc[3], 'MCB[µ=0.005]': cc[4], 'MCB[µ=0.05]': cc[5], 'MCB[µ=0.5]': cc[6],
                           'EXP4.S[µ=0]': cc[3], 'EXP4.S[µ=0.005]': cc[4], 'EXP4.S[µ=0.05]': cc[5], 'EXP4.S[µ=0.5]': cc[6],
                           'D-TS[µ=0]': cc[3], 'D-TS[µ=0.005]': cc[4], 'D-TS[µ=0.05]': cc[5], 'D-TS[µ=0.5]': cc[6],
                           'Best Expert': cc[1]}
color_per_alg_palette = {corval_key: cc[0], "CORRAL": cc[2],
                         'MCB[µ=0]': cc[3], 'MCB[µ=0.005]': cc[3], 'MCB[µ=0.05]': cc[3], 'MCB[µ=0.5]': cc[3],
                         'EXP4.S[µ=0]': cc[4], 'EXP4.S[µ=0.005]': cc[4], 'EXP4.S[µ=0.05]': cc[4], 'EXP4.S[µ=0.5]': cc[4],
                         'D-TS': cc[5], 'D-TS[µ=0.005]': cc[5], 'D-TS[µ=0.05]': cc[5], 'D-TS[µ=0.5]': cc[5],
                         'Best Expert': cc[1]}

In [None]:
def get_dfs(periods=None,
            seeds=None,
            arms=None,
            experts=None,
            minimize_seeds=False,
            algorithms=None,
            algs_only=False,
            minimize=False, algs=None,
            trim_decay_before=True, base='server/nonstationary_results/'):
    data = []

    filenames = os.listdir(base)
    np.random.seed(0)
    np.random.shuffle(filenames)
    printed_algs = True

    def parse_file(filename):
        file_seed = int(filename.split("_")[1])
        if seeds is not None and file_seed not in seeds:
            return None

        K, N = map(int, filename.split("_")[2:4])
        if arms is not None and K not in arms:
            return None
        if experts is not None and N not in experts:
            return None

        file_period = (filename.split("_")[:][-1].split(".")[0])
        file_period = int(
            file_period) if file_period != 'static' else file_period

        if periods is not None and file_period not in periods:
            return None
        import pandas as pd
        import numpy as np
        file_df = pd.read_feather(base+filename)

        if len(file_df.decay.unique()) == 1:
            return None

        assert not np.isnan((file_df[file_df.decay > 0].t).values[0])

        if trim_decay_before:

            file_df.algorithm = file_df.algorithm.apply(
                lambda s: s.split("decay=")[0].split("(B=")[0].split("alpha")[0].replace("(", ""))

        if algorithms is not None:
            file_df = file_df[file_df.algorithm.isin(algorithms)]

        if algs is not None:
            file_df = file_df.query("algorithm in @algs")

        if algs_only:
            file_df = file_df[(~file_df.algorithm.str.contains("expert")) & (
                ~file_df.algorithm.str.contains("optimal")) & (~file_df.algorithm.str.contains("random"))]

        file_df = file_df[file_df.decay < 1]

        if minimize:

            file_df = file_df.groupby(by=['algorithm', 'type', 'n_arms',
                                          'n_experts',   'period',
                                          'decay', 'problem'], observed=True).mean().reset_index().dropna()

        if minimize_seeds:
            file_df = file_df.groupby(by=['algorithm', 'type', 'n_arms', 't',
                                          'n_experts',  'period',
                                          'decay', 'problem'], observed=True).mean().reset_index().dropna()

        return file_df
    data = p_umap(parse_file, filenames, smoothing=0, desc="reading files")
    # data = [parse_file(f) for f in tqdm(filenames)]

    return [d for d in data if d is not None]

In [None]:

all_data = get_dfs(arms=None, experts=None, algs_only=False,
                   minimize=True, seeds=None,
                   algorithms=None)

print("merging...")
df = pd.concat(all_data)

df.groupby(["period", "algorithm", "decay", "n_experts"]).mean().reset_index().groupby(
    ["algorithm", "decay", "n_experts"]).mean().reset_index().groupby(["algorithm", "decay",]).mean().sort_values(by="performance")

In [None]:
figure_dir = 'figures'
os.makedirs(figure_dir, exist_ok=True)

In [None]:
mcb_variants = [f"MCB[µ={decay}]" for decay in (0, 0.005, 0.05, 0.5)]
exp4s_variants = [f"EXP4.S[µ={decay}]" for decay in (0, 0.005, 0.05, 0.5)]
dts_variants = [f"D-TS[µ={decay}]" for decay in (0, 0.005, 0.05, 0.5)]

In [None]:
for base_algs in (["expert 0", "CORVAL[max]",]+mcb_variants[1:]+exp4s_variants[1:]+dts_variants[1:], ["expert 0", "CORVAL[max]", "CORRAL",]+mcb_variants,):

    data = get_dfs(trim_decay_before=True, arms=[8], algs=base_algs,
                   algs_only=False, minimize=True, seeds=None,
                   algorithms=None)

    df = pd.concat(data).reset_index()

    sns.set_context('poster', font_scale=1.2, rc={"lines.linewidth": 5})


    df['N'] = df.n_experts


    df = df[['algorithm', 'type', 'N', 'period',
                'decay', 't', 'experiment', 'performance',
                'problem']]

    df.loc[df.algorithm == 'expert 0', 'algorithm'] = 'Best Expert'
    df["algorithm"] = df["algorithm"].str.replace(
        "CORVAL[max]", corval_key, regex=False)

    df['average reward'] = df.performance

    if "CORRAL" not in base_algs:
        palette = color_per_alg_palette
    else:
        palette = color_per_decay_palette

    g = sns.relplot(data=df[(df.N >= 2)], x="N", y='average reward', hue='algorithm', markers=markers, palette=palette,
                    kind='line', aspect=1.05, hue_order=sorted(df.algorithm.unique()), style="algorithm", col="period", 
                    col_order=sorted((df.period.unique()), key=lambda p: 100000 if p == 'static' else p) )
    plt.subplots_adjust(wspace = 0.05)
    sns.move_legend(g, "lower left", bbox_to_anchor=(
        0.1, .95), ncol= 4,  frameon=False, title='')

    plt.xscale('log')
    prefix = "base" if "CORRAL" not in base_algs else ""
    plt.savefig(os.path.join(figure_dir, prefix +
                f"perf_N.pdf"), bbox_inches='tight')
    plt.show()

In [None]:
linestyles = {"Best Expert":  (0, (4, 1, 4, 1, 1, 1)), "D-TS": "--", "TS": "-", "EXP4": "-", "EXP4.S": "--",
              "MCB[µ=0.005]": "--", "MCB[µ=0.05]": "-.", "MCB[µ=0.5]": ":", "MCB[µ=0]": "-", "Random Criterion": "-", "CORRAL": "--", corval_key: "-"}

df = pd.concat(data).reset_index()
if len(df.model_error.unique()) == 1:
    print("To plot this, please run the experiments with the --extra-figure flag, see README.")
else:
    df["random_reward"] = df['shape'].apply(
        lambda s: np.mean((np.arange(8)/7)**s))
    print(df.algorithm.unique())
    fig, ax = plt.subplots()
    fig.set_size_inches(10, 5)

    print("fixing names...")
    df["algorithm"] = df["algorithm"].str.replace(
        "CORVAL[max]", corval_key, regex=False)
    df.loc[df.algorithm == 'expert 0', 'algorithm'] = 'Best Expert'
    ddf = df.copy()

    print(ddf.algorithm.unique())

    for alg in sorted(ddf.algorithm.unique())[::-1]:

        sdf = ddf.query("algorithm==@alg")

        misspecifications = sdf.model_error.values
        performances = sdf.performance.values

        bins = np.linspace(0, 1, 11)
        categories = pd.cut(misspecifications, bins, include_lowest=True)

        bin_df = pd.DataFrame(
            {'misspecification': misspecifications, 'average reward': performances})

        grouped = bin_df.groupby(categories)

        bootstrap_results = grouped['average reward'].apply(bootstrap_ci)

        bootstrap_df = pd.DataFrame(bootstrap_results.tolist(
        ), index=bootstrap_results.index, columns=['mean', 'lo', 'hi'])
        x_vals = (bins[1:]+bins[:-1])/2
        mean, lo, hi = (bootstrap_df.values).T

        plt.plot(x_vals, mean, label=alg,
                 color=color_per_decay_palette[alg], marker=markers[alg], linestyle=linestyles[alg])
        plt.fill_between(x_vals, lo, hi, alpha=.2,
                         color=color_per_decay_palette[alg], linestyle=linestyles[alg])

    plt.legend(loc="upper left", bbox_to_anchor=(1.05, 1), frameon=False)
    plt.ylabel("average reward")
    plt.xlabel("normalized model error")

    sns.despine()
    plt.savefig("figures/misspecification_plot.pdf", bbox_inches="tight")

In [None]:
for base_algs in (["expert 0", "CORVAL[max]",]+mcb_variants[1:]+exp4s_variants[1:]+dts_variants[1:], ["expert 0", "CORVAL[max]", "CORRAL",]+mcb_variants,):
    

    corval_key = "$\\bf{CORVAL}$"
    fig, axs = plt.subplots(2, 2, figsize=(25, 13), sharey=True,)
    fig.tight_layout(h_pad=1, w_pad=1)
    panel_tags = "abcd"
    sns.set_context('poster', font_scale=1.7, rc={"lines.linewidth": 5})
    for p, period in tqdm(list(enumerate([100, 500, 2500, 'static']))):
        kind = None

        data = get_dfs(periods=[(period)], trim_decay_before=True,
                    algs_only=False, minimize=False, minimize_seeds=True, seeds=None,
                    algs=base_algs)

        df = pd.concat(data)

        df.loc[df.algorithm == 'expert 0', 'algorithm'] = 'Best Expert'

        df["algorithm"] = df["algorithm"].str.replace(
            "CORVAL[max]", corval_key, regex=False)


        if "CORRAL" not in base_algs:
            palette = color_per_alg_palette
        else:
            palette = color_per_decay_palette
        df['average reward'] = df.performance

        algs = sorted(df.algorithm.unique())
        line_styles = {k: v for k, v in zip(algs[::], all_linestyles.values())}

        best_expert_perf = df.query("algorithm=='Best Expert'").performance.mean()
        i, j = p//len(axs), p % len(axs)

        for alg in (algs[::-1]):
            data = df.query("algorithm == @alg and period==@period and n_arms==8")
            max_t = data.t.max()+10
            d = data.sort_values(by="t")[["t", "average reward"]].values.reshape(
                (max_t//10, -1, 2))
            X = d[..., 0].T
            Y = d[..., 1].T
            lo, hi, med = bootstrap_confidence_interval(Y, axis=0)
            axs[i, j].plot((X[0]), (Y.mean(axis=0)), color=palette[alg],
                        linestyle=line_styles[alg], zorder=1 if alg == corval_key else -np.var(med))
            markers_x = np.zeros_like((X[0]))+np.nan
            markers_y = np.zeros_like((Y.mean(axis=0)))+np.nan
            markers_x[np.arange(0, len(markers_x), 50)] = (X[0])[
                np.arange(0, len(markers_x), 50)]
            markers_y[np.arange(0, len(markers_y), 50)] = (
                Y.mean(axis=0))[np.arange(0, len(markers_x), 50)]
            axs[i, j].plot(markers_x, markers_y, color=palette[alg], marker=markers[alg], linestyle=line_styles[alg],
                        zorder=1 if alg == corval_key else -np.var(med), label=alg, markersize=20)

            axs[i, j].fill_between((X[0]), (lo), (hi), alpha=.2, color=palette[alg],
                                linestyle=line_styles[alg], zorder=-np.var(med))

        if period != 'static':
            for p in range(0, max_t+1, int(period)):
                axs[i, j].axvline(p, color="grey", zorder=-
                                100, linewidth=1, alpha=.5)

        axs[i, j].set_title(" "*2+panel_tags[0], loc="left")
        if j == 0:
            axs[i, j].set_ylabel("reward")
        if i > 0:
            axs[i, j].set_xlabel("t")
        sns.despine()
        panel_tags = panel_tags[1:]
    plt.legend()
    sns.move_legend(axs[i, j], "lower center", bbox_to_anchor=(-.15,
                    2.4), ncol=4,  frameon=False, columnspacing=0.8)

    prefix = "base" if "CORRAL" not in base_algs else ""
    plt.savefig(os.path.join(figure_dir, prefix+"timeplot.pdf"), bbox_inches='tight')
    plt.show()