# Choosing between earlier and later expressions, given two groups' observed interactions

December 2024

In [15]:
# import jax
# import jax.numpy as jnp
# from jax import lax
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import chi2

import data_tools
import model_tools

from enums import *

In [2]:
data = data_tools.get_data()
tangram_info = data_tools.get_tangram_info(data)
data_organized = data_tools.make_data_matrices(data)


In [None]:
tangram_info

In [4]:
best_params_no_social, best_nll_no_social, _ = model_tools.grid_search_nll(data_organized, tangram_info, model_type="no_social")
best_params_social, best_nll_social, _ = model_tools.grid_search_nll(data_organized, tangram_info, model_type="social")

In [None]:
print(f"Best params no social: {best_params_no_social}")
print(f"NLL no social: {best_nll_no_social}")
print(f"Best params social: {best_params_social}")
print(f"NLL social: {best_nll_social}")

Get model predictions (probability of later utterance)

In [6]:
preds_no_social = model_tools.get_model_preds(*best_params_no_social)
preds_social = model_tools.get_model_preds(*best_params_social)

Put the model predictions in a dataframe

In [None]:
# Put it in a dataframe
df_preds = pd.DataFrame(columns=["type", "condition", "tangram_type", "p_earlier"])
rows = []
for condition in Conditions:
    for tangram_type in TangramTypes:
        rows.append(
            {
                "type": "no_social",
                "condition": condition.name,
                "tangram_type": tangram_type.name,
                "p_earlier": 1 - preds_no_social[condition, tangram_type],
            }
        )

for condition in Conditions:
    for tangram_type in TangramTypes:
        rows.append(
            {
                "type": "social",
                "condition": condition.name,
                "tangram_type": tangram_type.name,
                "p_earlier": 1 - preds_social[condition, tangram_type],
            }
        )

df_preds = pd.concat([df_preds, pd.DataFrame(rows)], ignore_index=True)

# remove jax stuff
df_preds["p_earlier"] = df_preds["p_earlier"].apply(lambda x: x.item())
df_preds

Plot

In [None]:
g = sns.catplot(
    data=df_preds,
    x="condition",
    y="p_earlier",
    hue="tangram_type",
    col="type",
    kind="point",
    height=6,
    aspect=1,
    palette="colorblind",
    dodge=0.1,
    alpha=0.6,
    linestyle="none",
)

g.set(ylim=(0, 1))
for ax in g.axes.flat:
    ax.axhline(0.5, ls="--", color="gray")

# add an extra title to each of the plots
titles = [
    "No social utility\nBest: $\\alpha=1$, $w_r=1.9$, $w_c=0.7$\nNLL: 1413.0602",
    "Social utility only on 'one social'\nBest: $\\alpha=0.5$, $w_r=3.4$, $w_s=0.5$, $w_c=1.2$\nNLL: 1410.3212",
]
for ax, title in zip(g.axes.flat, titles):
    ax.set_title(title)  # add padding to the title
plt.show()

Do likelihood ratio test

In [None]:
# Test statistic
lr_stat = 2 * (best_nll_no_social - best_nll_social)

# Degrees of freedom: number of additional parameters in the "social" model
df = 1  # w_s is the additional parameter

p_value = chi2.sf(lr_stat, df)

print(f"Likelihood Ratio Test Statistic: {lr_stat}")
print(f"Degrees of Freedom: {df}")
print(f"P-value: {p_value}")