# RSA Model Fitting Notebook

This notebook fits RSA speaker models to participant data from the N5M1 experiment.

For each participant and each goal condition (inf, persp, persm), we fit **all 6 models**:
1. **Literal** - Uniform distribution over true utterances
2. **inf_T** - Informative pragmatic, dynamic (update_internal=True)
3. **inf_F** - Informative pragmatic, static (update_internal=False)
4. **persp_T** - Persuade+ pragmatic, dynamic
5. **persp_F** - Persuade+ pragmatic, static
6. **persm_T** - Persuade- pragmatic, dynamic
7. **persm_F** - Persuade- pragmatic, static

Output formats:
- **Wide format**: One row per participant (3 conditions × 6 models = 18 fits per participant)
- **Long format**: One row per participant × condition × model

## Cell 1: Imports and Setup

In [19]:
import pandas as pd
import numpy as np
import warnings
from pathlib import Path
from typing import List, Dict, Tuple, Any, Optional
from tqdm import tqdm

# Import RSA modules (these should be in the same directory or in PYTHONPATH)
import sys
# Uncomment and modify if needed:
# sys.path.append('/path/to/rsa_modules/')

from rsa_optimal_exp_core import World, LiteralSpeaker, PragmaticSpeaker_obs
from rsa_optimal_exp_fitting import log_likelihood_utt_seq, log_likelihood_alpha_opt_utt_seq

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

## Cell 2: Configuration

In [20]:
# Experiment configuration
# n=1: One group/experiment
# m=5: 5 patients (Bernoulli trials) per group
N_EXPERIMENTS = 1  # Number of independent experiments
M_PATIENTS = 5     # Number of patients per experiment (Bernoulli trials)

# Model fitting configuration
ALPHA_BOUNDS = (0.001, 100.0)  # Bounds for alpha optimization
GRID_SEARCH = False            # Use grid search (more robust)
GRID_POINTS = 300             # Number of grid points
GRID_SPACING = "log"          # Logarithmic spacing
INCLUDE_DETERM = True         # Include deterministic alpha

# Expected number of trials per condition
EXPECTED_TRIALS = 10

# Mapping from experiment data to RSA model
PREDICATE_MAP = {'Effective': 'successful', 'Ineffective': 'unsuccessful'}
QUANTIFIER_MAP = {'All': 'all', 'Most': 'most', 'Some': 'some', 'No': 'no'}

# Goal conditions in the experiment
GOAL_CONDITIONS = ['inf', 'persp', 'persm']

# All models to fit for each condition
# Format: (model_name, psi, update_internal)
# psi: 'inf', 'pers+', 'pers-'
# update_internal: True (dynamic), False (static)
MODELS_TO_FIT = [
    ('literal', None, None),           # Literal model (no psi/update)
    ('inf_T', 'inf', True),            # Informative, dynamic
    ('inf_F', 'inf', False),           # Informative, static
    ('persp_T', 'pers+', True),        # Persuade+, dynamic
    ('persp_F', 'pers+', False),       # Persuade+, static
    ('persm_T', 'pers-', True),        # Persuade-, dynamic
    ('persm_F', 'pers-', False),       # Persuade-, static
]

MODEL_NAMES = [m[0] for m in MODELS_TO_FIT]

## Cell 3: Helper Functions

In [21]:
def create_world(n: int = N_EXPERIMENTS, m: int = M_PATIENTS) -> World:
    """Create the RSA World object for the experiment."""
    return World(n=n, m=m)


def num_effective_to_observation(num_effective: int, m: int = M_PATIENTS) -> Tuple[int, ...]:
    """
    Convert num_effective (count of effective patients) to observation tuple.
    
    For n=1, m=5, observation is a one-hot encoding of length m+1=6:
    - num_effective=0 -> (1, 0, 0, 0, 0, 0)
    - num_effective=3 -> (0, 0, 0, 1, 0, 0)
    - num_effective=5 -> (0, 0, 0, 0, 0, 1)
    """
    n_effective = int(num_effective)
    return tuple(1 if i == n_effective else 0 for i in range(m + 1))


