# Imports

In [1]:
import os
import errno
from importlib import reload
import pickle 
from pprint import pprint
from itertools import product
from glob import glob
from os.path import basename, splitext, split

In [2]:
import itertools
from numba import jit



In [3]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import matplotlib
# just so saved plots aren't also shown
matplotlib.use('Agg')
%matplotlib inline
import seaborn as sns

import pymc3 as pm
import theano as T
import theano.tensor as tt
import arviz as az
import xarray as xr
from json import dump
from os.path import exists, join

  import pkg_resources


In [4]:
from functions.helper_functions import save_trace, load_trace
from functions.data_functions import get_and_clean_data_exp2

In [5]:
from functions.argstrengths import (
    calculate_nonparametric_argstrength, 
    theano_calculate_pragmatic_argstrength, 
    theano_calculate_pragmatic_speaker, 
    calculate_argumentative_strength, 
    calculate_maximin_argstrength
)

from functions.helper_functions import (
    verify, 
    normalize, 
    theano_calculate_pragmatic_speaker,
    get_costs,
    calculate_pragmatic_speaker,
    theano_normalize, 
    theano_softmax
)

In [6]:
from functions.models_variablearray import (
    factory_model_base,
    factory_model_lr_argstrength,
    factory_model_maximin_argstrength,
    factory_model_prag_argstrength,
    factory_model_nonparametric_argstrength,
)

In [7]:
# packages versions
print(
    '\n'.join(
        f'{m.__name__}=={m.__version__}' 
        for m in globals().values() 
        if getattr(m, '__version__', None)
    )
)

pandas==2.2.3
numpy==1.23.5
matplotlib==3.8.4
seaborn==0.13.2
pymc3==3.11.4
theano==1.1.2
arviz==0.12.1
xarray==2023.7.0


In [8]:
folder_exp2_traces = '../../arglang_model_fitting/experiment2_traces'
# where to store the argstrengths, since some of them take a while to compute
# (namely, the ones with full state space)
path_to_argstrengths_folder = './argstrengths/'
folder_exp2_cleaned_data = '../data/data_experiment2/'

In [9]:
folder_exp1_data = '../data/data_experiment1/data.csv'
folder_exp2_data = '../data/data_experiment2/data.csv'

In [10]:
exp1and2_data = get_and_clean_data_exp2(
    pathdata_firstexp=folder_exp1_data,
    pathdata=folder_exp2_data
)
_, data, list_possible_observations, possible_utterances = exp1and2_data

0  were excluded because incompletely recorded
14  of the participants were excluded as they gave more than 4 false responses
113  of the observations in the included participants were excluded because literally false


In [11]:
if exists(folder_exp2_data):
    exp1and2_data = get_and_clean_data_exp2(
        pathdata_firstexp=folder_exp1_data,
        pathdata=folder_exp2_data
    )
    _, data, list_possible_observations, possible_utterances = exp1and2_data
    data.to_csv(join(folder_exp2_cleaned_data, 'cleaned_data_1and2.csv'), index=False)
    with open(join(folder_exp2_cleaned_data, 'obs_1and2.json'),'w') as openfile:
        dump([x.tolist() for x in list_possible_observations], openfile)
    pd.DataFrame(possible_utterances).to_csv(join(folder_exp2_cleaned_data, 'utts.csv'), index=False)
else:
    # read the pre-cleaned data
    data = pd.read_csv(join(folder_exp2_cleaned_data, 'cleaned_data_1and2.csv'))
    list_possible_observations = pd.read_json(join(folder_exp2_cleaned_data, 'obs_1and2.json')).to_numpy()
    with open(join(folder_exp2_cleaned_data, 'obs_1and2.json'), 'r') as openfile:
        y = load(openfile)
    list_possible_observations = [np.array(x) for x in y]
    possible_utterances = pd.read_csv(join(folder_exp2_cleaned_data, 'utts.csv')).to_numpy()

0  were excluded because incompletely recorded
14  of the participants were excluded as they gave more than 4 false responses
113  of the observations in the included participants were excluded because literally false


# Define all models

In [12]:
model_base = factory_model_base(
    data, 
    list_possible_observations, 
    possible_utterances,
    include_S1=True
)

Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model


In [13]:
model_lr_argstrength = factory_model_lr_argstrength(
    data, 
    list_possible_observations, 
    possible_utterances,
    include_S1=True
)

Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model


In [14]:
model_maximin_argstrength = factory_model_maximin_argstrength(
    data, 
    list_possible_observations, 
    possible_utterances,
    include_S1=True
)

Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model


  return np.nanmin(logp_for-logp_against, 1)


In [15]:
model_prag_argstrength = factory_model_prag_argstrength(
    data, 
    list_possible_observations, 
    possible_utterances,
    include_S1=True
)

Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model


In [16]:
model_nonparametric_argstrength = factory_model_nonparametric_argstrength(
    data, 
    list_possible_observations, 
    possible_utterances,
    include_S1=True
)

Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model
Defining non-hierarchical model


# Calculate datapoints to use

Load all traces

In [17]:
trace_base = load_trace(
    'base',
    folder_exp2_traces
)['trace']

trace_lr_argstrength = load_trace(
    'lr_argstrength',
    folder_exp2_traces
)['trace']

trace_maximin_argstrength = load_trace(
    'maximin_argstrength',
    folder_exp2_traces
)['trace']

trace_nonparametric_argstrength = load_trace(
    'nonparametric_argstrength',
    folder_exp2_traces
)['trace']

trace_prag_argstrength = load_trace(
    'prag_argstrength',
    folder_exp2_traces
)['trace']

Display loaded models.

In [18]:
model_base
model_lr_argstrength
model_maximin_argstrength
model_nonparametric_argstrength
model_prag_argstrength

<pymc3.model.Model at 0x30854fd90>

Display loaded traces.

In [19]:
trace_base
trace_lr_argstrength
trace_maximin_argstrength
trace_nonparametric_argstrength
trace_prag_argstrength

Store traces into a dictionary with model names as keys.

In [20]:
traces = {
    # predictions for the base model are calculated separately
    # "base": trace_base,
    "lr": trace_lr_argstrength,
    "maximin": trace_maximin_argstrength,
    "nonparametric": trace_nonparametric_argstrength,
    "prag": trace_prag_argstrength
}

## Computing listener posteriors over speaker types based on posterior probs for utterances from traces obtained with speaker models

This code constructs listener-side belief distributions over speaker types
(low, info, high) from posterior samples of speaker models.

The final object of interest is the listener belief:

    p(arg_type | observation, utterance, posterior draw)

which is computed separately for each model and then averaged across posterior
draws.

---

### Core idea

Each speaker model provides posterior samples of speaker behavior in the form:

    p(utterance | observation, speaker_type, parameters)

for three speaker types:
- low-argumentative
- informational (base)
- high-argumentative

The listener belief over speaker types is computed by stacking these predictions
and applying Bayes’ rule, i.e. multiplying by a prior over speaker types and
normalizing across types.

---

### make_L_type_given_ou

This function builds the listener posterior over speaker types for each posterior
draw.

Inputs:
- S1_low_c, S1_base_c, S1_high_c:
  Arrays of shape (S, U, O), where
  - S = number of posterior draws
  - U = number of utterances
  - O = number of observations
- prior:
  Optional prior over speaker types (defaults to uniform over low/info/high)

Processing steps:
- Stack the three speaker predictions into a single array of shape (S, 3, U, O)
  in the fixed order: [low, info, high].
- Multiply by the prior over speaker types.
- Normalize along the speaker-type dimension so probabilities sum to 1.
- If an utterance–observation pair has zero total probability across all types,
  fall back to the prior instead of producing NaNs.
