In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.colors import LinearSegmentedColormap

from chatbot_personalization.utils.helper_functions import get_base_dir_path

# Generate Figure 5

## Load data

In [None]:
path = get_base_dir_path() / "data/ratings_and_matching.csv"

df = pd.read_csv(path)
df = pd.read_csv(path)
df = df.drop(columns=["utility", "user_id"])
df = df.set_index("assignments")

statements = df.columns

k = len(statements)

In [None]:
group_to_label = {statements[i]: "G" + str(i + 1) for i in range(k)}
statement_to_label = {statements[i]: "S" + str(i + 1) for i in range(k)}
label_to_statement = {v: k for k, v in statement_to_label.items()}
df.rename(index=group_to_label, columns=statement_to_label, inplace=True)
df = df.sort_index(axis="columns")
df = df.sort_index(axis="index")


assignment_table = df.copy().astype(bool)
assignment_table[:] = False
for col in df.columns:
    assignment_table.loc["G" + col[1:], col] = True

levels = df.stack().astype(int)
matched_levels = df[assignment_table].stack().astype(int)

level_names = pd.Series(
    index=["not at all", "poorly", "somewhat", "mostly", "perfectly"],
    data=[0, 1, 2, 3, 4],
)

matched_level_frequencies = matched_levels.value_counts(normalize=True)
matched_level_frequencies = level_names.map(matched_level_frequencies).fillna(0)

level_frequencies = levels.value_counts(normalize=True)
level_frequencies = level_names.map(level_frequencies).fillna(0)

unique_indices = df.index.unique()

## Generate Figure 5

In [None]:
color_map = LinearSegmentedColormap.from_list("", ["#8BA9EB", "#061B5F"])
base_color = "#3753A5"  # used for main paper

fig, axes = plt.subplots(len(unique_indices), 1, figsize=(10, 1 * len(unique_indices)))


for ax, unique_index in zip(axes, unique_indices):
    subframe = df.loc[unique_index]

    k = len(subframe.columns)
    colors = color_map(np.linspace(0, 1, k))
    ax.text(
        -0.1,
        0.5,
        str(unique_index),
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax.transAxes,
        fontsize=15,
    )

    xticks = []

    for col_index, col_name in enumerate(subframe.columns):
        unique_values = subframe[col_name].unique()
        total_count = len(subframe[col_name])

        for unique_value in unique_values:
            frequency = (subframe[col_name] == unique_value).sum()
            normalized_frequency = frequency / total_count
            color = colors[unique_value]

            ax.barh(
                unique_value,
                width=normalized_frequency,
                left=col_index - 0.5 * normalized_frequency,
                height=1,
                align="center",
                color=color,
                alpha=1.0,
            )

    ax.set_xlim(-0.5, len(subframe.columns) - 0.5)
    ax.set_yticks(range(len(subframe.columns)))
    ax.set_yticklabels(
        ["not at all", "poorly", "somewhat", "mostly", "perfectly"],
        rotation=0,
        fontsize=6,
    )

    if unique_index == unique_indices[-1]:
        ax.tick_params(
            axis="both",
            which="both",
            bottom=True,
            top=False,
            left=True,
            right=False,
            labelbottom=True,
            labelleft=True,
        )

        ax.set_xticks(range(len(subframe.columns)))
        ax.set_xticklabels(subframe.columns, rotation=0, fontsize=15)
    else:
        ax.tick_params(
            axis="both",
            which="both",
            bottom=False,
            top=True,
            left=True,
            right=False,
            labelbottom=False,
            labelleft=True,
        )

plt.subplots_adjust(hspace=0.1)
plt.show()
plt.savefig("fig5_assigned_utilities_histogram.pdf", bbox_inches="tight", pad_inches=0)