In [58]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import yaml
import glob

sns.set_theme()
sns.set_context("paper")
sns.set(font_scale=1.5)

ibm_colors = [
    "#648FFF",
    # "#785EF0",
    # "#DC267F",
    "#FE6100",
    # "#FFB000",
]

In [59]:
fig_dir = Path.home() / "dev/data/exp/follow/figures"
fig_dir.mkdir(parents=True, exist_ok=True)

follow_dir = Path.home() / "dev/data/exp/follow"

exp_runs = glob.glob("2024-11-19*/", root_dir=follow_dir)

data_dirs = [follow_dir / run for run in exp_runs]

In [60]:
def read_yaml(file_path):
    with open(file_path, mode="r") as file:
        data = yaml.safe_load(file)
    return data


configs = [read_yaml(data_dir / "config.yaml") for data_dir in data_dirs]
for i in range(1, len(configs)):
    assert configs[i]["congestion_levels"] == configs[i - 1]["congestion_levels"]
congestion_percents = configs[0]["congestion_levels"]

df = pd.concat(
    [pd.read_csv(data_dir / "result.csv") for data_dir in data_dirs], ignore_index=True
)
map_names = df["map_name"].unique()

df = df[df["solver"].isin(["lacam", "lacam_no_following"])]

df["solver"] = df["solver"].map({"lacam": "Vanilla LaCAM", "lacam_no_following": "Following-Free LaCAM"})

# map solver to color
ibm_palette = dict(zip(["Vanilla LaCAM", "Following-Free LaCAM"], ibm_colors))

In [None]:
for map_name in map_names:
    filtered_df = df[df["map_name"] == map_name]
    filtered_df = filtered_df.assign(
        congestion=(100 * filtered_df["num_agents"] / filtered_df["num_open_vertices"])
        .round(0)
        .astype(int)
    )
    filtered_df = filtered_df.assign(
        normalized_cost=(filtered_df["soc"] / filtered_df["soc_lb"])
    )
    filtered_df = filtered_df.sort_values(by="solver")

    y_vars = ["makespan", 
              "normalized_cost",
              "comp_time",
    ]
    ylabels = {
        "makespan": "$\\leftarrow$ Makespan (steps)",
        "comp_time": "$\\leftarrow$ Runtime (ms)",
        "normalized_cost": "$\\leftarrow$Cost (soc/soc_lb)",
    }

    for i, y_var in enumerate(y_vars):
        fig, ax = plt.subplots()

        melted = filtered_df.melt(
            id_vars=["congestion", "solver", "solved"],
            var_name="measurement",
            value_vars=[y_var],
        )

        melted = melted[melted["solved"] == 1]
        # Linearize x-axis
        for percent in range(min(congestion_percents), max(congestion_percents) + 1):
            if percent not in melted["congestion"].values:
                for solver in melted["solver"].unique():
                    melted = pd.concat(
                        [
                            melted,
                            pd.DataFrame(
                                {
                                    "congestion": [percent],
                                    "solver": [solver],
                                    "solved": [1],
                                    "measurement": [y_var],
                                    "value": [float("nan")],
                                }
                            ),
                        ],
                        ignore_index=True,
                    )

        sns.despine(left=False, bottom=False)
        ax = sns.stripplot(
            data=melted,
            x="congestion",
            y="value",
            hue="solver",
            dodge=True,
            alpha=0.3,
            jitter=0,
            ax=ax,
            legend=False,
            zorder=1,
            palette=ibm_palette,
        )

        ax = sns.pointplot(
            data=melted,
            x="congestion",
            y="value",
            hue="solver",
            dodge=0.8 - 0.8 / len(melted["solver"].unique()),
            errorbar=None,
            markers="x",
            markersize=10,
            linestyle="none",
            ax=ax,
            legend=True,
            estimator="median",
            palette=ibm_palette,
        )

        ax.set_ylim(bottom=0)
        if y_var == "normalized_cost":
            ax.set_ylim(bottom=1)

        ax.set_xlabel("Congestion (%)")
        ax.set_ylabel(ylabels[y_var])

        ax.get_legend().remove()
        if y_var == y_vars[-1]:
            legend_handles, legend_labels = ax.get_legend_handles_labels()
            for handle in legend_handles:
                handle.set_linewidth(0)
                handle.set_marker("o")
                handle.set_markersize(20)
            fig.legend(
                legend_handles,
                legend_labels,
                loc="lower center",
                bbox_to_anchor=(0.5, -0.3),
                # ncol=len(melted["solver"].unique()),
                ncol=1,
                frameon=True,
                title="Solver",
            )
        ax.grid(True)
        ax.tick_params(axis="x", rotation=90)
        fig.savefig(fig_dir / f"{Path(map_name).stem}_{y_var}.pdf", bbox_inches="tight")