In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path

sys.path.append(str(Path("..").resolve()))

In [None]:
root_path = "outputs/wandb_export"
# file_name = "training_events.csv"
file_name = "validation_events.csv"

condition_four_rooms_reseed = {
    "reseeding-4rooms": "4-rooms, carry Z (1)",
    "reseeding-4rooms-dup1": "4-rooms, carry Z (2)",
}

condition_four_rooms_reset = {
    "no-reseeding-4rooms": "4-rooms, reset Z (1)",
    "no-reseeding-4rooms-dup1": "4-rooms, reset Z (2)",
}

condition_maze_reseed = {
    "reseeding-dynamic-maze": "Maze, carry Z (1)",
    "reseeding-dynamic-maze-dup1": "Maze, carry Z (2)",
    "reseeding-dynamic-maze-dup2": "Maze, carry Z (3)",
    "reseeding-dynamic-maze-dup3": "Maze, carry Z (4)",
    "reseeding-dynamic-maze-dup4": "Maze, carry Z (5)",
}

condition_maze_reset = {
    "no-reseeding-dynamic-maze": "Maze, reset Z (1)",
    "no-reseeding-dynamize-maze-dup1": "Maze, reset Z (2)",
    "no-reseeding-dynamic-maze-dup2": "Maze, reset Z (3)",
    "no-reseeding-dynamic-maze-dup3": "Maze, reset Z (4)",
    "no-reseeding-dynamic-maze-dup-4": "Maze, reset Z (5)",
}

In [None]:
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt


def read_file(root_path, expt_path, file_name):
    file_name = os.path.join(root_path, expt_path, file_name)
    print(f"Reading {file_name}")
    df = pd.read_csv(file_name)
    return df


# df = read_file(root_path, condition_maze_reseed[0], file_name)

In [None]:
def filter_df(df):
    df = df[["val_frac_envs_terminated_reward_1", "_step", "val_avg_episode_length"]]

    df_filtered = df[df["val_frac_envs_terminated_reward_1"].notna()]
    return df_filtered


def load_condition(expts: dict):
    df_dict = {}
    for expt in expts.keys():
        df = read_file(root_path, expt, file_name)
        df_filtered = filter_df(df)
        df_filtered["Experiment"] = expts[expt]
        df_dict[expt] = df_filtered
    return df_dict

In [None]:
results_condition_four_rooms_reseed = load_condition(condition_four_rooms_reseed)
results_condition_four_rooms_reset = load_condition(condition_four_rooms_reset)
results_condition_maze_reseed = load_condition(condition_maze_reseed)
results_condition_maze_reset = load_condition(condition_maze_reset)

In [None]:
k = results_condition_four_rooms_reseed.keys()


def combine(expts: dict, expts_results: dict, df_list: list, hue, style):
    for k in expts_results.keys():
        expt_name = expts[k]
        df_results = expts_results[k]

        df_results["Hue"] = df_results["Experiment"]
        df_results["Style"] = style

        print(k, expt_name)
        df_list.append(df_results)
        # print(results.columns)
        # print(results.head())


dash_styles = {
    "Carry Z": "",
    "Reset Z": (2, 2),
}
df_list = []

# 4-rooms
if False:
    # if True:
    combine(
        condition_four_rooms_reseed,
        results_condition_four_rooms_reseed,
        df_list,
        hue="x",
        style="Carry Z",
    )  # Sus
    combine(
        results_condition_four_rooms_reset,
        results_condition_four_rooms_reset,
        df_list,
        hue="x",
        style="Reset Z",
    )
    environment = "4-rooms environment"

# Maze
if True:
    # if False:
    combine(
        condition_maze_reseed,
        results_condition_maze_reseed,
        df_list,
        hue="x",
        style="Carry Z",
    )
    combine(
        condition_maze_reset,
        results_condition_maze_reset,
        df_list,
        hue="x",
        style="Reset Z",
    )
    environment = "Maze environment"

combined_df = pd.concat(df_list)

x_key = "_step"

# Metric
# y_key = "val_avg_episode_length"
# y_name = "Mean episode length"
# legend_loc = "upper right"

y_key = "val_frac_envs_terminated_reward_1"
y_name = "Fraction of validation episodes reaching goal"
legend_loc = "lower right"

# Clip length of series cos some are too much longer than others
combined_df = combined_df[combined_df[x_key] <= 200000]

plt.figure(figsize=(10, 6))
ax = sns.lineplot(
    data=combined_df,
    x=x_key,
    y=y_key,
    palette="colorblind",
    hue="Hue",
    style="Style",
    dashes=dash_styles,
    # legend=False,
)

# Get the legend handles and labels from the plot
handles, labels = ax.get_legend_handles_labels()

# Find the indices corresponding to the 'style' legend entries.
# The `labels` list will have the 'hue' labels first, followed by the 'style' labels.
# The length of the 'style' legend is based on the unique values in the 'event' column.
num_hue_levels = combined_df["Hue"].nunique()
style_handles = handles[num_hue_levels + 1 :]  # +1 to skip the "event" legend title
style_labels = labels[num_hue_levels + 1 :]  # +1 to skip the "event" legend title

# Remove the old legend
print(style_handles)
print(style_labels)
ax.legend().remove()

style_handles.pop(0)
style_labels.pop(0)

# Create a new legend using only the style handles and labels
ax.legend(
    style_handles, style_labels, loc=legend_loc, title="Condition"
)  # , title='Condition')

plt.title(y_name + ": " + environment)
plt.xlabel("Number of training minibatches")
plt.ylabel(y_name)
# plt.legend(title='Series')
plt.show()

In [None]:
combined_df.head(100)