In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import yaml

sns.set_theme(style="whitegrid")

In [None]:
fig_dir = Path.home() / "dev/data/exp/follow/figures"
fig_dir.mkdir(parents=True, exist_ok=True)

follow_dir = Path.home() / "dev/data/exp/follow"
exp_runs = [
    "2024-11-18T16-29-12.968",
    "2024-11-18T16-55-49.954",
    "2024-11-18T17-27-40.530",
    "2024-11-18T17-59-12.519",
    "2024-11-18T18-31-35.759",
    "2024-11-18T19-15-48.482",
]
data_dirs = [follow_dir / run for run in exp_runs]

In [None]:
def read_yaml(file_path):
    with open(file_path, mode="r") as file:
        data = yaml.safe_load(file)
    return data


configs = [read_yaml(data_dir / "config.yaml") for data_dir in data_dirs]
for i in range(1, len(configs)):
    assert configs[i]["maps"] == configs[i - 1]["maps"]
map_names = [f"{x}.map" for x in configs[0]["maps"]]
congestion_percents = configs[0]["congestion_levels"]

df = pd.concat(
    [pd.read_csv(data_dir / "result.csv") for data_dir in data_dirs], ignore_index=True
)

In [None]:
solvers_of_interest = [
    "lacam",
    "lacam_no_following",
    "lacam_star",
    "lacam_star_no_following",
]
for map_name in map_names:
    filtered_df = df[df["map_name"] == map_name]
    filtered_df = filtered_df[filtered_df["solver"].isin(solvers_of_interest)]
    filtered_df = filtered_df.assign(
        congestion=(100 * filtered_df["num_agents"] / filtered_df["num_open_vertices"])
        .round(0)
        .astype(int)
    )

    fields = ["makespan", "comp_time"]
    ylabels = {"makespan": "Makespan (steps)", "comp_time": "Computation Time (ms)"}
    jitter = {"makespan": 0.05, "comp_time": 0.05}

    f, axs = plt.subplots(1, len(fields))
    f.suptitle(map_name)
    f.set_size_inches(18, 6)
    plt.subplots_adjust(wspace=0.2)

    for i, field in enumerate(fields):
        ax = axs[i]

        melted = filtered_df.melt(
            id_vars=["congestion", "solver", "solved"],
            var_name="measurement",
            value_vars=[field],
        )

        # melted = melted[melted["solved"] == 1]

        sns.despine(left=False, bottom=False)
        ax = sns.stripplot(
            data=melted,
            x="congestion",
            y="value",
            hue="solver",
            dodge=True,
            alpha=0.3,
            jitter=jitter[field],
            ax=ax,
            legend=False,
            zorder=1,
            # palette="pastel",
        )

        ax = sns.pointplot(
            data=melted,
            x="congestion",
            y="value",
            hue="solver",
            dodge=0.8 - 0.8 / len(melted["solver"].unique()),
            palette="dark",
            errorbar=None,
            markers="_",
            markersize=10,
            linestyle="none",
            ax=ax,
            legend=True,
        )

        # include full range of congestion even if particular map is missing it
        ax.set_xticks(range(len(congestion_percents)))
        ax.set_xticklabels(congestion_percents)

        ax.set_ylim(0)

        ax.set_xlabel("Congestion (%)")
        ax.set_ylabel(ylabels[field])

        ax.get_legend().remove()
        legend_handles, legend_labels = ax.get_legend_handles_labels()
        f.legend(
            legend_handles,
            legend_labels,
            loc="lower center",
            bbox_to_anchor=(0.5, -0.1),
            ncol=len(melted["solver"].unique()),
            frameon=True,
            title="Solver",
        )
    plt.show()