In [None]:
import pandas as pd
import os
import json


import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

# Set seaborn style
sns.set_theme(style="whitegrid", context="paper", font_scale=1.2)
sns.set_palette("colorblind")

# Matplotlib settings
import matplotlib as mpl
mpl.rcParams['axes.labelweight'] = 'bold'
mpl.rcParams['axes.titlesize'] = 'x-large'
mpl.rcParams['xtick.labelsize'] = 'large'
mpl.rcParams['ytick.labelsize'] = 'large'
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['figure.dpi'] = 100
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['axes.labelsize'] = 'x-large'
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans', 'sans-serif']


In [None]:
# Load classifications

def parse_new_tom(msg_data):
    """
    Parse ToM annotation JSON and flatten verification results so each category (e.g. 'desire')
    has its own verification column: 'desire_verification', 'intention_verification', etc.
    """
    # annotation fields
    fields = ["knowledge", "mentalistic", "belief", "intention", "desire", "percept", "emotion"]
    ret = {}
    for field in fields:
        ret[field] = msg_data.get(field)
        ret[f"{field}_rationale"] = msg_data.get(f"{field}_rationale")
    
    # Verification fields
    verification_block = msg_data.get("verification", {})
    verification_categories = verification_block.get("ToM_categories", {})
    verification_results = verification_categories.get("results", [])
    # One verification field per category (e.g., desire_verification: "A")
    for field in fields:
        ret[f"{field}_verification"] = None
    if verification_results and isinstance(verification_results, list):
        for res in verification_results:
            cat = res.get("category", "").lower()
            ans = res.get("answer")
            if cat in fields:
                ret[f"{cat}_verification"] = ans
    ret["raw_text"] = msg_data # keep raw text for reference
    return ret

folder_path = "../dummy_data/tom_annotated_data"

data_new_tom = []
for user_folder in os.listdir(folder_path):
    user_path = os.path.join(folder_path, user_folder)
    if not os.path.isdir(user_path):
        continue
    for conv_folder in os.listdir(user_path):
        conv_path = os.path.join(user_path, conv_folder)
        if not os.path.isdir(conv_path):
            continue
        for msg_file in os.listdir(conv_path):
            if msg_file.endswith(".json"):
                msg_path = os.path.join(conv_path, msg_file)
                with open(msg_path, "r") as f:
                    try:
                        msg_data = json.load(f)
                    except Exception as e:
                        print(f"Error reading {msg_path}: {e}")
                        continue
                res = parse_new_tom(msg_data)
                data_new_tom.append({
                    "user_id": user_folder,
                    "conversation_id": conv_folder,
                    "message_id": msg_file.replace(".json", ""),
                    **res,
                })

df_new_tom = pd.DataFrame(data_new_tom)


In [None]:
def apply_verification(df, categories):
    """
    For each row, if the verification column {category}_verification is C, D, or E,
    set the {category} column to False. (Exclude non-verified categories.)
    """
    for cat in categories:
        ver_col = f"{cat}_verification"
        if ver_col in df.columns and cat in df.columns:
            df.loc[df[ver_col].isin(["C", "D", "E"]), cat] = False
    
    # Add ToM column: True if any of the categories is True for that row
    df["ToM"] = df[categories].any(axis=1)
    return df

# List the main categories you want to check (adjust as necessary)
categories = ["belief", "desire", "intention", "emotion", "percept", "knowledge", "mentalistic"]

df_new_tom = apply_verification(df_new_tom, categories)


In [None]:
# Compute the percentage of unique memory_id that have at least one ToM=True annotation
num_unique_memories = df_new_tom["message_id"].nunique()
num_memories_with_tom = df_new_tom[df_new_tom["ToM"] == True]["message_id"].nunique()
percent_memories_with_tom = 100.0 * num_memories_with_tom / num_unique_memories if num_unique_memories > 0 else 0
print(f"Percentage of unique memory_id with ToM: {percent_memories_with_tom:.2f}% ({num_memories_with_tom}/{num_unique_memories})")

In [None]:
# Compute the percentage of unique user_id that have at least one ToM=True sample
num_total_users = df_new_tom["user_id"].nunique()
num_users_with_tom = df_new_tom[df_new_tom["ToM"] == True]["user_id"].nunique()
percent_users_with_tom = 100.0 * num_users_with_tom / num_total_users if num_total_users > 0 else 0
print(f"Percentage of unique users with at least one ToM: {percent_users_with_tom:.2f}% ({num_users_with_tom}/{num_total_users})")

In [None]:
import matplotlib.pyplot as plt

# shows the distribution of categories over samples that have some ToM content (ToM == True).
# The percentage is computed as (100 * true_count_for_cat / total_tom_true). Percentages do not sum up to 100% because a single sample can be assigned multiple ToM categories.

# Define the categories to plot
categories_to_keep = ['belief', 'emotion', 'desire', 'intention', 'percept', 'knowledge', 'mentalistic']

# Filter to only rows where ToM is True to get distribution of categories over ToM=True rows
df_tom_true = df_new_tom[df_new_tom["ToM"] == True]

cat_true_counts = []
cat_true_percents = []

total_tom_true = len(df_tom_true)

for cat in categories_to_keep:
    true_count = df_tom_true[cat].sum()  # count of True values for this category among ToM=True rows
    cat_true_counts.append(true_count)
    percent = 100.0 * true_count / total_tom_true if total_tom_true > 0 else 0
    cat_true_percents.append(percent)

# Combine categories, counts, and percents for sorting
cat_data = list(zip(categories_to_keep, cat_true_counts, cat_true_percents))
# Sort descending
cat_data_sorted = sorted(cat_data, key=lambda x: x[2], reverse=True)
sorted_categories, sorted_counts, sorted_percents = zip(*cat_data_sorted) if cat_data_sorted else ([], [], [])

