  Returns P(theta) \propto exp( sum_{psi, alpha} log_joint(theta,psi,alpha) ).
  Returns P(psi) \propto exp( sum_{theta, alpha} log_joint(theta,psi,alpha) ).
  Returns P(alpha) \propto exp( sum_{theta, psi} log_joint(theta,psi,alpha)).


In [None]:
import numpy as np
import pandas as pd
import itertools
from scipy.special import logsumexp
import copy
from typing import Dict, Tuple, List, Optional


# =============================================================================
# Step 1: Set up the world and basic quantities
# =============================================================================

print("="*70)
print("STEP 1: Setting up the world")
print("="*70)

world = World(n=1, m=5)

utterances = world.utterances
theta_values = world.theta_values
n_utterances = len(utterances)
n_theta = len(theta_values)
n_rounds = 5
n_sequences = n_utterances ** n_rounds

# Flat prior over theta
log_prior_theta = np.full(n_theta, -np.log(n_theta))

# P(O|θ) matrix for marginalizing
log_P_O_given_theta = world.obs_log_likelihood_theta.values  # (n_obs, n_theta)

# All sequences as tuples of indices
all_sequences = list(itertools.product(range(n_utterances), repeat=n_rounds))
sequence_labels = [tuple(utterances[i] for i in seq) for seq in all_sequences]

# Alpha values
alpha_values = [1.000, 1.275, 1.626, 2.073, 2.643, 3.371, 4.299, 5.482, 
                6.988, 8.909, 11.36, 14.48, 18.46, 23.54, 30.02, 38.28, 
                48.80, 62.23, 79.32, 100.0]
n_alpha = len(alpha_values)

print(f"Utterances ({n_utterances}): {utterances}")
print(f"Theta values ({n_theta}): {list(theta_values)}")
print(f"Rounds: {n_rounds}")
print(f"Total sequences: {n_sequences}")
print(f"Alpha values: {n_alpha} values from {alpha_values[0]} to {alpha_values[-1]}")

# =============================================================================
# Step 2: Define helper functions
# =============================================================================

print("\n" + "="*70)
print("STEP 2: Defining helper functions")
print("="*70)

def get_log_P_u_given_theta_from_speaker(speaker, log_P_O_given_theta):
    """
    Compute log P(u|θ) = log Σ_O P(u|O) · P(O|θ) for a speaker.
    Returns shape (n_utterances, n_theta).
    """
    log_P_u_given_O = speaker.utterance_log_prob_obs.values  # (n_utt, n_obs)
    log_P_u_given_theta = log_M_product(
        log_P_u_given_O,
        log_P_O_given_theta,
        precise=USE_PRECISE_LOGSPACE
    )
    return log_P_u_given_theta


def compute_log_P_seq_given_theta_independent(log_P_u_given_theta, all_sequences, n_rounds):
    """
    Compute log P(seq | θ) when utterances are conditionally independent given θ.
    Used for: Literal speaker, and Pragmatic speakers with update_internal=False.
    
    log P(seq | θ) = Σᵢ log P(uᵢ | θ)
    
    Returns shape (n_sequences, n_theta).
    """
    n_sequences = len(all_sequences)
    n_theta = log_P_u_given_theta.shape[1]
    
    log_P_seq_given_theta = np.zeros((n_sequences, n_theta))
    
    for seq_idx, seq in enumerate(all_sequences):
        for round_idx in range(n_rounds):
            u_idx = seq[round_idx]
            log_P_seq_given_theta[seq_idx, :] += log_P_u_given_theta[u_idx, :]
    
    return log_P_seq_given_theta


def marginalize_over_theta(log_P_seq_given_theta, log_prior_theta):
    """
    Compute log P(seq) = log Σ_θ P(seq | θ) · P(θ).
    Returns shape (n_sequences,).
    """
    log_P_seq_and_theta = log_P_seq_given_theta + log_prior_theta  # broadcast
    log_P_seq = logsumexp(log_P_seq_and_theta, axis=1)
    return log_P_seq

print("Helper functions defined.")

# =============================================================================
# Step 3: Compute for LITERAL SPEAKER (no alpha dependence)
# =============================================================================

print("\n" + "="*70)
print("STEP 3: Computing for LITERAL SPEAKER")
print("="*70)

literal_listener = LiteralListener(world)
literal_speaker = literal_listener.literal_speaker

# Get P(u|θ) - this is already computed in literal_listener
log_P_u_given_theta_literal = literal_listener.utterance_log_likelihood_theta.values

# Compute P(seq|θ) - utterances are conditionally independent
log_P_seq_given_theta_literal = compute_log_P_seq_given_theta_independent(
    log_P_u_given_theta_literal, all_sequences, n_rounds
)

# Marginalize over θ
log_P_seq_literal = marginalize_over_theta(log_P_seq_given_theta_literal, log_prior_theta)

print(f"Sum P(seq) for literal: {np.exp(logsumexp(log_P_seq_literal)):.10f}")

# =============================================================================
# Step 4: Compute for PRAGMATIC SPEAKERS with update_internal=FALSE
# =============================================================================

print("\n" + "="*70)
print("STEP 4: Computing for PRAGMATIC SPEAKERS (update_internal=False)")
print("="*70)

# These are conditionally independent given θ, so we use the simple approach.
# We need to compute for each (psi, alpha) combination.

# Storage: dict[psi] -> array of shape (n_sequences, n_alpha)
log_P_seq_pragmatic_F = {
    "inf": np.zeros((n_sequences, n_alpha)),
    "pers+": np.zeros((n_sequences, n_alpha)),
    "pers-": np.zeros((n_sequences, n_alpha))
}

for psi in ["inf", "pers+", "pers-"]:
    omega = "coop" if psi == "inf" else "strat"
    print(f"\n  Processing psi='{psi}' (omega='{omega}')...")
    
    for alpha_idx, alpha in enumerate(alpha_values):
        # Create pragmatic speaker with update_internal=False
        speaker = PragmaticSpeaker_obs(
            world=world,
            omega=omega,
            psi=psi,
            update_internal=False,
            alpha=alpha,
            beta=0.0
        )
        
        # Get P(u|θ)
        log_P_u_given_theta = get_log_P_u_given_theta_from_speaker(speaker, log_P_O_given_theta)
        
        # Compute P(seq|θ) - conditionally independent
        log_P_seq_given_theta = compute_log_P_seq_given_theta_independent(
            log_P_u_given_theta, all_sequences, n_rounds
        )
        
        # Marginalize over θ
        log_P_seq = marginalize_over_theta(log_P_seq_given_theta, log_prior_theta)
        
        log_P_seq_pragmatic_F[psi][:, alpha_idx] = log_P_seq
        
        if alpha_idx == 0 or alpha_idx == n_alpha - 1:
            print(f"    alpha={alpha:.2f}: Sum P(seq) = {np.exp(logsumexp(log_P_seq)):.10f}")

print("\nCompleted update_internal=False speakers.")

# =============================================================================
# Step 5: Compute for PRAGMATIC SPEAKERS with update_internal=TRUE
# =============================================================================

print("\n" + "="*70)
print("STEP 5: Computing for PRAGMATIC SPEAKERS (update_internal=True)")
print("="*70)

# This requires computing P(u|θ) for each history prefix.
# Total histories: 1 + 8 + 64 + 512 + 4096 = 4681

def create_speaker_with_history(world, omega, psi, alpha, history_utterances):
    """
    Create a pragmatic speaker and apply utterance history to its internal listener.
    """
    speaker = PragmaticSpeaker_obs(
        world=world,
        omega=omega,
        psi=psi,
        update_internal=True,
        alpha=alpha,
        beta=0.0
    )
    
    # Apply each utterance in the history
    for u in history_utterances:
        speaker.literal_listener.listen_and_update(u)
        speaker.utterance_log_prob_obs = speaker._compute_utterance_log_prob_obs(alpha)
    
    return speaker


def compute_log_P_seq_given_theta_history_dependent(
    world, omega, psi, alpha, all_sequences, n_rounds, log_P_O_given_theta, utterances
):
    """
    Compute log P(seq | θ) when utterances depend on history (update_internal=True).
    
    log P(seq | θ) = Σᵢ log P^{h_i}(uᵢ | θ)
    
    where h_i = (u_0, ..., u_{i-1}) is the history before round i.
    
    Returns shape (n_sequences, n_theta).
    """
    n_sequences = len(all_sequences)
    n_theta = len(world.theta_values)
    n_utterances = len(utterances)
    
    # Precompute P(u|θ) for all histories
    # Key: history tuple (indices), Value: log P(u|θ) matrix of shape (n_utt, n_theta)
    history_log_P_u_given_theta = {}
    
    for history_length in range(n_rounds):
        if history_length == 0:
            # Empty history
            speaker = create_speaker_with_history(world, omega, psi, alpha, [])
            history_log_P_u_given_theta[()] = get_log_P_u_given_theta_from_speaker(
                speaker, log_P_O_given_theta
            )
        else:
            # All histories of this length
            for history in itertools.product(range(n_utterances), repeat=history_length):
                history_utterances = [utterances[idx] for idx in history]
                speaker = create_speaker_with_history(world, omega, psi, alpha, history_utterances)
                history_log_P_u_given_theta[history] = get_log_P_u_given_theta_from_speaker(
                    speaker, log_P_O_given_theta
                )
    
    # Now compute P(seq|θ) for each sequence
    log_P_seq_given_theta = np.zeros((n_sequences, n_theta))
    
    for seq_idx, seq in enumerate(all_sequences):
        log_prob = np.zeros(n_theta)
        for round_idx in range(n_rounds):
            history = seq[:round_idx]
            u_idx = seq[round_idx]
            log_prob += history_log_P_u_given_theta[history][u_idx, :]
        log_P_seq_given_theta[seq_idx, :] = log_prob
    
    return log_P_seq_given_theta


# Storage: dict[psi] -> array of shape (n_sequences, n_alpha)
log_P_seq_pragmatic_T = {
    "inf": np.zeros((n_sequences, n_alpha)),
    "pers+": np.zeros((n_sequences, n_alpha)),
    "pers-": np.zeros((n_sequences, n_alpha))
}

for psi in ["inf", "pers+", "pers-"]:
    omega = "coop" if psi == "inf" else "strat"
    print(f"\n  Processing psi='{psi}' (omega='{omega}')...")
    
    for alpha_idx, alpha in enumerate(alpha_values):
        # Compute P(seq|θ) with history dependence
        log_P_seq_given_theta = compute_log_P_seq_given_theta_history_dependent(
            world, omega, psi, alpha, all_sequences, n_rounds, 
            log_P_O_given_theta, utterances
        )
        
        # Marginalize over θ
        log_P_seq = marginalize_over_theta(log_P_seq_given_theta, log_prior_theta)
        
        log_P_seq_pragmatic_T[psi][:, alpha_idx] = log_P_seq
        
        if alpha_idx % 5 == 0 or alpha_idx == n_alpha - 1:
            print(f"    alpha={alpha:.2f}: Sum P(seq) = {np.exp(logsumexp(log_P_seq)):.10f}")

print("\nCompleted update_internal=True speakers.")

# =============================================================================
# Step 6: Organize all results into a single DataFrame
# =============================================================================

print("\n" + "="*70)
print("STEP 6: Organizing results into DataFrame")
print("="*70)

# We'll create a DataFrame with columns:
# - sequence (tuple of utterances)
# - sequence_idx (tuple of indices)
# - literal (single column, no alpha dependence)
# - pragmatic_inf_F_alpha=X (one column per alpha)
# - pragmatic_persp_F_alpha=X
# - pragmatic_persm_F_alpha=X
# - pragmatic_inf_T_alpha=X
# - pragmatic_persp_T_alpha=X
# - pragmatic_persm_T_alpha=X

results = {
    'sequence': sequence_labels,
    'sequence_idx': all_sequences,
    'literal': np.exp(log_P_seq_literal)
}

# Add pragmatic F columns
for psi, short_name in [("inf", "inf"), ("pers+", "persp"), ("pers-", "persm")]:
    for alpha_idx, alpha in enumerate(alpha_values):
        col_name = f"pragmatic_{short_name}_F_alpha={alpha:.2f}"
        results[col_name] = np.exp(log_P_seq_pragmatic_F[psi][:, alpha_idx])

# Add pragmatic T columns
for psi, short_name in [("inf", "inf"), ("pers+", "persp"), ("pers-", "persm")]:
    for alpha_idx, alpha in enumerate(alpha_values):
        col_name = f"pragmatic_{short_name}_T_alpha={alpha:.2f}"
        results[col_name] = np.exp(log_P_seq_pragmatic_T[psi][:, alpha_idx])

results_df = pd.DataFrame(results)

print(f"DataFrame shape: {results_df.shape}")
print(f"Columns: {len(results_df.columns)}")
# Expected: 2 (sequence info) + 1 (literal) + 3*20 (F) + 3*20 (T) = 123 columns

# =============================================================================
# Step 7: Verify all probabilities sum to 1
# =============================================================================

print("\n" + "="*70)
print("STEP 7: Verification - All probability columns should sum to 1")
print("="*70)

prob_cols = [c for c in results_df.columns if c not in ['sequence', 'sequence_idx']]
sums = results_df[prob_cols].sum()

print("\nColumn sums (should all be 1.0):")
print(f"  literal: {sums['literal']:.10f}")

for speaker_type in ['inf_F', 'persp_F', 'persm_F', 'inf_T', 'persp_T', 'persm_T']:
    cols = [c for c in prob_cols if speaker_type in c]
    col_sums = sums[cols]
    print(f"  pragmatic_{speaker_type}: min={col_sums.min():.10f}, max={col_sums.max():.10f}")

# =============================================================================
# Step 8: Summary statistics - entropy and concentration
# =============================================================================

print("\n" + "="*70)
print("STEP 8: Summary statistics - Entropy and concentration")
print("="*70)

def compute_entropy(P):
    """Compute entropy -Σ P log P."""
    P_safe = P[P > 0]
    return -np.sum(P_safe * np.log(P_safe))

max_entropy = np.log(n_sequences)
print(f"Max possible entropy (uniform): {max_entropy:.4f}")

# Compute entropy for each speaker type and alpha
summary_rows = []

# Literal
P_literal = np.exp(log_P_seq_literal)
summary_rows.append({
    'speaker_type': 'literal',
    'update_internal': '-',
    'psi': '-',
    'alpha': '-',
    'entropy': compute_entropy(P_literal),
    'top1_prob': np.max(P_literal),
    'top10_prob': np.sort(P_literal)[-10:].sum()
})

# Pragmatic
for update_internal, log_P_dict, label in [
    (False, log_P_seq_pragmatic_F, 'F'),
    (True, log_P_seq_pragmatic_T, 'T')
]:
    for psi in ["inf", "pers+", "pers-"]:
        for alpha_idx, alpha in enumerate(alpha_values):
            P = np.exp(log_P_dict[psi][:, alpha_idx])
            summary_rows.append({
                'speaker_type': f'pragmatic_{label}',
                'update_internal': update_internal,
                'psi': psi,
                'alpha': alpha,
                'entropy': compute_entropy(P),
                'top1_prob': np.max(P),
                'top10_prob': np.sort(P)[-10:].sum()
            })

summary_df = pd.DataFrame(summary_rows)

print("\nLiteral speaker:")
print(summary_df[summary_df['speaker_type'] == 'literal'].to_string(index=False))

print("\nPragmatic speakers - entropy by alpha (selected alphas):")
selected_alphas = [1.0, 2.073, 5.482, 14.48, 38.28, 100.0]
for psi in ["inf", "pers+", "pers-"]:
    print(f"\n  psi='{psi}':")
    subset = summary_df[
        (summary_df['psi'] == psi) & 
        (summary_df['alpha'].isin(selected_alphas))
    ][['update_internal', 'alpha', 'entropy', 'top1_prob', 'top10_prob']]
    print(subset.to_string(index=False))

# =============================================================================
# Step 9: Compare top sequences across speaker types
# =============================================================================

print("\n" + "="*70)
print("STEP 9: Top 10 sequences for each speaker type (alpha=5.48)")
print("="*70)

alpha_demo = 5.482
alpha_str = f"{alpha_demo:.2f}"

print("\n--- LITERAL ---")
top_literal = results_df.nlargest(10, 'literal')[['sequence', 'literal']]
print(top_literal.to_string(index=False))

for psi, short in [("inf", "inf"), ("pers+", "persp"), ("pers-", "persm")]:
    print(f"\n--- PRAGMATIC psi='{psi}' update_internal=False ---")
    col = f"pragmatic_{short}_F_alpha={alpha_str}"
    print(results_df.nlargest(10, col)[['sequence', col]].to_string(index=False))
    
    print(f"\n--- PRAGMATIC psi='{psi}' update_internal=True ---")
    col = f"pragmatic_{short}_T_alpha={alpha_str}"
    print(results_df.nlargest(10, col)[['sequence', col]].to_string(index=False))

# =============================================================================
# Step 10: Save results
# =============================================================================

print("\n" + "="*70)
print("STEP 10: Saving results")
print("="*70)

STEP 1: Setting up the world
Utterances (8): ['all,successful', 'all,unsuccessful', 'most,successful', 'most,unsuccessful', 'some,successful', 'some,unsuccessful', 'no,successful', 'no,unsuccessful']
Theta values (11): [np.float64(0.0), np.float64(0.1), np.float64(0.2), np.float64(0.3), np.float64(0.4), np.float64(0.5), np.float64(0.6), np.float64(0.7), np.float64(0.8), np.float64(0.9), np.float64(1.0)]
Rounds: 5
Total sequences: 32768
Alpha values: 20 values from 1.0 to 100.0

STEP 2: Defining helper functions
Helper functions defined.

STEP 3: Computing for LITERAL SPEAKER
Sum P(seq) for literal: 1.0000000000

STEP 4: Computing for PRAGMATIC SPEAKERS (update_internal=False)

  Processing psi='inf' (omega='coop')...
    alpha=1.00: Sum P(seq) = 1.0000000000
    alpha=100.00: Sum P(seq) = 1.0000000000

  Processing psi='pers+' (omega='strat')...
    alpha=1.00: Sum P(seq) = 1.0000000000
    alpha=100.00: Sum P(seq) = 1.0000000000

  Processing psi='pers-' (omega='strat')...
    alpha=1

OSError: Cannot save file into a non-existent directory: '/home/claude'

In [5]:
# Save main results
results_df.to_csv('./sequence_probabilities.csv', index=False)
print("Saved: ./sequence_probabilities.csv")

# Save summary
summary_df.to_csv('./sequence_summary.csv', index=False)
print("Saved: ./sequence_summary.csv")

# Also save in a more compact format (log probabilities)
log_results = {
    'sequence': sequence_labels,
    'sequence_idx': all_sequences,
    'log_literal': log_P_seq_literal
}

for psi, short_name in [("inf", "inf"), ("pers+", "persp"), ("pers-", "persm")]:
    for alpha_idx, alpha in enumerate(alpha_values):
        log_results[f"log_pragmatic_{short_name}_F_alpha={alpha:.2f}"] = log_P_seq_pragmatic_F[psi][:, alpha_idx]
        log_results[f"log_pragmatic_{short_name}_T_alpha={alpha:.2f}"] = log_P_seq_pragmatic_T[psi][:, alpha_idx]

log_results_df = pd.DataFrame(log_results)
log_results_df.to_csv('./sequence_log_probabilities.csv', index=False)
print("Saved: ./sequence_log_probabilities.csv")

print("\n" + "="*70)
print("COMPLETE!")
print("="*70)

Saved: ./sequence_probabilities.csv
Saved: ./sequence_summary.csv
Saved: ./sequence_log_probabilities.csv

COMPLETE!


In [7]:
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, kendalltau
import itertools

# =============================================================================
# Step 1: Assuming we have results_df from earlier computation
# Let's first check its structure
# =============================================================================

print("="*70)
print("STEP 1: Checking results_df structure")
print("="*70)

# List all columns
print(f"Total columns: {len(results_df.columns)}")
print(f"\nColumn names sample:")
print([c for c in results_df.columns if 'inf_T' in c][:5])

# Extract alpha values from column names
alpha_values = [1.000, 1.275, 1.626, 2.073, 2.643, 3.371, 4.299, 5.482, 
                6.988, 8.909, 11.36, 14.48, 18.46, 23.54, 30.02, 38.28, 
                48.80, 62.23, 79.32, 100.0]

# =============================================================================
# Step 2: Define function to analyze rank stability
# =============================================================================

print("\n" + "="*70)
print("STEP 2: Analyzing rank stability across alpha values")
print("="*70)

def analyze_rank_stability(results_df, speaker_type, alpha_values):
    """
    Analyze how stable the ranking of sequences is across different alpha values.
    
    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with sequence probabilities
    speaker_type : str
        e.g., 'inf_T', 'inf_F', 'persp_T', etc.
    alpha_values : list
        List of alpha values
    
    Returns
    -------
    dict with stability metrics
    """
    # Get columns for this speaker type
    cols = [f"pragmatic_{speaker_type}_alpha={a:.2f}" for a in alpha_values]
    
    # Check all columns exist
    missing = [c for c in cols if c not in results_df.columns]
    if missing:
        print(f"Warning: Missing columns: {missing[:3]}...")
        # Try alternative format
        cols = [f"pragmatic_{speaker_type}_alpha={a}" for a in alpha_values]
    
    # Extract probability matrix: (n_sequences, n_alphas)
    prob_matrix = results_df[cols].values
    
    # Compute ranks for each alpha (rank 1 = highest probability)
    rank_matrix = np.zeros_like(prob_matrix)
    for j in range(len(alpha_values)):
        rank_matrix[:, j] = (-prob_matrix[:, j]).argsort().argsort() + 1
    
    # Spearman correlation between all pairs of alphas
    n_alphas = len(alpha_values)
    spearman_corr = np.zeros((n_alphas, n_alphas))
    for i in range(n_alphas):
        for j in range(n_alphas):
            spearman_corr[i, j], _ = spearmanr(rank_matrix[:, i], rank_matrix[:, j])
    
    # Track the top-k sequences for each alpha
    top_k_values = [1, 5, 10, 50, 100]
    top_k_stability = {}
    
    for k in top_k_values:
        # For each alpha, get the top-k sequence indices
        top_k_sets = []
        for j in range(n_alphas):
            top_k_idx = np.argsort(-prob_matrix[:, j])[:k]
            top_k_sets.append(set(top_k_idx))
        
        # Compute Jaccard similarity between consecutive alphas
        jaccard_consecutive = []
        for j in range(n_alphas - 1):
            intersection = len(top_k_sets[j] & top_k_sets[j+1])
            union = len(top_k_sets[j] | top_k_sets[j+1])
            jaccard_consecutive.append(intersection / union if union > 0 else 0)
        
        # Compute Jaccard between first and last alpha
        jaccard_first_last = len(top_k_sets[0] & top_k_sets[-1]) / len(top_k_sets[0] | top_k_sets[-1])
        
        # How many sequences are in top-k for ALL alphas?
        common_all = set.intersection(*top_k_sets)
        
        top_k_stability[k] = {
            'jaccard_consecutive_mean': np.mean(jaccard_consecutive),
            'jaccard_consecutive_min': np.min(jaccard_consecutive),
            'jaccard_first_last': jaccard_first_last,
            'common_all_count': len(common_all),
            'common_all_sequences': common_all
        }
    
    # Track rank of the #1 sequence (at alpha_min) across all alphas
    top1_at_alpha_min = np.argmax(prob_matrix[:, 0])
    top1_ranks_across_alpha = rank_matrix[top1_at_alpha_min, :]
    
    # Track rank of the #1 sequence (at alpha_max) across all alphas
    top1_at_alpha_max = np.argmax(prob_matrix[:, -1])
    top1_max_ranks_across_alpha = rank_matrix[top1_at_alpha_max, :]
    
    return {
        'spearman_corr': spearman_corr,
        'top_k_stability': top_k_stability,
        'top1_at_alpha_min_idx': top1_at_alpha_min,
        'top1_at_alpha_min_ranks': top1_ranks_across_alpha,
        'top1_at_alpha_max_idx': top1_at_alpha_max,
        'top1_at_alpha_max_ranks': top1_max_ranks_across_alpha,
        'prob_matrix': prob_matrix,
        'rank_matrix': rank_matrix
    }