- Run a sanity check to ensure normalization worked correctly.

Output:
- L: listener posterior of shape (S, 3, U, O)

---

### reshape_base_S1

Utility function for the base (informational) model.

Purpose:
- Convert a trace array with shape (chain, draw, U, O)
  into a flat array with shape (S, U, O),
  where S = chain × draw.

This ensures compatibility with argument-direction models.

---

### reshape_argdir_S1

Utility function for models with argument direction.

Purpose:
- Take a trace array with shape (chain, draw, D, U, O),
  where D indexes argument direction.
- Flatten chain and draw into a single sample dimension.
- Extract:
  - S1_low: predictions for the low-argumentative speaker
  - S1_high: predictions for the high-argumentative speaker

Both outputs have shape (S, U, O).

---

### Main processing loop

For each model in the `traces` dictionary:

1. Extract speaker predictions
   - Load low- and high-argumentative predictions from the model trace.
   - Load base (informational) predictions once from the base trace.

2. Align posterior draws
   - Truncate all prediction arrays to a common number of posterior samples
     so that low, info, and high predictions are aligned draw-by-draw.

3. Compute listener posteriors
   - Call make_L_type_given_ou to obtain listener beliefs over speaker types
     for every posterior draw.

4. Average over posterior draws
   - Compute the mean listener posterior across draws, yielding an array
     of shape (3, U, O).

5. Store results
   - L_by_model:
     Full listener posterior per model (shape: S × 3 × U × O).
   - Lmean_by_model:
     Listener posterior averaged over draws (shape: 3 × U × O).

---

### Final stacked output

All per-model averaged listener posteriors are stacked into a single array:

- Lmean_all_models with shape (M, 3, U, O)

where M is the number of models.

These outputs are later used for:
- visualization of listener beliefs,
- model comparison,
- and item selection based on dominance criteria such as
  whether the informational speaker is more probable than the
  low and high speakers combined.

In [21]:
import xarray as xr

EPS = 1e-12

def make_L_type_given_ou(S1_low_c, S1_base_c, S1_high_c, prior=None):
    """
    Build L(type | o,u,theta) from S1_* = p(u | o, type, theta)
    Inputs: each (S,U,O)
    Output: L (S,T,U,O) with T=3 ordered [low, info, high]
    Handles zero-evidence cells by falling back to prior.
    """
    if prior is None:
        prior = np.ones(3) / 3.0
    prior = np.asarray(prior, dtype=float)
    prior = prior / prior.sum()

    # (S,T,U,O)
    S1_types = np.stack([S1_low_c, S1_base_c, S1_high_c], axis=1)
    unnorm = S1_types * prior[None, :, None, None]   # (S,T,U,O)
    Z = unnorm.sum(axis=1, keepdims=True)            # (S,1,U,O)

    # default posterior = prior
    L = np.broadcast_to(prior[None, :, None, None], unnorm.shape).copy()

    # normalize where Z>0
    mask = (Z > 0)                                   # (S,1,U,O)
    normed = unnorm / (Z + EPS)                      # safe even if Z=0, we'll only write where mask
    L[mask.repeat(3, axis=1)] = normed[mask.repeat(3, axis=1)]

    # sanity
    if not np.allclose(L.sum(axis=1), 1.0, atol=1e-6):
        bad = L.sum(axis=1)
        raise ValueError(f"L not normalized. min={bad.min()} max={bad.max()}")

    return L  # (S,3,U,O)

def reshape_base_S1(trace_base, var="S1_0"):
    """
    base trace: (chain, draw, U, O) -> (S,U,O)
    """
    raw = trace_base.posterior.get(var).values
    return raw.reshape(-1, *raw.shape[-2:])

def reshape_argdir_S1(trace, var="S1_0", low_idx=0, high_idx=1):
    """
    argdir trace: (chain, draw, D, U, O) -> S1_low, S1_high each (S,U,O)
    """
    raw = trace.posterior.get(var).values
    # flatten chain/draw -> S, keep last 3 dims (D,U,O)
    flat = raw.reshape(-1, *raw.shape[-3:])  # (S,D,U,O)
    S1_low = flat[:, low_idx, :, :]
    S1_high = flat[:, high_idx, :, :]
    return S1_low, S1_high

# ---------- MAIN: extract per-model matrices from traces, compute posterior probs for utterances averaged by draws, using Bayes rule to compute  ----------

# 1) build base info S1 once
S1_base = reshape_base_S1(trace_base, var="S1_0")  # (S_base,U,O)

# storage
L_by_model = {}        # full (S,3,U,O) per model
Lmean_by_model = {}    # averaged over draws -> (3,U,O) per model
Lmean_stack = []       # list of (3,U,O), same order as model_names
model_names = list(traces.keys())

for name in model_names:
    trace = traces[name]

    # 2) get low/high S1 for this model
    S1_low, S1_high = reshape_argdir_S1(trace, var="S1_0", low_idx=0, high_idx=1)  # each (S,U,O)

    # 3) align sample size across low/high/base
    S_common = min(S1_low.shape[0], S1_high.shape[0], S1_base.shape[0])
    S1_low_c  = S1_low[:S_common]
    S1_high_c = S1_high[:S_common]
    S1_base_c = S1_base[:S_common]

    # 4) compute listener posterior per draw
    L = make_L_type_given_ou(S1_low_c, S1_base_c, S1_high_c, prior=np.ones(3)/3.0)  # (S,3,U,O)

    # 5) average over draws -> (3,U,O)
    L_mean = L.mean(axis=0)

    # 6) store
    L_by_model[name] = L
    Lmean_by_model[name] = L_mean
    Lmean_stack.append(L_mean)

# Stack all averaged matrices together: (M,3,U,O)
Lmean_all_models = np.stack(Lmean_stack, axis=0)

print("Per-model full L shapes:", {k: v.shape for k, v in L_by_model.items()})
print("Per-model mean L shapes:", {k: v.shape for k, v in Lmean_by_model.items()})
print("Stacked mean shape (M,T,U,O):", Lmean_all_models.shape)


Per-model full L shapes: {'lr': (4000, 3, 32, 20), 'maximin': (4000, 3, 32, 20), 'nonparametric': (4000, 3, 32, 20), 'prag': (4000, 3, 32, 20)}
Per-model mean L shapes: {'lr': (3, 32, 20), 'maximin': (3, 32, 20), 'nonparametric': (3, 32, 20), 'prag': (3, 32, 20)}
Stacked mean shape (M,T,U,O): (4, 3, 32, 20)


## Plotting listener beliefs, exporting tables, and computing model-averaged predictions

### Note: This code block takes several minutes to run.
This code visualizes draw-averaged listener belief distributions, exports corresponding
tabular summaries, and constructs a model-averaged listener representation using
predictive performance–based weights.

The input to this section consists of listener belief tensors `L_mean` with shape
(3 × U × O), where the first dimension corresponds to speaker type
(low, info, high), U indexes utterances, and O indexes observations.

---

## Setup and configuration

The code initializes plotting utilities and defines an output directory for figures
and CSV files. Two global orderings are specified:

- `MODEL_ORDER`: the sequence in which models are processed and visualized.
- `TYPE_ORDER`: the fixed ordering of speaker types, assumed to match the first axis
  of all listener belief tensors.

These constants ensure consistency across plots, tables, and downstream analyses.

---

## Label construction utilities

### `dense_obs_label`
Constructs a compact string representation of an observation suitable for use as
a subplot label.

- String observations are normalized by collapsing whitespace.
- List- or array-like observations are converted into a `|`-separated sequence of
  integer-like values.
- Long observations are truncated into a head–ellipsis–tail format to maintain
  readability in dense panel plots.