def format_utterance(predicate: str, quantifier: str) -> str:
    """
    Convert experiment predicate/quantifier to RSA utterance format.
    
    RSA uses comma-separated format: "quantifier,predicate"
    Example: ('Effective', 'Most') -> 'most,successful'
    """
    pred = PREDICATE_MAP.get(predicate, predicate.lower())
    quant = QUANTIFIER_MAP.get(quantifier, quantifier.lower())
    return f"{quant},{pred}"


def extract_trials_for_condition(df_row: pd.Series, condition: str) -> Dict[str, Any]:
    """
    Extract observation and utterance sequences for a specific condition.
    
    Parameters
    ----------
    df_row : pd.Series
        A row from the processed data (one participant)
    condition : str
        Condition prefix: 'inf', 'persp', or 'persm'
    
    Returns
    -------
    Dict with:
        - obs_seq: List of observation tuples
        - utt_seq: List of utterance strings
        - n_trials: Number of valid trials
        - is_complete: Whether all expected trials are present
    """
    obs_seq = []
    utt_seq = []
    
    for r in range(1, 11):  # Rounds 1-10
        num_eff_col = f'{condition}_r{r}_num_effective'
        pred_col = f'{condition}_r{r}_predicate'
        quant_col = f'{condition}_r{r}_quantifier'
        
        # Check if columns exist and have valid values
        if num_eff_col not in df_row.index:
            continue
            
        num_eff = df_row.get(num_eff_col)
        pred = df_row.get(pred_col)
        quant = df_row.get(quant_col)
        
        # Skip if any value is missing
        if pd.isna(num_eff) or pd.isna(pred) or pd.isna(quant):
            continue
        if pred == '' or quant == '':
            continue
            
        obs = num_effective_to_observation(int(num_eff))
        utt = format_utterance(pred, quant)
        
        obs_seq.append(obs)
        utt_seq.append(utt)
    
    n_trials = len(obs_seq)
    is_complete = (n_trials == EXPECTED_TRIALS)
    
    return {
        'obs_seq': obs_seq,
        'utt_seq': utt_seq,
        'n_trials': n_trials,
        'is_complete': is_complete
    }

## Cell 4: Model Fitting Functions

In [22]:
def fit_literal_model(
    world: World,
    obs_seq: List[Tuple[int, ...]],
    utt_seq: List[str]
) -> Dict[str, Any]:
    """
    Fit literal speaker model (no alpha optimization needed).
    
    Returns
    -------
    Dict with:
        - ll: float (log-likelihood)
        - alpha: None (literal has no alpha)
    """
    config = {
        'speaker_type': 'literal',
        'initial_beliefs_theta': None
    }
    
    try:
        ll = log_likelihood_utt_seq(world, obs_seq, utt_seq, config)
    except Exception as e:
        ll = np.nan
    
    return {
        'll': ll,
        'alpha': None
    }


def fit_pragmatic_model(
    world: World,
    obs_seq: List[Tuple[int, ...]],
    utt_seq: List[str],
    psi: str,
    update_internal: bool,
    alpha_bounds: Tuple[float, float] = ALPHA_BOUNDS,
    grid_search: bool = GRID_SEARCH,
    grid_points: int = GRID_POINTS,
    grid_spacing: str = GRID_SPACING,
    include_determ: bool = INCLUDE_DETERM
) -> Dict[str, Any]:
    """
    Fit pragmatic speaker model with alpha optimization.
    
    Parameters
    ----------
    world : World
        RSA World object
    obs_seq : List of observation tuples
    utt_seq : List of utterance strings
    psi : str
        Speaker goal: 'inf', 'pers+', or 'pers-'
    update_internal : bool
        True for dynamic model, False for static model
    
    Returns
    -------
    Dict with:
        - ll: float (log-likelihood at optimal alpha)
        - alpha: float or 'determ' or None
    """
    config = {
        'speaker_type': 'pragmatic',
        'omega': 'strat',  # Strategic speaker
        'psi': psi,
        'update_internal': update_internal,
        'beta': 0.0,  # Pure goal (no informativeness mixing)
        'initial_beliefs_theta': None
    }
    
    try:
        result = log_likelihood_alpha_opt_utt_seq(
            world=world,
            obs_seq=obs_seq,
            utt_seq=utt_seq,
            speaker_config=config,
            alpha_bounds=alpha_bounds,
            grid_search=grid_search,
            grid_points=grid_points,
            grid_spacing=grid_spacing,
            include_determ=include_determ
        )
        
        return {
            'll': result['max_log_likelihood'],
            'alpha': result['optimal_alpha']
        }
        
    except Exception as e:
        return {
            'll': np.nan,
            'alpha': None
        }