# =============================================================================
# Step 3: Analyze each speaker type
# =============================================================================

print("\n" + "="*70)
print("STEP 3: Results for each speaker type")
print("="*70)

speaker_types = ['inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']

all_results = {}

for speaker_type in speaker_types:
    print(f"\n{'='*60}")
    print(f"Speaker type: {speaker_type}")
    print(f"{'='*60}")
    
    results = analyze_rank_stability(results_df, speaker_type, alpha_values)
    all_results[speaker_type] = results
    
    # Print Spearman correlation summary
    spearman = results['spearman_corr']
    print(f"\nSpearman correlation matrix (selected alphas):")
    selected_idx = [0, 4, 9, 14, 19]  # alpha = 1, 2.64, 8.91, 30.0, 100
    selected_alphas = [alpha_values[i] for i in selected_idx]
    print(f"Alphas: {selected_alphas}")
    for i in selected_idx:
        row = [f"{spearman[i, j]:.3f}" for j in selected_idx]
        print(f"  α={alpha_values[i]:>6.2f}: {row}")
    
    # Print top-k stability
    print(f"\nTop-k stability:")
    print(f"  {'k':<6} {'Jaccard(consec)':<18} {'Jaccard(1st-last)':<18} {'Common to ALL':<15}")
    for k in [1, 5, 10, 50, 100]:
        stats = results['top_k_stability'][k]
        print(f"  {k:<6} {stats['jaccard_consecutive_mean']:.3f} (min:{stats['jaccard_consecutive_min']:.3f})   "
              f"{stats['jaccard_first_last']:.3f}              {stats['common_all_count']:<15}")
    
    # Print rank trajectory of top sequence
    print(f"\nRank trajectory of top-1 sequence at α=1.0:")
    print(f"  Sequence index: {results['top1_at_alpha_min_idx']}")
    print(f"  Sequence: {results_df.loc[results['top1_at_alpha_min_idx'], 'sequence']}")
    print(f"  Ranks across α: ", end="")
    for i in [0, 4, 9, 14, 19]:
        print(f"α={alpha_values[i]:.1f}→#{int(results['top1_at_alpha_min_ranks'][i])}, ", end="")
    print()
    
    print(f"\nRank trajectory of top-1 sequence at α=100:")
    print(f"  Sequence index: {results['top1_at_alpha_max_idx']}")
    print(f"  Sequence: {results_df.loc[results['top1_at_alpha_max_idx'], 'sequence']}")
    print(f"  Ranks across α: ", end="")
    for i in [0, 4, 9, 14, 19]:
        print(f"α={alpha_values[i]:.1f}→#{int(results['top1_at_alpha_max_ranks'][i])}, ", end="")
    print()

# =============================================================================
# Step 4: Visualize rank stability with a summary table
# =============================================================================

print("\n" + "="*70)
print("STEP 4: Summary comparison across speaker types")
print("="*70)

summary_rows = []
for speaker_type in speaker_types:
    results = all_results[speaker_type]
    spearman = results['spearman_corr']
    
    summary_rows.append({
        'speaker_type': speaker_type,
        'spearman_1_vs_100': spearman[0, -1],
        'spearman_1_vs_10': spearman[0, 9],  # alpha ~8.9
        'spearman_10_vs_100': spearman[9, -1],
        'top1_stable': results['top1_at_alpha_min_idx'] == results['top1_at_alpha_max_idx'],
        'top5_common_all': results['top_k_stability'][5]['common_all_count'],
        'top10_common_all': results['top_k_stability'][10]['common_all_count'],
        'top50_common_all': results['top_k_stability'][50]['common_all_count'],
        'top100_jaccard_1_100': results['top_k_stability'][100]['jaccard_first_last'],
    })

summary_df = pd.DataFrame(summary_rows)
print("\nRank stability summary:")
print(summary_df.to_string(index=False))

# =============================================================================
# Step 5: Check if same sequence is top-1 across all alphas
# =============================================================================

print("\n" + "="*70)
print("STEP 5: Which sequences are consistently top-ranked?")
print("="*70)

for speaker_type in speaker_types:
    results = all_results[speaker_type]
    prob_matrix = results['prob_matrix']
    
    # Find sequence that is #1 most often
    top1_counts = np.zeros(prob_matrix.shape[0])
    for j in range(len(alpha_values)):
        top1_idx = np.argmax(prob_matrix[:, j])
        top1_counts[top1_idx] += 1
    
    # Get sequences that are ever #1
    ever_top1 = np.where(top1_counts > 0)[0]
    
    print(f"\n{speaker_type}:")
    print(f"  Number of sequences that are #1 for at least one alpha: {len(ever_top1)}")
    
    if len(ever_top1) <= 5:
        print(f"  These sequences:")
        for idx in ever_top1:
            seq = results_df.loc[idx, 'sequence']
            count = int(top1_counts[idx])
            print(f"    {seq}: #1 for {count}/{len(alpha_values)} alphas")
    else:
        print(f"  Top sequences by #1 count:")
        for idx in np.argsort(-top1_counts)[:5]:
            if top1_counts[idx] > 0:
                seq = results_df.loc[idx, 'sequence']
                count = int(top1_counts[idx])
                print(f"    {seq}: #1 for {count}/{len(alpha_values)} alphas")

# =============================================================================
# Step 6: Detailed look at rank changes
# =============================================================================

print("\n" + "="*70)
print("STEP 6: Detailed rank change analysis for inf_T")
print("="*70)

results = all_results['inf_T']
prob_matrix = results['prob_matrix']
rank_matrix = results['rank_matrix']

# Find sequences with biggest rank changes
rank_range = np.max(rank_matrix, axis=1) - np.min(rank_matrix, axis=1)
most_volatile = np.argsort(-rank_range)[:10]

print("\nMost volatile sequences (largest rank range across alphas):")
print(f"{'Sequence':<60} {'Min Rank':<10} {'Max Rank':<10} {'Range':<10}")
for idx in most_volatile:
    seq = results_df.loc[idx, 'sequence']
    min_rank = int(np.min(rank_matrix[idx, :]))
    max_rank = int(np.max(rank_matrix[idx, :]))
    rng = int(rank_range[idx])
    print(f"{str(seq):<60} {min_rank:<10} {max_rank:<10} {rng:<10}")

# Find most stable sequences (among top 100 at any alpha)
top100_any = set()
for j in range(len(alpha_values)):
    top100_any.update(np.argsort(-prob_matrix[:, j])[:100])

top100_list = list(top100_any)
rank_range_top100 = rank_range[top100_list]
most_stable_top100 = [top100_list[i] for i in np.argsort(rank_range_top100)[:10]]

print("\nMost stable sequences (among those in top-100 for any alpha):")
print(f"{'Sequence':<60} {'Min Rank':<10} {'Max Rank':<10} {'Range':<10}")
for idx in most_stable_top100:
    seq = results_df.loc[idx, 'sequence']
    min_rank = int(np.min(rank_matrix[idx, :]))
    max_rank = int(np.max(rank_matrix[idx, :]))
    rng = int(rank_range[idx])
    print(f"{str(seq):<60} {min_rank:<10} {max_rank:<10} {rng:<10}")

# =============================================================================
# Step 7: Compare rank stability between T and F versions
# =============================================================================

print("\n" + "="*70)
print("STEP 7: Comparing update_internal=T vs F stability")
print("="*70)

for psi in ['inf', 'persp', 'persm']:
    type_T = f'{psi}_T'
    type_F = f'{psi}_F'
    
    spearman_T = all_results[type_T]['spearman_corr'][0, -1]
    spearman_F = all_results[type_F]['spearman_corr'][0, -1]
    
    top10_T = all_results[type_T]['top_k_stability'][10]['jaccard_first_last']
    top10_F = all_results[type_F]['top_k_stability'][10]['jaccard_first_last']
    
    print(f"\n{psi}:")
    print(f"  Spearman(α=1 vs α=100): T={spearman_T:.3f}, F={spearman_F:.3f}")
    print(f"  Top-10 Jaccard(α=1 vs α=100): T={top10_T:.3f}, F={top10_F:.3f}")
    
    # Are the top sequences the same between T and F?
    top1_T = all_results[type_T]['top1_at_alpha_max_idx']
    top1_F = all_results[type_F]['top1_at_alpha_max_idx']
    seq_T = results_df.loc[top1_T, 'sequence']
    seq_F = results_df.loc[top1_F, 'sequence']
    print(f"  Top-1 at α=100: T={seq_T}, F={seq_F}")
    print(f"  Same? {seq_T == seq_F}")

STEP 1: Checking results_df structure
Total columns: 123

Column names sample:
['pragmatic_inf_T_alpha=1.00', 'pragmatic_inf_T_alpha=1.27', 'pragmatic_inf_T_alpha=1.63', 'pragmatic_inf_T_alpha=2.07', 'pragmatic_inf_T_alpha=2.64']

STEP 2: Analyzing rank stability across alpha values

STEP 3: Results for each speaker type

Speaker type: inf_T

Spearman correlation matrix (selected alphas):
Alphas: [1.0, 2.643, 8.909, 30.02, 100.0]
  α=  1.00: ['1.000', '0.969', '0.628', '-0.020', '-0.230']
  α=  2.64: ['0.969', '1.000', '0.778', '0.184', '-0.029']
  α=  8.91: ['0.628', '0.778', '1.000', '0.724', '0.553']
  α= 30.02: ['-0.020', '0.184', '0.724', '1.000', '0.969']
  α=100.00: ['-0.230', '-0.029', '0.553', '0.969', '1.000']

Top-k stability:
  k      Jaccard(consec)    Jaccard(1st-last)  Common to ALL  
  1      0.421 (min:0.000)   0.000              0              
  5      0.730 (min:0.000)   0.000              0              
  10     0.847 (min:0.000)   0.000              0            

In [8]:
# =============================================================================
# Analyze stability around α ≈ 8.9
# =============================================================================

print("="*70)
print("Analyzing stability around α ≈ 8.9")
print("="*70)

# Find indices for alphas in the range 5-15 (neighborhood of 8.9)
alpha_values = [1.000, 1.275, 1.626, 2.073, 2.643, 3.371, 4.299, 5.482, 
                6.988, 8.909, 11.36, 14.48, 18.46, 23.54, 30.02, 38.28, 
                48.80, 62.23, 79.32, 100.0]

# Indices for α ∈ [5, 15]: indices 7, 8, 9, 10, 11 → alphas 5.48, 6.99, 8.91, 11.36, 14.48
mid_range_indices = [7, 8, 9, 10, 11]
mid_range_alphas = [alpha_values[i] for i in mid_range_indices]
print(f"Mid-range alphas: {mid_range_alphas}")

for speaker_type in ['inf_T', 'inf_F', 'persp_T', 'persm_T']:
    print(f"\n--- {speaker_type} ---")
    
    # Get columns for this speaker type
    cols = [f"pragmatic_{speaker_type}_alpha={alpha_values[i]:.2f}" for i in mid_range_indices]
    prob_matrix = results_df[cols].values
    
    # Compute ranks
    rank_matrix = np.zeros_like(prob_matrix)
    for j in range(len(mid_range_indices)):
        rank_matrix[:, j] = (-prob_matrix[:, j]).argsort().argsort() + 1
    
    # Find sequences in top-k for ALL mid-range alphas
    for k in [1, 5, 10, 20, 50]:
        top_k_sets = []
        for j in range(len(mid_range_indices)):
            top_k_idx = set(np.argsort(-prob_matrix[:, j])[:k])
            top_k_sets.append(top_k_idx)
        
        common = set.intersection(*top_k_sets)
        
        print(f"  k={k}: {len(common)} sequences in top-{k} for ALL mid-range alphas")
        
        if k <= 10 and len(common) > 0:
            for idx in list(common)[:5]:  # Show up to 5
                seq = results_df.loc[idx, 'sequence']
                ranks = [int(rank_matrix[idx, j]) for j in range(len(mid_range_indices))]
                print(f"    {seq}")
                print(f"      Ranks: {dict(zip(mid_range_alphas, ranks))}")

# =============================================================================
# Find the MOST stable sequences in the mid-range
# =============================================================================

print("\n" + "="*70)
print("Most stable sequences in α ∈ [5.5, 14.5] for inf_T")
print("="*70)

cols = [f"pragmatic_inf_T_alpha={alpha_values[i]:.2f}" for i in mid_range_indices]
prob_matrix = results_df[cols].values

# Compute rank range (max - min) across mid-range alphas only
rank_matrix = np.zeros_like(prob_matrix)
for j in range(len(mid_range_indices)):
    rank_matrix[:, j] = (-prob_matrix[:, j]).argsort().argsort() + 1

rank_range_mid = np.max(rank_matrix, axis=1) - np.min(rank_matrix, axis=1)
min_rank_mid = np.min(rank_matrix, axis=1)

# Find sequences that are: (1) in top-100 for at least one mid-range alpha, (2) have small rank range
top100_any_mid = set()
for j in range(len(mid_range_indices)):
    top100_any_mid.update(np.argsort(-prob_matrix[:, j])[:100])

top100_list = list(top100_any_mid)
print(f"Sequences in top-100 for any mid-range alpha: {len(top100_list)}")

# Sort by rank range (most stable first)
stability_data = []
for idx in top100_list:
    stability_data.append({
        'idx': idx,
        'sequence': results_df.loc[idx, 'sequence'],
        'min_rank': int(np.min(rank_matrix[idx, :])),
        'max_rank': int(np.max(rank_matrix[idx, :])),
        'rank_range': int(rank_range_mid[idx]),
        'mean_rank': np.mean(rank_matrix[idx, :])
    })

stability_df = pd.DataFrame(stability_data)
stability_df = stability_df.sort_values(['rank_range', 'mean_rank'])

print("\nTop 20 most stable sequences (smallest rank range in mid-range alphas):")
print(f"{'Sequence':<70} {'Min':<6} {'Max':<6} {'Range':<6}")
for _, row in stability_df.head(20).iterrows():
    print(f"{str(row['sequence']):<70} {row['min_rank']:<6} {row['max_rank']:<6} {row['rank_range']:<6}")

Analyzing stability around α ≈ 8.9
Mid-range alphas: [5.482, 6.988, 8.909, 11.36, 14.48]