### `utt_label_short`
Generates a short utterance label by concatenating the three components
(Q, A1, A2) associated with a given utterance index.

### `obs_label_short`
Retrieves an observation by index and formats it using `dense_obs_label`.

---

## Panel plot construction

### `plot_L_panels_compact`

This function produces a grid of bar plots visualizing listener beliefs.

Panel structure:
- Rows correspond to utterances.
- Columns correspond to observations.
- Each panel shows a bar chart over speaker types representing
  p(low), p(info), and p(high) for the given utterance–observation pair.

Plot formatting:
- Axis labels and ticks are minimized to reduce visual clutter.
- Y-axis ticks are shown only on the leftmost column.
- X-axis speaker-type labels are shown only on the bottom row.
- Utterance labels are placed along the left margin.
- Observation labels are displayed as column titles, optionally rotated.

Highlighting rule:
- Panels in which p(info) exceeds the sum of p(low) and p(high) are outlined
  with a colored rectangular frame.
- This visual marker identifies utterance–observation pairs that strongly
  favor an informational speaker interpretation.

### `save_panel_plot`

A thin wrapper around the plotting function that saves each figure to disk,
closes the Matplotlib figure to manage memory, and returns the output path.

---

## Listener belief DataFrame construction

### `build_listener_df_for_model`

Converts a listener belief tensor of shape (3 × U × O) into a long-format
DataFrame with one row per utterance–observation pair.

Each row includes:
- model identifier,
- utterance index and observation index,
- probabilities for low, info, and high speaker types,
- a boolean flag indicating whether the info-dominance criterion
  (p(info) > p(low) + p(high)) is satisfied.

Human-readable utterance and observation labels are optionally attached to
facilitate inspection and merging with empirical data.

---

## Predictive-performance–based model weighting

### Motivation

Due to compatibility issues with automated stacking utilities, model averaging
is implemented using a robust alternative based on PSIS-LOO expected log predictive
density (ELPD).

### `_extract_elpd_loo`

Extracts the ELPD value from the object returned by PSIS-LOO in a version-robust
manner, accommodating differences across ArviZ releases.

### `get_model_weights_via_loo`

For each model trace:
- computes PSIS-LOO ELPD,
- assembles an ELPD comparison table,
- converts ELPD values into normalized weights via a softmax transform.

The resulting weights provide a smooth, performance-based scheme for combining
listener belief distributions across models.

---

## Main execution flow

### Per-model outputs

For each model specified in `MODEL_ORDER`:
- the draw-averaged listener belief tensor is retrieved,
- a long-format DataFrame summarizing listener beliefs is constructed,
- a panel plot visualizing the belief distribution is generated and saved.

All per-model DataFrames are concatenated into a single omnibus table.

### Model-averaged listener representation

Using the LOO-derived weights:
- a weighted average of per-model listener belief tensors is computed,
- normalization across speaker types is verified for every utterance–observation pair,
- a fifth panel plot visualizing the model-averaged listener beliefs is produced,
- a corresponding DataFrame is generated and appended to the omnibus table
  under an explicit “average” model label.

---

## Exported outputs

The code produces the following artifacts:
- one panel plot per model and one model-averaged panel plot (PNG),
- a combined CSV containing listener belief predictions for all models,
- CSV files documenting PSIS-LOO ELPD values and derived model weights.

These outputs support visualization, model comparison, and principled
item selection based on listener-side inferences.

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# define output directory
OUT_DIR = "./item_generation_listener_side/"
os.makedirs(OUT_DIR, exist_ok=True)

# global constants
MODEL_ORDER = ["lr", "maximin", "nonparametric", "prag"]  # adjust if needed
TYPE_ORDER = ["low", "info", "high"]  # must match L_mean type axis order [0,1,2]


