In [17]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import yaml
import glob

sns.set_theme()
sns.set_context("paper")

In [18]:
fig_dir = Path.home() / "dev/data/exp/follow/figures"
fig_dir.mkdir(parents=True, exist_ok=True)

follow_dir = Path.home() / "dev/data/exp/follow"

exp_runs = glob.glob("2024-11-19*/", root_dir=follow_dir)

data_dirs = [follow_dir / run for run in exp_runs]

In [19]:
def read_yaml(file_path):
    with open(file_path, mode="r") as file:
        data = yaml.safe_load(file)
    return data


configs = [read_yaml(data_dir / "config.yaml") for data_dir in data_dirs]
for i in range(1, len(configs)):
    assert configs[i]["congestion_levels"] == configs[i - 1]["congestion_levels"]
congestion_percents = configs[0]["congestion_levels"]

df = pd.concat(
    [pd.read_csv(data_dir / "result.csv") for data_dir in data_dirs], ignore_index=True
)
map_names = df["map_name"].unique()

In [None]:
maps_to_skip = [
    "brc202d.map",
    "Berlin_1_256.map",
    "Boston_0_256.map",
    "Paris_1_256.map",
]
for map_name in map_names:
    if map_name in maps_to_skip:
        continue
    filtered_df = df[df["map_name"] == map_name]
    filtered_df = filtered_df[~filtered_df["solver"].str.contains("star")]
    filtered_df = filtered_df[~filtered_df["solver"].str.contains("mccg")]
    filtered_df = filtered_df.assign(
        congestion=(100 * filtered_df["num_agents"] / filtered_df["num_open_vertices"])
        .round(0)
        .astype(int)
    )
    filtered_df = filtered_df.assign(
        normalized_cost=(filtered_df["soc"] / filtered_df["soc_lb"])
    )
    filtered_df = filtered_df.sort_values(by="solver")

    fields = ["makespan", 
              "normalized_cost",
              "comp_time",
    ]
    ylabels = {
        "makespan": "$\\leftarrow$ Makespan (steps)",
        "comp_time": "$\\leftarrow$ Computation Time (ms)",
        "normalized_cost": "$\\leftarrow$Cost (soc/soc_lb)",
    }

    f, axs = plt.subplots(1, len(fields))
    # f, axs = plt.subplots(len(fields), 1)
    # f.suptitle(map_name)
    f.set_size_inches(18, 6)
    # f.set_size_inches(6, 8)
    # f.subplots_adjust(hspace=0.5)

    for i, field in enumerate(fields):
        ax = axs[i]

        melted = filtered_df.melt(
            id_vars=["congestion", "solver", "solved"],
            var_name="measurement",
            value_vars=[field],
        )

        melted = melted[melted["solved"] == 1]
        # Ensure all congestion levels are represented
        for percent in range(min(congestion_percents), max(congestion_percents) + 1):
            if percent not in melted["congestion"].values:
                for solver in melted["solver"].unique():
                    melted = pd.concat(
                        [
                            melted,
                            pd.DataFrame(
                                {
                                    "congestion": [percent],
                                    "solver": [solver],
                                    "solved": [1],
                                    "measurement": [field],
                                    "value": [float("nan")],
                                }
                            ),
                        ],
                        ignore_index=True,
                    )

        sns.despine(left=False, bottom=False)
        ax = sns.stripplot(
            data=melted,
            x="congestion",
            y="value",
            hue="solver",
            dodge=True,
            alpha=0.3,
            jitter=0,
            ax=ax,
            legend=False,
            zorder=1,
            # palette="pastel",
        )

        ax = sns.pointplot(
            data=melted,
            x="congestion",
            y="value",
            hue="solver",
            dodge=0.8 - 0.8 / len(melted["solver"].unique()),
            palette="dark",
            errorbar=None,
            markers="_",
            markersize=10,
            linestyle="none",
            ax=ax,
            legend=True,
            estimator="median",
        )

        ax.set_ylim(bottom=0)
        if field == "normalized_cost":
            ax.set_ylim(bottom=1)

        ax.set_xlabel("Congestion (%)")
        ax.set_ylabel(ylabels[field])

        ax.get_legend().remove()
        legend_handles, legend_labels = ax.get_legend_handles_labels()
        f.legend(
            legend_handles,
            legend_labels,
            loc="lower center",
            bbox_to_anchor=(0.5, -0.1),
            ncol=len(melted["solver"].unique()),
            frameon=True,
            title="Solver",
        )
        ax.grid(True)
    plt.subplots_adjust(left=0.05, right=1.0, wspace=0.2)
    f.savefig(fig_dir / f"{map_name}.pdf")
    plt.show()