--- inf_T ---
  k=1: 0 sequences in top-1 for ALL mid-range alphas
  k=5: 4 sequences in top-5 for ALL mid-range alphas
    ('most,successful', 'most,unsuccessful', 'most,successful', 'most,successful', 'most,successful')
      Ranks: {5.482: 2, 6.988: 2, 8.909: 1, 11.36: 2, 14.48: 1}
    ('most,unsuccessful', 'most,successful', 'most,unsuccessful', 'most,unsuccessful', 'most,unsuccessful')
      Ranks: {5.482: 1, 6.988: 1, 8.909: 2, 11.36: 1, 14.48: 2}
    ('most,successful', 'most,unsuccessful', 'most,unsuccessful', 'most,unsuccessful', 'most,unsuccessful')
      Ranks: {5.482: 4, 6.988: 4, 8.909: 3, 11.36: 4, 14.48: 4}
    ('most,unsuccessful', 'most,successful', 'most,successful', 'most,successful', 'most,successful')
      Ranks: {5.482: 3, 6.988: 3, 8.909: 4, 11.36: 3, 14.48: 3}
  k=10: 4 sequences in top-10 for ALL mid-range alphas
    ('most,successful', 'most,unsuccessful', 'most,successf

In [None]:
import numpy as np
import pandas as pd
import itertools
from scipy.special import logsumexp

# =============================================================================
# Step 1: Setup
# =============================================================================

print("="*70)
print("STEP 1: Setup")
print("="*70)

world = World(n=1, m=5)

utterances = world.utterances
theta_values = world.theta_values
n_utterances = len(utterances)
n_theta = len(theta_values)
n_rounds = 5
n_sequences = n_utterances ** n_rounds

# P(O|θ) matrix - rows are observations, columns are theta values
log_P_O_given_theta = world.obs_log_likelihood_theta.values  # (n_obs, n_theta)

# All sequences as tuples of utterance indices
all_sequences = list(itertools.product(range(n_utterances), repeat=n_rounds))
sequence_labels = [tuple(utterances[i] for i in seq) for seq in all_sequences]

# Alpha values
alpha_values = [1.000, 1.275, 1.626, 2.073, 2.643, 3.371, 4.299, 5.482, 
                6.988, 8.909, 11.36, 14.48, 18.46, 23.54, 30.02, 38.28, 
                48.80, 62.23, 79.32, 100.0]
n_alpha = len(alpha_values)

print(f"Utterances ({n_utterances}): {utterances}")
print(f"Theta values ({n_theta}): {list(theta_values)}")
print(f"Alpha values: {n_alpha} values from {alpha_values[0]} to {alpha_values[-1]}")
print(f"Sequences: {n_sequences}")
print(f"Rounds: {n_rounds}")

# =============================================================================
# Step 2: Helper function
# =============================================================================

print("\n" + "="*70)
print("STEP 2: Define helper function")
print("="*70)

def get_log_P_u_given_theta(speaker, log_P_O_given_theta):
    """
    Marginalize P(u|O) over O to get P(u|θ).
    
    P(u|θ) = Σ_O P(u|O) P(O|θ)
    
    In log space: log P(u|θ) = logsumexp_O [log P(u|O) + log P(O|θ)]
    
    Parameters
    ----------
    speaker : PragmaticSpeaker_obs
        Speaker with utterance_log_prob_obs attribute (shape: n_utt × n_obs)
    log_P_O_given_theta : np.ndarray
        Log P(O|θ) matrix (shape: n_obs × n_theta)
    
    Returns
    -------
    np.ndarray
        Log P(u|θ) matrix (shape: n_utt × n_theta)
    """
    log_P_u_given_O = speaker.utterance_log_prob_obs.values  # (n_utt, n_obs)
    # Matrix multiply in log space: (n_utt, n_obs) @ (n_obs, n_theta) -> (n_utt, n_theta)
    return log_M_product(log_P_u_given_O, log_P_O_given_theta, precise=USE_PRECISE_LOGSPACE)

print("Helper function defined: get_log_P_u_given_theta(speaker, log_P_O_given_theta)")

# =============================================================================
# Step 3: Storage and computation
# =============================================================================

print("\n" + "="*70)
print("STEP 3: Computing log P(seq | θ) matrices for all speakers")
print("="*70)

# Storage: (speaker_type, alpha) -> np.array of shape (n_sequences, n_theta)
# For literal speaker, alpha is None
log_P_seq_given_theta = {}

# -----------------------------------------------------------------------------
# LITERAL SPEAKER
# -----------------------------------------------------------------------------
print("\nLiteral speaker...")

# LiteralListener already computes P(u|θ) internally via marginalizing over O
literal_listener = LiteralListener(world)
log_P_u_given_theta_literal = literal_listener.utterance_log_likelihood_theta.values  # (n_utt, n_theta)

# Verify shape
assert log_P_u_given_theta_literal.shape == (n_utterances, n_theta), \
    f"Expected shape ({n_utterances}, {n_theta}), got {log_P_u_given_theta_literal.shape}"

# Compute P(seq|θ) - utterances are conditionally independent given θ
# log P(seq|θ) = Σ_r log P(u_r|θ)
log_P_seq_theta_literal = np.zeros((n_sequences, n_theta))
for seq_idx, seq in enumerate(all_sequences):
    for r in range(n_rounds):
        u_idx = seq[r]
        log_P_seq_theta_literal[seq_idx, :] += log_P_u_given_theta_literal[u_idx, :]

log_P_seq_given_theta[('literal', None)] = log_P_seq_theta_literal
print(f"  Shape: {log_P_seq_theta_literal.shape}")
print("  Done.")

# -----------------------------------------------------------------------------
# PRAGMATIC SPEAKERS with update_internal=False
# -----------------------------------------------------------------------------
print("\nPragmatic speakers (update_internal=False)...")

psi_map_F = {
    'inf_F': 'inf',
    'persp_F': 'pers+',
    'persm_F': 'pers-'
}

for speaker_type, psi in psi_map_F.items():
    omega = 'coop' if psi == 'inf' else 'strat'
    print(f"\n  {speaker_type} (psi={psi}, omega={omega})...")
    
    for alpha_idx, alpha in enumerate(alpha_values):
        # Create pragmatic speaker
        speaker = PragmaticSpeaker_obs(
            world=world,
            omega=omega,
            psi=psi,
            update_internal=False,
            alpha=alpha,
            beta=0.0
        )
        
        # Get log P(u|θ) by marginalizing P(u|O) over O
        log_P_u_given_theta = get_log_P_u_given_theta(speaker, log_P_O_given_theta)
        
        # Verify shape
        assert log_P_u_given_theta.shape == (n_utterances, n_theta), \
            f"Expected shape ({n_utterances}, {n_theta}), got {log_P_u_given_theta.shape}"
        
        # Compute P(seq|θ) - conditionally independent given θ
        log_P_seq_theta = np.zeros((n_sequences, n_theta))
        for seq_idx, seq in enumerate(all_sequences):
            for r in range(n_rounds):
                u_idx = seq[r]
                log_P_seq_theta[seq_idx, :] += log_P_u_given_theta[u_idx, :]
        
        log_P_seq_given_theta[(speaker_type, alpha)] = log_P_seq_theta
        
        if alpha_idx % 5 == 0:
            print(f"    alpha={alpha:.3f} done")
    
    print(f"  {speaker_type} complete: {n_alpha} matrices stored")

# -----------------------------------------------------------------------------
# PRAGMATIC SPEAKERS with update_internal=True
# -----------------------------------------------------------------------------
print("\nPragmatic speakers (update_internal=True)...")

psi_map_T = {
    'inf_T': 'inf',
    'persp_T': 'pers+',
    'persm_T': 'pers-'
}

for speaker_type, psi in psi_map_T.items():
    omega = 'coop' if psi == 'inf' else 'strat'
    print(f"\n  {speaker_type} (psi={psi}, omega={omega})...")
    
    for alpha_idx, alpha in enumerate(alpha_values):
        
        # Precompute log P(u|θ) for all possible histories
        # History = tuple of utterance indices for previous rounds
        # We need histories of length 0, 1, 2, 3, 4
        history_log_P_u_given_theta = {}
        
        for hist_len in range(n_rounds):
            if hist_len == 0:
                histories = [()]
            else:
                histories = list(itertools.product(range(n_utterances), repeat=hist_len))
            
            for history in histories:
                # Convert history indices to utterance strings
                history_utterances = [utterances[i] for i in history]
                
                # Create a fresh speaker
                speaker = PragmaticSpeaker_obs(
                    world=world,
                    omega=omega,
                    psi=psi,
                    update_internal=True,
                    alpha=alpha,
                    beta=0.0
                )
                
                # Apply history: update internal listener for each past utterance
                for u in history_utterances:
                    speaker.literal_listener.listen_and_update(u)
                    speaker.utterance_log_prob_obs = speaker._compute_utterance_log_prob_obs(alpha)
                
                # Now get P(u|θ) for this history state
                history_log_P_u_given_theta[history] = get_log_P_u_given_theta(speaker, log_P_O_given_theta)
        
        # Compute P(seq|θ) using history-dependent probabilities
        # log P(seq|θ) = Σ_r log P^{history_r}(u_r|θ)
        log_P_seq_theta = np.zeros((n_sequences, n_theta))
        for seq_idx, seq in enumerate(all_sequences):
            for r in range(n_rounds):
                history = seq[:r]  # Utterances before round r
                u_idx = seq[r]     # Utterance at round r
                log_P_seq_theta[seq_idx, :] += history_log_P_u_given_theta[history][u_idx, :]
        
        log_P_seq_given_theta[(speaker_type, alpha)] = log_P_seq_theta
        
        if alpha_idx % 5 == 0:
            print(f"    alpha={alpha:.3f} done")
    
    print(f"  {speaker_type} complete: {n_alpha} matrices stored")

print(f"\nTotal matrices stored: {len(log_P_seq_given_theta)}")
print(f"  Expected: 1 (literal) + 3×{n_alpha} (F) + 3×{n_alpha} (T) = {1 + 6*n_alpha}")

# =============================================================================
# Step 4: Verification
# =============================================================================

print("\n" + "="*70)
print("STEP 4: Verification - each P(seq|θ) should sum to 1 over sequences")
print("="*70)

all_valid = True
for key, log_P in log_P_seq_given_theta.items():
    P = np.exp(log_P)
    sums = P.sum(axis=0)  # Sum over sequences for each theta
    
    if not np.allclose(sums, 1.0, atol=1e-5):
        print(f"  WARNING {key}: sums range [{sums.min():.6f}, {sums.max():.6f}]")
        all_valid = False

if all_valid:
    print("All matrices verified: Σ_seq P(seq|θ) = 1.0 for each θ ✓")

# =============================================================================
# Step 5: Define marginalization helper functions
# =============================================================================

print("\n" + "="*70)
print("STEP 5: Define marginalization helpers")
print("="*70)

def get_log_P_seq_given_theta_matrix(speaker_type, alpha=None):
    """
    Get the full log P(seq|θ) matrix for a speaker.
    
    Returns array of shape (n_sequences, n_theta).
    """
    if speaker_type == 'literal':
        key = ('literal', None)
    else:
        key = (speaker_type, alpha)
    
    if key not in log_P_seq_given_theta:
        raise KeyError(f"No matrix found for {key}")
    
    return log_P_seq_given_theta[key]


def get_P_seq_for_theta(speaker_type, theta, alpha=None):
    """
    Get P(seq | speaker, alpha, θ) for a specific theta value.
    
    Returns array of shape (n_sequences,).
    """
    log_P = get_log_P_seq_given_theta_matrix(speaker_type, alpha)
    theta_idx = np.where(np.isclose(theta_values, theta))[0]
    
    if len(theta_idx) == 0:
        raise ValueError(f"Theta {theta} not found in theta_values: {list(theta_values)}")
    
    return np.exp(log_P[:, theta_idx[0]])


def get_P_seq_marginalized(speaker_type, alpha=None, log_prior_theta=None):
    """
    Get P(seq | speaker, alpha) marginalized over θ.
    
    P(seq) = Σ_θ P(seq|θ) P(θ)
    
    If log_prior_theta is None, uses flat (uniform) prior over theta.
    
    Returns array of shape (n_sequences,).
    """
    log_P = get_log_P_seq_given_theta_matrix(speaker_type, alpha)
    
    if log_prior_theta is None:
        log_prior_theta = np.full(n_theta, -np.log(n_theta))
    
    # log P(seq) = logsumexp_θ [log P(seq|θ) + log P(θ)]
    log_P_seq = logsumexp(log_P + log_prior_theta, axis=1)
    return np.exp(log_P_seq)

print("Helper functions defined:")
print("  - get_log_P_seq_given_theta_matrix(speaker_type, alpha)")
print("  - get_P_seq_for_theta(speaker_type, theta, alpha)")
print("  - get_P_seq_marginalized(speaker_type, alpha, log_prior_theta)")

# =============================================================================
# Step 6: Build comprehensive DataFrame
# =============================================================================

print("\n" + "="*70)
print("STEP 6: Build comprehensive DataFrame")
print("="*70)

# Collect all columns in a dictionary first (avoids fragmentation warning)
all_columns = {}

# Sequence identifiers
all_columns['sequence'] = sequence_labels
all_columns['sequence_idx'] = all_sequences

# Literal columns: one per theta
for theta_idx, theta in enumerate(theta_values):
    col_name = f"literal_theta={theta:.1f}"
    all_columns[col_name] = np.exp(log_P_seq_given_theta[('literal', None)][:, theta_idx])

# Pragmatic columns: speaker × alpha × theta
for speaker_type in ['inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']:
    for alpha in alpha_values:
        log_P = log_P_seq_given_theta[(speaker_type, alpha)]
        for theta_idx, theta in enumerate(theta_values):
            col_name = f"{speaker_type}_alpha={alpha:.2f}_theta={theta:.1f}"
            all_columns[col_name] = np.exp(log_P[:, theta_idx])

# Build DataFrame in one shot
results = pd.DataFrame(all_columns)

print(f"DataFrame shape: {results.shape}")
print(f"  Rows (sequences): {results.shape[0]}")
print(f"  Columns: {results.shape[1]}")
print(f"    - 2 identifier columns")
print(f"    - {n_theta} literal columns")
print(f"    - 6 × {n_alpha} × {n_theta} = {6 * n_alpha * n_theta} pragmatic columns")

# =============================================================================
# Step 7: Verify DataFrame
# =============================================================================

print("\n" + "="*70)
print("STEP 7: Verify DataFrame columns sum to 1")
print("="*70)

prob_cols = [c for c in results.columns if c not in ['sequence', 'sequence_idx']]
print(f"Checking {len(prob_cols)} probability columns...")

bad_cols = []
for col in prob_cols:
    total = results[col].sum()
    if not np.isclose(total, 1.0, atol=1e-5):
        bad_cols.append((col, total))

if bad_cols:
    print(f"  WARNING: {len(bad_cols)} columns don't sum to 1:")
    for col, total in bad_cols[:5]:
        print(f"    {col}: {total:.6f}")
else:
    print("  All probability columns sum to 1.0 ✓")


STEP 1: Setup
Utterances (8): ['all,successful', 'all,unsuccessful', 'most,successful', 'most,unsuccessful', 'some,successful', 'some,unsuccessful', 'no,successful', 'no,unsuccessful']
Theta values (11): [np.float64(0.0), np.float64(0.1), np.float64(0.2), np.float64(0.3), np.float64(0.4), np.float64(0.5), np.float64(0.6), np.float64(0.7), np.float64(0.8), np.float64(0.9), np.float64(1.0)]
Alpha values: 20 values from 1.0 to 100.0
Sequences: 32768
Rounds: 5

STEP 2: Define helper function
Helper function defined: get_log_P_u_given_theta(speaker, log_P_O_given_theta)

STEP 3: Computing log P(seq | θ) matrices for all speakers

Literal speaker...
  Shape: (32768, 11)
  Done.

Pragmatic speakers (update_internal=False)...

  inf_F (psi=inf, omega=coop)...
    alpha=1.000 done
    alpha=3.371 done
    alpha=11.360 done
    alpha=38.280 done
  inf_F complete: 20 matrices stored

  persp_F (psi=pers+, omega=strat)...
    alpha=1.000 done
    alpha=3.371 done
    alpha=11.360 done
    alpha=38

OSError: Cannot save file into a non-existent directory: '/home/claude'

In [17]:
# =============================================================================
# Step 8: Save results
# =============================================================================

print("\n" + "="*70)
print("STEP 8: Save results")
print("="*70)

# Save DataFrame as CSV
csv_path = 'P_seq_given_speaker_alpha_theta.csv'
results.to_csv(csv_path, index=False)
print(f"Saved DataFrame: {csv_path}")

# Save raw log matrices as NPZ (preserves precision for later use)
npz_data = {
    'theta_values': theta_values,
    'alpha_values': np.array(alpha_values),
    'sequences': np.array(all_sequences),
    'utterances': np.array(utterances)
}

# Add each log P(seq|θ) matrix
for (speaker_type, alpha), log_P in log_P_seq_given_theta.items():
    if alpha is None:
        key = f"{speaker_type}"
    else:
        key = f"{speaker_type}_alpha={alpha:.2f}"
    npz_data[key] = log_P

npz_path = 'log_P_seq_given_theta_matrices.npz'
np.savez(npz_path, **npz_data)  # Use np.savez instead of np.savez_compressed
print(f"Saved log matrices: {npz_path}")

# =============================================================================
# Step 9: Usage examples
# =============================================================================

print("\n" + "="*70)
print("STEP 9: Usage examples")
print("="*70)

# Example 1: P(seq | inf_T, alpha=8.909, theta=0.5)
print("\nExample 1: P(seq | inf_T, α=8.909, θ=0.5)")
P = get_P_seq_for_theta('inf_T', 0.5, alpha=8.909)
top_idx = np.argmax(P)
print(f"  Top sequence: {sequence_labels[top_idx]}")
print(f"  Probability: {P[top_idx]:.6f}")

# Example 2: P(seq | inf_T, alpha=8.909) with flat prior over theta
print("\nExample 2: P(seq | inf_T, α=8.909) marginalized over θ [flat prior]")
P = get_P_seq_marginalized('inf_T', alpha=8.909)
top_idx = np.argmax(P)
print(f"  Top sequence: {sequence_labels[top_idx]}")
print(f"  Probability: {P[top_idx]:.6f}")

# Example 3: P(seq | literal) with flat prior
print("\nExample 3: P(seq | literal) marginalized over θ [flat prior]")
P = get_P_seq_marginalized('literal')
top_idx = np.argmax(P)
print(f"  Top sequence: {sequence_labels[top_idx]}")
print(f"  Probability: {P[top_idx]:.6f}")

# Example 4: Top sequence for inf_T at each theta
print("\nExample 4: Top sequence for inf_T (α=8.909) at each θ:")
for theta in theta_values:
    P = get_P_seq_for_theta('inf_T', theta, alpha=8.909)
    top_idx = np.argmax(P)
    seq_abbrev = ",".join([u.split(",")[0][:2] for u in sequence_labels[top_idx]])
    print(f"  θ={theta:.1f}: {seq_abbrev:<20} (P={P[top_idx]:.4f})")

# Example 5: Compare marginals across speaker types
print("\nExample 5: Top sequence for each speaker type (α=8.909, flat prior):")
for speaker_type in ['literal', 'inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']:
    alpha = 8.909 if speaker_type != 'literal' else None
    P = get_P_seq_marginalized(speaker_type, alpha=alpha)
    top_idx = np.argmax(P)
    seq_str = str(sequence_labels[top_idx])
    if len(seq_str) > 60:
        seq_str = seq_str[:57] + "..."
    print(f"  {speaker_type:<10}: {seq_str:<65} (P={P[top_idx]:.4f})")

print("\n" + "="*70)
print("COMPLETE!")
print("="*70)
print(f"\nStored objects:")
print(f"  - log_P_seq_given_theta: dict with {len(log_P_seq_given_theta)} matrices")
print(f"  - results: DataFrame with shape {results.shape}")
print(f"\nFiles saved:")
print(f"  - {csv_path}")
print(f"  - {npz_path}")

# =============================================================================
# Step 9: Usage examples
# =============================================================================

print("\n" + "="*70)
print("STEP 9: Usage examples")
print("="*70)

# Example 1: P(seq | inf_T, alpha=8.909, theta=0.5)
print("\nExample 1: P(seq | inf_T, α=8.909, θ=0.5)")
P = get_P_seq_for_theta('inf_T', 0.5, alpha=8.909)
top_idx = np.argmax(P)
print(f"  Top sequence: {sequence_labels[top_idx]}")
print(f"  Probability: {P[top_idx]:.6f}")

# Example 2: P(seq | inf_T, alpha=8.909) with flat prior over theta
print("\nExample 2: P(seq | inf_T, α=8.909) marginalized over θ [flat prior]")
P = get_P_seq_marginalized('inf_T', alpha=8.909)
top_idx = np.argmax(P)
print(f"  Top sequence: {sequence_labels[top_idx]}")
print(f"  Probability: {P[top_idx]:.6f}")

# Example 3: P(seq | literal) with flat prior
print("\nExample 3: P(seq | literal) marginalized over θ [flat prior]")
P = get_P_seq_marginalized('literal')
top_idx = np.argmax(P)
print(f"  Top sequence: {sequence_labels[top_idx]}")
print(f"  Probability: {P[top_idx]:.6f}")

# Example 4: Top sequence for inf_T at each theta
print("\nExample 4: Top sequence for inf_T (α=8.909) at each θ:")
for theta in theta_values:
    P = get_P_seq_for_theta('inf_T', theta, alpha=8.909)
    top_idx = np.argmax(P)
    seq_abbrev = ",".join([u.split(",")[0][:2] for u in sequence_labels[top_idx]])
    print(f"  θ={theta:.1f}: {seq_abbrev:<20} (P={P[top_idx]:.4f})")

# Example 5: Compare marginals across speaker types
print("\nExample 5: Top sequence for each speaker type (α=8.909, flat prior):")
for speaker_type in ['literal', 'inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']:
    alpha = 8.909 if speaker_type != 'literal' else None
    P = get_P_seq_marginalized(speaker_type, alpha=alpha)
    top_idx = np.argmax(P)
    seq_str = str(sequence_labels[top_idx])
    if len(seq_str) > 60:
        seq_str = seq_str[:57] + "..."
    print(f"  {speaker_type:<10}: {seq_str:<65} (P={P[top_idx]:.4f})")

print("\n" + "="*70)
print("COMPLETE!")
print("="*70)
print(f"\nStored objects:")
print(f"  - log_P_seq_given_theta: dict with {len(log_P_seq_given_theta)} matrices")
print(f"  - results: DataFrame with shape {results.shape}")
print(f"\nFiles saved:")
print(f"  - {csv_path}")
print(f"  - {npz_path}")


STEP 8: Save results
Saved DataFrame: P_seq_given_speaker_alpha_theta.csv
Saved log matrices: log_P_seq_given_theta_matrices.npz

STEP 9: Usage examples

Example 1: P(seq | inf_T, α=8.909, θ=0.5)
  Top sequence: ('most,successful', 'most,unsuccessful', 'most,successful', 'most,unsuccessful', 'most,successful')
  Probability: 0.021869

Example 2: P(seq | inf_T, α=8.909) marginalized over θ [flat prior]
  Top sequence: ('most,successful', 'most,unsuccessful', 'most,successful', 'most,successful', 'most,successful')
  Probability: 0.008243

Example 3: P(seq | literal) marginalized over θ [flat prior]
  Top sequence: ('some,successful', 'some,successful', 'some,successful', 'some,successful', 'some,successful')
  Probability: 0.001877

Example 4: Top sequence for inf_T (α=8.909) at each θ:
  θ=0.0: al,al,al,al,al       (P=0.0131)
  θ=0.1: mo,al,so,al,so       (P=0.0053)
  θ=0.2: mo,so,mo,mo,so       (P=0.0143)
  θ=0.3: mo,mo,mo,mo,mo       (P=0.0265)
  θ=0.4: mo,mo,mo,mo,mo       (P=0.033

In [19]:
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import itertools

# =============================================================================
# REVISED ANALYSIS: Stability within theta, across T/F
# =============================================================================

print("="*70)
print("REVISED RANK STABILITY ANALYSIS")
print("="*70)

print("""
Key insight:
- Theta = different true world states (not comparable)
- We analyze stability WITHIN each theta
- We want T and F variants of same psi to agree (experimental robustness)
""")

# -----------------------------------------------------------------------------
# Part 1: Helper functions
# -----------------------------------------------------------------------------

def compute_ranks(P):
    """Convert probabilities to ranks (1 = highest probability)."""
    return (-P).argsort().argsort() + 1

def get_top_k_set(P, k):
    """Get set of indices in top-k by probability."""
    return set(np.argsort(-P)[:k])

def jaccard(set1, set2):
    """Compute Jaccard similarity between two sets."""
    if len(set1) == 0 and len(set2) == 0:
        return 1.0
    return len(set1 & set2) / len(set1 | set2)

# -----------------------------------------------------------------------------
# Part 2: For each THETA, analyze rank stability across ALPHA
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 2: Rank stability across ALPHA (within each theta)")
print("="*70)

# Selected alphas for comparison
alpha_pairs = [
    (1.0, 5.482),
    (5.482, 8.909),
    (8.909, 14.48),
    (14.48, 100.0),
    (5.482, 14.48),  # Mid-range comparison
]

for speaker_type in ['inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']:
    print(f"\n--- {speaker_type} ---")
    print(f"{'θ':<6}", end="")
    for a1, a2 in alpha_pairs:
        print(f"ρ({a1:.1f},{a2:.1f})".center(14), end="")
    print()
    print("-" * 80)
    
    for theta in theta_values:
        print(f"{theta:<6.1f}", end="")
        for a1, a2 in alpha_pairs:
            P1 = get_P_seq_for_theta(speaker_type, theta, alpha=a1)
            P2 = get_P_seq_for_theta(speaker_type, theta, alpha=a2)
            rho, _ = spearmanr(P1, P2)
            print(f"{rho:^14.3f}", end="")
        print()

# -----------------------------------------------------------------------------
# Part 3: For each (theta, alpha), compare T vs F for same psi
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 3: Stability across T vs F (same psi)")
print("="*70)

print("""
Question: For a given (theta, alpha), do inf_T and inf_F agree on rankings?
If yes, we can treat "informative speaker" as one model regardless of T/F.
""")

selected_alphas = [1.0, 5.482, 8.909, 14.48, 30.02, 100.0]

for psi_name, type_T, type_F in [('inf', 'inf_T', 'inf_F'), 
                                   ('persp', 'persp_T', 'persp_F'), 
                                   ('persm', 'persm_T', 'persm_F')]:
    print(f"\n--- {psi_name}: {type_T} vs {type_F} ---")
    print(f"{'θ':<6}", end="")
    for alpha in selected_alphas:
        print(f"α={alpha:.1f}".center(12), end="")
    print()
    print("-" * 80)
    
    for theta in theta_values:
        print(f"{theta:<6.1f}", end="")
        for alpha in selected_alphas:
            P_T = get_P_seq_for_theta(type_T, theta, alpha=alpha)
            P_F = get_P_seq_for_theta(type_F, theta, alpha=alpha)
            rho, _ = spearmanr(P_T, P_F)
            print(f"{rho:^12.3f}", end="")
        print()

# -----------------------------------------------------------------------------
# Part 4: Top-k overlap between T and F
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 4: Top-k Jaccard overlap between T and F (same psi)")
print("="*70)

# Focus on mid-range alpha
alpha_focus = 8.909

for psi_name, type_T, type_F in [('inf', 'inf_T', 'inf_F'), 
                                   ('persp', 'persp_T', 'persp_F'), 
                                   ('persm', 'persm_T', 'persm_F')]:
    print(f"\n--- {psi_name} (α={alpha_focus}) ---")
    print(f"{'θ':<6} {'Top-1 same?':<12} {'Top-5 Jacc':<12} {'Top-10 Jacc':<12} {'Top-20 Jacc':<12}")
    print("-" * 60)
    
    for theta in theta_values:
        P_T = get_P_seq_for_theta(type_T, theta, alpha=alpha_focus)
        P_F = get_P_seq_for_theta(type_F, theta, alpha=alpha_focus)
        
        top1_same = np.argmax(P_T) == np.argmax(P_F)
        jacc5 = jaccard(get_top_k_set(P_T, 5), get_top_k_set(P_F, 5))
        jacc10 = jaccard(get_top_k_set(P_T, 10), get_top_k_set(P_F, 10))
        jacc20 = jaccard(get_top_k_set(P_T, 20), get_top_k_set(P_F, 20))
        
        print(f"{theta:<6.1f} {str(top1_same):<12} {jacc5:<12.3f} {jacc10:<12.3f} {jacc20:<12.3f}")

# -----------------------------------------------------------------------------
# Part 5: Within each theta, find sequences stable across alpha AND across T/F
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 5: Stable sequences within each theta (across alpha & T/F)")
print("="*70)

# Mid-range alphas
mid_alphas = [5.482, 6.988, 8.909, 11.36, 14.48]

for theta in theta_values:
    print(f"\n{'='*60}")
    print(f"θ = {theta}")
    print(f"{'='*60}")
    
    # For inf: collect top-20 sets for all (alpha, T/F) combinations
    top20_sets_inf = []
    for alpha in mid_alphas:
        for speaker in ['inf_T', 'inf_F']:
            P = get_P_seq_for_theta(speaker, theta, alpha=alpha)
            top20_sets_inf.append(get_top_k_set(P, 20))
    
    common_inf = set.intersection(*top20_sets_inf)
    print(f"\nINF: Sequences in top-20 for ALL (mid-α, T/F) combos: {len(common_inf)}")
    if len(common_inf) > 0 and len(common_inf) <= 5:
        for idx in common_inf:
            print(f"  {sequence_labels[idx]}")
    
    # For persp
    top20_sets_persp = []
    for alpha in mid_alphas:
        for speaker in ['persp_T', 'persp_F']:
            P = get_P_seq_for_theta(speaker, theta, alpha=alpha)
            top20_sets_persp.append(get_top_k_set(P, 20))
    
    common_persp = set.intersection(*top20_sets_persp)
    print(f"\nPERSP: Sequences in top-20 for ALL (mid-α, T/F) combos: {len(common_persp)}")
    if len(common_persp) > 0 and len(common_persp) <= 5:
        for idx in common_persp:
            print(f"  {sequence_labels[idx]}")
    
    # For persm
    top20_sets_persm = []
    for alpha in mid_alphas:
        for speaker in ['persm_T', 'persm_F']:
            P = get_P_seq_for_theta(speaker, theta, alpha=alpha)
            top20_sets_persm.append(get_top_k_set(P, 20))
    
    common_persm = set.intersection(*top20_sets_persm)
    print(f"\nPERSM: Sequences in top-20 for ALL (mid-α, T/F) combos: {len(common_persm)}")
    if len(common_persm) > 0 and len(common_persm) <= 5:
        for idx in common_persm:
            print(f"  {sequence_labels[idx]}")

# -----------------------------------------------------------------------------
# Part 6: Detailed comparison for a specific theta (e.g., 0.5)
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 6: Detailed analysis for θ = 0.5")
print("="*70)

theta_focus = 0.5
alpha_focus = 8.909

print(f"\nComparing top sequences at θ={theta_focus}, α={alpha_focus}")

# Get top-10 for each speaker type
for speaker in ['literal', 'inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']:
    if speaker == 'literal':
        P = get_P_seq_for_theta('literal', theta_focus, alpha=None)
    else:
        P = get_P_seq_for_theta(speaker, theta_focus, alpha=alpha_focus)
    
    print(f"\n--- {speaker} ---")
    top10_idx = np.argsort(-P)[:10]
    for rank, idx in enumerate(top10_idx, 1):
        seq = sequence_labels[idx]
        prob = P[idx]
        # Abbreviate
        abbrev = ",".join([u.split(",")[0][:2] for u in seq])
        print(f"  {rank:>2}. {abbrev:<25} P={prob:.6f}")

# -----------------------------------------------------------------------------
# Part 7: For each theta, compute rank correlation matrix across all speakers
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 7: Rank correlation between speaker types (within theta)")
print("="*70)

alpha_focus = 8.909
speakers = ['literal', 'inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']

for theta in [0.3, 0.5, 0.7]:  # Selected thetas
    print(f"\n--- θ = {theta}, α = {alpha_focus} ---")
    print(f"{'':12}", end="")
    for s in speakers:
        print(f"{s:>10}", end="")
    print()
    
    for s1 in speakers:
        print(f"{s1:12}", end="")
        for s2 in speakers:
            if s1 == 'literal':
                P1 = get_P_seq_for_theta('literal', theta, alpha=None)
            else:
                P1 = get_P_seq_for_theta(s1, theta, alpha=alpha_focus)
            
            if s2 == 'literal':
                P2 = get_P_seq_for_theta('literal', theta, alpha=None)
            else:
                P2 = get_P_seq_for_theta(s2, theta, alpha=alpha_focus)
            
            rho, _ = spearmanr(P1, P2)
            print(f"{rho:>10.3f}", end="")
        print()

# -----------------------------------------------------------------------------
# Part 8: Summary table - best alpha range for T/F agreement
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 8: Best alpha range for T/F agreement (by psi, by theta)")
print("="*70)

print("""
For each (psi, theta), find alpha range where Spearman(T, F) > 0.9
""")

all_alphas = alpha_values

for psi_name, type_T, type_F in [('inf', 'inf_T', 'inf_F'), 
                                   ('persp', 'persp_T', 'persp_F'), 
                                   ('persm', 'persm_T', 'persm_F')]:
    print(f"\n--- {psi_name} ---")
    print(f"{'θ':<6} {'Alphas with ρ(T,F) > 0.9':<50} {'Best ρ':<10}")
    print("-" * 70)
    
    for theta in theta_values:
        good_alphas = []
        best_rho = -1
        for alpha in all_alphas:
            P_T = get_P_seq_for_theta(type_T, theta, alpha=alpha)
            P_F = get_P_seq_for_theta(type_F, theta, alpha=alpha)
            rho, _ = spearmanr(P_T, P_F)
            if rho > 0.9:
                good_alphas.append(alpha)
            if rho > best_rho:
                best_rho = rho
        
        if good_alphas:
            alpha_str = f"[{min(good_alphas):.2f}, {max(good_alphas):.2f}] ({len(good_alphas)} vals)"
        else:
            alpha_str = "None"
        print(f"{theta:<6.1f} {alpha_str:<50} {best_rho:<10.3f}")

# -----------------------------------------------------------------------------
# Part 9: Summary
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("SUMMARY")
print("="*70)

print("""
KEY FINDINGS:

1. WITHIN-THETA ALPHA STABILITY:
   - Rankings change across alpha, but mid-range alphas are more stable
   - Extreme alphas (1.0, 100.0) often have different rankings

2. T vs F AGREEMENT:
   - For same psi, T and F versions may or may not agree
   - Agreement depends on theta and alpha
   - Some (theta, alpha) combinations show high T/F correlation

3. IMPLICATIONS FOR EXPERIMENT:
   - Choose (theta, alpha) where T/F agreement is high for each psi
   - This allows treating inf_T and inf_F as "informative model"
   - Select sequences that are stable across this range

NEXT STEPS:
   - Focus on (theta, alpha) combos with high T/F agreement
   - For those conditions, compute discriminability (JS) between psi types
   - Find sequences that discriminate between psi while being likely
""")

REVISED RANK STABILITY ANALYSIS

Key insight:
- Theta = different true world states (not comparable)
- We analyze stability WITHIN each theta
- We want T and F variants of same psi to agree (experimental robustness)


Part 2: Rank stability across ALPHA (within each theta)

--- inf_T ---
θ       ρ(1.0,5.5)    ρ(5.5,8.9)   ρ(8.9,14.5)  ρ(14.5,100.0)  ρ(5.5,14.5)  
--------------------------------------------------------------------------------
0.0       1.000         1.000         1.000         0.998         1.000     
0.1       0.966         0.987         0.977         0.566         0.933     
0.2       0.935         0.977         0.960         0.643         0.881     
0.3       0.892         0.959         0.947         0.728         0.823     
0.4       0.826         0.942         0.937         0.795         0.774     
0.5       0.795         0.926         0.938         0.815         0.756     
0.6       0.826         0.942         0.937         0.795         0.774     
0.7       0.89

In [20]:
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import itertools

# =============================================================================
# DISCRIMINATION ANALYSIS: High within-speaker, low between-speaker stability
# =============================================================================

print("="*70)
print("FINDING DISCRIMINATING SEQUENCES")
print("="*70)

print("""
Goal: Find sequences that are:
  - STABLE within a speaker type (across alphas) → characteristic
  - UNSTABLE between speaker types → discriminating
""")

# -----------------------------------------------------------------------------
# Helper functions
# -----------------------------------------------------------------------------

def compute_ranks(P):
    """Convert probabilities to ranks (1 = highest probability)."""
    return (-P).argsort().argsort() + 1

def get_top_k_set(P, k):
    """Get set of indices in top-k by probability."""
    return set(np.argsort(-P)[:k])

# Mid-range alphas for stability assessment
mid_alphas = [5.482, 6.988, 8.909, 11.36, 14.48]

# All speaker types (excluding literal for now, or include it)
speaker_types = ['literal', 'inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']
# For psi-level analysis
psi_speakers = {
    'inf': ['inf_T', 'inf_F'],
    'persp': ['persp_T', 'persp_F'],
    'persm': ['persm_T', 'persm_F']
}

# -----------------------------------------------------------------------------
# Part 1: For each (theta, speaker), compute mean rank across mid-alphas
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 1: Compute mean rank across mid-alphas for each sequence")
print("="*70)

# For each theta, create a DataFrame with mean rank per speaker
# Shape: (n_sequences, n_speakers) for each theta

mean_ranks_by_theta = {}

for theta in theta_values:
    mean_ranks = {}
    
    for speaker in speaker_types:
        if speaker == 'literal':
            # Literal has no alpha
            P = get_P_seq_for_theta('literal', theta, alpha=None)
            ranks = compute_ranks(P)
            mean_ranks[speaker] = ranks.astype(float)  # Just one "alpha"
        else:
            # Average rank across mid-alphas
            all_ranks = []
            for alpha in mid_alphas:
                P = get_P_seq_for_theta(speaker, theta, alpha=alpha)
                ranks = compute_ranks(P)
                all_ranks.append(ranks)
            mean_ranks[speaker] = np.mean(all_ranks, axis=0)
    
    mean_ranks_by_theta[theta] = pd.DataFrame(mean_ranks)

print(f"Computed mean ranks for {len(theta_values)} theta values")
print(f"Shape per theta: {mean_ranks_by_theta[0.5].shape}")

# -----------------------------------------------------------------------------
# Part 2: For each (theta, speaker), compute rank VARIANCE across mid-alphas
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 2: Compute rank variance (stability) across mid-alphas")
print("="*70)

# Low variance = high stability within speaker

rank_variance_by_theta = {}

for theta in theta_values:
    rank_vars = {}
    
    for speaker in speaker_types:
        if speaker == 'literal':
            # No variance for literal
            rank_vars[speaker] = np.zeros(n_sequences)
        else:
            all_ranks = []
            for alpha in mid_alphas:
                P = get_P_seq_for_theta(speaker, theta, alpha=alpha)
                ranks = compute_ranks(P)
                all_ranks.append(ranks)
            rank_vars[speaker] = np.var(all_ranks, axis=0)
    
    rank_variance_by_theta[theta] = pd.DataFrame(rank_vars)

print("Computed rank variance for each theta")

# -----------------------------------------------------------------------------
# Part 3: Find sequences with HIGH within-speaker stability, LOW mean rank
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 3: Sequences characteristic of each speaker (stable & top-ranked)")
print("="*70)

def find_characteristic_sequences(theta, speaker, top_k=50, max_variance=100):
    """
    Find sequences that are:
    - In top-k by mean rank for this speaker
    - Have low variance across alphas (stable)
    """
    mean_rank = mean_ranks_by_theta[theta][speaker].values
    variance = rank_variance_by_theta[theta][speaker].values
    
    # Candidates: top-k by mean rank
    top_k_idx = np.argsort(mean_rank)[:top_k]
    
    # Filter by variance
    stable_idx = [idx for idx in top_k_idx if variance[idx] <= max_variance]
    
    return stable_idx, mean_rank, variance

# For each theta, for each speaker, show characteristic sequences
for theta in [0.3, 0.5, 0.7]:
    print(f"\n{'='*60}")
    print(f"θ = {theta}")
    print(f"{'='*60}")
    
    for speaker in ['inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']:
        stable_idx, mean_rank, variance = find_characteristic_sequences(theta, speaker, top_k=20, max_variance=50)
        
        print(f"\n--- {speaker}: {len(stable_idx)} stable sequences in top-20 ---")
        if len(stable_idx) > 0:
            for idx in stable_idx[:5]:  # Show top 5
                seq = sequence_labels[idx]
                abbrev = ",".join([u.split(",")[0][:2] for u in seq])
                print(f"  {abbrev:<25} mean_rank={mean_rank[idx]:.1f}, var={variance[idx]:.1f}")

# -----------------------------------------------------------------------------
# Part 4: For each sequence, compute "discrimination score"
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 4: Compute discrimination score for each sequence")
print("="*70)

print("""
Discrimination score for sequence s, target speaker A:
  = (mean rank of s for other speakers) - (mean rank of s for speaker A)
  
High score = sequence is good for A but bad for others = discriminating!
""")

def compute_discrimination_scores(theta):
    """
    For each sequence and each speaker, compute how much better 
    this sequence ranks for that speaker vs others.
    """
    mean_ranks = mean_ranks_by_theta[theta]
    
    scores = {}
    for target_speaker in speaker_types:
        target_rank = mean_ranks[target_speaker].values
        other_speakers = [s for s in speaker_types if s != target_speaker]
        other_ranks = mean_ranks[other_speakers].values  # (n_seq, n_other)
        mean_other_rank = np.mean(other_ranks, axis=1)
        
        # Discrimination = how much worse others rank this sequence
        scores[target_speaker] = mean_other_rank - target_rank
    
    return pd.DataFrame(scores)

# Compute for selected thetas
discrimination_by_theta = {}
for theta in theta_values:
    discrimination_by_theta[theta] = compute_discrimination_scores(theta)

# Show top discriminating sequences for each speaker at θ = 0.5
theta = 0.5
print(f"\n--- θ = {theta}: Top discriminating sequences per speaker ---")

disc_df = discrimination_by_theta[theta]
mean_ranks = mean_ranks_by_theta[theta]
rank_vars = rank_variance_by_theta[theta]

for speaker in ['inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']:
    print(f"\n{speaker}:")
    
    # Get top 10 by discrimination score
    top_disc_idx = np.argsort(-disc_df[speaker].values)[:10]
    
    print(f"  {'Sequence':<30} {'Disc Score':<12} {'Mean Rank':<12} {'Variance':<12}")
    for idx in top_disc_idx:
        seq = sequence_labels[idx]
        abbrev = ",".join([u.split(",")[0][:2] for u in seq])
        disc = disc_df[speaker].values[idx]
        mr = mean_ranks[speaker].values[idx]
        var = rank_vars[speaker].values[idx]
        print(f"  {abbrev:<30} {disc:<12.1f} {mr:<12.1f} {var:<12.1f}")

# -----------------------------------------------------------------------------
# Part 5: Combined score: discrimination + stability + good rank
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 5: Combined score (discrimination × stability × good rank)")
print("="*70)

print("""
We want sequences that are:
  1. Highly discriminating (high disc score)
  2. Stable across alphas (low variance)
  3. Actually likely for the target speaker (low mean rank)
  
Combined score = disc_score / (1 + sqrt(variance)) / (1 + log(mean_rank))
""")

def compute_combined_scores(theta):
    """Compute combined discrimination-stability-rank score."""
    disc_df = discrimination_by_theta[theta]
    mean_ranks = mean_ranks_by_theta[theta]
    rank_vars = rank_variance_by_theta[theta]
    
    combined = {}
    for speaker in speaker_types:
        disc = disc_df[speaker].values
        mr = mean_ranks[speaker].values
        var = rank_vars[speaker].values
        
        # Combined score (higher is better)
        # Reward: high discrimination, low variance, low mean rank
        combined[speaker] = disc / (1 + np.sqrt(var)) / (1 + np.log1p(mr))
    
    return pd.DataFrame(combined)

# Show for θ = 0.5
theta = 0.5
combined_df = compute_combined_scores(theta)

print(f"\n--- θ = {theta}: Top sequences by combined score ---")

for speaker in ['inf_T', 'inf_F', 'persp_T', 'persp_F', 'persm_T', 'persm_F']:
    print(f"\n{speaker}:")
    
    top_idx = np.argsort(-combined_df[speaker].values)[:10]
    
    disc_df = discrimination_by_theta[theta]
    mean_ranks = mean_ranks_by_theta[theta]
    rank_vars = rank_variance_by_theta[theta]
    
    print(f"  {'Sequence':<30} {'Combined':<10} {'Disc':<10} {'MeanRank':<10} {'Var':<10}")
    for idx in top_idx:
        seq = sequence_labels[idx]
        abbrev = ",".join([u.split(",")[0][:2] for u in seq])
        comb = combined_df[speaker].values[idx]
        disc = disc_df[speaker].values[idx]
        mr = mean_ranks[speaker].values[idx]
        var = rank_vars[speaker].values[idx]
        print(f"  {abbrev:<30} {comb:<10.3f} {disc:<10.1f} {mr:<10.1f} {var:<10.1f}")

# -----------------------------------------------------------------------------
# Part 6: PSI-level analysis (collapsing T and F)
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 6: PSI-level discrimination (inf vs persp vs persm)")
print("="*70)

print("""
Now treating inf_T and inf_F as one "inf" model, etc.
Find sequences that discriminate between PSI types.
""")

def compute_psi_mean_rank(theta, psi):
    """Average mean rank across T and F variants."""
    speakers = psi_speakers[psi]
    ranks = [mean_ranks_by_theta[theta][s].values for s in speakers]
    return np.mean(ranks, axis=0)

def compute_psi_discrimination(theta):
    """Discrimination scores at PSI level."""
    psi_ranks = {psi: compute_psi_mean_rank(theta, psi) for psi in ['inf', 'persp', 'persm']}
    
    # Also include literal
    psi_ranks['literal'] = mean_ranks_by_theta[theta]['literal'].values
    
    scores = {}
    for target in ['literal', 'inf', 'persp', 'persm']:
        target_rank = psi_ranks[target]
        others = [p for p in ['literal', 'inf', 'persp', 'persm'] if p != target]
        other_ranks = np.array([psi_ranks[p] for p in others])
        mean_other = np.mean(other_ranks, axis=0)
        scores[target] = mean_other - target_rank
    
    return pd.DataFrame(scores)

# Show PSI-level discrimination for θ = 0.5
theta = 0.5
psi_disc = compute_psi_discrimination(theta)

print(f"\n--- θ = {theta}: Top PSI-discriminating sequences ---")

for psi in ['inf', 'persp', 'persm']:
    print(f"\n{psi.upper()}:")
    
    top_idx = np.argsort(-psi_disc[psi].values)[:10]
    
    print(f"  {'Sequence':<30} {'PSI Disc':<12} {'inf rank':<12} {'persp rank':<12} {'persm rank':<12}")
    for idx in top_idx:
        seq = sequence_labels[idx]
        abbrev = ",".join([u.split(",")[0][:2] for u in seq])
        disc = psi_disc[psi].values[idx]
        inf_r = compute_psi_mean_rank(theta, 'inf')[idx]
        persp_r = compute_psi_mean_rank(theta, 'persp')[idx]
        persm_r = compute_psi_mean_rank(theta, 'persm')[idx]
        print(f"  {abbrev:<30} {disc:<12.1f} {inf_r:<12.1f} {persp_r:<12.1f} {persm_r:<12.1f}")

# -----------------------------------------------------------------------------
# Part 7: Find "maximally discriminating" sets of sequences
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 7: Sequences that discriminate BETWEEN all PSI types")
print("="*70)

print("""
Find sequences where inf, persp, and persm have VERY DIFFERENT ranks.
Metric: variance of ranks across the three PSI types (higher = more discriminating)
""")

def compute_psi_rank_variance(theta):
    """For each sequence, compute variance of ranks across PSI types."""
    psi_ranks = np.array([
        compute_psi_mean_rank(theta, 'inf'),
        compute_psi_mean_rank(theta, 'persp'),
        compute_psi_mean_rank(theta, 'persm')
    ])  # (3, n_sequences)
    
    return np.var(psi_ranks, axis=0)

for theta in [0.3, 0.5, 0.7]:
    print(f"\n--- θ = {theta} ---")
    
    psi_var = compute_psi_rank_variance(theta)
    
    # Top sequences by PSI rank variance
    top_idx = np.argsort(-psi_var)[:15]
    
    print(f"  {'Sequence':<35} {'Rank Var':<12} {'inf':<8} {'persp':<8} {'persm':<8}")
    for idx in top_idx:
        seq = sequence_labels[idx]
        abbrev = ",".join([u.split(",")[0][:2] for u in seq])
        var = psi_var[idx]
        inf_r = compute_psi_mean_rank(theta, 'inf')[idx]
        persp_r = compute_psi_mean_rank(theta, 'persp')[idx]
        persm_r = compute_psi_mean_rank(theta, 'persm')[idx]
        print(f"  {abbrev:<35} {var:<12.0f} {inf_r:<8.0f} {persp_r:<8.0f} {persm_r:<8.0f}")

# -----------------------------------------------------------------------------
# Part 8: Summary table of best discriminating sequences per theta
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 8: Summary - Best discriminating sequences for each theta")
print("="*70)

print("""
For each theta, show the TOP 5 sequences that maximize rank variance across PSI types.
These are sequences where the three PSI models disagree most about likelihood.
""")

summary_data = []

for theta in theta_values:
    psi_var = compute_psi_rank_variance(theta)
    top_idx = np.argsort(-psi_var)[:5]
    
    for rank, idx in enumerate(top_idx, 1):
        seq = sequence_labels[idx]
        summary_data.append({
            'theta': theta,
            'rank': rank,
            'sequence': seq,
            'psi_var': psi_var[idx],
            'inf_rank': compute_psi_mean_rank(theta, 'inf')[idx],
            'persp_rank': compute_psi_mean_rank(theta, 'persp')[idx],
            'persm_rank': compute_psi_mean_rank(theta, 'persm')[idx]
        })

summary_df = pd.DataFrame(summary_data)

# Display nicely
for theta in theta_values:
    subset = summary_df[summary_df['theta'] == theta]
    print(f"\nθ = {theta}:")
    for _, row in subset.iterrows():
        abbrev = ",".join([u.split(",")[0][:2] for u in row['sequence']])
        print(f"  {row['rank']}. {abbrev:<30} var={row['psi_var']:.0f}  inf={row['inf_rank']:.0f} persp={row['persp_rank']:.0f} persm={row['persm_rank']:.0f}")

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)

FINDING DISCRIMINATING SEQUENCES

Goal: Find sequences that are:
  - STABLE within a speaker type (across alphas) → characteristic
  - UNSTABLE between speaker types → discriminating


Part 1: Compute mean rank across mid-alphas for each sequence
Computed mean ranks for 11 theta values
Shape per theta: (32768, 7)

Part 2: Compute rank variance (stability) across mid-alphas
Computed rank variance for each theta

Part 3: Sequences characteristic of each speaker (stable & top-ranked)

θ = 0.3

--- inf_T: 20 stable sequences in top-20 ---
  mo,mo,mo,mo,mo            mean_rank=1.4, var=0.2
  mo,mo,mo,mo,mo            mean_rank=2.4, var=0.2
  mo,so,mo,mo,so            mean_rank=2.6, var=1.8
  mo,so,mo,mo,mo            mean_rank=4.0, var=0.8
  mo,so,mo,so,mo            mean_rank=5.4, var=0.2

--- inf_F: 20 stable sequences in top-20 ---
  mo,mo,mo,mo,mo            mean_rank=1.0, var=0.0
  mo,mo,mo,mo,mo            mean_rank=3.2, var=2.2
  mo,mo,mo,mo,mo            mean_rank=3.8, var=1.0
  mo,

In [23]:
# =============================================================================
# COMPLETE SELF-CONTAINED DISCRIMINATION ANALYSIS SCRIPT
# =============================================================================
#
# This script:
# 1. Sets up the RSA world and computes all P(seq | θ, speaker, α) matrices
# 2. Defines listener models (cooperative vs uncertain)
# 3. Computes discrimination metrics between listeners
# 4. Finds optimal sequences for experimental use
#
# Only imports from the core RSA module and standard libraries.
# =============================================================================

import numpy as np
import pandas as pd
import itertools
from scipy.special import logsumexp
from scipy.spatial.distance import jensenshannon
from joblib import Parallel, delayed
import warnings

# Import from the RSA module (assumes rsa_optimal_exp_core.py is available)
from rsa_optimal_exp_core import (
    World, 
    LiteralListener, 
    PragmaticSpeaker_obs,
    log_M_product,
    USE_PRECISE_LOGSPACE
)

print("="*70)
print("COMPLETE SELF-CONTAINED DISCRIMINATION ANALYSIS")
print("="*70)

# =============================================================================
# PART A: SETUP AND COMPUTE P(seq | θ, speaker, α) MATRICES
# =============================================================================

print("\n" + "="*70)
print("PART A: Computing P(seq | θ, speaker, α) matrices")
print("="*70)

# -----------------------------------------------------------------------------
# A1: Create World
# -----------------------------------------------------------------------------

print("\n--- A1: Creating World ---")

world = World(n=1, m=5)

utterances = world.utterances
theta_values = world.theta_values
n_utterances = len(utterances)
n_theta = len(theta_values)
n_rounds = 5
n_sequences = n_utterances ** n_rounds

# P(O|θ) matrix
log_P_O_given_theta = world.obs_log_likelihood_theta.values  # (n_obs, n_theta)

# All sequences
all_sequences = list(itertools.product(range(n_utterances), repeat=n_rounds))
sequence_labels = [tuple(utterances[i] for i in seq) for seq in all_sequences]

# Alpha values (mid-range for stability)
alpha_values = [5.482, 6.988, 8.909, 11.36, 14.48]
n_alpha = len(alpha_values)

print(f"Utterances ({n_utterances}): {utterances}")
print(f"Theta values ({n_theta}): {list(theta_values)}")
print(f"Alpha values: {alpha_values}")
print(f"Sequences: {n_sequences}")
print(f"Rounds: {n_rounds}")

# -----------------------------------------------------------------------------
# A2: Helper function
# -----------------------------------------------------------------------------

print("\n--- A2: Defining helper functions ---")

def get_log_P_u_given_theta(speaker, log_P_O_given_theta):
    """Marginalize P(u|O) over O to get P(u|θ)."""
    log_P_u_given_O = speaker.utterance_log_prob_obs.values
    return log_M_product(log_P_u_given_O, log_P_O_given_theta, precise=USE_PRECISE_LOGSPACE)

print("Helper functions defined.")

# -----------------------------------------------------------------------------
# A3: Compute log P(seq | θ) for all speakers
# -----------------------------------------------------------------------------

print("\n--- A3: Computing log P(seq | θ) matrices ---")

# Storage: (speaker_type, alpha) -> np.array of shape (n_sequences, n_theta)
log_P_seq_given_theta = {}

# LITERAL SPEAKER
print("\nLiteral speaker...")
literal_listener = LiteralListener(world)
log_P_u_given_theta_literal = literal_listener.utterance_log_likelihood_theta.values

log_P_seq_theta_literal = np.zeros((n_sequences, n_theta))
for seq_idx, seq in enumerate(all_sequences):
    for r in range(n_rounds):
        log_P_seq_theta_literal[seq_idx, :] += log_P_u_given_theta_literal[seq[r], :]

log_P_seq_given_theta[('literal', None)] = log_P_seq_theta_literal
print("  Done.")

# PRAGMATIC SPEAKERS with update_internal=False
print("\nPragmatic speakers (update_internal=False)...")

psi_map_F = {
    'inf_F': 'inf',
    'persp_F': 'pers+',
    'persm_F': 'pers-'
}

for speaker_type, psi in psi_map_F.items():
    omega = 'coop' if psi == 'inf' else 'strat'
    print(f"  {speaker_type}...")
    
    for alpha in alpha_values:
        speaker = PragmaticSpeaker_obs(
            world=world, omega=omega, psi=psi,
            update_internal=False, alpha=alpha, beta=0.0
        )
        log_P_u_given_theta = get_log_P_u_given_theta(speaker, log_P_O_given_theta)
        
        log_P_seq_theta = np.zeros((n_sequences, n_theta))
        for seq_idx, seq in enumerate(all_sequences):
            for r in range(n_rounds):
                log_P_seq_theta[seq_idx, :] += log_P_u_given_theta[seq[r], :]
        
        log_P_seq_given_theta[(speaker_type, alpha)] = log_P_seq_theta

# PRAGMATIC SPEAKERS with update_internal=True
print("\nPragmatic speakers (update_internal=True)...")

psi_map_T = {
    'inf_T': 'inf',
    'persp_T': 'pers+',
    'persm_T': 'pers-'
}

for speaker_type, psi in psi_map_T.items():
    omega = 'coop' if psi == 'inf' else 'strat'
    print(f"  {speaker_type}...")
    
    for alpha in alpha_values:
        # Precompute P(u|θ) for all histories
        history_log_P_u_given_theta = {}
        
        for hist_len in range(n_rounds):
            if hist_len == 0:
                histories = [()]
            else:
                histories = list(itertools.product(range(n_utterances), repeat=hist_len))
            
            for history in histories:
                history_utterances = [utterances[i] for i in history]
                speaker = PragmaticSpeaker_obs(
                    world=world, omega=omega, psi=psi,
                    update_internal=True, alpha=alpha, beta=0.0
                )
                for u in history_utterances:
                    speaker.literal_listener.listen_and_update(u)
                    speaker.utterance_log_prob_obs = speaker._compute_utterance_log_prob_obs(alpha)
                
                history_log_P_u_given_theta[history] = get_log_P_u_given_theta(speaker, log_P_O_given_theta)
        
        # Compute P(seq|θ)
        log_P_seq_theta = np.zeros((n_sequences, n_theta))
        for seq_idx, seq in enumerate(all_sequences):
            for r in range(n_rounds):
                history = seq[:r]
                log_P_seq_theta[seq_idx, :] += history_log_P_u_given_theta[history][seq[r], :]
        
        log_P_seq_given_theta[(speaker_type, alpha)] = log_P_seq_theta

print(f"\nTotal matrices computed: {len(log_P_seq_given_theta)}")

# Verify normalization
print("\nVerifying normalization (should sum to 1 for each θ):")
for key, log_P in list(log_P_seq_given_theta.items())[:5]:
    P = np.exp(log_P)
    sums = P.sum(axis=0)
    print(f"  {key}: min={sums.min():.6f}, max={sums.max():.6f}")

# =============================================================================
# PART B: COMPUTE AVERAGED P(seq | θ, ψ) MATRICES
# =============================================================================

print("\n" + "="*70)
print("PART B: Computing averaged P(seq | θ, ψ) matrices")
print("="*70)

print("""
For each ψ ∈ {inf, persp, persm}, compute:
  P(seq | θ, ψ) = average over {T, F} × {alphas}
""")

# Speaker groupings
psi_to_speakers = {
    'inf': ['inf_T', 'inf_F'],
    'persp': ['persp_T', 'persp_F'],
    'persm': ['persm_T', 'persm_F']
}

# Storage: psi -> log P(seq | θ, ψ) matrix
log_P_seq_given_theta_psi = {}

for psi, speakers in psi_to_speakers.items():
    print(f"\nComputing for ψ = {psi}...")
    
    # Collect all log P matrices for this psi
    all_log_P = []
    for speaker in speakers:
        for alpha in alpha_values:
            key = (speaker, alpha)
            if key in log_P_seq_given_theta:
                all_log_P.append(log_P_seq_given_theta[key])
    
    all_log_P = np.array(all_log_P)  # (n_models, n_sequences, n_theta)
    n_models = len(all_log_P)
    
    # Average in probability space: logsumexp - log(n)
    log_P_avg = logsumexp(all_log_P, axis=0) - np.log(n_models)
    
    log_P_seq_given_theta_psi[psi] = log_P_avg
    print(f"  Averaged over {n_models} models")

# Also add literal
log_P_seq_given_theta_psi['literal'] = log_P_seq_given_theta[('literal', None)]

# Verify
print("\nVerifying averaged matrices:")
for psi, log_P in log_P_seq_given_theta_psi.items():
    P = np.exp(log_P)
    sums = P.sum(axis=0)
    print(f"  {psi}: min={sums.min():.6f}, max={sums.max():.6f}")

# =============================================================================
# PART C: DEFINE LISTENER MODELS AND COMPUTE POSTERIORS
# =============================================================================

print("\n" + "="*70)
print("PART C: Defining listener models and computing posteriors")
print("="*70)

# Flat prior over theta
log_prior_theta = np.full(n_theta, -np.log(n_theta))
prior_theta = np.exp(log_prior_theta)

# -----------------------------------------------------------------------------
# C1: Define listener P(seq | θ) functions
# -----------------------------------------------------------------------------

print("\n--- C1: Defining listener models ---")

print("""
Cooperative Listener:
  - Assumes speaker has ψ = inf
  - P_coop(seq | θ) = P(seq | θ, ψ=inf)

Uncertain Listener:
  - Uniform prior over ψ ∈ {inf, persp, persm}
  - P_uncertain(seq | θ) = (1/3) * Σ_ψ P(seq | θ, ψ)
""")

# Cooperative: just use inf
log_P_seq_given_theta_coop = log_P_seq_given_theta_psi['inf']

# Uncertain: average over three psi types
log_P_psi_list = [
    log_P_seq_given_theta_psi['inf'],
    log_P_seq_given_theta_psi['persp'],
    log_P_seq_given_theta_psi['persm']
]
log_P_psi_array = np.array(log_P_psi_list)
log_P_seq_given_theta_uncertain = logsumexp(log_P_psi_array, axis=0) - np.log(3)

print(f"Cooperative P(seq|θ) shape: {log_P_seq_given_theta_coop.shape}")
print(f"Uncertain P(seq|θ) shape: {log_P_seq_given_theta_uncertain.shape}")

# -----------------------------------------------------------------------------
# C2: Compute posteriors P(θ | seq)
# -----------------------------------------------------------------------------

print("\n--- C2: Computing posteriors P(θ | seq) ---")

def compute_posterior(log_P_seq_given_theta, log_prior_theta):
    """
    Compute P(θ | seq) for all sequences.
    
    Returns normalized posterior probabilities, shape (n_sequences, n_theta)
    """
    # log P(θ | seq) ∝ log P(seq | θ) + log P(θ)
    log_unnorm = log_P_seq_given_theta + log_prior_theta
    
    # Normalize: subtract logsumexp over theta
    log_norm = logsumexp(log_unnorm, axis=1, keepdims=True)
    log_posterior = log_unnorm - log_norm
    
    return np.exp(log_posterior)

# Compute posteriors for each listener type
P_theta_given_seq_coop = compute_posterior(log_P_seq_given_theta_coop, log_prior_theta)
P_theta_given_seq_uncertain = compute_posterior(log_P_seq_given_theta_uncertain, log_prior_theta)

# Also compute for each PSI separately (useful for detailed analysis)
P_theta_given_seq_inf = compute_posterior(log_P_seq_given_theta_psi['inf'], log_prior_theta)
P_theta_given_seq_persp = compute_posterior(log_P_seq_given_theta_psi['persp'], log_prior_theta)
P_theta_given_seq_persm = compute_posterior(log_P_seq_given_theta_psi['persm'], log_prior_theta)

print(f"Cooperative posterior shape: {P_theta_given_seq_coop.shape}")
print(f"Uncertain posterior shape: {P_theta_given_seq_uncertain.shape}")

# Verify normalization
print("\nVerifying posterior normalization:")
print(f"  Cooperative: min={P_theta_given_seq_coop.sum(axis=1).min():.6f}, max={P_theta_given_seq_coop.sum(axis=1).max():.6f}")
print(f"  Uncertain: min={P_theta_given_seq_uncertain.sum(axis=1).min():.6f}, max={P_theta_given_seq_uncertain.sum(axis=1).max():.6f}")

# =============================================================================
# PART D: COMPUTE DISCRIMINATION METRICS
# =============================================================================

print("\n" + "="*70)
print("PART D: Computing discrimination metrics")
print("="*70)

# -----------------------------------------------------------------------------
# D1: JS Divergence with numerical stability
# -----------------------------------------------------------------------------

print("\n--- D1: Computing JS divergence ---")

def js_divergence_safe(P, Q, eps=1e-10):
    """
    Compute JS divergence with numerical stability.
    
    Adds small epsilon to avoid log(0) and handles edge cases.
    """
    # Add epsilon and renormalize
    P_safe = P + eps
    P_safe = P_safe / P_safe.sum()
    
    Q_safe = Q + eps
    Q_safe = Q_safe / Q_safe.sum()
    
    # Mixture distribution
    M = 0.5 * (P_safe + Q_safe)
    
    # KL divergences
    kl_pm = np.sum(P_safe * np.log2(P_safe / M))
    kl_qm = np.sum(Q_safe * np.log2(Q_safe / M))
    
    # JS divergence
    js = 0.5 * (kl_pm + kl_qm)
    
    return js

def compute_js_divergence_batch_safe(P1, P2, eps=1e-10):
    """
    Compute JS divergence for each row (sequence) with numerical stability.
    """
    n_seq = P1.shape[0]
    js_div = np.zeros(n_seq)
    
    for i in range(n_seq):
        js_div[i] = js_divergence_safe(P1[i], P2[i], eps=eps)
    
    return js_div

print("Computing JS divergence between cooperative and uncertain listeners...")

# Parallel computation
chunk_size = 1000
n_chunks = (n_sequences + chunk_size - 1) // chunk_size
chunks = [(i * chunk_size, min((i + 1) * chunk_size, n_sequences)) for i in range(n_chunks)]

def compute_js_chunk(start_idx, end_idx):
    return compute_js_divergence_batch_safe(
        P_theta_given_seq_coop[start_idx:end_idx],
        P_theta_given_seq_uncertain[start_idx:end_idx]
    )

js_results = Parallel(n_jobs=-1, verbose=1)(
    delayed(compute_js_chunk)(start, end) for start, end in chunks
)

js_divergence_coop_vs_uncertain = np.concatenate(js_results)

print(f"\nJS divergence (coop vs uncertain):")
print(f"  Min: {js_divergence_coop_vs_uncertain.min():.6f}")
print(f"  Max: {js_divergence_coop_vs_uncertain.max():.6f}")
print(f"  Mean: {js_divergence_coop_vs_uncertain.mean():.6f}")
print(f"  Median: {np.median(js_divergence_coop_vs_uncertain):.6f}")

# -----------------------------------------------------------------------------
# D2: E[θ] and other summary statistics
# -----------------------------------------------------------------------------

print("\n--- D2: Computing E[θ] and other metrics ---")

# E[θ | seq] for each listener
E_theta_coop = (P_theta_given_seq_coop * theta_values).sum(axis=1)
E_theta_uncertain = (P_theta_given_seq_uncertain * theta_values).sum(axis=1)
E_theta_diff = np.abs(E_theta_coop - E_theta_uncertain)

# Variance
Var_theta_coop = (P_theta_given_seq_coop * (theta_values ** 2)).sum(axis=1) - E_theta_coop ** 2
Var_theta_uncertain = (P_theta_given_seq_uncertain * (theta_values ** 2)).sum(axis=1) - E_theta_uncertain ** 2

print(f"E[θ] difference: min={E_theta_diff.min():.4f}, max={E_theta_diff.max():.4f}, mean={E_theta_diff.mean():.4f}")

# -----------------------------------------------------------------------------
# D3: Pairwise PSI JS divergences
# -----------------------------------------------------------------------------

print("\n--- D3: Computing pairwise PSI JS divergences ---")

def compute_js_parallel(P1, P2, name):
    """Compute JS divergence in parallel."""
    def compute_chunk(start, end):
        return compute_js_divergence_batch_safe(P1[start:end], P2[start:end])
    
    results = Parallel(n_jobs=-1, verbose=0)(
        delayed(compute_chunk)(start, end) for start, end in chunks
    )
    js = np.concatenate(results)
    print(f"  {name}: min={js.min():.4f}, max={js.max():.4f}, mean={js.mean():.4f}")
    return js

js_inf_vs_persp = compute_js_parallel(P_theta_given_seq_inf, P_theta_given_seq_persp, "JS(inf vs persp)")
js_inf_vs_persm = compute_js_parallel(P_theta_given_seq_inf, P_theta_given_seq_persm, "JS(inf vs persm)")
js_persp_vs_persm = compute_js_parallel(P_theta_given_seq_persp, P_theta_given_seq_persm, "JS(persp vs persm)")

# =============================================================================
# PART E: COMPUTE NORMALCY METRICS
# =============================================================================

print("\n" + "="*70)
print("PART E: Computing normalcy metrics")
print("="*70)

print("""
Normalcy metrics:
1. max_P_overall: Maximum P(seq | θ, ψ) over all θ, ψ
2. marginal_P: P(seq) marginalized over θ, ψ with flat priors
3. min_max_P: Minimum (over ψ) of max (over θ) P(seq | θ, ψ)
""")

# Compute max P(seq | θ, ψ) for each ψ
max_P_per_psi = {}
for psi in ['inf', 'persp', 'persm']:
    log_P = log_P_seq_given_theta_psi[psi]
    max_P_per_psi[psi] = np.exp(log_P).max(axis=1)

# Metric 1: Max over all (θ, ψ)
max_P_overall = np.maximum.reduce([max_P_per_psi[psi] for psi in ['inf', 'persp', 'persm']])

# Metric 2: Marginal P(seq) with flat priors
log_marginal_per_psi = {}
for psi in ['inf', 'persp', 'persm']:
    log_P = log_P_seq_given_theta_psi[psi]
    log_marginal_per_psi[psi] = logsumexp(log_P + log_prior_theta, axis=1)

log_marginal_array = np.array([log_marginal_per_psi[psi] for psi in ['inf', 'persp', 'persm']])
log_marginal_P = logsumexp(log_marginal_array, axis=0) - np.log(3)
marginal_P = np.exp(log_marginal_P)

# Metric 3: Min (over ψ) of max (over θ)
min_max_P = np.minimum.reduce([max_P_per_psi[psi] for psi in ['inf', 'persp', 'persm']])

print(f"max_P_overall: min={max_P_overall.min():.6e}, max={max_P_overall.max():.6f}, median={np.median(max_P_overall):.6e}")
print(f"marginal_P: min={marginal_P.min():.6e}, max={marginal_P.max():.6f}, median={np.median(marginal_P):.6e}")
print(f"min_max_P: min={min_max_P.min():.6e}, max={min_max_P.max():.6f}, median={np.median(min_max_P):.6e}")

# =============================================================================
# PART F: CREATE RESULTS DATAFRAME
# =============================================================================

print("\n" + "="*70)
print("PART F: Creating comprehensive results DataFrame")
print("="*70)

# Sequence characteristics
def count_utterance_types(seq):
    """Count different characteristics of the sequence."""
    n_all = sum(1 for u in seq if u.startswith('all,'))
    n_most = sum(1 for u in seq if u.startswith('most,'))
    n_some = sum(1 for u in seq if u.startswith('some,'))
    n_no = sum(1 for u in seq if u.startswith('no,'))
    n_successful = sum(1 for u in seq if u.endswith(',successful'))
    n_unsuccessful = sum(1 for u in seq if u.endswith(',unsuccessful'))
    n_unique = len(set(seq))
    return n_all, n_most, n_some, n_no, n_successful, n_unsuccessful, n_unique

seq_chars = [count_utterance_types(seq) for seq in sequence_labels]

results_discrimination = pd.DataFrame({
    'sequence_idx': list(range(n_sequences)),
    'sequence': sequence_labels,
    
    # Main discrimination metric
    'js_coop_vs_uncertain': js_divergence_coop_vs_uncertain,
    
    # E[θ] metrics
    'E_theta_coop': E_theta_coop,
    'E_theta_uncertain': E_theta_uncertain,
    'E_theta_diff': E_theta_diff,
    
    # Variance metrics
    'Var_theta_coop': Var_theta_coop,
    'Var_theta_uncertain': Var_theta_uncertain,
    
    # Normalcy metrics
    'max_P_overall': max_P_overall,
    'marginal_P': marginal_P,
    'min_max_P': min_max_P,
    'max_P_inf': max_P_per_psi['inf'],
    'max_P_persp': max_P_per_psi['persp'],
    'max_P_persm': max_P_per_psi['persm'],
    
    # Pairwise PSI JS divergences
    'js_inf_vs_persp': js_inf_vs_persp,
    'js_inf_vs_persm': js_inf_vs_persm,
    'js_persp_vs_persm': js_persp_vs_persm,
    
    # Sequence characteristics
    'n_all': [c[0] for c in seq_chars],
    'n_most': [c[1] for c in seq_chars],
    'n_some': [c[2] for c in seq_chars],
    'n_no': [c[3] for c in seq_chars],
    'n_successful': [c[4] for c in seq_chars],
    'n_unsuccessful': [c[5] for c in seq_chars],
    'n_unique_utterances': [c[6] for c in seq_chars],
})

print(f"Results DataFrame shape: {results_discrimination.shape}")
print(f"Columns: {list(results_discrimination.columns)}")

# =============================================================================
# PART G: SELECT TOP DISCRIMINATING SEQUENCES
# =============================================================================

print("\n" + "="*70)
print("PART G: Selecting top discriminating sequences")
print("="*70)

# Different normalcy thresholds
normalcy_percentiles = [50, 25, 10, 5]

print("\nTop discriminating sequences at different normalcy thresholds:\n")

for pct in normalcy_percentiles:
    threshold = np.percentile(marginal_P, 100 - pct)
    mask = results_discrimination['marginal_P'] >= threshold
    filtered = results_discrimination[mask].copy()
    
    print(f"--- Top {pct}% by marginal_P (threshold={threshold:.6e}, n={mask.sum()}) ---")
    
    top_disc = filtered.nlargest(10, 'js_coop_vs_uncertain')
    
    print(f"{'Rank':<5} {'JS':<8} {'E[θ] diff':<10} {'marginal_P':<12} {'Sequence'}")
    print("-" * 100)
    
    for rank, (_, row) in enumerate(top_disc.iterrows(), 1):
        seq = row['sequence']
        # Abbreviate: mos = most,successful, mou = most,unsuccessful, etc.
        abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
        print(f"{rank:<5} {row['js_coop_vs_uncertain']:<8.4f} {row['E_theta_diff']:<10.4f} {row['marginal_P']:<12.6e} {abbrev}")
    print()

# =============================================================================
# PART H: PARETO-OPTIMAL SEQUENCES
# =============================================================================

print("\n" + "="*70)
print("PART H: Finding Pareto-optimal sequences")
print("="*70)

def find_pareto_frontier(df, col1, col2):
    """Find Pareto-optimal points (maximize both columns)."""
    values = df[[col1, col2]].values
    values = -values  # Convert to minimization
    
    n = len(values)
    is_pareto = np.ones(n, dtype=bool)
    
    for i in range(n):
        if is_pareto[i]:
            dominated = np.all(values <= values[i], axis=1) & np.any(values < values[i], axis=1)
            dominated[i] = False
            is_pareto[dominated] = False
    
    return df.index[is_pareto].tolist()

pareto_idx = find_pareto_frontier(results_discrimination, 'js_coop_vs_uncertain', 'marginal_P')
print(f"Number of Pareto-optimal sequences: {len(pareto_idx)}")

pareto_df = results_discrimination.loc[pareto_idx].sort_values('js_coop_vs_uncertain', ascending=False)

print(f"\nTop 20 Pareto-optimal sequences (sorted by JS):")
print(f"{'Rank':<5} {'JS':<8} {'marginal_P':<12} {'E[θ] diff':<10} {'Sequence'}")
print("-" * 100)

for rank, (_, row) in enumerate(pareto_df.head(20).iterrows(), 1):
    seq = row['sequence']
    abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
    print(f"{rank:<5} {row['js_coop_vs_uncertain']:<8.4f} {row['marginal_P']:<12.6e} {row['E_theta_diff']:<10.4f} {abbrev}")

# =============================================================================
# PART I: DETAILED ANALYSIS OF TOP CANDIDATES
# =============================================================================

print("\n" + "="*70)
print("PART I: Detailed analysis of top candidates")
print("="*70)

# Select top candidates: high JS AND high normalcy
# Filter to top 25% normalcy first
pareto_normal = pareto_df[pareto_df['marginal_P'] >= np.percentile(marginal_P, 75)]
top_candidates = pareto_normal.head(5)

print(f"Analyzing top {len(top_candidates)} Pareto-optimal sequences with high normalcy:\n")

for rank, (idx, row) in enumerate(top_candidates.iterrows(), 1):
    seq = row['sequence']
    
    print(f"{'='*60}")
    print(f"Candidate {rank}: {seq}")
    print(f"{'='*60}")
    
    print(f"\nDiscrimination metrics:")
    print(f"  JS(coop vs uncertain): {row['js_coop_vs_uncertain']:.4f}")
    print(f"  E[θ] difference: {row['E_theta_diff']:.4f}")
    print(f"  E[θ|seq] for coop listener: {row['E_theta_coop']:.4f}")
    print(f"  E[θ|seq] for uncertain listener: {row['E_theta_uncertain']:.4f}")
    
    print(f"\nNormalcy metrics:")
    print(f"  Marginal P(seq): {row['marginal_P']:.6e}")
    print(f"  Max P under inf: {row['max_P_inf']:.6e}")
    print(f"  Max P under persp: {row['max_P_persp']:.6e}")
    print(f"  Max P under persm: {row['max_P_persm']:.6e}")
    
    print(f"\nPairwise PSI JS divergences:")
    print(f"  JS(inf vs persp): {row['js_inf_vs_persp']:.4f}")
    print(f"  JS(inf vs persm): {row['js_inf_vs_persm']:.4f}")
    print(f"  JS(persp vs persm): {row['js_persp_vs_persm']:.4f}")
    
    # Show posterior distributions
    print(f"\nPosterior P(θ | seq):")
    print(f"  {'θ':<6} {'P_coop':<12} {'P_uncertain':<12} {'P_inf':<12} {'P_persp':<12} {'P_persm':<12}")
    
    p_coop = P_theta_given_seq_coop[idx]
    p_unc = P_theta_given_seq_uncertain[idx]
    p_inf = P_theta_given_seq_inf[idx]
    p_persp = P_theta_given_seq_persp[idx]
    p_persm = P_theta_given_seq_persm[idx]
    
    for i, theta in enumerate(theta_values):
        print(f"  {theta:<6.1f} {p_coop[i]:<12.4f} {p_unc[i]:<12.4f} {p_inf[i]:<12.4f} {p_persp[i]:<12.4f} {p_persm[i]:<12.4f}")
    
    print()

# =============================================================================
# PART J: SAVE RESULTS
# =============================================================================

print("\n" + "="*70)
print("PART J: Saving results")
print("="*70)

# Save full results
results_discrimination.to_csv('discrimination_analysis_results.csv', index=False)
print("Saved: discrimination_analysis_results.csv")

# Save Pareto frontier
pareto_df.to_csv('pareto_optimal_sequences.csv', index=False)
print("Saved: pareto_optimal_sequences.csv")

# Save posteriors for top candidates
top_candidates_idx = top_candidates.index.tolist()
posteriors_data = {'theta': list(theta_values)}
for i, idx in enumerate(top_candidates_idx[:10]):
    posteriors_data[f'coop_{i}'] = P_theta_given_seq_coop[idx]
    posteriors_data[f'uncertain_{i}'] = P_theta_given_seq_uncertain[idx]
    posteriors_data[f'inf_{i}'] = P_theta_given_seq_inf[idx]
    posteriors_data[f'persp_{i}'] = P_theta_given_seq_persp[idx]
    posteriors_data[f'persm_{i}'] = P_theta_given_seq_persm[idx]

posteriors_df = pd.DataFrame(posteriors_data)
posteriors_df.to_csv('top_candidates_posteriors.csv', index=False)
print("Saved: top_candidates_posteriors.csv")

# Save listener models
np.savez(
    'listener_models.npz',
    log_P_seq_given_theta_coop=log_P_seq_given_theta_coop,
    log_P_seq_given_theta_uncertain=log_P_seq_given_theta_uncertain,
    log_P_seq_given_theta_inf=log_P_seq_given_theta_psi['inf'],
    log_P_seq_given_theta_persp=log_P_seq_given_theta_psi['persp'],
    log_P_seq_given_theta_persm=log_P_seq_given_theta_psi['persm'],
    theta_values=theta_values,
    sequences=np.array(sequence_labels, dtype=object)
)
print("Saved: listener_models.npz")

# =============================================================================
# PART K: SUMMARY
# =============================================================================

print("\n" + "="*70)
print("PART K: Summary")
print("="*70)

print(f"""
SUMMARY
=======

Total sequences analyzed: {n_sequences}

Discrimination (JS between coop and uncertain listeners):
  - Min: {js_divergence_coop_vs_uncertain.min():.4f}
  - Max: {js_divergence_coop_vs_uncertain.max():.4f}
  - Mean: {js_divergence_coop_vs_uncertain.mean():.4f}
  - Median: {np.median(js_divergence_coop_vs_uncertain):.4f}
  - 90th percentile: {np.percentile(js_divergence_coop_vs_uncertain, 90):.4f}
  - 99th percentile: {np.percentile(js_divergence_coop_vs_uncertain, 99):.4f}

Normalcy (marginal probability):
  - Min: {marginal_P.min():.6e}
  - Max: {marginal_P.max():.6e}
  - Median: {np.median(marginal_P):.6e}

E[θ] difference:
  - Max: {E_theta_diff.max():.4f}
  - Correlation with JS: {np.corrcoef(js_divergence_coop_vs_uncertain, E_theta_diff)[0,1]:.4f}

Pareto-optimal sequences: {len(pareto_idx)}
  - With high normalcy (top 25%): {len(pareto_normal)}

Top recommended sequences for experiment:
""")

for rank, (_, row) in enumerate(pareto_normal.head(5).iterrows(), 1):
    seq = row['sequence']
    print(f"  {rank}. {seq}")
    print(f"     JS={row['js_coop_vs_uncertain']:.4f}, P={row['marginal_P']:.6e}, E[θ] diff={row['E_theta_diff']:.4f}")

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)

  Returns P(theta) \propto exp( sum_{psi, alpha} log_joint(theta,psi,alpha) ).
  Returns P(psi) \propto exp( sum_{theta, alpha} log_joint(theta,psi,alpha) ).
  Returns P(alpha) \propto exp( sum_{theta, psi} log_joint(theta,psi,alpha)).