# ---------------------------
# Label helpers
# ---------------------------
def dense_obs_label(obs, max_elems=20):
    """Dense label like 12|12|...|0 for an observation."""
    if isinstance(obs, str):
        s = " ".join(obs.split())
        return s

    try:
        vals = list(obs)
    except TypeError:
        return str(obs)

    def to_str(x):
        try:
            if isinstance(x, (int, np.integer)):
                return str(int(x))
            if isinstance(x, (float, np.floating)) and float(x).is_integer():
                return str(int(x))
        except Exception:
            pass
        return str(x)

    parts = [to_str(v) for v in vals]
    if len(parts) > max_elems:
        head = parts[: max_elems // 2]
        tail = parts[-(max_elems // 2):]
        parts = head + ["..."] + tail
    return "|".join(parts)

def utt_label_short(possible_utterances, utt, sep=" / "):
    Q, A1, A2 = possible_utterances[utt]
    return f"{Q}{sep}{A1}{sep}{A2}"

def obs_label_short(list_possible_observations, obs_idx):
    obs = list_possible_observations[0][obs_idx]
    return dense_obs_label(obs)


# ---------------------------
# Plotting
# ---------------------------
def plot_L_panels_compact(
    L_mean,
    possible_utterances,
    list_possible_observations,
    type_labels=("low", "info", "high"),
    model_name=None,
    sharey=True,
    highlight_rule=True,
    highlight_color="red",
    highlight_lw=2.0,
    obs_rotation=90
):
    """
    Panel plot: rows=utterance, cols=observation.
    Each panel shows p(type | o,u) as bars over types.

    Highlights panels where info > low + high with a red frame.
    """
    L_mean = np.asarray(L_mean)
    assert L_mean.ndim == 3, f"Expected (T,U,O), got {L_mean.shape}"
    T, U, O = L_mean.shape
    assert T == 3, "Highlight rule assumes type order is (low, info, high) with T=3."

    utt_labels = [utt_label_short(possible_utterances, u) for u in range(U)]
    obs_labels = [obs_label_short(list_possible_observations, o) for o in range(O)]

    fig_w = max(12, O * 0.55)
    fig_h = max(6, U * 0.33)
    fig, axes = plt.subplots(U, O, figsize=(fig_w, fig_h), sharey=sharey)
    axes = np.atleast_2d(axes)

    x = np.arange(T)

    for u in range(U):
        for o in range(O):
            ax = axes[u, o]
            vals = L_mean[:, u, o]

            ax.bar(x, vals)
            ax.set_ylim(0, 1)

            ax.set_xticks([])
            ax.set_yticks([])

            if o == 0:
                ax.set_yticks([0, 0.5, 1.0])
                ax.set_yticklabels(["0", "0.5", "1"], fontsize=7)

            if u == U - 1:
                ax.set_xticks(x)
                ax.set_xticklabels(type_labels, rotation=90, fontsize=7)

            if highlight_rule:
                low, info, high = vals[0], vals[1], vals[2]
                if info > (low + high):
                    rect = Rectangle(
                        (0, 0), 1, 1,
                        transform=ax.transAxes,
                        fill=False,
                        edgecolor=highlight_color,
                        linewidth=highlight_lw,
                        zorder=10,
                        clip_on=False
                    )
                    ax.add_patch(rect)

    for u in range(U):
        axes[u, 0].set_ylabel(utt_labels[u], fontsize=7, rotation=0, labelpad=36, va="center")

    for o in range(O):
        axes[0, o].set_title(obs_labels[o], fontsize=7, rotation=obs_rotation, va="bottom")

    title = "Averaged listener beliefs p(type | o,u)"
    if model_name is not None:
        title += f" — {model_name}"
    fig.suptitle(title, y=1.01)

    fig.tight_layout()
    return fig

def save_panel_plot(model_name, L_mean, filename):
    fig = plot_L_panels_compact(
        L_mean,
        possible_utterances=possible_utterances,
        list_possible_observations=list_possible_observations,
        type_labels=TYPE_ORDER,
        model_name=model_name,
        highlight_rule=True,
        obs_rotation=90
    )
    out_path = os.path.join(OUT_DIR, filename)
    fig.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close(fig)
    return out_path


# ---------------------------
# DataFrame builder
# ---------------------------
def build_listener_df_for_model(model_name, L_mean):
    """
    L_mean: (3,U,O) averaged over draws for one model.
    One row per (utt, obs) cell.
    """
    L_mean = np.asarray(L_mean)
    assert L_mean.shape[0] == 3, f"{model_name}: expected (3,U,O), got {L_mean.shape}"

    _, U, O = L_mean.shape
    utts, obss = np.unravel_index(np.arange(U * O), (U, O))

    low  = L_mean[0, utts, obss]
    info = L_mean[1, utts, obss]
    high = L_mean[2, utts, obss]

    df = pd.DataFrame({
        "model": model_name,
        "utt_idx": utts.astype(int),
        "obs_idx": obss.astype(int),
        "p_low": low,
        "p_info": info,
        "p_high": high,
        "highlight_info_gt_low_plus_high": info > (low + high),
    })

    # Keep labels (can be large; comment out if you want smaller CSV)
    df["utt_label"] = [f"{possible_utterances[u][0]} / {possible_utterances[u][1]} / {possible_utterances[u][2]}"
                       for u in df["utt_idx"].to_numpy()]
    df["obs_label"] = [str(list_possible_observations[0][o]) for o in df["obs_idx"].to_numpy()]

    return df


# ---------------------------
# Robust stacking weights
# ---------------------------
def _extract_elpd_loo(loo_res):
    """
    Robustly extract elpd_loo from ArviZ loo result across versions.
    ArviZ returns an ELPDData (often pandas-like). Prefer key access.
    """
    # Most common: ELPDData behaves like a pandas Series with index keys
    for key in ["elpd_loo", "loo", "elpd"]:
        try:
            if key in loo_res:
                return float(loo_res[key])
        except TypeError:
            # some versions don't support "key in loo_res"
            pass

    # Fallback: try .get (works for pandas-like objects)
    try:
        val = loo_res.get("elpd_loo", None)
        if val is not None:
            return float(val)
    except Exception:
        pass

    raise AttributeError(
        f"Could not find elpd_loo in loo result. Available keys/index: "
        f"{getattr(loo_res, 'index', getattr(loo_res, 'keys', lambda: 'UNKNOWN')())}"
    )

def get_model_weights_via_loo(traces_dict, model_order):
    """
    Computes weights using PSIS-LOO ELPD and a softmax transform.
    This avoids az.compare(method='stacking') which is broken in your env.

    Returns:
      weights: pd.Series indexed by model
      loo_table: pd.DataFrame with elpd_loo per model
    """
    rows = []
    for m in model_order:
        loo_res = az.loo(traces_dict[m])   # requires log_likelihood in trace
        elpd = _extract_elpd_loo(loo_res)
        rows.append({"model": m, "elpd_loo": elpd})

    loo_table = pd.DataFrame(rows).set_index("model").sort_values("elpd_loo", ascending=False)

    # softmax over elpd_loo for weights
    x = loo_table["elpd_loo"].to_numpy()
    x = x - np.max(x)          # stability
    w = np.exp(x)
    w = w / w.sum()

    weights = pd.Series(w, index=loo_table.index, name="weight")
    return weights, loo_table
# ------------------------------------------------------------
# 1) Build per-model dataframes + 2) iterate models and save 4 plots
# ------------------------------------------------------------
dfs = []
saved_figs = []

for model_name in MODEL_ORDER:
    if model_name not in Lmean_by_model:
        raise KeyError(f"Missing {model_name} in Lmean_by_model. Available: {list(Lmean_by_model.keys())}")

    L_mean = Lmean_by_model[model_name]  # (3,U,O)

    df_m = build_listener_df_for_model(model_name, L_mean)
    dfs.append(df_m)

    saved_figs.append(
        save_panel_plot(model_name, L_mean, f"listener_panels_{model_name}.png")
    )

df_all = pd.concat(dfs, ignore_index=True)

# ------------------------------------------------------------
# 3) Model-averaged fifth plot (LOO-weighted; avoids az.compare stacking bug)
# ------------------------------------------------------------

weights, loo_table = get_model_weights_via_loo(traces, MODEL_ORDER)

print("LOO ELPD table:\n", loo_table)
print("\nWeights (softmax over elpd_loo):\n", weights)

# Save diagnostics
loo_table.to_csv(os.path.join(OUT_DIR, "model_loo_elpd_table.csv"))
weights.to_csv(os.path.join(OUT_DIR, "model_weights_softmax_elpd_loo.csv"))

# Weighted average of listener matrices
L_mean_avg = np.zeros_like(Lmean_by_model[MODEL_ORDER[0]], dtype=float)

for m in weights.index:  # note: weights are sorted by elpd_loo
    L_mean_avg += float(weights.loc[m]) * Lmean_by_model[m]

# sanity: each (u,o) should sum to 1 across types
if not np.allclose(L_mean_avg.sum(axis=0), 1.0, atol=1e-6):
    bad = L_mean_avg.sum(axis=0)
    raise ValueError(f"Model-averaged L not normalized. min={bad.min()} max={bad.max()}")

avg_model_name = "model_avg_softmax_elpd_loo"

saved_figs.append(
    save_panel_plot(avg_model_name, L_mean_avg, f"listener_panels_{avg_model_name}.png")
)

df_avg = build_listener_df_for_model(avg_model_name, L_mean_avg)

# Also concate df_avg to df_all with a column "model" = average
df_avg["model"] = "average"
df_all = pd.concat([df_all, df_avg], ignore_index=True)

# ------------------------------------------------------------
# 4) Export
# ------------------------------------------------------------
csv_path = os.path.join(OUT_DIR, "listener_beliefs_by_model_u_o.csv")
df_all.to_csv(csv_path, index=False)

print("\nSaved figures:")
for p in saved_figs:
    print(" -", p)

print("\nSaved CSVs:")
print(" -", csv_path)

LOO ELPD table:
                   elpd_loo
model                     
nonparametric -5751.582702
maximin       -5784.516547
lr            -6057.390274
prag          -6073.275790

Weights (softmax over elpd_loo):
 model
nonparametric     1.000000e+00
maximin           4.977515e-15
lr               1.546887e-133
prag             1.951941e-140
Name: weight, dtype: float64
 - ./item_generation_listener_side/listener_beliefs_model_avg_softmax_elpd_loo.csv

Saved figures:
 - ./item_generation_listener_side/listener_panels_lr.png
 - ./item_generation_listener_side/listener_panels_maximin.png
 - ./item_generation_listener_side/listener_panels_nonparametric.png
 - ./item_generation_listener_side/listener_panels_prag.png
 - ./item_generation_listener_side/listener_panels_model_avg_softmax_elpd_loo.png

Saved CSVs:
 - ./item_generation_listener_side/listener_beliefs_by_model_u_o.csv


## Selecting dominant items for item generation and exporting results

This block performs item selection from model-predicted listener beliefs by identifying
utterance–observation pairs that are strongly dominated by a single speaker type. The
selection is carried out on a chosen model variant (here: the `"average"` model), and
produces three ranked sets of candidates (info-dominant, high-dominant, low-dominant),
which are then concatenated and exported to a single CSV.

---

## 1) Select model variant for item selection

- A subset of `df_all` is created by filtering rows where `model == "average"`.
- This yields `df_generation`, the working table used for selecting candidate items.

This step ensures that item selection is based on a single, explicitly defined model
variant (e.g., model-averaged predictions), rather than mixing different models.

---

## 2) Safety check (required probability columns)

- The code verifies that `df_generation` contains the three required probability columns:
  - `p_low`, `p_info`, `p_high`
- If any column is missing, execution stops with a clear error message.

This prevents silent failures and guarantees that subsequent dominance calculations are valid.

---

## 3) Compute dominance margins

Three dominance margins are computed for each utterance–observation pair:

- `info_margin` = p_info − (p_low + p_high)
- `high_margin` = p_high − (p_low + p_info)
- `low_margin`  = p_low  − (p_high + p_info)

Each margin measures how strongly one speaker type dominates the other two combined.
A positive margin indicates that the corresponding type is more probable than the other
two types together.

These margins are added as new columns to `df_generation`.

---

## 4) Define display columns for inspection

A list of essential columns is assembled for printing ranked items:
- indices: `utt_idx`, `obs_idx`
- probabilities: `p_low`, `p_info`, `p_high`
- model identifier: `model`

If present, short human-readable labels are also included:
- `utt_label`
- `obs_label`

This provides a compact, interpretable view of selected candidates.

---

## 5) Select top 12 candidates per dominance category

Three candidate sets are constructed independently:

### Info-dominant candidates
- Filter rows where `info_margin > 0`
- Sort by `info_margin` descending
- Take the top 12 rows

### High-dominant candidates
- Filter rows where `high_margin > 0`
- Sort by `high_margin` descending
- Take the top 12 rows

### Low-dominant candidates
- Filter rows where `low_margin > 0`
- Sort by `low_margin` descending
- Take the top 12 rows

Each resulting table contains the strongest candidates according to the corresponding
dominance criterion. The code prints each ranked set for immediate inspection.

---

## 6) Export combined selection to CSV

- The three top-12 tables are concatenated into one DataFrame (`df_top12_all`).
- The combined table is written to:
  `listener_top12_dominant_items_by_condition.csv`

This output provides a single, consolidated CSV containing the best candidate items
for all three speaker-type conditions under the selected model variant.

In [7]:
# Load df_all if it is not already in memory
import os
import pandas as pd

df_all = pd.read_csv(os.path.join(OUT_DIR, "listener_beliefs_by_model_u_o.csv"))
# Select the model variant for item selection
df_generation = df_all[df_all["model"] == "average"].copy()

# Safety check
required = {"p_low", "p_info", "p_high"}
missing = required - set(df_generation.columns)
if missing:
    raise ValueError(f"Missing required columns in df_generation for selection: {sorted(missing)}")

# --------------------------------------------------
# Compute dominance margins
# --------------------------------------------------
df_generation["info_margin"] = (
    df_generation["p_info"]
    - (df_generation["p_low"] + df_generation["p_high"])
)

df_generation["high_margin"] = (
    df_generation["p_high"]
    - (df_generation["p_low"] + df_generation["p_info"])
)

df_generation["low_margin"] = (
    df_generation["p_low"]
    - (df_generation["p_high"] + df_generation["p_info"])
)

# Columns to show
cols_show = ["utt_idx", "obs_idx", "p_low", "p_info", "p_high", "model"]
if "utt_label" in df_generation.columns:
    cols_show.append("utt_label")
if "obs_label" in df_generation.columns:
    cols_show.append("obs_label")

# --------------------------------------------------
# Top 12 INFO-dominant items
# --------------------------------------------------
top12_info_items = (
    df_generation[df_generation["info_margin"] > 0]
    .sort_values("info_margin", ascending=False)
    .head(12)
)

top12_info_items['condition'] = 'info'
print("\nTop 12 INFO-dominant (u,o) pairs:")
print(top12_info_items[cols_show + ["info_margin"]])

# --------------------------------------------------
# Top 12 HIGH-dominant items
# --------------------------------------------------
top12_high_items = (
    df_generation[df_generation["high_margin"] > 0]
    .sort_values("high_margin", ascending=False)
    .head(12)
)

top12_high_items['condition'] = 'high'

print("\nTop 12 HIGH-dominant (u,o) pairs:")
print(top12_high_items[cols_show + ["high_margin"]])

# --------------------------------------------------
# Top 12 LOW-dominant items
# --------------------------------------------------
top12_low_items = (
    df_generation[df_generation["low_margin"] > 0]
    .sort_values("low_margin", ascending=False)
    .head(12)
)

top12_low_items['condition'] = 'low'
print("\nTop 12 LOW-dominant (u,o) pairs:")
print(top12_low_items[cols_show + ["low_margin"]])

# ------------------------------------------------------------
# 4) Export
# ------------------------------------------------------------
# Stack top12_condition_items all together
dfs_top12 = [
    top12_info_items,
    top12_high_items,
    top12_low_items
]
df_top12_all = pd.concat(dfs_top12, ignore_index=True)
csv_path_top12 = os.path.join(OUT_DIR, "listener_top12_dominant_items_by_condition.csv")
df_top12_all.to_csv(csv_path_top12, index=False)
print("\nSaved CSVs:")
print(" -", csv_path_top12)
# ------------------------------------------------------------



Top 12 INFO-dominant (u,o) pairs:
      utt_idx  obs_idx     p_low    p_info    p_high    model  \
3107       27        7  0.258905  0.608529  0.132565  average   
3083       26        3  0.132565  0.608529  0.258905  average   
2803       12        3  0.270484  0.587582  0.141934  average   
2827       13        7  0.141934  0.587582  0.270484  average   
2563        0        3  0.142578  0.578815  0.278607  average   
2703        7        3  0.142578  0.578815  0.278607  average   
2687        6        7  0.278607  0.578815  0.142578  average   
2587        1        7  0.278607  0.578815  0.142578  average   
2987       21        7  0.336704  0.572920  0.090376  average   
2963       20        3  0.090376  0.572920  0.336704  average   
2947       19        7  0.116081  0.571251  0.312668  average   
2923       18        3  0.312668  0.571251  0.116081  average   

                utt_label    obs_label  info_margin  
3107   all / some / wrong  [3 3 3 3 3]     0.217059  
3083   all 

## Convert the current item table to a format that could be directly parsed by the downstream pipeline.

In [14]:
import os
import re
import ast
import pandas as pd

# --------------------------------------------------
# Load selected items
# --------------------------------------------------
OUT_DIR = "./item_generation_listener_side/"
path_new_item_tables = os.path.join(
    OUT_DIR, "listener_top12_dominant_items_by_condition.csv"
)

df_new_items = pd.read_csv(path_new_item_tables)

print("Original columns:")
print(df_new_items.columns.tolist())
print(df_new_items.head())


# --------------------------------------------------
# 1) Parse utt_label -> Q1, Q2, A
# --------------------------------------------------
def parse_utt_label(s):
    """
    Parse utterance label of the form 'Q1 / Q2 / A'
    into a tuple (Q1, Q2, A).
    """
    if pd.isna(s):
        return pd.Series([None, None, None], index=["Q1", "Q2", "A"])

    parts = [p.strip() for p in str(s).split("/")]

    if len(parts) != 3:
        # fail loudly but safely
        return pd.Series([None, None, None], index=["Q1", "Q2", "A"])

    return pd.Series(parts, index=["Q1", "Q2", "A"])


df_new_items[["Q1", "Q2", "A"]] = df_new_items["utt_label"].apply(parse_utt_label)


# --------------------------------------------------
# 2) Parse obs_label -> observation list [0,0,0,0,0]
# --------------------------------------------------
def parse_obs_label(s):
    """
    Parse observation label like:
      '[3 3 3 3 3]' or '[3, 3, 3, 3, 3]'
    into a list of ints.
    """
    if pd.isna(s):
        return []

    s = str(s).strip()

    # remove surrounding brackets if present
    s = s.strip("[]")

    # extract integers robustly
    nums = re.findall(r"-?\d+", s)
    return [int(n) for n in nums]


df_new_items["observation_list"] = df_new_items["obs_label"].apply(parse_obs_label)


# --------------------------------------------------
# Optional: pretty string version "[0, 0, 0, 0, 0]"
# --------------------------------------------------
df_new_items["observation"] = df_new_items["observation_list"].apply(
    lambda x: "[" + ", ".join(map(str, x)) + "]"
)


# --------------------------------------------------
# Inspect result
# --------------------------------------------------
df_final_output = df_new_items[
        ["Q1", "Q2", "A", "observation", "condition"]
    ].copy()

df_final_output = df_final_output.reset_index(drop=True, inplace=False)
print("\nFinal item table:")
print(df_final_output.head(15))

# --------------------------------------------------
# 3) Export final table
# --------------------------------------------------
csv_path_final = "../experiments/listener_side/items/final_listener_items.csv"
df_final_output.to_csv(csv_path_final, index=False)
print("\nSaved final item table to:", csv_path_final)

Original columns:
['model', 'utt_idx', 'obs_idx', 'p_low', 'p_info', 'p_high', 'highlight_info_gt_low_plus_high', 'utt_label', 'obs_label', 'info_margin', 'high_margin', 'low_margin', 'condition']
     model  utt_idx  obs_idx     p_low    p_info    p_high  \
0  average       27        7  0.258905  0.608529  0.132565   
1  average       26        3  0.132565  0.608529  0.258905   
2  average       12        3  0.270484  0.587582  0.141934   
3  average       13        7  0.141934  0.587582  0.270484   
4  average        0        3  0.142578  0.578815  0.278607   

   highlight_info_gt_low_plus_high            utt_label    obs_label  \
0                             True   all / some / wrong  [3 3 3 3 3]   
1                             True   all / some / right  [9 9 9 9 9]   
2                             True  some / most / right  [9 9 9 9 9]   
3                             True  some / most / wrong  [3 3 3 3 3]   
4                             True  none / none / right  [9 9 9 9 9]  

# Compare model predictions with empirical distributions

Procedure

1.	Load the item tables, which specify the utterance–observation (u, o) pairs associated with each experimental condition. For each condition (high, info, low), the table contains 10 items.
2.	Subset the item tables by condition and retain only the items from the high and low conditions for further analysis.
3.	Load the empirical results from a CSV file. Using these data, compute the empirical response proportions for speaker types (high, info, low), aggregated by utterance–observation pairs and by condition (high vs. low).
4.	For each of the four models, as well as for the weighted model average (five predictors in total), compute a measure of closeness between the model-predicted and empirical distributions over speaker types, aggregated by utterance–observation pairs and conditions (high vs. low). Identify the model that provides the closest fit to the empirical data.
5.	Based on this best-fitting model, select 10 utterance–observation pairs that satisfy the criterion
p(info) > p(high) + p(low).
These selected pairs are then used as new items in the info condition.

In [None]:
# Load item tables + results (robust to small formatting differences)
import os
import re
import numpy as np
import pandas as pd


# -----------------------------
# Config
# -----------------------------
path_item_tables = "../experiments/listener_side/items/final_listener_items.csv"
path_results     = "../data/data_listenerside/data_pilot1.csv"

OUT_DIR = "./listener_side/item_generation"
os.makedirs(OUT_DIR, exist_ok=True)

# Column config (explicit, since you know the schema)
ITEMS_COLS = {
    "condition": "condition",
    "Q1": "Q1",
    "Q2": "Q2",
    "A": "A",
    "observation": "observation",
}
RESULTS_COLS = {
    "condition": "condition",
    "Q1": "Q1",
    "Q2": "Q2",
    "A": "A",
    "observation": "studentsArray",
    "response": "response",
}

# Response coding in pilot data
RESPONSE_MAPPING = {"Student": "low", "Teacher": "high", "Principal": "info"}
VALID_TYPES = {"low", "info", "high"}

# Conditions to keep (drop "sample*" etc.)
KEEP_CONDITIONS = {"high", "low", "info"}


# -----------------------------
# Load
# -----------------------------
items = pd.read_csv(path_item_tables)
results = pd.read_csv(path_results, keep_default_na=False, na_values=[""])

print("items columns:", items.columns.tolist())
print("results columns:", results.columns.tolist())


results['Q1'] = results['Q1'].astype(str)
# -----------------------------
# Helpers: validation & normalization
# -----------------------------
def require_columns(df, cols, df_name="df"):
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise ValueError(f"{df_name} missing columns: {missing}\nAvailable: {df.columns.tolist()}")

def normalize_text(x) -> str:
    """Lowercase, strip, collapse whitespace; return empty string for NA."""
    s = str(x).strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def normalize_utt(x) -> str:
    """
    Canonicalize utterance component:
    - lower
    - trim whitespace
    - remove trailing '.' (results A sometimes has a period)
    """
    s = normalize_text(x)
    s = re.sub(r"[.]+$", "", s)
    return s

def parse_items_observation(obs_str: str):
    """Parse items observation like '[0, 0, 0, 0, 0]' into list[int]."""
    if pd.isna(obs_str):
        return []
    nums = re.findall(r"-?\d+", str(obs_str))
    return [int(n) for n in nums]

def parse_results_observation(students_array: str):
    """Parse results observation like '12|3|0|0|0' into list[int]."""
    if pd.isna(students_array):
        return []
    parts = [p.strip() for p in str(students_array).split("|") if p.strip() != ""]
    out = []
    for p in parts:
        m = re.findall(r"-?\d+", p)
        out.append(int(m[0]) if m else 0)
    return out

def obs_to_key(obs_list):
    """Canonical observation join key: '12|3|0|0|0'."""
    return "|".join(str(int(x)) for x in obs_list)

def summarize_unmatched(merged, join_cols, out_dir, n=200):
    n_unmatched = merged["item_id"].isna().sum()
    print(f"Unmatched results rows (no item_id found): {n_unmatched}")
    if n_unmatched > 0:
        dbg_path = os.path.join(out_dir, "unmatched_results_rows.csv")
        merged.loc[merged["item_id"].isna(), join_cols + [RESULTS_COLS["response"]]].head(n).to_csv(dbg_path, index=False)
        print(f"Saved up to {n} unmatched rows to: {dbg_path}")


# -----------------------------
# Validate schema
# -----------------------------
require_columns(items, ITEMS_COLS.values(), df_name="items")
require_columns(results, RESULTS_COLS.values(), df_name="results")


# -----------------------------
# Data cleansing (results)
# -----------------------------
# Drop NA responses
results = results[~results[RESULTS_COLS["response"]].isna()].copy()

# Drop condition starting with 'sample' (robust)
cond_series = results[RESULTS_COLS["condition"]].astype(str)
results = results[~cond_series.str.lower().str.startswith("sample")].copy()

# Normalize condition strings
results[RESULTS_COLS["condition"]] = results[RESULTS_COLS["condition"]].map(normalize_text)

# Keep only expected conditions (optional; remove if you want all)
results = results[results[RESULTS_COLS["condition"]].isin(KEEP_CONDITIONS)].copy()

# Map response roles -> types
results[RESULTS_COLS["response"]] = results[RESULTS_COLS["response"]].map(RESPONSE_MAPPING)

# Drop rows that did not map (safety)
results = results[results[RESULTS_COLS["response"]].isin(VALID_TYPES)].copy()

# Reset index after filtering
results = results.reset_index(drop=True)

print("\nAfter cleansing results:")
print("  n rows:", len(results))
print("  conditions:", results[RESULTS_COLS["condition"]].value_counts().to_dict())
print("  responses:", results[RESULTS_COLS["response"]].value_counts().to_dict())


# -----------------------------
# Build canonical join keys
# -----------------------------
items_key = items.copy()
results_key = results.copy()

# create item_id if not present
if "item_id" not in items_key.columns:
    items_key["item_id"] = np.arange(len(items_key), dtype=int)

# utterance keys
for c in ["Q1", "Q2", "A"]:
    items_key[f"{c}_key"] = items_key[ITEMS_COLS[c]].map(normalize_utt)
    results_key[f"{c}_key"] = results_key[RESULTS_COLS[c]].map(normalize_utt)

# observation keys
items_key["obs_key"] = items_key[ITEMS_COLS["observation"]].map(parse_items_observation).map(obs_to_key)
results_key["obs_key"] = results_key[RESULTS_COLS["observation"]].map(parse_results_observation).map(obs_to_key)

# condition keys
items_key["condition_key"] = items_key[ITEMS_COLS["condition"]].map(normalize_text)
results_key["condition_key"] = results_key[RESULTS_COLS["condition"]].map(normalize_text)

# response key (already mapped to low/info/high)
results_key["response_key"] = results_key[RESULTS_COLS["response"]].map(normalize_text)

join_cols = ["condition_key", "Q1_key", "Q2_key", "A_key", "obs_key"]

print("\nExample keys (items):")
print(items_key[join_cols + ["item_id"]].head())

print("\nExample keys (results):")
print(results_key[join_cols + ["response_key"]].head())


# -----------------------------
# Merge results -> items (attach item_id)
# -----------------------------
# NOTE: validate="m:1" assumes each (cond, Q1, Q2, A, obs) appears at most once in items.
# If you have duplicates in items, change to validate="m:m" and investigate duplicates.
merged = results_key.merge(
    items_key[join_cols + ["item_id"]],
    on=join_cols,
    how="left",
    validate="m:1"
)

print(f"\nMerged rows: {len(merged)}")
summarize_unmatched(merged, join_cols, OUT_DIR, n=200)

# keep only matched rows
merged_matched = merged.dropna(subset=["item_id"]).copy()
merged_matched["item_id"] = merged_matched["item_id"].astype(int)


# -----------------------------
# Empirical proportions per (condition, item_id)
# -----------------------------
# Count responses
counts = (
    merged_matched
    .groupby(["condition_key", "item_id", "response_key"])
    .size()
    .rename("n")
    .reset_index()
)

# Total per item
totals = (
    counts
    .groupby(["condition_key", "item_id"])["n"]
    .sum()
    .rename("n_total")
    .reset_index()
)

emp = counts.merge(totals, on=["condition_key", "item_id"], how="left")
emp["p"] = emp["n"] / emp["n_total"]

# Pivot to wide probabilities
emp_wide = (
    emp.pivot_table(index=["condition_key", "item_id"], columns="response_key", values="p", fill_value=0.0)
    .reset_index()
)

# Ensure all three columns exist (even if absent in some subsets)
for col in ["low", "info", "high"]:
    if col not in emp_wide.columns:
        emp_wide[col] = 0.0

emp_wide = emp_wide.rename(columns={"low": "p_low_emp", "info": "p_info_emp", "high": "p_high_emp"})

# Attach original item info for readability
emp_wide = emp_wide.merge(
    items_key[["item_id", ITEMS_COLS["condition"], ITEMS_COLS["Q1"], ITEMS_COLS["Q2"], ITEMS_COLS["A"], ITEMS_COLS["observation"], "obs_key"]],
    on="item_id",
    how="left"
)

# Basic sanity checks
prob_sum = emp_wide[["p_low_emp", "p_info_emp", "p_high_emp"]].sum(axis=1)
if not np.allclose(prob_sum.to_numpy(), 1.0, atol=1e-8):
    print("\n[WARN] Some empirical probability rows do not sum to 1. Showing worst 10:")
    tmp = emp_wide.assign(prob_sum=prob_sum).sort_values("prob_sum")
    print(tmp[["condition_key", "item_id", "p_low_emp", "p_info_emp", "p_high_emp", "prob_sum"]].head(10))

print("\nEmpirical proportions (head):")
print(emp_wide.head())

# Save empirical table
emp_out = os.path.join(OUT_DIR, "empirical_proportions_by_item.csv")
emp_wide.to_csv(emp_out, index=False)
print(f"\nSaved empirical proportions table to: {emp_out}")

In [None]:
# ------------------------------------------------------------
# Step 1) Load model prediction CSVs and stack them into one df
# ------------------------------------------------------------
path_models = {
    "average": "./listener_side/item_generation/listener_beliefs_model_avg_softmax_elpd_loo.csv",
    "four_models": "./listener_side/item_generation/listener_beliefs_by_model_u_o.csv",
}

df_four = pd.read_csv(path_models["four_models"])
df_avg  = pd.read_csv(path_models["average"])

# Basic sanity checks (required columns)
req = {"model", "utt_idx", "obs_idx", "p_low", "p_info", "p_high"}
missing_four = req - set(df_four.columns)
missing_avg  = req - set(df_avg.columns)
if missing_four:
    raise ValueError(f"four_models CSV missing columns: {sorted(missing_four)}")
if missing_avg:
    raise ValueError(f"average CSV missing columns: {sorted(missing_avg)}")

# Ensure consistent dtypes
for df in (df_four, df_avg):
    df["utt_idx"] = df["utt_idx"].astype(int)
    df["obs_idx"] = df["obs_idx"].astype(int)
    for c in ["p_low", "p_info", "p_high"]:
        df[c] = df[c].astype(float)

# Standardize the average model name so it doesn't collide / stays clear
df_avg = df_avg.copy()
df_avg["model"] = "average"

# Stack: 4 models + average
pred_all = pd.concat([df_four, df_avg], ignore_index=True)

print("Loaded predictions:")
print("  four_models rows:", len(df_four))
print("  average rows:    ", len(df_avg))
print("  stacked rows:    ", len(pred_all))
print("\nModels in stacked df:", pred_all["model"].value_counts())

# Optional sanity: probabilities sum to 1
prob_sum = pred_all[["p_low", "p_info", "p_high"]].sum(axis=1)
if not (abs(prob_sum - 1.0) < 1e-6).all():
    print("\n[WARN] Some prediction rows do not sum to 1 (showing 10 worst):")
    tmp = pred_all.assign(_sum=prob_sum).sort_values("_sum")
    print(tmp[["model","utt_idx","obs_idx","p_low","p_info","p_high","_sum"]].head(10))

pred_all.head()

In [None]:
import re

# ------------------------------------------------------------
# Step 2) Build robust text keys in BOTH empirical + prediction dfs
#   pred_all keys: utt_label, obs_label
#   emp_hl keys:   Q1, Q2, A, plus observation (string like "[0, 0, 0, 0, 0]")
# Goal: merge on (condition_key, utt_key, obs_key)
# ------------------------------------------------------------

# ---------- helpers ----------
def normalize_text(x) -> str:
    if pd.isna(x):
        return ""
    s = str(x).strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def normalize_utt_component(x) -> str:
    # remove trailing periods (A in results sometimes has '.')
    s = normalize_text(x)
    s = re.sub(r"[.]+$", "", s)
    return s

def utt_triplet_to_key(q1, q2, a, sep=" / "):
    return f"{normalize_utt_component(q1)}{sep}{normalize_utt_component(q2)}{sep}{normalize_utt_component(a)}"

def parse_obs_any_to_list(x):
    """
    Handles:
      - "[0, 0, 0, 0, 0]"
      - "[9 9 3 3 3]"
      - "12|3|0|0|0"
      - actual lists/arrays
    """
    if pd.isna(x):
        return []
    # already list-like
    if isinstance(x, (list, tuple, np.ndarray)):
        return [int(v) for v in x]
    s = str(x).strip()
    # if pipe-delimited
    if "|" in s:
        parts = [p for p in s.split("|") if p.strip() != ""]
        out = []
        for p in parts:
            m = re.findall(r"-?\d+", p)
            out.append(int(m[0]) if m else 0)
        return out
    # otherwise extract all ints from string (covers brackets/commas/spaces)
    nums = re.findall(r"-?\d+", s)
    return [int(n) for n in nums]

def obs_list_to_key(lst):
    return "|".join(str(int(v)) for v in lst)

def obs_to_key(x):
    return obs_list_to_key(parse_obs_any_to_list(x))


# ---------- 2A) build keys for pred_all ----------
if not {"utt_label", "obs_label"}.issubset(pred_all.columns):
    raise ValueError("pred_all must contain utt_label and obs_label columns.")

pred_key = pred_all.copy()
pred_key["condition_key"] = pred_key.get("condition_key", "")  # may not exist; we'll add later if needed
pred_key["utt_key"] = pred_key["utt_label"].map(lambda s: " / ".join([normalize_utt_component(p) for p in str(s).split("/")]))
pred_key["obs_key"] = pred_key["obs_label"].map(obs_to_key)

# quick sanity
print("pred_key example:")
print(pred_key[["model", "utt_label", "utt_key", "obs_label", "obs_key"]].head())


# ---------- 2B) build keys for emp_hl ----------
# Expect emp_hl to have Q1,Q2,A and an observation column OR already has utt_label/obs_label.
emp_key = emp_hl.copy()

# condition_key should already exist; if not, create from condition
if "condition_key" not in emp_key.columns:
    if "condition" in emp_key.columns:
        emp_key["condition_key"] = emp_key["condition"].map(normalize_text)
    else:
        raise ValueError("emp_hl must have condition_key or condition.")

# Build utt_key
if {"Q1", "Q2", "A"}.issubset(emp_key.columns):
    emp_key["utt_key"] = emp_key.apply(lambda r: utt_triplet_to_key(r["Q1"], r["Q2"], r["A"]), axis=1)
elif "utt_label" in emp_key.columns:
    emp_key["utt_key"] = emp_key["utt_label"].map(lambda s: " / ".join([normalize_utt_component(p) for p in str(s).split("/")]))
else:
    raise ValueError("emp_hl must have (Q1,Q2,A) or utt_label.")

# Build obs_key
# Try common empirical observation columns
obs_col_emp = None
for c in ["observation", "obs_label", "studentsArray", "obs"]:
    if c in emp_key.columns:
        obs_col_emp = c
        break
if obs_col_emp is None:
    raise ValueError("emp_hl must have an observation column like 'observation' (or provide obs_label).")

emp_key["obs_key"] = emp_key[obs_col_emp].map(obs_to_key)

print("\nemp_key example:")
show_cols = ["condition_key", "utt_key", "obs_key", "p_low_emp", "p_info_emp", "p_high_emp"]
print(emp_key[show_cols].head())


# ---------- 2C) merge predictions with empirical ----------
# Note: condition_key is included. If your pred_all does NOT include condition_key,
# we merge only on (utt_key, obs_key), then reattach condition from emp_key.
pred_has_condition = "condition_key" in pred_key.columns and pred_key["condition_key"].astype(str).str.len().gt(0).any()

if pred_has_condition:
    merge_cols = ["condition_key", "utt_key", "obs_key"]
else:
    merge_cols = ["utt_key", "obs_key"]

merged = pred_key.merge(
    emp_key,
    on=merge_cols,
    how="inner",
    suffixes=("_pred", "_emp"),
)

print("\nMerged rows:", len(merged))
print("Merged models:", merged["model"].value_counts().to_dict())

# If pred didn't have condition_key, pull it from empirical side for later steps
if "condition_key" not in merged.columns:
    merged["condition_key"] = merged["condition_key_emp"]

# Optional: save merged for inspection
# merged.to_csv(os.path.join(OUT_DIR, "pred_vs_emp_merged_by_text_keys.csv"), index=False)

merged.head()

In [None]:
import numpy as np
import pandas as pd

# ------------------------------------------------------------
# Step 3) Quantify model–data alignment and identify best model
# ------------------------------------------------------------
# We assume `merged` exists from the previous step and contains:
#   model,
#   p_low, p_info, p_high           (model predictions)
#   p_low_emp, p_info_emp, p_high_emp (empirical)
#   condition_key, utt_key, obs_key
#
# We compute multiple goodness-of-fit metrics for robustness.
# ------------------------------------------------------------

# ---------- helpers ----------
def tv_distance(p, q):
    """
    Total Variation distance between two categorical distributions.
    p, q: arrays of shape (..., 3)
    """
    return 0.5 * np.sum(np.abs(p - q), axis=-1)

def l2_distance(p, q):
    """Squared L2 distance."""
    return np.sum((p - q) ** 2, axis=-1)

def cross_entropy(p_emp, p_model, eps=1e-12):
    """
    Cross-entropy H(p_emp, p_model).
    Lower is better.
    """
    p_model = np.clip(p_model, eps, 1.0)
    return -np.sum(p_emp * np.log(p_model), axis=-1)


# ---------- prepare arrays ----------
p_model = merged[["p_low", "p_info", "p_high"]].to_numpy()
p_emp   = merged[["p_low_emp", "p_info_emp", "p_high_emp"]].to_numpy()

# ---------- compute per-item distances ----------
merged = merged.copy()
merged["tv"]  = tv_distance(p_model, p_emp)
merged["l2"]  = l2_distance(p_model, p_emp)
merged["ce"]  = cross_entropy(p_emp, p_model)

# Optional sanity
assert np.isfinite(merged[["tv","l2","ce"]].to_numpy()).all()

print("\nPer-item fit (head):")
print(
    merged[
        ["model", "condition_key", "utt_key", "obs_key",
         "tv", "l2", "ce"]
    ].head()
)


# ------------------------------------------------------------
# Step 4) Aggregate fit by model
# ------------------------------------------------------------
fit_summary = (
    merged
    .groupby("model", as_index=False)
    .agg(
        n_items=("tv", "size"),
        tv_mean=("tv", "mean"),
        tv_median=("tv", "median"),
        l2_mean=("l2", "mean"),
        ce_mean=("ce", "mean"),
    )
    .sort_values("tv_mean")
)

print("\nModel fit summary (lower is better):")
print(fit_summary)

# Identify best model according to primary metric (TV distance)
best_model = fit_summary.iloc[0]["model"]
print("\nBest-fitting model (by TV distance):", best_model)

# Optional: save summary
# fit_summary.to_csv(os.path.join(OUT_DIR, "model_fit_summary.csv"), index=False)


# ------------------------------------------------------------
# Step 5) (Optional but recommended) Condition-wise diagnostics
# ------------------------------------------------------------
fit_by_condition = (
    merged
    .groupby(["model", "condition_key"], as_index=False)
    .agg(
        tv_mean=("tv", "mean"),
        ce_mean=("ce", "mean"),
        n=("tv", "size"),
    )
    .sort_values(["model", "condition_key"])
)

print("\nModel fit by condition:")
print(fit_by_condition)

# Optional: save
# fit_by_condition.to_csv(os.path.join(OUT_DIR, "model_fit_by_condition.csv"), index=False)