def fit_all_models_for_sequence(
    world: World,
    obs_seq: List[Tuple[int, ...]],
    utt_seq: List[str]
) -> Dict[str, Dict[str, Any]]:
    """
    Fit all 6 models for one observation-utterance sequence.
    
    Returns
    -------
    Dict with model names as keys, each containing {'ll': float, 'alpha': value}
    """
    results = {}
    
    for model_name, psi, update_internal in MODELS_TO_FIT:
        if model_name == 'literal':
            results[model_name] = fit_literal_model(world, obs_seq, utt_seq)
        else:
            results[model_name] = fit_pragmatic_model(
                world, obs_seq, utt_seq, psi, update_internal
            )
    
    return results

## Cell 5: Main Processing Functions

In [23]:
def process_participant_wide(
    df_row: pd.Series,
    world: World
) -> Dict[str, Any]:
    """
    Process all conditions for a single participant (wide format).
    
    Returns dict with one entry per condition-model combination.
    """
    participant_id = df_row.get('participant_id', df_row.name)
    result = {'participant_id': participant_id}
    
    for condition in GOAL_CONDITIONS:
        # Extract trial data
        trial_data = extract_trials_for_condition(df_row, condition)
        
        # Store trial info
        result[f'{condition}_n_trials'] = trial_data['n_trials']
        result[f'{condition}_is_complete'] = trial_data['is_complete']
        
        # Fit all models if we have any trials
        if trial_data['n_trials'] > 0:
            model_results = fit_all_models_for_sequence(
                world,
                trial_data['obs_seq'],
                trial_data['utt_seq']
            )
            
            # Store results for each model
            for model_name in MODEL_NAMES:
                mr = model_results[model_name]
                result[f'{condition}_{model_name}_ll'] = mr['ll']
                result[f'{condition}_{model_name}_alpha'] = mr['alpha']
        else:
            # No trials - fill with NaN
            for model_name in MODEL_NAMES:
                result[f'{condition}_{model_name}_ll'] = np.nan
                result[f'{condition}_{model_name}_alpha'] = None
    
    return result


def fit_all_participants_wide(
    df: pd.DataFrame,
    verbose: bool = True
) -> pd.DataFrame:
    """
    Fit all models for all participants (wide format).
    
    Parameters
    ----------
    df : pd.DataFrame
        Processed data with one row per participant
    verbose : bool
        Whether to show progress
    
    Returns
    -------
    pd.DataFrame
        Wide-format results with one row per participant
        Columns: participant_id, {condition}_{model}_ll, {condition}_{model}_alpha
    """
    # Create world
    world = create_world(n=N_EXPERIMENTS, m=M_PATIENTS)
    
    if verbose:
        print(f"World created: n={N_EXPERIMENTS}, m={M_PATIENTS}")
        print(f"  Utterances: {world.utterances}")
        print(f"  Possible outcomes: {len(world.possible_outcomes)}")
        print(f"\nFitting {len(MODEL_NAMES)} models × {len(GOAL_CONDITIONS)} conditions = {len(MODEL_NAMES) * len(GOAL_CONDITIONS)} fits per participant")
    
    # Process all participants
    all_results = []
    
    iterator = df.iterrows()
    if verbose:
        iterator = tqdm(list(iterator), desc="Fitting models")
    
    for idx, row in iterator:
        result = process_participant_wide(row, world)
        all_results.append(result)
    
    # Convert to DataFrame
    results_df = pd.DataFrame(all_results)
    
    # Reorder columns
    cols = ['participant_id']
    for condition in GOAL_CONDITIONS:
        cols.extend([f'{condition}_n_trials', f'{condition}_is_complete'])
        for model in MODEL_NAMES:
            cols.extend([f'{condition}_{model}_ll', f'{condition}_{model}_alpha'])
    
    # Only include columns that exist
    cols = [c for c in cols if c in results_df.columns]
    results_df = results_df[cols]
    
    return results_df