COMPLETE SELF-CONTAINED DISCRIMINATION ANALYSIS

PART A: Computing P(seq | θ, speaker, α) matrices

--- A1: Creating World ---
Utterances (8): ['all,successful', 'all,unsuccessful', 'most,successful', 'most,unsuccessful', 'some,successful', 'some,unsuccessful', 'no,successful', 'no,unsuccessful']
Theta values (11): [np.float64(0.0), np.float64(0.1), np.float64(0.2), np.float64(0.3), np.float64(0.4), np.float64(0.5), np.float64(0.6), np.float64(0.7), np.float64(0.8), np.float64(0.9), np.float64(1.0)]
Alpha values: [5.482, 6.988, 8.909, 11.36, 14.48]
Sequences: 32768
Rounds: 5

--- A2: Defining helper functions ---
Helper functions defined.

--- A3: Computing log P(seq | θ) matrices ---

Literal speaker...
  Done.

Pragmatic speakers (update_internal=False)...
  inf_F...
  persp_F...
  persm_F...

Pragmatic speakers (update_internal=True)...
  inf_T...
  persp_T...
  persm_T...

Total matrices computed: 31

Verifying normalization (should sum to 1 for each θ):
  ('literal', None): min=1.

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 11 concurrent workers.
[Parallel(n_jobs=-1)]: Done  33 out of  33 | elapsed:    0.7s finished



