In [None]:
%load_ext autoreload 
%autoreload 2

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from chatbot_personalization.utils.helper_functions import get_base_dir_path
import os

# Generate Figure 2

## Load data

In [None]:
## For replicating the paper ##
dirname = get_base_dir_path() / "data/gen_query_eval/"

## For running the code on data generated by the script `gen_query_eval.py` ##
# dirname = get_base_dir_path() / "data/demo_data/2024-09-06-22:01:20__gen_query_eval"

In [None]:
filenames = [file for file in os.listdir(dirname) if file.endswith(".csv")]
dfs = [pd.read_csv(os.path.join(dirname, csv)) for csv in filenames]
df = pd.concat(dfs)

if "comment" not in df.columns:
    df["comment"] = df["statement"]

round_nums = [
    round_num for round_num in df["round_num"].unique() if not pd.isna(round_num)
]

statement_types = [
    statement_type
    for statement_type in df["statement_type"].unique()
    if not pd.isna(statement_type)
]

round_num_to_statement_type_to_statement = {
    round_num: df[(df["round_num"] == round_num) & ~df["statement_type"].isna()][
        ["statement_type", "statement"]
    ]
    .set_index("statement_type")
    .to_dict()["statement"]
    for round_num in round_nums
}

## Plotting code

In [None]:
def create_heatmap(
    pivot_table, title_name, save_path, vmin=None, vmax=None, legend=True
):
    # Create the custom color map
    color_map = LinearSegmentedColormap.from_list("custom_cmap", ["#F8F8FF", "#3753A5"])

    # Set diagonal values to NaN
    np.fill_diagonal(pivot_table.values, np.nan)

    # Determine vmin and vmax if not provided
    if vmin is None:
        vmin = pivot_table.min().min()
    if vmax is None:
        vmax = pivot_table.max().max()

    print(vmin, vmax)

    # Set the figure size
    plt.figure(figsize=(11, 8))  # Adjust the size to reduce tick overlap

    # Create the heatmap
    ax = sns.heatmap(
        pivot_table,
        annot=True,
        cmap=color_map,
        linewidths=2.0,
        fmt=".2f",
        vmin=vmin,
        vmax=vmax,
        annot_kws={"size": 20},
    )

    # Adjust the size of the ticks and rotate them to prevent overlap
    plt.xticks(rotation=60, ha="right", fontsize=25)
    plt.yticks(rotation=0, fontsize=25)

    # Adjust the font size for the heatmap and axis titles
    # plt.xlabel("Statement Type", fontsize=26)
    # plt.ylabel("Statement Type", fontsize=26)
    ax.set_xlabel("")
    ax.set_ylabel("")

    # Set the title if provided
    if title_name:
        plt.title(title_name, fontsize=25)

    # Show the plot
    plt.tight_layout()
    plt.savefig(f"{save_path}.pdf", bbox_inches="tight")
    # plt.show()

In [None]:
def make_pie_plot(pie_df, save_path, title=""):
    # Reset index if needed to ensure statement_type is a column
    pie_df = pie_df.reset_index()

    # Setup the figure and axis
    fig, ax = plt.subplots(figsize=(7, 5))

    # Create the pie chart with the specified colors and a wedge separation
    wedges, texts, autotexts = ax.pie(
        pie_df["value"],
        labels=pie_df["statement_type"],
        autopct="%1.1f%%",
        colors=["C1", "C2", "C0", "C4"],
        # colors=colors,
        wedgeprops=dict(
            width=1, edgecolor="w"
        ),  # This creates the white space between pieces
    )

    # Improve the display of the labels and percentages
    for text in texts:
        text.set_color("black")  # Change label color to black
        text.set_fontsize(13)  # Increase label font size
    for autotext in autotexts:
        autotext.set_color("white")
        autotext.set_fontsize(16)

    # Set the title of the pie chart
    plt.title(title)

    # Display the plot
    plt.savefig(f"{save_path}.pdf", bbox_inches="tight")
    # plt.show()

In [None]:
def get_utilities_of_statement_type(
    df, round_num, statement_type, round_num_to_statement_type_to_statement
):
    statement = round_num_to_statement_type_to_statement[round_num][statement_type]
    utilities = sorted(
        df[
            (df["comment"] == statement)
            & (df["round_num"] == round_num)
            & (df["query_type"] == "discriminative")
        ]["query2_output"].to_list()
    )
    # there is a weird bug where one output is 13 for some reason so I'm just clipping them
    return [min(max(u, 0), 4) for u in utilities]

# Generate Figure 2b

In [None]:
data = []
for round_num in round_nums:
    for statement_type in statement_types:
        utilities = get_utilities_of_statement_type(
            df, round_num, statement_type, round_num_to_statement_type_to_statement
        )
        assert len(utilities) == 40
        top20_min = np.min(utilities[20:])
        data.append(
            {
                "round_num": round_num,
                "statement_type": statement_type,
                "top20_min": top20_min,
            }
        )
plot_df = pd.DataFrame(data)

plot_df["top20_min_winner"] = plot_df.groupby("round_num")["top20_min"].transform(
    lambda x: x == x.max()
)

true_winners = plot_df[plot_df["top20_min_winner"]]

statement_freq = true_winners["statement_type"].value_counts()

# Adding missing statement types with 0 count
for st in statement_types:
    if st not in statement_freq.index:
        statement_freq.at[st] = 0

# Sorting the frequencies by statement type
statement_freq = pd.DataFrame(statement_freq.sort_index())
statement_freq.rename(columns={"count": "value"}, inplace=True)

make_pie_plot(
    statement_freq,
    save_path=get_base_dir_path() / "plots/fig2_slate_composition_pie_chart",
)

# Generate Figure 2a

In [None]:
pivot_table = plot_df.pivot_table(
    index="round_num", columns="statement_type", values="top20_min"
)
max_approach = (
    plot_df[plot_df["top20_min_winner"]].groupby("round_num")["statement_type"].first()
)
pivot_table["maximum_approach"] = max_approach

pivot_table["maximum"] = pivot_table[["all", "nn(s=5)", "random 1", "random 5"]].max(
    axis=1
)

plt.figure(figsize=(10, 6))

# Sorting the columns in the renamed DataFrame based on their mean values
sorted_columns_renamed = (
    pivot_table[["all", "nn(s=5)", "random 1", "random 5", "maximum"]]
    .mean()
    .sort_values()
    .index
)

# Data for swarm plot with sorted and renamed columns
sorted_data_for_swarmplot_renamed = pivot_table[sorted_columns_renamed]

# Setting up the figure for the swarm plot
plt.figure(figsize=(10, 6))

# Creating the swarm plot with sorted and renamed columns, edge color, and increased tick size
sns.swarmplot(
    data=sorted_data_for_swarmplot_renamed,
    palette=[
        "C0",
        "C4",
        "C1",
        "C2",
        "black",
    ],
    edgecolor="black",
    linewidth=0.0,
    size=8,
)

# Customizing tick size
plt.xticks(fontsize=20)
plt.yticks(fontsize=21)

# # Removing title and adding labels
plt.xlabel("", fontsize=14)
plt.ylabel("", fontsize=14)

save_path = get_base_dir_path() / "plots/fig2_slate_composition_swarm_plot"
# Display the swarm plot
plt.savefig(f"{save_path}.pdf", bbox_inches="tight")