## Cell 6: Create Anonymized Version

In [38]:
def create_anonymized_version(df):
    """
    Create anonymized version by removing identifying columns.
    Keeps participant_id as the only identifier.
    """
    # Columns to remove for anonymization
    id_columns = ['subject_id', 'prolific_pid', 'study_id', 'session_id', 
                  'start_time', 'completion_time', 'source_file']
    
    df_anon = df.copy()
    for col in id_columns:
        if col in df_anon.columns:
            df_anon = df_anon.drop(columns=[col])
    
    return df_anon

## Cell 7: Run Model Fitting

Edit `INPUT_PATH` to point to your processed data file.

In [36]:
INPUT_PATH_FULL = './processed_speaker_n1_full.csv'  

# Load data
df_raw = pd.read_csv(INPUT_PATH_FULL)
print(f"Loaded {len(df_raw)} participants from {INPUT_PATH_FULL}")

# Filter to completed participants only
df = df_raw[df_raw['completion_status'] == 'completed'].copy()
print(f"Filtering {len(df_raw)} to {len(df)} completed participants")

# Run model fitting (wide format)
print(f"\n{'='*70}")
print("FITTING MODELS")
print(f"{'='*70}")
fitted_result = fit_all_participants_wide(df, verbose=True)

Loaded 111 participants from ./processed_speaker_n1_full.csv
Filtering 111 to 109 completed participants

FITTING MODELS
World created: n=1, m=5
  Utterances: ['all,successful', 'all,unsuccessful', 'most,successful', 'most,unsuccessful', 'some,successful', 'some,unsuccessful', 'no,successful', 'no,unsuccessful']
  Possible outcomes: 6

Fitting 7 models × 3 conditions = 21 fits per participant


Fitting models: 100%|██████████| 109/109 [04:28<00:00,  2.47s/it]


_merge
both          109
left_only       0
right_only      0
Name: count, dtype: int64

In [46]:
key = "participant_id"
overlap = df.columns.intersection(fitted_result.columns).difference([key])
print("Overlapping columns to drop before merge:", overlap.tolist())
df_fitted = df.merge(
    fitted_result.drop(columns=overlap),
    on=key,
    how="left",
    validate="one_to_one",
    indicator=True 
)

df_fitted["_merge"].value_counts()

Overlapping columns to drop before merge: ['inf_n_trials', 'persm_n_trials', 'persp_n_trials']


_merge
both          109
left_only       0
right_only      0
Name: count, dtype: int64

In [47]:
# Create anonymized version
df_fitted_anon = create_anonymized_version(df_fitted)

In [48]:
df.columns.tolist()

