In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
df = pd.read_csv("./best_worst_results.csv")
df

In [None]:
dataset_info_column = df.iloc[:, 1]


def dataset_info(name):
    _, _, n_reg, concurrent_reg, length, ss = name.split("_")
    return int(n_reg), int(concurrent_reg), int(length), "True" in ss


# add each of the dataset info columns to the dataframe
df["n_reg"] = dataset_info_column.apply(lambda x: dataset_info(x)[0])
df["concurrent_reg"] = dataset_info_column.apply(lambda x: dataset_info(x)[1])
df["length"] = dataset_info_column.apply(lambda x: dataset_info(x)[2])
df["ss"] = dataset_info_column.apply(lambda x: dataset_info(x)[3])

# for each model, if the n_reg, concurrent_reg, length, and ss are the same, then mark the dataset as
# identity: this is the dataset the model was trained on


def dataset_same_as_model(row):
    nr = row["n_reg"]
    cr = row["concurrent_reg"]
    length = row["length"]
    ss = row["ss"]
    model = row["model"]
    if f"model_{nr}_{cr}" == model and not ss:
        return True
    if f"model_{nr}_{cr}_l{length}" == model and not ss:
        return True
    if (f"model_{nr}_{cr}ss" == model) and ss:
        return True

    return False


df["identity"] = df.apply(lambda x: dataset_same_as_model(x), axis=1)

df

In [None]:
df.dataset.unique()

In [None]:
view = df[df.model.apply(lambda x: "100" in x)]

view["dataset_x_identity"] = view.apply(
    lambda x: f"{x['dataset']}{'_identity' if x['identity'] else ''}", axis=1
)
view = view.sort_values(by=["identity", "dataset"], ascending=True)

# Create a color mapping based on whether 'identity' is in the dataset_x_identity column
# Create a color mapping based on whether 'identity' is in the dataset_x_identity column
unique_values = view["dataset_x_identity"].unique()
palette = sns.color_palette("tab20", len(unique_values))
color_mapping = {
    value: "red" if "identity" in value else palette[i]
    for i, value in enumerate(unique_values)
}
colors = view["dataset_x_identity"].map(color_mapping).values

f, ax = plt.subplots(figsize=(20, 7))
sns.barplot(
    data=view,
    hue="dataset_x_identity",
    y="acc",
    x="model",
    palette=colors,
    alpha=1,
    dodge=True,
)
ax.set_yticks(np.arange(0, 1.1, 0.1))
ax.set_ylim(0.5, 1.05)
# set xtick rotation to 45 degrees
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

ax.grid(
    axis="y",
    linestyle="--",
    linewidth=0.5,
    alpha=0.7,
)
ax.legend(bbox_to_anchor=(1.23, 1), loc="upper right")
f.tight_layout()
plt.show()