# Plot
fig, ax = plt.subplots(figsize=(6,4))

x = range(len(sorted_categories))
ax.bar(x, sorted_percents, color="C0")

ax.set_ylabel('Percentage')
ax.set_xlabel('')
ax.set_xticks(x)
ax.set_xticklabels([c.capitalize() for c in sorted_categories], rotation=40, fontweight='bold', fontsize=12)

# Annotate bars with absolute counts as text labels
for i, (count, percent) in enumerate(zip(sorted_counts, sorted_percents)):
    percent_str = f"{round(percent, 0):.0f}%"
    ax.text(i, percent + max(sorted_percents)*0.01 if sorted_percents else 0.5, f"{percent_str}\n(n={count})",
            ha='center', va='bottom', fontsize=11,)

ax.set_ylim(0, 100)
ax.set_yticks([0, 50, 100])
ax.set_yticklabels([0, 50, 100], fontweight='bold')

plt.tight_layout()
plt.savefig("tom_category_dist_tom_true.pdf")

In [None]:
import matplotlib.pyplot as plt

# shows the count/percentage per category across all samples (not filtered).
# The percentage is computed as (100 * true_count_for_cat / total_samples)

# Define the categories to plot
categories_to_keep = ['belief', 'emotion', 'desire', 'intention', 'percept', 'knowledge', 'mentalistic']

cat_true_counts = []
cat_true_percents = []

total_msgs = len(df_new_tom)

for cat in categories_to_keep:
    true_count = df_new_tom[cat].sum()  # count of True values for this category
    cat_true_counts.append(true_count)
    percent = 100.0 * true_count / total_msgs if total_msgs > 0 else 0 # Divide by len of all memories
    cat_true_percents.append(percent)

# Combine into tuples for sorting
cat_data = list(zip(categories_to_keep, cat_true_counts, cat_true_percents))
# Sort by descending percent
cat_data_sorted = sorted(cat_data, key=lambda x: x[2], reverse=True)
sorted_categories, sorted_counts, sorted_percents = zip(*cat_data_sorted) if cat_data_sorted else ([], [], [])

# Plot
fig, ax = plt.subplots(figsize=(6,4))

x = range(len(sorted_categories))
ax.bar(x, sorted_percents, color="C0")

ax.set_ylabel('Percentage')
ax.set_xlabel('')
ax.set_xticks(x)
ax.set_xticklabels([c.capitalize() for c in sorted_categories], rotation=40, fontweight='bold', fontsize=12)

# Annotate bars with absolute counts as text labels
for i, (count, percent) in enumerate(zip(sorted_counts, sorted_percents)):
    percent_str = f"{round(percent, 0):.0f}%"
    ax.text(i, percent + max(sorted_percents)*0.01 if sorted_percents else 0.5, f"{percent_str}\n(n={count})", 
             ha='center', va='bottom', fontsize=11,)

ax.set_ylim(0, 100)
ax.set_yticks([0, 50, 100])
ax.set_yticklabels([0, 50, 100], fontweight='bold')

plt.tight_layout()
plt.savefig("tom_category_dist_all.pdf")

In [None]:
# shows the presence of ToM categories across users.
# For each user, we check if there is any message where the category is True. Then we compute the count/percentage of unique users who have at least one instance of each category (divided by all users).


# Define the categories to plot
categories_to_keep = ['belief', 'emotion', 'desire', 'intention', 'percept', 'knowledge', 'mentalistic']

# For each user, check which categories appear at least once (column is True anywhere for user)
user_ids = df_new_tom['user_id'].unique()
user_cat_true = {uid: [] for uid in user_ids}

for cat in categories_to_keep:
    # For each user, is there any message where the category is True?
    cat_true_users = df_new_tom.loc[df_new_tom[cat] == True, 'user_id'].unique()
    for uid in cat_true_users:
        user_cat_true[uid].append(cat)

# Create a Series mapping user_id to list of present categories
user_cat_series = pd.Series(user_cat_true)

# For each category, count the number of unique users who have at least one instance of it
user_counts_per_cat = {cat: sum([cat in cats for cats in user_cat_series]) for cat in categories_to_keep}
user_percents_per_cat = {cat: 100 * count / len(user_cat_series) if len(user_cat_series) > 0 else 0 for cat, count in user_counts_per_cat.items()}

# Sort categories by decreasing user percentage
sorted_items = sorted(user_percents_per_cat.items(), key=lambda x: x[1], reverse=True)
sorted_cats = [cat for cat, _ in sorted_items]

fig2, ax2 = plt.subplots(figsize=(6,4))
bars = ax2.bar(
    range(len(sorted_cats)),
    [user_percents_per_cat[c] for c in sorted_cats],
    tick_label=[c.capitalize() for c in sorted_cats]
)
ax2.set_ylabel('Percentage of Users')
ax2.set_xlabel('')
ax2.set_xticks(range(len(sorted_cats)))
ax2.set_xticklabels([c.capitalize() for c in sorted_cats], rotation=40, fontweight='bold', fontsize=12)

# Annotate bars with percentage and user count
for i, cat in enumerate(sorted_cats):
    percent = user_percents_per_cat[cat]
    count = user_counts_per_cat[cat]
    percent_str = f"{round(percent, 0):.0f}%"
    ax2.text(i, percent + max(user_percents_per_cat.values())*0.01, f"{percent_str}\n(n={count})", 
             ha='center', va='bottom', fontsize=11,)

ax2.set_ylim(0, 100)
ax2.set_yticks([0, 50, 100])
ax2.set_yticklabels([0, 50, 100], fontweight='bold')
plt.tight_layout()
plt.savefig("tom_category_dist_per_user_new.pdf")
plt.show()