['participant_id',
 'study',
 'subject_id',
 'prolific_pid',
 'study_id',
 'session_id',
 'experiment_version',
 'start_time',
 'completion_status',
 'completion_time',
 'terminated_early',
 'termination_reason',
 'duration_minutes',
 'total_time_elapsed_ms',
 'block_1_scenario',
 'block_2_scenario',
 'block_3_scenario',
 'block_order',
 'attention_total_failures',
 'attention_block_1_passed',
 'attention_block_1_round',
 'attention_block_1_time_elapsed',
 'attention_block_1_num_effective',
 'attention_block_1_stimulus_variant',
 'attention_block_1_required_description',
 'attention_block_2_passed',
 'attention_block_2_round',
 'attention_block_2_time_elapsed',
 'attention_block_2_num_effective',
 'attention_block_2_stimulus_variant',
 'attention_block_2_required_description',
 'attention_block_3_passed',
 'attention_block_3_round',
 'attention_block_3_time_elapsed',
 'attention_block_3_num_effective',
 'attention_block_3_stimulus_variant',
 'attention_block_3_required_description',
 '

In [49]:
df_fitted_anon.head(3)

Unnamed: 0,participant_id,study,experiment_version,completion_status,terminated_early,termination_reason,duration_minutes,total_time_elapsed_ms,block_1_scenario,block_2_scenario,...,persm_inf_F_alpha,persm_persp_T_ll,persm_persp_T_alpha,persm_persp_F_ll,persm_persp_F_alpha,persm_persm_T_ll,persm_persm_T_alpha,persm_persm_F_ll,persm_persm_F_alpha,_merge
0,P002,pilot,1.0.0,completed,False,,11.804617,708246,pers_minus,pers_plus,...,0.001005,-11.850625,0.001005,-11.852123,0.001005,0.0,determ,0.0,determ,both
1,P003,pilot,1.0.0,completed,False,,7.86035,471616,informative,pers_minus,...,0.001005,-11.850625,0.001005,-11.852123,0.001005,0.0,determ,0.0,determ,both
2,P004,pilot,1.0.0,completed,False,,12.351267,741064,pers_minus,informative,...,0.001005,-11.850625,0.001005,-11.852123,0.001005,0.0,determ,0.0,determ,both


## Cell 8: View Wide Format Results

In [50]:
# View first few rows (wide format)
print("Wide format (first 3 rows, selected columns):")
display_cols = ['participant_id']
for cond in GOAL_CONDITIONS[:1]:  # Just show first condition
    display_cols.extend([f'{cond}_n_trials', f'{cond}_is_complete'])
    display_cols.extend([f'{cond}_{m}_ll' for m in MODEL_NAMES])
df_fitted_anon[display_cols].head(3)

Wide format (first 3 rows, selected columns):


Unnamed: 0,participant_id,inf_n_trials,inf_is_complete,inf_literal_ll,inf_inf_T_ll,inf_inf_F_ll,inf_persp_T_ll,inf_persp_F_ll,inf_persm_T_ll,inf_persm_F_ll
0,P002,10,True,-11.273805,-11.2755,-10.550514,-11.275738,-11.275794,-8.567772,-9.346416
1,P003,10,True,-11.273805,-11.276538,-4.527161,-11.276544,-11.277816,-0.693147,-0.693147
2,P004,10,True,-11.561487,-11.311724,-1.386294,-9.209576,-8.041613,-11.563691,-11.565156


In [51]:
# View all LL columns for one participant
print("All log-likelihoods for first participant:")
ll_cols = ['participant_id'] + [c for c in df_fitted.columns if '_ll' in c]
df_fitted[ll_cols].head(1).T

All log-likelihoods for first participant:


Unnamed: 0,0
participant_id,P002
inf_literal_ll,-11.273805
inf_inf_T_ll,-11.2755
inf_inf_F_ll,-10.550514
inf_persp_T_ll,-11.275738
inf_persp_F_ll,-11.275794
inf_persm_T_ll,-8.567772
inf_persm_F_ll,-9.346416
persp_literal_ll,-11.849169
persp_inf_T_ll,-11.8496


## Cell 13: Save Results

In [55]:
# === EDIT THESE PATHS ===
OUTPUT_WIDE = './raw_do_not_track/speaker_n1_fitted.csv'

# Save wide format
df_fitted.to_csv(OUTPUT_WIDE, index=False)
print(f"Saved fitted data ({len(df_fitted)} rows, {len(df_fitted.columns)} columns) to: {OUTPUT_WIDE}")

Saved fitted data (109 rows, 325 columns) to: ./raw_do_not_track/speaker_n1_fitted.csv


In [54]:
# === EDIT THESE PATHS ===
OUTPUT_WIDE = './speaker_n1_fitted_anonymized.csv'

# Save wide format
df_fitted_anon.to_csv(OUTPUT_WIDE, index=False)
print(f"Saved anonymized fitted data ({len(df_fitted_anon)} rows, {len(df_fitted_anon.columns)} columns) to: {OUTPUT_WIDE}")

Saved anonymized fitted data (109 rows, 318 columns) to: ./speaker_n1_fitted_anonymized.csv


# Cell 14: Generate Code Book

In [None]:
def profile_dataframe(df: pd.DataFrame, key_cols=None, max_levels=25) -> pd.DataFrame:
    """
    Build a "data dictionary" style profile of df:
    - dtype, missingness, unique counts
    - numeric stats (min/max/mean)
    - sample levels for low-cardinality columns
    """
    if key_cols is None:
        key_cols = []

    rows = []
    n = len(df)

    for col in df.columns:
        s = df[col]
        dtype = str(s.dtype)

        n_missing = int(s.isna().sum())
        pct_missing = (n_missing / n) if n else np.nan
        n_unique = int(s.nunique(dropna=True))

        info = {
            "column": col,
            "dtype": dtype,
            "n_rows": n,
            "n_missing": n_missing,
            "pct_missing": pct_missing,
            "n_unique": n_unique,
            "is_key": col in set(key_cols),
        }

        # Numeric summary
        if pd.api.types.is_numeric_dtype(s):
            info.update({
                "min": float(np.nanmin(s)) if n_missing < n else np.nan,
                "max": float(np.nanmax(s)) if n_missing < n else np.nan,
                "mean": float(np.nanmean(s)) if n_missing < n else np.nan,
                "example_values": ""
            })

        # Datetime-like (try parse)
        elif pd.api.types.is_datetime64_any_dtype(s):
            info.update({
                "min": str(s.min()) if n_missing < n else "",
                "max": str(s.max()) if n_missing < n else "",
                "mean": "",
                "example_values": ""
            })

        # Object / categorical: show levels if small
        else:
            info.update({"min": "", "max": "", "mean": ""})
            if n_unique <= max_levels:
                levels = s.dropna().astype(str).unique().tolist()
                info["example_values"] = "; ".join(levels[:max_levels])
            else:
                # show a few examples
                examples = s.dropna().astype(str).head(5).tolist()
                info["example_values"] = "; ".join(examples)

        rows.append(info)

    profile = pd.DataFrame(rows).sort_values(
        by=["is_key", "pct_missing", "n_unique"],
        ascending=[False, False, True]
    )

    # Key checks (if provided)
    for k in key_cols:
        if k in df.columns:
            dup_count = int(df[k].duplicated().sum())
            print(f"[KEY CHECK] {k}: unique={df[k].is_unique}, duplicates={dup_count}, missing={df[k].isna().sum()}")
        else:
            print(f"[KEY CHECK] {k}: column not found in df")

    return profile


# Example usage:
profile = profile_dataframe(df_fitted_anon, key_cols=["participant_id"])
print(profile.head(30))
profile.to_csv("data_dictionary.csv", index=False)

[KEY CHECK] participant_id: unique=True, duplicates=0, missing=0
                                 column    dtype  n_rows  n_missing  \
0                        participant_id   object     109          0   
5                    termination_reason   object     109        109   
65                     inf_r1_rt_approx  float64     109        109   
136                  persp_r1_rt_approx  float64     109        109   
207                  persm_r1_rt_approx  float64     109        109   
274                   inf_literal_alpha   object     109        109   
289                 persp_literal_alpha   object     109        109   
304                 persm_literal_alpha   object     109        109   
17   attention_block_1_stimulus_variant  float64     109        100   
28      attention_block_3_num_effective  float64     109        100   
16      attention_block_1_num_effective  float64     109        100   
23   attention_block_2_stimulus_variant  float64     109        100   
22      atte