In [None]:
# Requires seaborn in addition to requirements.txt packages
import math
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import evaluation

# --- Plot style configuration ---
plt.style.use("default")
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times New Roman"] + plt.rcParams["font.serif"]
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
plt.rcParams["font.size"] = 14
plt.rcParams["axes.labelsize"] = 14
plt.rcParams["axes.titlesize"] = 12
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["legend.fontsize"] = 12
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.5
plt.rcParams["figure.figsize"] = [12, 10]
plt.rcParams["figure.dpi"] = 300
plt.rcParams["savefig.dpi"] = 300

# --- Load data from final_returns folder ---
df = evaluation.load_results_dataframe()
df.head()

In [None]:
# --- Define bandit evaluation and plotting ---
def plot_trial(final_scores, num_subsample=8, **kwargs):
    # Run multiple bandit trials and compute confidence intervals
    runs = np.stack(final_scores.values)
    res = evaluation.bootstrap_bandit_trials(
        runs, num_subsample=num_subsample, n_bootstraps=2000, num_repeats=500
    )

    # Plotting
    line = plt.semilogx(res["pulls"], res["estimated_bests_mean"], **kwargs)
    plt.fill_between(
        res["pulls"],
        res["estimated_bests_ci_low"],
        res["estimated_bests_ci_high"],
        alpha=0.3,
        color=line[0].get_color(),
    )

In [None]:
# --- Filter dataframe (optional) ---
plot_df = df
# E.g. plot_df = df.loc[df['algorithm'] != 'td3_bc']

# --- Run and plot bandit evaluation ---
ALGORITHM_COLOR_DICT = {algo: col for algo, col in zip(
    plot_df['algorithm'].unique(), sns.color_palette("colorblind"))}
plot_df["dataset"] = pd.Categorical(plot_df["dataset"])
num_datasets = len(plot_df["dataset"].unique())
num_algorithms = len(plot_df["algorithm"].unique())
g = sns.FacetGrid(
    plot_df,
    col="dataset",
    hue="algorithm",
    col_wrap=math.ceil(np.sqrt(num_datasets)),
    sharex=False,
    sharey=False,
    palette=ALGORITHM_COLOR_DICT,
)
print("Running bandit evaluation... ", end='')
g.map(plot_trial, "final_scores")
print("Done.")

# --- Format combined plot ---
# Remove individual x and y labels
for ax in g.axes.flat:
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_title(ax.get_title().replace("dataset = ", ""))

# Add and format big axis for shared labels
fig = plt.gcf()
big_ax = fig.add_subplot(111, frameon=False)
big_ax.tick_params(labelcolor="none", top=False, bottom=False, left=False, right=False)
big_ax.grid(False)
plt.xlabel("Number of policy evaluations")
plt.ylabel("Mean score")

# Add and format legend
g.add_legend(title="Algorithm")
leg = g._legend
for line, text in zip(leg.get_lines(), leg.get_texts()):
    line.set_linewidth(5.0)
sns.move_legend(
    g,
    "upper center",
    bbox_to_anchor=(0.45, 0.0),
    ncol=1+(num_algorithms//2),
    frameon=True,
    edgecolor="gray",
    fancybox=True,
    shadow=False,
)

# Save the figure
plt.savefig("evaluation_plots.pdf", bbox_inches="tight")