In [None]:
import pandas as pd

project_name = 'CANDID_DAC'

run_id_to_plot = '58x8jy9x' # SAQL on 5D CANDID Sigmoid
run_id_to_plot = 'oi9p2pvs' # SDQN on 5D CANDID Sigmoid


dim = 5
# importance_sigmoid = False
importance_base = 0.5
reward_shape = 'exponential'
exp_reward = 4.6

config_path = f'../run_data/{project_name}_configs.csv'
# path_to_ckpts = f'../../results/{project_name}/run_{run_id_to_plot}/'

config_df = pd.read_csv(config_path)

In [None]:
import warnings
warnings.filterwarnings('ignore')
from plotting_helpers import load_policy_from_checkpoint
from dacbench.benchmarks import SigmoidBenchmark
import numpy as np
# locate the last checkpoint of the run
checkpoint_dir = f'../../results/models/{project_name}/{run_id_to_plot}'

# get the environment to plot on

importances = np.array([importance_base**i for i in range(dim)])
env = SigmoidBenchmark().get_importances_benchmark(dimension=dim, importances=importances, reward_shape=reward_shape)

config = config_df[config_df['run_id'] == run_id_to_plot].iloc[0].to_dict()
policy = load_policy_from_checkpoint(config=config, env=env, ckpt_directory=checkpoint_dir, episode=10000, final=True)

In [None]:
# run the policy on several instances of the importance benchmark and plot the results
import numpy as np
from candid_dac.policies import AtomicPolicy
import torch

truth_on_instances = []
actions_on_instances = []
obtained_rewards = []

actions = np.zeros((dim, env.n_steps))
truth = np.zeros((dim, env.n_steps))

# instances_ids = [0, 50, 100, 150, 200, 299] # if importance_sigmoid else [None]  # only plot on single instance if not importance sigmoid
instances_ids = np.linspace(0, 299, 6, endpoint=True, dtype=int)
print(instances_ids)
# inst_id = 15

for i, inst_id in enumerate(instances_ids):
    # env.use_next_instance(instance_id=inst_id)
    obs, _ = env.reset(instance_id=inst_id)
    obtained_reward = 0
    # if not importance_sigmoid:
    # for i in range(inst_id):
    #     obs, _ = env.reset()

    for t in range(env.n_steps):
        if isinstance(policy, AtomicPolicy):
            action = policy(torch.tensor(obs))
            action = np.unravel_index(action, env.action_space.nvec)
        else:
            action = policy.get_action(obs)
        obs, reward, terminated, truncated, _ = env.step(action)
        obtained_reward += reward
        actions[:, t] = action

    actions_on_instances.append(actions.copy())
    obtained_rewards.append(obtained_reward)

    points_in_time = np.linspace(0, env.n_steps - 1, 100, endpoint=True)
    truth_on_instances.append(env._sig(points_in_time, env.slopes[0], env.shifts[0]))

print(len(truth_on_instances))
print(len(actions_on_instances))


In [None]:
# now plot the results
import matplotlib.pyplot as plt
from plotting_helpers import translate_run_name
print(env.inst_id)
# plot truth and actions on each dimension in a separate plot
rows = len(truth_on_instances) // 3
plt.rcParams.update({
    'font.size': 8,           # Global font size
    'axes.titlesize': 7,      # Title size of individual plots
    'axes.labelsize': 7,      # Label size for x and y labels
    'xtick.labelsize': 6.5,      # Size of x-tick labels
    'ytick.labelsize': 6.5,      # Size of y-tick labels
    'legend.fontsize': 7,      # Size of the legend text
    'figure.titlesize': 12,     # Title size of the entire figure
    'lines.linewidth': 0.75,
    'lines.markersize': 2.5,

})
width = 6 # latex textwidth
fig, axs = plt.subplots(rows, 3, figsize=(6, 6/2.5), sharex=True, sharey=True)

algorithm = translate_run_name(run_name=config['run_name'])

for inst, ax in enumerate(axs.flatten()):
    ax.set_title(f'Instance {instances_ids[inst]}')
    truth = truth_on_instances[inst]
    actions = actions_on_instances[inst]
    ax.plot(points_in_time, truth, label='sigmoid target', color='tab:orange')

    # use a color map to color the actions according to the dimension
    colors = plt.cm.viridis(np.linspace(1, 0, dim))

    for i in range(dim):
        action = env.compute_pred_from_actions(actions, level=i+1)
        # ax.plot(np.arange(env.n_steps), action, label=f'{algorithm} agent prediction (up to dim {i+1})', color=colors[i])
        ax.scatter(np.arange(env.n_steps), action, label=f'{algorithm} agent prediction (up to dim {i+1})', color=colors[i],
                   alpha=0.5 + 0.5 * (i+1) / dim, zorder=i+10, edgecolors='none')
        ax.set_xticks(range(0, env.n_steps, 1))
        ax.set_yticks(np.linspace(0, 1, 6))
        # ax.set_yticks(np.linspace(0, 1, 3))
        # ax.set_xlabel('$t$')
    ax.plot(np.arange(env.n_steps), action, linestyle=':', label=f'{algorithm} agent prediction (full)', color=colors[-1])
    # ax.scatter(np.arange(env.n_steps), action, label=f'{algorithm} agent prediction', color='red', marker='x')

    # scatter the final aggregated prediction
    action = env.compute_pred_from_actions(actions)

from matplotlib.lines import Line2D
from matplotlib import cm
colors = plt.cm.viridis_r(np.linspace(0, 1, dim-1))
# create a customized legend with a colorbar below it
custom_lines = [Line2D([0], [0], color="tab:orange", linestyle="-", label="target"),
                Line2D([0], [0], linestyle="None", marker="o", color="gray", label="partially aggr. prediction"),
                Line2D([0], [0], color=colors[-1], linestyle=":", label="final prediction"),]
legend = axs[0, -1].legend(loc='center left', bbox_to_anchor=(1, -0.2), handles=custom_lines)
# Get the bounding box of the legend
legend_box = legend.get_window_extent()
# Convert the bounding box from display units to figure units
legend_box_fig = legend_box.transformed(fig.transFigure.inverted())
# Create new axes for the colorbar below the legend
cbar_ax = fig.add_axes([legend_box_fig.x0, legend_box_fig.y0 - 0.1, legend_box_fig.width, 0.05])
# Create color bar with manually specified ticks
cbar = fig.colorbar(cm.ScalarMappable(cmap=cm.viridis_r, norm=plt.Normalize(vmin=dim-1, vmax=0)),
                    cax=cbar_ax, ticks=np.arange(dim), orientation='horizontal')
cbar.ax.tick_params(length=0)  # Remove ticks
# cbar.ax.xaxis.set_label_coords(0.5, -0.5)  # Position the label
cbar.set_label('aggregated dimensions')



for ax in axs[:, 0]:
    ax.set_ylabel('prediction value')
for ax in axs[rows-1, :]:
    ax.set_xlabel('$t$')
# # place the legend to the right of the last plot
# axs[0, -1].legend(loc='center left', bbox_to_anchor=(1, -0.2))
plt.subplots_adjust(wspace=0.1, hspace=0.15)
print(f"Algorithm: {algorithm}")
# fig.suptitle(f"{algorithm} agent on instances of {dim}D CANDID Sigmoid", fontweight='bold')
# plt.show()
fig.savefig(f'./paper_plots/{