JS divergence (coop vs uncertain):
  Min: -0.000000
  Max: 0.673912
  Mean: 0.018739
  Median: 0.000121

--- D2: Computing E[θ] and other metrics ---
E[θ] difference: min=0.0000, max=0.3024, mean=0.0150

--- D3: Computing pairwise PSI JS divergences ---
  JS(inf vs persp): min=-0.0000, max=0.9498, mean=0.0721
  JS(inf vs persm): min=-0.0000, max=0.9498, mean=0.0721
  JS(persp vs persm): min=-0.0000, max=0.9953, mean=0.1540

PART E: Computing normalcy metrics

Normalcy metrics:
1. max_P_overall: Maximum P(seq | θ, ψ) over all θ, ψ
2. marginal_P: P(seq) marginalized over θ, ψ with flat priors
3. min_max_P: Minimum (over ψ) of max (over θ) P(seq | θ, ψ)

max_P_overall: min=5.601328e-10, max=0.771490, median=1.354431e-06
marginal_P: min=5.108069e-11, max=0.032919, median=1.035942e-07
min_max_P: min=1.250232e-17, max=0.000473, median=3.354473e-10

PART F: Creating comprehensive results DataFrame
Results DataFrame shape: (32768, 24)
Columns: ['sequence_idx', 'sequence', 'js_coop_vs_uncertai

In [24]:
# =============================================================================
# FIX: Corrected Pareto frontier and additional analyses
# =============================================================================

print("="*70)
print("CORRECTED ANALYSIS")
print("="*70)

# -----------------------------------------------------------------------------
# Fix 1: Corrected Pareto frontier function
# -----------------------------------------------------------------------------

def find_pareto_frontier_correct(df, col1, col2):
    """
    Find Pareto-optimal points (maximize both columns).
    
    A point is Pareto-optimal if no other point is better in BOTH dimensions.
    """
    values = df[[col1, col2]].values
    n = len(values)
    is_pareto = np.ones(n, dtype=bool)
    
    for i in range(n):
        if is_pareto[i]:
            # Point j dominates point i if j is >= i in all dimensions and > i in at least one
            # We want to KEEP point i if no other point dominates it
            # Here we find points that are dominated BY point i (i.e., i >= them in all, i > them in some)
            # But we need the reverse: mark i as not Pareto if some j dominates i
            
            for j in range(n):
                if i != j and is_pareto[j]:
                    # Check if j dominates i (j >= i in all dims, j > i in at least one)
                    if (values[j, 0] >= values[i, 0] and values[j, 1] >= values[i, 1] and
                        (values[j, 0] > values[i, 0] or values[j, 1] > values[i, 1])):
                        is_pareto[i] = False
                        break
    
    return df.index[is_pareto].tolist()

# Faster vectorized version
def find_pareto_frontier_fast(df, col1, col2):
    """
    Find Pareto-optimal points (maximize both columns) - vectorized version.
    """
    values = df[[col1, col2]].values
    n = len(values)
    is_pareto = np.ones(n, dtype=bool)
    
    for i in range(n):
        if is_pareto[i]:
            # Check if any other point dominates point i
            # j dominates i if: j >= i in all dims AND j > i in at least one
            all_geq = np.all(values >= values[i], axis=1)  # j >= i in all dims
            any_greater = np.any(values > values[i], axis=1)  # j > i in at least one dim
            dominates_i = all_geq & any_greater
            dominates_i[i] = False  # Point doesn't dominate itself
            
            if np.any(dominates_i):
                is_pareto[i] = False
    
    return df.index[is_pareto].tolist()

print("\n--- Corrected Pareto Frontier ---")
pareto_idx_correct = find_pareto_frontier_fast(results_discrimination, 'js_coop_vs_uncertain', 'marginal_P')
print(f"Number of Pareto-optimal sequences: {len(pareto_idx_correct)}")

pareto_df_correct = results_discrimination.loc[pareto_idx_correct].sort_values('js_coop_vs_uncertain', ascending=False)

print(f"\nTop 20 Pareto-optimal sequences (sorted by JS, descending):")
print(f"{'Rank':<5} {'JS':<8} {'marginal_P':<12} {'E[θ] diff':<10} {'Sequence'}")
print("-" * 100)

for rank, (_, row) in enumerate(pareto_df_correct.head(20).iterrows(), 1):
    seq = row['sequence']
    abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
    print(f"{rank:<5} {row['js_coop_vs_uncertain']:<8.4f} {row['marginal_P']:<12.6e} {row['E_theta_diff']:<10.4f} {abbrev}")

# -----------------------------------------------------------------------------
# Analysis 2: Understanding WHY these sequences discriminate
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("WHY DO THESE SEQUENCES DISCRIMINATE?")
print("="*70)

# Look at the top 2 sequences
top_seqs = ['some,successful', 'some,unsuccessful']

for base_utt in top_seqs:
    # Find the sequence that's all this utterance
    seq_tuple = tuple([base_utt] * 5)
    idx = sequence_labels.index(seq_tuple)
    
    print(f"\n--- Sequence: {base_utt} × 5 ---")
    
    # Show P(seq | θ) for each speaker model
    print(f"\nP(seq | θ, ψ) for each speaker type:")
    print(f"{'θ':<6} {'INF':<12} {'PERSP':<12} {'PERSM':<12} {'UNCERTAIN':<12}")
    
    log_P_inf = log_P_seq_given_theta_psi['inf'][idx]
    log_P_persp = log_P_seq_given_theta_psi['persp'][idx]
    log_P_persm = log_P_seq_given_theta_psi['persm'][idx]
    log_P_unc = log_P_seq_given_theta_uncertain[idx]
    
    for i, theta in enumerate(theta_values):
        P_inf = np.exp(log_P_inf[i])
        P_persp = np.exp(log_P_persp[i])
        P_persm = np.exp(log_P_persm[i])
        P_unc = np.exp(log_P_unc[i])
        print(f"{theta:<6.1f} {P_inf:<12.6f} {P_persp:<12.6f} {P_persm:<12.6f} {P_unc:<12.6f}")
    
    # Show the resulting posteriors
    print(f"\nPosterior P(θ | seq) for each listener:")
    print(f"{'θ':<6} {'COOP (inf)':<12} {'UNCERTAIN':<12} {'Difference':<12}")
    
    p_coop = P_theta_given_seq_coop[idx]
    p_unc = P_theta_given_seq_uncertain[idx]
    
    for i, theta in enumerate(theta_values):
        diff = p_coop[i] - p_unc[i]
        print(f"{theta:<6.1f} {p_coop[i]:<12.4f} {p_unc[i]:<12.4f} {diff:<+12.4f}")
    
    E_coop = (p_coop * theta_values).sum()
    E_unc = (p_unc * theta_values).sum()
    print(f"\nE[θ|seq]: Coop={E_coop:.4f}, Uncertain={E_unc:.4f}, Diff={E_coop - E_unc:+.4f}")

# -----------------------------------------------------------------------------
# Analysis 3: What makes a sequence "normal" vs "weird"?
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("NORMAL vs WEIRD SEQUENCES")
print("="*70)

# Show examples at different normalcy levels
normalcy_examples = [
    (np.percentile(marginal_P, 99), "Top 1% normalcy"),
    (np.percentile(marginal_P, 90), "Top 10% normalcy"),
    (np.percentile(marginal_P, 50), "Median normalcy"),
    (np.percentile(marginal_P, 10), "Bottom 10% normalcy"),
]

for threshold, label in normalcy_examples:
    close_to_threshold = np.abs(results_discrimination['marginal_P'] - threshold) < threshold * 0.1
    if close_to_threshold.sum() > 0:
        example = results_discrimination[close_to_threshold].iloc[0]
        seq = example['sequence']
        abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
        print(f"\n{label} (P ≈ {threshold:.2e}):")
        print(f"  Example: {abbrev}")
        print(f"  JS={example['js_coop_vs_uncertain']:.4f}, E[θ] diff={example['E_theta_diff']:.4f}")

# -----------------------------------------------------------------------------
# Analysis 4: Recommended sequences for experiment
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("RECOMMENDED SEQUENCES FOR EXPERIMENT")
print("="*70)

print("""
Selection criteria:
1. High JS divergence (> 0.3) - good discrimination
2. High marginal probability (top 10%) - not "weird"
3. Meaningful E[θ] difference (> 0.1)
""")

# Filter by criteria
candidates = results_discrimination[
    (results_discrimination['js_coop_vs_uncertain'] > 0.3) &
    (results_discrimination['marginal_P'] >= np.percentile(marginal_P, 90)) &
    (results_discrimination['E_theta_diff'] > 0.1)
].sort_values('js_coop_vs_uncertain', ascending=False)

print(f"\nFound {len(candidates)} candidates meeting all criteria:\n")

print(f"{'Rank':<5} {'JS':<8} {'E[θ] diff':<10} {'P':<12} {'Sequence'}")
print("-" * 90)

for rank, (idx, row) in enumerate(candidates.iterrows(), 1):
    seq = row['sequence']
    # Full sequence display
    print(f"{rank:<5} {row['js_coop_vs_uncertain']:<8.4f} {row['E_theta_diff']:<10.4f} {row['marginal_P']:<12.6e} {seq}")
    
    if rank >= 10:
        if len(candidates) > 10:
            print(f"... and {len(candidates) - 10} more")
        break

# -----------------------------------------------------------------------------
# Analysis 5: Detailed breakdown of top candidate
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("DETAILED ANALYSIS: TOP CANDIDATE")
print("="*70)

if len(candidates) > 0:
    top_idx = candidates.index[0]
    top_row = candidates.iloc[0]
    top_seq = top_row['sequence']
    
    print(f"\nSequence: {top_seq}")
    print(f"\nThis sequence consists of: {top_seq[0]} repeated 5 times")
    
    print(f"\n1. DISCRIMINATION METRICS:")
    print(f"   JS(coop vs uncertain) = {top_row['js_coop_vs_uncertain']:.4f}")
    print(f"   E[θ] difference = {top_row['E_theta_diff']:.4f}")
    
    print(f"\n2. NORMALCY METRICS:")
    print(f"   Marginal P(seq) = {top_row['marginal_P']:.4e} (top {100*(marginal_P < top_row['marginal_P']).mean():.1f}%)")
    print(f"   Max P under inf = {top_row['max_P_inf']:.4e}")
    print(f"   Max P under persp = {top_row['max_P_persp']:.4e}")
    print(f"   Max P under persm = {top_row['max_P_persm']:.4e}")
    
    print(f"\n3. PAIRWISE PSI DIVERGENCES:")
    print(f"   JS(inf vs persp) = {top_row['js_inf_vs_persp']:.4f}")
    print(f"   JS(inf vs persm) = {top_row['js_inf_vs_persm']:.4f}")
    print(f"   JS(persp vs persm) = {top_row['js_persp_vs_persm']:.4f}")
    
    print(f"\n4. POSTERIOR DISTRIBUTIONS:")
    print(f"   {'θ':<6} {'P_coop':<10} {'P_uncertain':<12} {'P_inf':<10} {'P_persp':<10} {'P_persm':<10}")
    print("   " + "-"*70)
    
    for i, theta in enumerate(theta_values):
        print(f"   {theta:<6.1f} {P_theta_given_seq_coop[top_idx, i]:<10.4f} "
              f"{P_theta_given_seq_uncertain[top_idx, i]:<12.4f} "
              f"{P_theta_given_seq_inf[top_idx, i]:<10.4f} "
              f"{P_theta_given_seq_persp[top_idx, i]:<10.4f} "
              f"{P_theta_given_seq_persm[top_idx, i]:<10.4f}")
    
    print(f"\n5. INTERPRETATION:")
    E_coop = (P_theta_given_seq_coop[top_idx] * theta_values).sum()
    E_unc = (P_theta_given_seq_uncertain[top_idx] * theta_values).sum()
    
    print(f"   Cooperative listener (assuming inf speaker):")
    print(f"     E[θ|seq] = {E_coop:.3f}")
    print(f"   Uncertain listener (uniform over ψ):")
    print(f"     E[θ|seq] = {E_unc:.3f}")
    print(f"   ")
    print(f"   The cooperative listener infers a {'HIGHER' if E_coop > E_unc else 'LOWER'} θ")
    print(f"   because they assume an informative speaker.")

# -----------------------------------------------------------------------------
# Analysis 6: Save final recommendations
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("SAVING FINAL RECOMMENDATIONS")
print("="*70)

# Save recommended sequences
if len(candidates) > 0:
    candidates.to_csv('recommended_sequences.csv', index=False)
    print("Saved: recommended_sequences.csv")

# Save corrected Pareto frontier
pareto_df_correct.to_csv('pareto_optimal_sequences_corrected.csv', index=False)
print("Saved: pareto_optimal_sequences_corrected.csv")

print("\n" + "="*70)
print("SUMMARY OF RECOMMENDATIONS")
print("="*70)

print(f"""
TOP RECOMMENDED SEQUENCES FOR EXPERIMENT:

1. ('some,successful', 'some,successful', 'some,successful', 'some,successful', 'some,successful')
   - JS = 0.674, E[θ] diff = 0.302, P = 3.3%
   - Interpretation: Cooperative listener thinks θ is HIGH (informative speaker 
     would say this at high θ), but uncertain listener hedges.

2. ('some,unsuccessful', 'some,unsuccessful', 'some,unsuccessful', 'some,unsuccessful', 'some,unsuccessful')
   - JS = 0.674, E[θ] diff = 0.302, P = 3.3%
   - Interpretation: Symmetric to #1 but for low θ.

3-10. Sequences with 1 "most" and 4 "some" utterances
   - JS ≈ 0.46-0.48, E[θ] diff ≈ 0.24
   - More varied, still highly discriminating

EXPERIMENT DESIGN NOTES:
- Use sequences #1 and #2 as primary stimuli (highest discrimination)
- Include some from #3-10 for variety
- Both cooperative and uncertain listeners will see the SAME sequence
- Measure their posterior beliefs about θ
- Compare to model predictions
""")

CORRECTED ANALYSIS

--- Corrected Pareto Frontier ---
Number of Pareto-optimal sequences: 1

Top 20 Pareto-optimal sequences (sorted by JS, descending):
Rank  JS       marginal_P   E[θ] diff  Sequence
----------------------------------------------------------------------------------------------------
1     0.6739   3.291926e-02 0.3024     sos,sos,sos,sos,sos

WHY DO THESE SEQUENCES DISCRIMINATE?

--- Sequence: some,successful × 5 ---

P(seq | θ, ψ) for each speaker type:
θ      INF          PERSP        PERSM        UNCERTAIN   
0.0    0.000000     0.000000     0.000000     0.000000    
0.1    0.000000     0.005436     0.000000     0.001812    
0.2    0.000000     0.048322     0.000000     0.016107    
0.3    0.000000     0.081926     0.000001     0.027309    
0.4    0.000000     0.059962     0.000004     0.019989    
0.5    0.000000     0.023934     0.000022     0.007986    
0.6    0.000000     0.005808     0.000136     0.001982    
0.7    0.000000     0.000967     0.000953     0.0006

In [34]:
# =============================================================================
# ANALYSIS: NON-REPETITIVE, SPEAKER-CHARACTERISTIC SEQUENCES
# =============================================================================

import numpy as np
import pandas as pd

print("="*70)
print("NON-REPETITIVE, SPEAKER-CHARACTERISTIC SEQUENCES")
print("="*70)

# -----------------------------------------------------------------------------
# Part 1: Add repetitiveness metrics to results
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 1: Computing repetitiveness metrics")
print("="*70)

def count_adjacent_repeats(seq):
    """Count how many adjacent pairs are the same utterance (0-4)."""
    count = 0
    for i in range(len(seq) - 1):
        if seq[i] == seq[i + 1]:
            count += 1
    return count

def count_total_repeats(seq):
    """Count total repetitions: 5 - number of unique utterances."""
    return 5 - len(set(seq))

def get_dominant_quantifier(seq):
    """Get the most common quantifier in the sequence."""
    quantifiers = [u.split(',')[0] for u in seq]
    from collections import Counter
    counts = Counter(quantifiers)
    return counts.most_common(1)[0][0]

def get_dominant_outcome(seq):
    """Get the most common outcome in the sequence."""
    outcomes = [u.split(',')[1] for u in seq]
    from collections import Counter
    counts = Counter(outcomes)
    return counts.most_common(1)[0][0]

# Add metrics to results DataFrame
results_discrimination['n_adjacent_repeats'] = [
    count_adjacent_repeats(seq) for seq in results_discrimination['sequence']
]
results_discrimination['n_total_repeats'] = [
    count_total_repeats(seq) for seq in results_discrimination['sequence']
]
results_discrimination['dominant_quantifier'] = [
    get_dominant_quantifier(seq) for seq in results_discrimination['sequence']
]
results_discrimination['dominant_outcome'] = [
    get_dominant_outcome(seq) for seq in results_discrimination['sequence']
]

print("Repetitiveness distribution:")
print(results_discrimination['n_adjacent_repeats'].value_counts().sort_index())

print("\nBy adjacent repeats, mean JS divergence:")
for n_rep in range(5):
    subset = results_discrimination[results_discrimination['n_adjacent_repeats'] == n_rep]
    print(f"  {n_rep} adjacent repeats: n={len(subset)}, mean JS={subset['js_coop_vs_uncertain'].mean():.4f}")

# -----------------------------------------------------------------------------
# Part 2: Categorize by dominant speaker type
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 2: Categorizing by dominant speaker type")
print("="*70)

print("""
For each sequence, determine which speaker type finds it MOST likely.
- "inf-characteristic": max_P_inf > max_P_persp AND max_P_inf > max_P_persm
- "persp-characteristic": max_P_persp is highest
- "persm-characteristic": max_P_persm is highest
""")

def categorize_by_speaker(row):
    """Determine which speaker type this sequence is most characteristic of."""
    max_inf = row['max_P_inf']
    max_persp = row['max_P_persp']
    max_persm = row['max_P_persm']
    
    if max_inf >= max_persp and max_inf >= max_persm:
        return 'inf'
    elif max_persp >= max_inf and max_persp >= max_persm:
        return 'persp'
    else:
        return 'persm'

def get_speaker_dominance_ratio(row):
    """How much more likely is this for the dominant speaker vs others?"""
    max_inf = row['max_P_inf']
    max_persp = row['max_P_persp']
    max_persm = row['max_P_persm']
    
    max_val = max(max_inf, max_persp, max_persm)
    second_max = sorted([max_inf, max_persp, max_persm])[-2]
    
    if second_max > 0:
        return max_val / second_max
    else:
        return float('inf')

results_discrimination['dominant_speaker'] = results_discrimination.apply(categorize_by_speaker, axis=1)
results_discrimination['speaker_dominance_ratio'] = results_discrimination.apply(get_speaker_dominance_ratio, axis=1)

print("\nDistribution of dominant speaker:")
print(results_discrimination['dominant_speaker'].value_counts())

print("\nBy dominant speaker and adjacent repeats:")
for speaker in ['inf', 'persp', 'persm']:
    print(f"\n{speaker.upper()}:")
    subset = results_discrimination[results_discrimination['dominant_speaker'] == speaker]
    print(subset['n_adjacent_repeats'].value_counts().sort_index())

# -----------------------------------------------------------------------------
# Part 3: Filter for non-repetitive, discriminating sequences
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 3: Non-repetitive, discriminating sequences")
print("="*70)

# Criteria:
# - n_adjacent_repeats <= 2 (at most 2 adjacent same utterances)
# - js_coop_vs_uncertain > 0.1 (meaningful discrimination)
# - marginal_P >= 10th percentile (not too weird)

threshold_normalcy = np.percentile(marginal_P, 90)

candidates_nonrep = results_discrimination[
    (results_discrimination['n_adjacent_repeats'] <= 2) &
    (results_discrimination['js_coop_vs_uncertain'] > 0.1) &
    (results_discrimination['marginal_P'] >= threshold_normalcy)
].copy()

print(f"Found {len(candidates_nonrep)} candidates with:")
print(f"  - n_adjacent_repeats <= 2")
print(f"  - JS > 0.1")
print(f"  - marginal_P >= {threshold_normalcy:.2e} (top 10%)")

# Sort by JS divergence
candidates_nonrep = candidates_nonrep.sort_values('js_coop_vs_uncertain', ascending=False)

print(f"\nTop 20 non-repetitive discriminating sequences:")
print(f"{'Rank':<5} {'JS':<7} {'E[θ]Δ':<8} {'AdjRep':<7} {'Speaker':<8} {'Sequence'}")
print("-" * 100)

for rank, (idx, row) in enumerate(candidates_nonrep.head(20).iterrows(), 1):
    seq = row['sequence']
    abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
    print(f"{rank:<5} {row['js_coop_vs_uncertain']:<7.4f} {row['E_theta_diff']:<8.4f} "
          f"{row['n_adjacent_repeats']:<7} {row['dominant_speaker']:<8} {abbrev}")

# -----------------------------------------------------------------------------
# Part 4: Breakdown by dominant speaker type
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 4: Top sequences by dominant speaker type")
print("="*70)

for speaker in ['inf', 'persp', 'persm']:
    print(f"\n{'='*60}")
    print(f"TOP {speaker.upper()}-CHARACTERISTIC SEQUENCES (non-repetitive)")
    print(f"{'='*60}")
    
    subset = candidates_nonrep[candidates_nonrep['dominant_speaker'] == speaker]
    subset = subset.sort_values('js_coop_vs_uncertain', ascending=False)
    
    print(f"Found {len(subset)} sequences")
    
    if len(subset) == 0:
        print("  (none found with current criteria)")
        continue
    
    print(f"\n{'Rank':<5} {'JS':<7} {'E[θ]Δ':<8} {'AdjRep':<7} {'max_P':<10} {'Sequence'}")
    print("-" * 90)
    
    for rank, (idx, row) in enumerate(subset.head(10).iterrows(), 1):
        seq = row['sequence']
        abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
        max_p = row[f'max_P_{speaker}']
        print(f"{rank:<5} {row['js_coop_vs_uncertain']:<7.4f} {row['E_theta_diff']:<8.4f} "
              f"{row['n_adjacent_repeats']:<7} {max_p:<10.4e} {abbrev}")

# -----------------------------------------------------------------------------
# Part 5: Relax criteria to find INF-characteristic sequences
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 5: Finding INF-characteristic sequences (relaxed criteria)")
print("="*70)

print("""
INF-characteristic sequences are rare among top discriminators.
Let's see what's available with relaxed criteria.
""")

# More relaxed: just require inf is dominant
inf_sequences = results_discrimination[
    (results_discrimination['dominant_speaker'] == 'inf') &
    (results_discrimination['n_adjacent_repeats'] <= 2)
].sort_values('js_coop_vs_uncertain', ascending=False)

print(f"INF-dominant sequences with ≤2 adjacent repeats: {len(inf_sequences)}")
print(f"\nTop 15 by JS divergence:")
print(f"{'Rank':<5} {'JS':<7} {'E[θ]Δ':<8} {'AdjRep':<7} {'marginal_P':<12} {'Sequence'}")
print("-" * 100)

for rank, (idx, row) in enumerate(inf_sequences.head(15).iterrows(), 1):
    seq = row['sequence']
    abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
    print(f"{rank:<5} {row['js_coop_vs_uncertain']:<7.4f} {row['E_theta_diff']:<8.4f} "
          f"{row['n_adjacent_repeats']:<7} {row['marginal_P']:<12.4e} {abbrev}")

# -----------------------------------------------------------------------------
# Part 6: Detailed analysis of best non-repetitive sequences per speaker
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 6: Detailed analysis of best sequences per speaker type")
print("="*70)

def analyze_sequence_detail(idx, row):
    """Print detailed analysis for a sequence."""
    seq = row['sequence']
    
    print(f"\nSequence: {seq}")
    print(f"Adjacent repeats: {row['n_adjacent_repeats']}")
    print(f"Dominant speaker: {row['dominant_speaker']}")
    
    print(f"\nDiscrimination:")
    print(f"  JS(coop vs uncertain) = {row['js_coop_vs_uncertain']:.4f}")
    print(f"  E[θ] diff = {row['E_theta_diff']:.4f}")
    print(f"  E[θ|seq, coop] = {row['E_theta_coop']:.4f}")
    print(f"  E[θ|seq, uncertain] = {row['E_theta_uncertain']:.4f}")
    
    print(f"\nNormalcy:")
    print(f"  marginal_P = {row['marginal_P']:.4e}")
    print(f"  max_P_inf = {row['max_P_inf']:.4e}")
    print(f"  max_P_persp = {row['max_P_persp']:.4e}")
    print(f"  max_P_persm = {row['max_P_persm']:.4e}")
    
    print(f"\nPosterior P(θ|seq):")
    print(f"  {'θ':<6} {'Coop':<10} {'Uncertain':<10} {'INF':<10} {'PERSP':<10} {'PERSM':<10}")
    print("  " + "-"*60)
    
    for i, theta in enumerate(theta_values):
        print(f"  {theta:<6.1f} {P_theta_given_seq_coop[idx, i]:<10.4f} "
              f"{P_theta_given_seq_uncertain[idx, i]:<10.4f} "
              f"{P_theta_given_seq_inf[idx, i]:<10.4f} "
              f"{P_theta_given_seq_persp[idx, i]:<10.4f} "
              f"{P_theta_given_seq_persm[idx, i]:<10.4f}")

# Best PERSP-characteristic (non-repetitive)
print("\n" + "="*60)
print("BEST PERSP-CHARACTERISTIC (non-repetitive)")
print("="*60)
persp_subset = candidates_nonrep[candidates_nonrep['dominant_speaker'] == 'persp']
if len(persp_subset) > 0:
    best_idx = persp_subset['js_coop_vs_uncertain'].idxmax()
    analyze_sequence_detail(best_idx, persp_subset.loc[best_idx])

# Best PERSM-characteristic (non-repetitive)
print("\n" + "="*60)
print("BEST PERSM-CHARACTERISTIC (non-repetitive)")
print("="*60)
persm_subset = candidates_nonrep[candidates_nonrep['dominant_speaker'] == 'persm']
if len(persm_subset) > 0:
    best_idx = persm_subset['js_coop_vs_uncertain'].idxmax()
    analyze_sequence_detail(best_idx, persm_subset.loc[best_idx])

# Best INF-characteristic (with relaxed normalcy)
print("\n" + "="*60)
print("BEST INF-CHARACTERISTIC (relaxed normalcy)")
print("="*60)
inf_subset = results_discrimination[
    (results_discrimination['dominant_speaker'] == 'inf') &
    (results_discrimination['n_adjacent_repeats'] <= 2) &
    (results_discrimination['js_coop_vs_uncertain'] > 0.01)
]
if len(inf_subset) > 0:
    best_idx = inf_subset['js_coop_vs_uncertain'].idxmax()
    analyze_sequence_detail(best_idx, inf_subset.loc[best_idx])

# -----------------------------------------------------------------------------
# Part 7: Summary table for experiment
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 7: Summary - Recommended sequences for experiment")
print("="*70)

print("""
SELECTION CRITERIA:
- Non-repetitive (≤2 adjacent same utterances)
- Discriminating (JS > threshold)
- Normal (top 10% by marginal probability, where possible)
- Representative of each speaker type
""")

# Collect recommendations
recommendations = []

# PERSP-characteristic
persp_recs = candidates_nonrep[candidates_nonrep['dominant_speaker'] == 'persp'].head(5)
for _, row in persp_recs.iterrows():
    recommendations.append({
        'category': 'PERSP-characteristic',
        'sequence': row['sequence'],
        'js': row['js_coop_vs_uncertain'],
        'E_theta_diff': row['E_theta_diff'],
        'n_adj_rep': row['n_adjacent_repeats'],
        'marginal_P': row['marginal_P'],
        'E_theta_coop': row['E_theta_coop'],
        'E_theta_uncertain': row['E_theta_uncertain']
    })

# PERSM-characteristic
persm_recs = candidates_nonrep[candidates_nonrep['dominant_speaker'] == 'persm'].head(5)
for _, row in persm_recs.iterrows():
    recommendations.append({
        'category': 'PERSM-characteristic',
        'sequence': row['sequence'],
        'js': row['js_coop_vs_uncertain'],
        'E_theta_diff': row['E_theta_diff'],
        'n_adj_rep': row['n_adjacent_repeats'],
        'marginal_P': row['marginal_P'],
        'E_theta_coop': row['E_theta_coop'],
        'E_theta_uncertain': row['E_theta_uncertain']
    })

# INF-characteristic (relaxed criteria)
inf_recs = inf_sequences.head(5)
for _, row in inf_recs.iterrows():
    recommendations.append({
        'category': 'INF-characteristic',
        'sequence': row['sequence'],
        'js': row['js_coop_vs_uncertain'],
        'E_theta_diff': row['E_theta_diff'],
        'n_adj_rep': row['n_adjacent_repeats'],
        'marginal_P': row['marginal_P'],
        'E_theta_coop': row['E_theta_coop'],
        'E_theta_uncertain': row['E_theta_uncertain']
    })

rec_df = pd.DataFrame(recommendations)

print("\nRECOMMENDED SEQUENCES:\n")

for category in ['PERSP-characteristic', 'PERSM-characteristic', 'INF-characteristic']:
    print(f"\n--- {category} ---")
    subset = rec_df[rec_df['category'] == category]
    
    for i, row in subset.iterrows():
        seq = row['sequence']
        abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
        print(f"  {abbrev}")
        print(f"    JS={row['js']:.4f}, E[θ] diff={row['E_theta_diff']:.4f}, "
              f"AdjRep={row['n_adj_rep']}, P={row['marginal_P']:.2e}")
        print(f"    Coop E[θ]={row['E_theta_coop']:.3f}, Uncertain E[θ]={row['E_theta_uncertain']:.3f}")

# Save recommendations
rec_df.to_csv('recommended_sequences_by_speaker.csv', index=False)
print("\nSaved: recommended_sequences_by_speaker.csv")

# Also update the full results
results_discrimination.to_csv('discrimination_analysis_results_full.csv', index=False)
print("Saved: discrimination_analysis_results_full.csv")

# -----------------------------------------------------------------------------
# Part 8: Visualize the trade-off
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Part 8: Trade-off analysis")
print("="*70)

print("""
Examining the trade-off between:
- Discrimination (JS divergence)
- Non-repetitiveness (adjacent repeats)
- Normalcy (marginal probability)
""")

print("\nMean JS by adjacent repeats and dominant speaker:")
print(f"{'AdjRep':<8} {'INF':<12} {'PERSP':<12} {'PERSM':<12} {'All':<12}")
print("-" * 56)

for n_rep in range(5):
    js_inf = results_discrimination[
        (results_discrimination['n_adjacent_repeats'] == n_rep) & 
        (results_discrimination['dominant_speaker'] == 'inf')
    ]['js_coop_vs_uncertain'].mean()
    
    js_persp = results_discrimination[
        (results_discrimination['n_adjacent_repeats'] == n_rep) & 
        (results_discrimination['dominant_speaker'] == 'persp')
    ]['js_coop_vs_uncertain'].mean()
    
    js_persm = results_discrimination[
        (results_discrimination['n_adjacent_repeats'] == n_rep) & 
        (results_discrimination['dominant_speaker'] == 'persm')
    ]['js_coop_vs_uncertain'].mean()
    
    js_all = results_discrimination[
        results_discrimination['n_adjacent_repeats'] == n_rep
    ]['js_coop_vs_uncertain'].mean()
    
    print(f"{n_rep:<8} {js_inf:<12.4f} {js_persp:<12.4f} {js_persm:<12.4f} {js_all:<12.4f}")

print("\nMax JS by adjacent repeats and dominant speaker:")
print(f"{'AdjRep':<8} {'INF':<12} {'PERSP':<12} {'PERSM':<12} {'All':<12}")
print("-" * 56)

for n_rep in range(5):
    js_inf = results_discrimination[
        (results_discrimination['n_adjacent_repeats'] == n_rep) & 
        (results_discrimination['dominant_speaker'] == 'inf')
    ]['js_coop_vs_uncertain'].max()
    
    js_persp = results_discrimination[
        (results_discrimination['n_adjacent_repeats'] == n_rep) & 
        (results_discrimination['dominant_speaker'] == 'persp')
    ]['js_coop_vs_uncertain'].max()
    
    js_persm = results_discrimination[
        (results_discrimination['n_adjacent_repeats'] == n_rep) & 
        (results_discrimination['dominant_speaker'] == 'persm')
    ]['js_coop_vs_uncertain'].max()
    
    js_all = results_discrimination[
        results_discrimination['n_adjacent_repeats'] == n_rep
    ]['js_coop_vs_uncertain'].max()
    
    print(f"{n_rep:<8} {js_inf:<12.4f} {js_persp:<12.4f} {js_persm:<12.4f} {js_all:<12.4f}")

# -----------------------------------------------------------------------------
# Part 9: Final summary
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("FINAL SUMMARY")
print("="*70)

print(f"""
KEY FINDINGS:

1. DISCRIMINATION vs REPETITIVENESS TRADE-OFF:
   - Highest JS divergence comes from highly repetitive sequences
   - But non-repetitive sequences (≤2 adj repeats) can still achieve JS > 0.3

2. SPEAKER-CHARACTERISTIC PATTERNS:
   - PERSP-characteristic: Many high-discrimination options available
   - PERSM-characteristic: Symmetric to PERSP (mirror sequences)
   - INF-characteristic: Harder to find - INF rarely dominates

3. WHY INF-CHARACTERISTIC IS RARE:
   - Informative speaker uses strong quantifiers ("all", "most", "no")
   - These are also used by persuasive speakers at extreme θ
   - So sequences with strong quantifiers are often dominated by PERSP or PERSM

4. RECOMMENDED EXPERIMENT DESIGN:
   - Include PERSP and PERSM characteristic sequences (good discrimination)
   - INF-characteristic may need relaxed criteria or different approach
   - Balance variety across sequences to avoid participant fatigue

SAVED FILES:
- recommended_sequences_by_speaker.csv: Top sequences per speaker type
- discrimination_analysis_results_full.csv: Full results with all metrics
""")

NON-REPETITIVE, SPEAKER-CHARACTERISTIC SEQUENCES

Part 1: Computing repetitiveness metrics
Repetitiveness distribution:
n_adjacent_repeats
0    19208
1    10976
2     2352
3      224
4        8
Name: count, dtype: int64

By adjacent repeats, mean JS divergence:
  0 adjacent repeats: n=19208, mean JS=0.0135
  1 adjacent repeats: n=10976, mean JS=0.0229
  2 adjacent repeats: n=2352, mean JS=0.0376
  3 adjacent repeats: n=224, mean JS=0.0658
  4 adjacent repeats: n=8, mean JS=0.1687

Part 2: Categorizing by dominant speaker type

For each sequence, determine which speaker type finds it MOST likely.
- "inf-characteristic": max_P_inf > max_P_persp AND max_P_inf > max_P_persm
- "persp-characteristic": max_P_persp is highest
- "persm-characteristic": max_P_persm is highest


Distribution of dominant speaker:
dominant_speaker
inf      23642
persp     4563
persm     4563
Name: count, dtype: int64

By dominant speaker and adjacent repeats:

INF:
n_adjacent_repeats
0    14336
1     7618
2     153

In [35]:
# =============================================================================
# Adjacent Repeats × Dominant Speaker Table
# =============================================================================

import pandas as pd
import numpy as np

print("="*70)
print("ADJACENT REPEATS × DOMINANT SPEAKER ANALYSIS")
print("="*70)

# Create the cross-tabulation
adj_reps = [0, 1, 2, 3, 4]
speakers = ['inf', 'persp', 'persm']

# -----------------------------------------------------------------------------
# Table 1: COUNT
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("TABLE 1: COUNT")
print("="*70)

print(f"\n{'AdjRep':<10}", end="")
for speaker in speakers:
    print(f"{speaker.upper():>12}", end="")
print(f"{'TOTAL':>12}")
print("-" * 58)

row_totals = []
for n_rep in adj_reps:
    print(f"{n_rep:<10}", end="")
    row_total = 0
    for speaker in speakers:
        count = len(results_discrimination[
            (results_discrimination['n_adjacent_repeats'] == n_rep) & 
            (results_discrimination['dominant_speaker'] == speaker)
        ])
        print(f"{count:>12}", end="")
        row_total += count
    print(f"{row_total:>12}")
    row_totals.append(row_total)

# Column totals
print("-" * 58)
print(f"{'TOTAL':<10}", end="")
for speaker in speakers:
    col_total = len(results_discrimination[results_discrimination['dominant_speaker'] == speaker])
    print(f"{col_total:>12}", end="")
print(f"{sum(row_totals):>12}")

# -----------------------------------------------------------------------------
# Table 2: MEAN JS DIVERGENCE
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("TABLE 2: MEAN JS DIVERGENCE")
print("="*70)

print(f"\n{'AdjRep':<10}", end="")
for speaker in speakers:
    print(f"{speaker.upper():>12}", end="")
print(f"{'ALL':>12}")
print("-" * 58)

for n_rep in adj_reps:
    print(f"{n_rep:<10}", end="")
    for speaker in speakers:
        subset = results_discrimination[
            (results_discrimination['n_adjacent_repeats'] == n_rep) & 
            (results_discrimination['dominant_speaker'] == speaker)
        ]
        if len(subset) > 0:
            mean_js = subset['js_coop_vs_uncertain'].mean()
            print(f"{mean_js:>12.4f}", end="")
        else:
            print(f"{'N/A':>12}", end="")
    
    # All speakers for this n_rep
    all_subset = results_discrimination[results_discrimination['n_adjacent_repeats'] == n_rep]
    mean_js_all = all_subset['js_coop_vs_uncertain'].mean()
    print(f"{mean_js_all:>12.4f}")

# Row for all adjacent repeats
print("-" * 58)
print(f"{'ALL':<10}", end="")
for speaker in speakers:
    subset = results_discrimination[results_discrimination['dominant_speaker'] == speaker]
    mean_js = subset['js_coop_vs_uncertain'].mean()
    print(f"{mean_js:>12.4f}", end="")
overall_mean = results_discrimination['js_coop_vs_uncertain'].mean()
print(f"{overall_mean:>12.4f}")

# -----------------------------------------------------------------------------
# Table 3: MAX JS DIVERGENCE
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("TABLE 3: MAX JS DIVERGENCE")
print("="*70)

print(f"\n{'AdjRep':<10}", end="")
for speaker in speakers:
    print(f"{speaker.upper():>12}", end="")
print(f"{'ALL':>12}")
print("-" * 58)

for n_rep in adj_reps:
    print(f"{n_rep:<10}", end="")
    for speaker in speakers:
        subset = results_discrimination[
            (results_discrimination['n_adjacent_repeats'] == n_rep) & 
            (results_discrimination['dominant_speaker'] == speaker)
        ]
        if len(subset) > 0:
            max_js = subset['js_coop_vs_uncertain'].max()
            print(f"{max_js:>12.4f}", end="")
        else:
            print(f"{'N/A':>12}", end="")
    
    # All speakers for this n_rep
    all_subset = results_discrimination[results_discrimination['n_adjacent_repeats'] == n_rep]
    max_js_all = all_subset['js_coop_vs_uncertain'].max()
    print(f"{max_js_all:>12.4f}")

# Row for all adjacent repeats
print("-" * 58)
print(f"{'ALL':<10}", end="")
for speaker in speakers:
    subset = results_discrimination[results_discrimination['dominant_speaker'] == speaker]
    max_js = subset['js_coop_vs_uncertain'].max()
    print(f"{max_js:>12.4f}", end="")
overall_max = results_discrimination['js_coop_vs_uncertain'].max()
print(f"{overall_max:>12.4f}")

# -----------------------------------------------------------------------------
# Table 4: COMBINED (for easy reading)
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("TABLE 4: COMBINED (Count / Mean JS / Max JS)")
print("="*70)

print(f"\n{'AdjRep':<8}", end="")
for speaker in speakers:
    print(f"{speaker.upper():^30}", end="")
print()

print(f"{'':8}", end="")
for speaker in speakers:
    print(f"{'Count':>10}{'Mean':>10}{'Max':>10}", end="")
print()
print("-" * 98)

for n_rep in adj_reps:
    print(f"{n_rep:<8}", end="")
    for speaker in speakers:
        subset = results_discrimination[
            (results_discrimination['n_adjacent_repeats'] == n_rep) & 
            (results_discrimination['dominant_speaker'] == speaker)
        ]
        count = len(subset)
        if count > 0:
            mean_js = subset['js_coop_vs_uncertain'].mean()
            max_js = subset['js_coop_vs_uncertain'].max()
            print(f"{count:>10}{mean_js:>10.4f}{max_js:>10.4f}", end="")
        else:
            print(f"{0:>10}{'N/A':>10}{'N/A':>10}", end="")
    print()

print("-" * 98)
print(f"{'TOTAL':<8}", end="")
for speaker in speakers:
    subset = results_discrimination[results_discrimination['dominant_speaker'] == speaker]
    count = len(subset)
    mean_js = subset['js_coop_vs_uncertain'].mean()
    max_js = subset['js_coop_vs_uncertain'].max()
    print(f"{count:>10}{mean_js:>10.4f}{max_js:>10.4f}", end="")
print()

# -----------------------------------------------------------------------------
# Save as CSV for easy reference
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("Saving tables to CSV...")
print("="*70)

# Create a comprehensive DataFrame
table_data = []
for n_rep in adj_reps:
    for speaker in speakers:
        subset = results_discrimination[
            (results_discrimination['n_adjacent_repeats'] == n_rep) & 
            (results_discrimination['dominant_speaker'] == speaker)
        ]
        count = len(subset)
        mean_js = subset['js_coop_vs_uncertain'].mean() if count > 0 else np.nan
        max_js = subset['js_coop_vs_uncertain'].max() if count > 0 else np.nan
        median_js = subset['js_coop_vs_uncertain'].median() if count > 0 else np.nan
        mean_E_diff = subset['E_theta_diff'].mean() if count > 0 else np.nan
        
        table_data.append({
            'n_adjacent_repeats': n_rep,
            'dominant_speaker': speaker,
            'count': count,
            'mean_js': mean_js,
            'max_js': max_js,
            'median_js': median_js,
            'mean_E_theta_diff': mean_E_diff
        })

table_df = pd.DataFrame(table_data)
table_df.to_csv('adj_repeats_x_speaker_table.csv', index=False)
print("Saved: adj_repeats_x_speaker_table.csv")

# Also create a pivot table version
print("\nPivot table (Count):")
pivot_count = table_df.pivot(index='n_adjacent_repeats', columns='dominant_speaker', values='count')
print(pivot_count)

print("\nPivot table (Mean JS):")
pivot_mean = table_df.pivot(index='n_adjacent_repeats', columns='dominant_speaker', values='mean_js')
print(pivot_mean.round(4))

print("\nPivot table (Max JS):")
pivot_max = table_df.pivot(index='n_adjacent_repeats', columns='dominant_speaker', values='max_js')
print(pivot_max.round(4))

ADJACENT REPEATS × DOMINANT SPEAKER ANALYSIS

TABLE 1: COUNT

AdjRep             INF       PERSP       PERSM       TOTAL
----------------------------------------------------------
0                14336        2436        2436       19208
1                 7618        1679        1679       10976
2                 1538         407         407        2352
3                  144          40          40         224
4                    6           1           1           8
----------------------------------------------------------
TOTAL            23642        4563        4563       32768

TABLE 2: MEAN JS DIVERGENCE

AdjRep             INF       PERSP       PERSM         ALL
----------------------------------------------------------
0               0.0021      0.0468      0.0468      0.0135
1               0.0023      0.0695      0.0695      0.0229
2               0.0035      0.1020      0.1020      0.0376
3               0.0058      0.1737      0.1737      0.0658
4               0.0003 

In [None]:
# =============================================================================
# BEST NON-REPETITIVE (AdjRep ≤ 1), HIGH-JS, NON-WEIRD SEQUENCES
# =============================================================================

import pandas as pd
import numpy as np

print("="*70)
print("BEST SEQUENCES: AdjRep ≤ 1, High JS, High Normalcy")
print("="*70)

# -----------------------------------------------------------------------------
# Filter criteria
# -----------------------------------------------------------------------------

# Adjacent repeats ≤ 1
# High normalcy (top 10% by marginal_P)
# Then sort by JS

normalcy_threshold = np.percentile(results_discrimination['marginal_P'], 90)

filtered = results_discrimination[
    (results_discrimination['n_adjacent_repeats'] <= 1) &
    (results_discrimination['marginal_P'] >= normalcy_threshold)
].copy()

print(f"\nFiltering criteria:")
print(f"  - Adjacent repeats ≤ 1")
print(f"  - Marginal P ≥ {normalcy_threshold:.2e} (top 10%)")
print(f"  - Total sequences meeting criteria: {len(filtered)}")

# Sort by JS divergence
filtered = filtered.sort_values('js_coop_vs_uncertain', ascending=False)

# -----------------------------------------------------------------------------
# Table by dominant speaker
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("TOP SEQUENCES BY DOMINANT SPEAKER")
print("="*70)

for speaker in ['persp', 'persm', 'inf']:
    print(f"\n{'='*60}")
    print(f"TOP {speaker.upper()}-CHARACTERISTIC (AdjRep ≤ 1, Top 10% normalcy)")
    print(f"{'='*60}")
    
    subset = filtered[filtered['dominant_speaker'] == speaker].head(10)
    
    if len(subset) == 0:
        print("  (none found)")
        continue
    
    print(f"\n{'Rank':<5} {'JS':<8} {'E[θ]Δ':<8} {'Rep':<5} {'P(seq)':<12} {'E[θ]coop':<10} {'E[θ]unc':<10} Sequence")
    print("-" * 110)
    
    for rank, (idx, row) in enumerate(subset.iterrows(), 1):
        seq = row['sequence']
        abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
        print(f"{rank:<5} {row['js_coop_vs_uncertain']:<8.4f} {row['E_theta_diff']:<8.4f} "
              f"{row['n_adjacent_repeats']:<5} {row['marginal_P']:<12.2e} "
              f"{row['E_theta_coop']:<10.4f} {row['E_theta_uncertain']:<10.4f} {abbrev}")

# -----------------------------------------------------------------------------
# Overall top sequences (regardless of dominant speaker)
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("OVERALL TOP 20 SEQUENCES (AdjRep ≤ 1, Top 10% normalcy)")
print("="*70)

top20 = filtered.head(20)

print(f"\n{'Rank':<5} {'JS':<8} {'E[θ]Δ':<8} {'Rep':<5} {'Speaker':<8} {'P(seq)':<12} {'Sequence'}")
print("-" * 100)

for rank, (idx, row) in enumerate(top20.iterrows(), 1):
    seq = row['sequence']
    abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
    print(f"{rank:<5} {row['js_coop_vs_uncertain']:<8.4f} {row['E_theta_diff']:<8.4f} "
          f"{row['n_adjacent_repeats']:<5} {row['dominant_speaker']:<8} "
          f"{row['marginal_P']:<12.2e} {abbrev}")

# -----------------------------------------------------------------------------
# Show full sequence names for top candidates
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("TOP CANDIDATES - FULL SEQUENCE NAMES")
print("="*70)

print("\n--- PERSP-characteristic (Coop estimates HIGHER θ) ---\n")
persp_top = filtered[filtered['dominant_speaker'] == 'persp'].head(5)
for rank, (idx, row) in enumerate(persp_top.iterrows(), 1):
    print(f"{rank}. {row['sequence']}")
    print(f"   JS={row['js_coop_vs_uncertain']:.4f}, Coop E[θ]={row['E_theta_coop']:.3f}, "
          f"Uncertain E[θ]={row['E_theta_uncertain']:.3f}, Diff={row['E_theta_diff']:+.3f}")
    print()

print("\n--- PERSM-characteristic (Coop estimates LOWER θ) ---\n")
persm_top = filtered[filtered['dominant_speaker'] == 'persm'].head(5)
for rank, (idx, row) in enumerate(persm_top.iterrows(), 1):
    print(f"{rank}. {row['sequence']}")
    print(f"   JS={row['js_coop_vs_uncertain']:.4f}, Coop E[θ]={row['E_theta_coop']:.3f}, "
          f"Uncertain E[θ]={row['E_theta_uncertain']:.3f}, Diff={row['E_theta_diff']:+.3f}")
    print()

print("\n--- INF-characteristic (Both estimate similar θ) ---\n")
inf_top = filtered[filtered['dominant_speaker'] == 'inf'].head(5)
if len(inf_top) == 0:
    print("(none found with current criteria - relaxing normalcy)")
    inf_top = results_discrimination[
        (results_discrimination['n_adjacent_repeats'] <= 1) &
        (results_discrimination['dominant_speaker'] == 'inf')
    ].sort_values('js_coop_vs_uncertain', ascending=False).head(5)

for rank, (idx, row) in enumerate(inf_top.iterrows(), 1):
    print(f"{rank}. {row['sequence']}")
    print(f"   JS={row['js_coop_vs_uncertain']:.4f}, Coop E[θ]={row['E_theta_coop']:.3f}, "
          f"Uncertain E[θ]={row['E_theta_uncertain']:.3f}, Diff={row['E_theta_diff']:+.3f}")
    print(f"   marginal_P={row['marginal_P']:.2e}")
    print()

# -----------------------------------------------------------------------------
# Zero adjacent repeats only
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("ZERO ADJACENT REPEATS ONLY (Maximum variety)")
print("="*70)

filtered_0rep = results_discrimination[
    (results_discrimination['n_adjacent_repeats'] == 0) &
    (results_discrimination['marginal_P'] >= normalcy_threshold)
].sort_values('js_coop_vs_uncertain', ascending=False)

print(f"\nTotal sequences with 0 adjacent repeats and top 10% normalcy: {len(filtered_0rep)}")

print(f"\n{'Rank':<5} {'JS':<8} {'E[θ]Δ':<8} {'Speaker':<8} {'Sequence'}")
print("-" * 80)

for rank, (idx, row) in enumerate(filtered_0rep.head(15).iterrows(), 1):
    seq = row['sequence']
    abbrev = ",".join([u.split(",")[0][:2] + u.split(",")[1][0] for u in seq])
    print(f"{rank:<5} {row['js_coop_vs_uncertain']:<8.4f} {row['E_theta_diff']:<8.4f} "
          f"{row['dominant_speaker']:<8} {abbrev}")

# -----------------------------------------------------------------------------
# Summary table for experiment design
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("FINAL RECOMMENDATIONS FOR EXPERIMENT")
print("="*70)

print("""
Based on the analysis, here are the recommended sequences:

CRITERIA USED:
- Adjacent repeats ≤ 1 (not boring/repetitive)
- Top 10% marginal probability (not weird)
- Sorted by JS divergence (maximum discrimination)
""")

# Collect final recommendations
final_recs = []

# Top 3 PERSP
for idx, row in filtered[filtered['dominant_speaker'] == 'persp'].head(6).iterrows():
    final_recs.append({
        'category': 'PERSP',
        'sequence': row['sequence'],
        'js': row['js_coop_vs_uncertain'],
        'E_theta_coop': row['E_theta_coop'],
        'E_theta_uncertain': row['E_theta_uncertain'],
        'E_theta_diff': row['E_theta_diff'],
        'adj_rep': row['n_adjacent_repeats'],
        'marginal_P': row['marginal_P'],
        'prediction': 'Coop > Uncertain'
    })

# Top 3 PERSM
for idx, row in filtered[filtered['dominant_speaker'] == 'persm'].head(6).iterrows():
    final_recs.append({
        'category': 'PERSM',
        'sequence': row['sequence'],
        'js': row['js_coop_vs_uncertain'],
        'E_theta_coop': row['E_theta_coop'],
        'E_theta_uncertain': row['E_theta_uncertain'],
        'E_theta_diff': row['E_theta_diff'],
        'adj_rep': row['n_adjacent_repeats'],
        'marginal_P': row['marginal_P'],
        'prediction': 'Coop < Uncertain'
    })

final_df = pd.DataFrame(final_recs)

print("\nFINAL RECOMMENDED SEQUENCES:")
print("="*70)

for cat in ['PERSP', 'PERSM']:
    print(f"\n--- {cat}-characteristic ---")
    subset = final_df[final_df['category'] == cat]
    for i, row in subset.iterrows():
        seq = row['sequence']
        print(f"\n  Sequence: {seq}")
        print(f"  JS = {row['js']:.4f}")
        print(f"  E[θ|coop] = {row['E_theta_coop']:.3f}, E[θ|uncertain] = {row['E_theta_uncertain']:.3f}")
        print(f"  Prediction: {row['prediction']} (diff = {row['E_theta_diff']:.3f})")
        print(f"  Adjacent repeats: {row['adj_rep']}, Normalcy: {row['marginal_P']:.2e}")

# Save final recommendations
final_df.to_csv('final_recommended_sequences.csv', index=False)
print("\n\nSaved: final_recommended_sequences.csv")

# -----------------------------------------------------------------------------
# Detailed posterior analysis for top candidates
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("DETAILED POSTERIOR ANALYSIS FOR TOP 2 CANDIDATES")
print("="*70)

# Best PERSP
best_persp_idx = filtered[filtered['dominant_speaker'] == 'persp'].index[0]
best_persp = filtered.loc[best_persp_idx]

print(f"\n--- BEST PERSP: {best_persp['sequence']} ---")
print(f"\nPosterior P(θ | sequence):")
print(f"{'θ':<6} {'Coop':<10} {'Uncertain':<12} {'Diff':<10}")
print("-"*40)
for i, theta in enumerate(theta_values):
    p_coop = P_theta_given_seq_coop[best_persp_idx, i]
    p_unc = P_theta_given_seq_uncertain[best_persp_idx, i]
    diff = p_coop - p_unc
    print(f"{theta:<6.1f} {p_coop:<10.4f} {p_unc:<12.4f} {diff:<+10.4f}")

# Best PERSM
best_persm_idx = filtered[filtered['dominant_speaker'] == 'persm'].index[0]
best_persm = filtered.loc[best_persm_idx]

print(f"\n--- BEST PERSM: {best_persm['sequence']} ---")
print(f"\nPosterior P(θ | sequence):")
print(f"{'θ':<6} {'Coop':<10} {'Uncertain':<12} {'Diff':<10}")
print("-"*40)
for i, theta in enumerate(theta_values):
    p_coop = P_theta_given_seq_coop[best_persm_idx, i]
    p_unc = P_theta_given_seq_uncertain[best_persm_idx, i]
    diff = p_coop - p_unc
    print(f"{theta:<6.1f} {p_coop:<10.4f} {p_unc:<12.4f} {diff:<+10.4f}")

BEST SEQUENCES: AdjRep ≤ 1, High JS, High Normalcy

Filtering criteria:
  - Adjacent repeats ≤ 1
  - Marginal P ≥ 4.33e-05 (top 10%)
  - Total sequences meeting criteria: 2418

TOP SEQUENCES BY DOMINANT SPEAKER

TOP PERSP-CHARACTERISTIC (AdjRep ≤ 1, Top 10% normalcy)

Rank  JS       E[θ]Δ    Rep   P(seq)       E[θ]coop   E[θ]unc    Sequence
--------------------------------------------------------------------------------------------------------------
1     0.4713   0.2387   1     7.68e-04     0.6845     0.4458     mos,sos,sou,sos,sos
2     0.4643   0.2387   1     8.63e-04     0.6840     0.4453     mos,sos,sos,sou,sos
3     0.4582   0.2487   0     4.50e-04     0.6881     0.4395     mos,sos,sou,sos,sou
4     0.4547   0.2475   1     4.34e-04     0.6869     0.4394     mos,sos,sou,sou,sos
5     0.4513   0.2440   1     3.87e-04     0.6883     0.4442     mos,sou,sos,sos,sou
6     0.4479   0.2427   0     3.75e-04     0.6871     0.4445     mos,sou,sos,sou,sos
7     0.4468   0.2374   1     5.35e-

In [33]:
# =============================================================================
# FIX: Add INF-characteristic sequences to recommendations
# =============================================================================

print("="*70)
print("COMPLETE FINAL RECOMMENDATIONS (including INF)")
print("="*70)

# Collect final recommendations - ALL THREE TYPES
final_recs_complete = []

# Top 3 PERSP
for idx, row in filtered[filtered['dominant_speaker'] == 'persp'].head(6).iterrows():
    final_recs_complete.append({
        'category': 'PERSP',
        'sequence': row['sequence'],
        'js': row['js_coop_vs_uncertain'],
        'E_theta_coop': row['E_theta_coop'],
        'E_theta_uncertain': row['E_theta_uncertain'],
        'E_theta_diff': row['E_theta_diff'],
        'adj_rep': row['n_adjacent_repeats'],
        'marginal_P': row['marginal_P'],
        'prediction': 'Coop > Uncertain'
    })

# Top 3 PERSM
for idx, row in filtered[filtered['dominant_speaker'] == 'persm'].head(6).iterrows():
    final_recs_complete.append({
        'category': 'PERSM',
        'sequence': row['sequence'],
        'js': row['js_coop_vs_uncertain'],
        'E_theta_coop': row['E_theta_coop'],
        'E_theta_uncertain': row['E_theta_uncertain'],
        'E_theta_diff': row['E_theta_diff'],
        'adj_rep': row['n_adjacent_repeats'],
        'marginal_P': row['marginal_P'],
        'prediction': 'Coop < Uncertain'
    })

# Top 3 INF
inf_filtered = filtered[filtered['dominant_speaker'] == 'inf']
if len(inf_filtered) == 0:
    # Relax normalcy criteria for INF
    print("No INF sequences meet strict criteria, relaxing normalcy threshold...")
    inf_filtered = results_discrimination[
        (results_discrimination['n_adjacent_repeats'] <= 1) &
        (results_discrimination['dominant_speaker'] == 'inf')
    ].sort_values('js_coop_vs_uncertain', ascending=False)

for idx, row in inf_filtered.head(50).iterrows():
    # Determine direction
    if row['E_theta_coop'] > row['E_theta_uncertain']:
        prediction = 'Coop > Uncertain (both high)'
    else:
        prediction = 'Coop < Uncertain (both low)'
    
    final_recs_complete.append({
        'category': 'INF',
        'sequence': row['sequence'],
        'js': row['js_coop_vs_uncertain'],
        'E_theta_coop': row['E_theta_coop'],
        'E_theta_uncertain': row['E_theta_uncertain'],
        'E_theta_diff': row['E_theta_diff'],
        'adj_rep': row['n_adjacent_repeats'],
        'marginal_P': row['marginal_P'],
        'prediction': prediction
    })

final_df_complete = pd.DataFrame(final_recs_complete)

# Display all recommendations
print("\n" + "="*70)
print("COMPLETE FINAL RECOMMENDED SEQUENCES")
print("="*70)

for cat in ['PERSP', 'PERSM', 'INF']:
    print(f"\n{'='*60}")
    print(f"{cat}-CHARACTERISTIC")
    print(f"{'='*60}")
    
    subset = final_df_complete[final_df_complete['category'] == cat]
    
    for i, row in subset.iterrows():
        seq = row['sequence']
        print(f"\n  Sequence: {seq}")
        print(f"  JS = {row['js']:.4f}")
        print(f"  E[θ|coop] = {row['E_theta_coop']:.3f}, E[θ|uncertain] = {row['E_theta_uncertain']:.3f}")
        print(f"  Prediction: {row['prediction']} (diff = {row['E_theta_diff']:.3f})")
        print(f"  Adjacent repeats: {row['adj_rep']}, Normalcy: {row['marginal_P']:.2e}")

# Save complete recommendations
final_df_complete.to_csv('final_recommended_sequences_complete.csv', index=False)
print("\n\nSaved: final_recommended_sequences_complete.csv")

# -----------------------------------------------------------------------------
# Summary comparison table
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("SUMMARY COMPARISON TABLE")
print("="*70)

print(f"\n{'Category':<10} {'JS':<8} {'E[θ]coop':<10} {'E[θ]unc':<10} {'Diff':<8} {'Interpretation'}")
print("-" * 80)

for cat in ['PERSP', 'PERSM', 'INF']:
    subset = final_df_complete[final_df_complete['category'] == cat]
    row = subset.iloc[0]  # Best one
    
    if cat == 'PERSP':
        interp = "Coop thinks θ HIGH, Uncertain thinks θ MODERATE"
    elif cat == 'PERSM':
        interp = "Coop thinks θ LOW, Uncertain thinks θ MODERATE"
    else:
        interp = "Both agree on θ (low discrimination)"
    
    print(f"{cat:<10} {row['js']:<8.4f} {row['E_theta_coop']:<10.3f} {row['E_theta_uncertain']:<10.3f} "
          f"{row['E_theta_diff']:<8.3f} {interp}")

# -----------------------------------------------------------------------------
# Key insight about INF
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("KEY INSIGHT: WHY INF HAS LOW DISCRIMINATION")
print("="*70)

print("""
INF-characteristic sequences have LOW JS divergence (0.15 vs 0.47 for PERSP/PERSM).

WHY?
- INF-characteristic sequences use strong quantifiers: "all", "most", "no"
- These are only plausible at EXTREME θ values (very high or very low)
- ALL speaker types agree that extreme quantifiers → extreme θ
- Therefore, Coop and Uncertain listeners reach SIMILAR conclusions

EXAMPLE: ('most,successful', 'all,successful', 'no,unsuccessful', 'all,successful', 'no,unsuccessful')
- ALL speakers agree: this sequence requires θ ≈ 0.9
- Coop listener: E[θ] = 0.886
- Uncertain listener: E[θ] = 0.919
- Difference: only 0.033

IMPLICATION FOR EXPERIMENT:
- INF-characteristic sequences can serve as CONTROL stimuli
- Both conditions should show SIMILAR θ estimates
- If they differ, it suggests the manipulation affected processing beyond the model
""")

# -----------------------------------------------------------------------------
# Detailed posterior for best INF sequence
# -----------------------------------------------------------------------------

print("\n" + "="*70)
print("DETAILED POSTERIOR FOR BEST INF SEQUENCE")
print("="*70)

best_inf_idx = inf_filtered.index[0]
best_inf = inf_filtered.iloc[0]

print(f"\nSequence: {best_inf['sequence']}")
print(f"\nPosterior P(θ | sequence):")
print(f"{'θ':<6} {'Coop':<10} {'Uncertain':<12} {'INF':<10} {'PERSP':<10} {'PERSM':<10}")
print("-" * 60)

for i, theta in enumerate(theta_values):
    p_coop = P_theta_given_seq_coop[best_inf_idx, i]
    p_unc = P_theta_given_seq_uncertain[best_inf_idx, i]
    p_inf = P_theta_given_seq_inf[best_inf_idx, i]
    p_persp = P_theta_given_seq_persp[best_inf_idx, i]
    p_persm = P_theta_given_seq_persm[best_inf_idx, i]
    print(f"{theta:<6.1f} {p_coop:<10.4f} {p_unc:<12.4f} {p_inf:<10.4f} {p_persp:<10.4f} {p_persm:<10.4f}")

print("""
OBSERVATION:
- All speaker-specific posteriors concentrate at high θ (0.8-1.0)
- The slight difference comes from PERSM putting more weight on θ=1.0
- This is why Uncertain (which includes PERSM) estimates slightly higher θ
""")

COMPLETE FINAL RECOMMENDATIONS (including INF)

COMPLETE FINAL RECOMMENDED SEQUENCES

PERSP-CHARACTERISTIC

  Sequence: ('most,successful', 'some,successful', 'some,unsuccessful', 'some,successful', 'some,successful')
  JS = 0.4713
  E[θ|coop] = 0.684, E[θ|uncertain] = 0.446
  Prediction: Coop > Uncertain (diff = 0.239)
  Adjacent repeats: 1, Normalcy: 7.68e-04

  Sequence: ('most,successful', 'some,successful', 'some,successful', 'some,unsuccessful', 'some,successful')
  JS = 0.4643
  E[θ|coop] = 0.684, E[θ|uncertain] = 0.445
  Prediction: Coop > Uncertain (diff = 0.239)
  Adjacent repeats: 1, Normalcy: 8.63e-04

  Sequence: ('most,successful', 'some,successful', 'some,unsuccessful', 'some,successful', 'some,unsuccessful')
  JS = 0.4582
  E[θ|coop] = 0.688, E[θ|uncertain] = 0.439
  Prediction: Coop > Uncertain (diff = 0.249)
  Adjacent repeats: 0, Normalcy: 4.50e-04

  Sequence: ('most,successful', 'some,successful', 'some,unsuccessful', 'some,unsuccessful', 'some,successful')
  JS = 

In [39]:
results_discrimination["sequence"][0]

('all,successful',
 'all,successful',
 'all,successful',
 'all,successful',
 'all,successful')

In [42]:
target = (
    "some,successful",
    "some,successful",
    "some,unsuccessful",
    "some,successful",
    "some,unsuccessful",
)

print(results_discrimination[
    results_discrimination["sequence"] == target
].T)

                                                                     18789
sequence_idx                                                         18789
sequence                 (some,successful, some,successful, some,unsucc...
js_coop_vs_uncertain                                              0.438319
E_theta_coop                                                      0.547124
E_theta_uncertain                                                 0.479388
E_theta_diff                                                      0.067736
Var_theta_coop                                                    0.024447
Var_theta_uncertain                                                0.09933
max_P_overall                                                     0.024833
marginal_P                                                        0.003303
min_max_P                                                              0.0
max_P_inf                                                              0.0
max_P_persp              