# Core Agents

In [1]:
import math
import warnings
import itertools
from typing import List, Dict, Union, Optional, Tuple, TypeVar, Iterator, Callable, Any, Literal
import numpy as np
import pandas as pd
from scipy.special import gammaln, logsumexp

USE_PRECISE_LOGSPACE = False

np.seterr(divide='ignore', under='ignore')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [2]:
RUN_DEMO =True

## Core Helper Functions

In [3]:
def log_M_product(LA, LB, precise=False, margin=None):
    """
    Compute the matrix product of two matrices A and B in log-space
    The inputs LA, LB are log-transformed values of non-negative matrix A and B.
    The output LC is the product of A and B in log-space.

    In the log-space, the dot product is computed as:
        LC[i,j] = log( sum_{s=1}^k exp(LA[i,s] + LB[s,j]) )
                = m +  log( sum_{s=1}^k exp(LA[i,s] + LB[s,j] - m) )

    Uses a more spacious shift that still prevents overflow but more robust to underflow:
          m = max_s (LA[i,s]+LB[s,j]) + log(n) - log(M_max) + margin
    where margin defaults to -log(eps) for float64.

    Input Validations:
      - Both LA and LB must be numpy ndarrays.
      - Both LA and LB must be 2-dimensional.
      - The number of columns in LA must equal the number of rows in LB.
      - Both LA and LB must have a floating-point data type (to represent log values).
      - Neither LA nor LB should contain positive infinity.

    Parameters:
      LA : np.ndarray
          2D numpy array of shape (m, k) representing the log-transformed values of matrix A
          (i.e., LA[i,s] = log(a[i,s]), with a[i,s] >= 0 and log(0) = -np.inf).
      LB : np.ndarray
          2D numpy array of shape (k, n) representing the log-transformed values of matrix B
          (i.e., LB[s,j] = log(b[s,j]), with b[s,j] >= 0 and log(0) = -np.inf).

    Returns:
      LC : np.ndarray
          2D numpy array of shape (m, n) representing the log-space dot product:
          LC[i,j] = log(sum_{s=1}^k exp(LA[i,s] + LB[s,j])).

    Raises:
      TypeError: If LA or LB is not a numpy ndarray, or if the arrays are not of a floating-point type.
      ValueError: If LA or LB is not 2-dimensional, if their inner dimensions are incompatible,
                  or if either input contains positive infinity.
    """

    # --- Input Validations ---
    if not (isinstance(LA, np.ndarray) and isinstance(LB, np.ndarray)):
        raise TypeError("Both LA and LB must be numpy ndarrays.")
    if LA.ndim != 2 or LB.ndim != 2:
        raise ValueError("Both LA and LB must be 2-dimensional.")
    if LA.shape[1] != LB.shape[0]:
        raise ValueError("The number of columns in LA must equal the number of rows in LB for multiplication.")
    if not (np.issubdtype(LA.dtype, np.floating) and np.issubdtype(LB.dtype, np.floating)):
        raise TypeError("Both LA and LB must have a floating-point data type representing log values.")
    if np.any(LA == np.inf) or np.any(LB == np.inf):
        raise ValueError("Input matrices must not contain positive infinity.")
    LA = np.asarray(LA, dtype=np.float64)
    LB = np.asarray(LB, dtype=np.float64)

    # --- Compute all pairwise products in log-space: shape (m, k, n) ---
    log_products = LA[:, :, None] + LB[None, :, :]

    if not precise:
        LC = logsumexp(log_products, axis=1)

    else:
        # --- Prepare constants ---
        if margin is None:
            margin = -np.log(np.finfo(LA.dtype).eps)   # ≈ 36 for float64
        M_max = np.finfo(LA.dtype).max
        k = LA.shape[1]

        # --- Spacious Log-Sum-Exp trick ---
        max_lp = np.max(log_products, axis=1)   # shape (m,n)
        all_neginf = (max_lp == -np.inf) # mask for all -inf slices
        shift = max_lp + np.log(k) - np.log(M_max) + margin # the shift parameter
        shift = np.where(all_neginf, 0.0, shift)   # avoiding shifting all -inf slices by -inf

        shifted_lp = log_products - shift[:, None, :]

        exp_shifted_lp = np.exp(shifted_lp)

        sum_exp_shifted_lp = np.apply_along_axis(math.fsum, 1, exp_shifted_lp)  # shape (m,n)

        LC = shift + np.log(np.clip(sum_exp_shifted_lp, np.finfo(float).tiny, None))
        LC = np.where(sum_exp_shifted_lp == 0.0, -np.inf, LC)

    return LC

In [4]:
def log_column_normalize(LX, precise=False):
    """
    Normalize a matrix of log-values so that each column values sum to one.

    In the log-space, the normalization by columns is computed as:
        NLX[i,j] = log( X[i,j] / sum_{s=1}^n X[s,j] )
                 = log( X[i,j] ) - log( sum_{s=1}^n X[s,j] )
                 = LX[i,j] - log( sum_{s=1}^n exp(LX[s,j]) )
                 = LX[i,j] - [ m[j] + log( sum_{s=1}^n exp(LX[s,j] - m[j]) ) ]
                 = ( LX[i,j] -  m[j] ) - log( sum_{s=1}^n exp(LX[s,j] - m[j]) )

    Uses a more spacious shift that still prevents overflow but more robust to underflow:
          m[j] = max_s {LX[s,j]} + log(n) - log(M_max) + margin
    where margin defaults to -log(eps) for float64.

    Input Validations:
      - LX must be numpy ndarrays.
      - LX must be 2-dimensional.
      - LX must have a floating-point data type (to represent log values).
      - LX must NOT contain positive infinity.

    Parameters:
    LX (np.array): A 2d-array containing the log(x_i) values.
    precise (bool): If True, uses a highly precise custom implementation
                   with enhanced numerical stability. If False, uses scipy.special.logsumexp
                   for faster computation with standard numerical stability.

    Returns:
    np.array: A 2d-array containing the log of normalized values.

    Raises:
      TypeError: If LX is not a numpy ndarray, or if the arrays are not of a floating-point type.
      ValueError: If LX is not 2-dimensional, or if input contains positive infinity.
    """

    # --- Input Validations ---
    if not isinstance(LX, np.ndarray):
        raise TypeError("LX must be numpy ndarrays.")
    if LX.ndim != 2:
        raise ValueError("LX must be 2-dimensional.")
    if not np.issubdtype(LX.dtype, np.floating):
        raise TypeError("LX must have a floating-point data type representing log values.")
    if np.any(LX == np.inf):
        raise ValueError("Input matrices must not contain positive infinity.")
    if np.any(np.isnan(LX)):
        raise ValueError("Input contains NaN")
        
    if not precise:
        # Compute log column sums using scipy's optimized implementation
        log_column_sums = logsumexp(LX, axis=0)

        # Handle all -inf columns (same as original)
        all_neg_inf_cols = np.all(LX == -np.inf, axis=0)
        if np.any(all_neg_inf_cols):
            warnings.warn("Input has column of all -inf", UserWarning)
            log_column_sums = np.where(all_neg_inf_cols, 0, log_column_sums)

        return LX - log_column_sums

    else:
        # --- Perform the precise log-space normalization ---
        M_LX = np.max(LX, axis=0) # shape: (n,)
        if np.any(M_LX == -np.inf):
            warnings.warn("Input has column of all -inf", UserWarning)

        n = LX.shape[0]
        margin = -np.log(np.finfo(LX.dtype).eps)
        shift = M_LX + np.log(n) - np.log(np.finfo(LX.dtype).max) + margin

        # all -inf slice will have max of -inf and thus no need to shift
        zeros_shift = shift == -np.inf
        shift = np.where(zeros_shift, 0, shift)

        LX_shifted = LX - shift  # shape: (m, n) - (1, n) = (m, n)

        # Exponentials of the normalized values.
        exp_LX_shifted = np.exp(LX_shifted)

        # Sum along the axis s.
        sum_exp_LX_shifted = np.apply_along_axis(math.fsum, axis=0, arr=exp_LX_shifted)
        log_sum_exp_LX_shifted = np.log(np.clip(sum_exp_LX_shifted, np.finfo(float).tiny, None))
        log_sum_exp_LX_shifted[zeros_shift] = 0

        return LX_shifted - log_sum_exp_LX_shifted

In [5]:
def log_column_softmax(X, alpha, precise=False):
    """
    Take in a score matrix X of shape (m, n), compute either
      - the log‑softmax (alpha is float) or
      - the log‑argmax (alpha == "determ") down each column,
    and return a (m, n) array of log‑probabilities.

    In log‑space, for each column j, log-softmax is calculated as:
        m_j = max_k X[k,j] + log(m) - log(M_max) + margin
        log_prob[i,j] = alpha * (X[i,j] − m_j)
                         − log( sum_{l=1}^m exp( alpha * (X[l,j] − m_j) ) )

    For alpha=="determ", ties at the max split equally:
        log_prob[i,j] = -log(count_max_j)  if X[i,j] == max_k X[k,j]
                      = -inf               otherwise

    Input Validations:
      - X must be a numpy.ndarray of floats, 2-dimensional, with no NaN/inf
      - alpha must be a positive float or the string "determ"

    Returns:
      log_prob : np.ndarray of shape (m, n), the log-probabilities
    """

    # Validate X
    if not isinstance(X, np.ndarray):
        raise TypeError("X must be a numpy.ndarray")
    if X.ndim != 2:
        raise ValueError(f"X must be 2-dimensional (got ndim={X.ndim})")
    if not np.issubdtype(X.dtype, np.floating):
        raise TypeError("X must have a floating-point dtype")
    if np.isnan(X).any():
        raise ValueError("X must not contain NaN values")
    if np.any(X == np.inf):
        raise ValueError("X must not contain positive infinity.")

    # Validate alpha
    if not (isinstance(alpha, float) or (isinstance(alpha, str) and alpha == "determ")):
        raise TypeError("alpha must be a float or the string 'determ'")
    if isinstance(alpha, float) and alpha <= 0:
        raise ValueError("alpha must be positive when using softmax")

    # Branch on alpha
    if alpha == "determ":
        # Hard argmax with equal splitting of ties
        col_max = np.max(X, axis=0)
        if np.any(col_max == -np.inf):  # Reject columns that are all -inf
            raise ValueError("Cannot compute softmax: some columns are all -inf")
        mask = X == col_max
        counts = mask.sum(axis=0)
        log_prob = np.where(mask, -np.log(counts)[np.newaxis, :], -np.inf)

    else:

        if not precise:
            # Softmax branch - simplified with scipy.special.logsumexp
            if np.any(np.max(X, axis=0) == -np.inf):  # Reject columns that are all -inf
                raise ValueError("Cannot compute softmax: some columns are all -inf")
            scaled = alpha * X
            log_prob = scaled - logsumexp(scaled, axis=0)

        else:
            # Softmax branch
            m, n = X.shape
            # Compute spacious shift
            scaled = alpha * X
            col_max = np.max(scaled, axis=0)
            if np.any(col_max == -np.inf):  # Reject columns that are all -inf
                raise ValueError("Cannot compute softmax: some columns are all -inf")
            margin = -np.log(np.finfo(X.dtype).eps)
            M_max = np.finfo(X.dtype).max
            shift = col_max + np.log(m) - np.log(M_max) + margin
            # Scale, shift, exponentiate, and sum
            scaled_shifted = scaled - shift
            exp_scaled_shifted = np.exp(scaled_shifted)
            sum_exp_scaled_shifted = np.apply_along_axis(math.fsum, 0, exp_scaled_shifted)

            # Compute log-sum-exp and normalize
            log_sum_exp_scaled_shifted = np.log(sum_exp_scaled_shifted)
            log_prob = scaled_shifted - log_sum_exp_scaled_shifted

    return log_prob

In [6]:
## CODE_DEMO ##
## matrix product
if RUN_DEMO:
    def demo_log_M_product():
        test_mtx_1 = np.array([
            [0.1,0.6],
            [0.2,0.2],
            [0.7,0.2]])
    
        test_mtx_2 = np.array([
            [0.7,0.6,0.3,0.4],
            [0.3,0.4,0.7,0.6]])
    
        log_test_mtx_1 = np.log(test_mtx_1)
        log_test_mtx_2 = np.log(test_mtx_2)
    
        print("In linear space:")
        print(test_mtx_1 @ test_mtx_2)
        print()
        print("In log-space (default):", USE_PRECISE_LOGSPACE)
        print(np.exp(log_M_product(log_test_mtx_1, log_test_mtx_2)))
        print()
        print("In log-space (precise):")
        print(np.exp(log_M_product(log_test_mtx_1, log_test_mtx_2, precise=True)))
        print()
        print("In log-space (inprecise):")
        print(np.exp(log_M_product(log_test_mtx_1, log_test_mtx_2, precise=False)))


    demo_log_M_product()

In linear space:
[[0.25 0.3  0.45 0.4 ]
 [0.2  0.2  0.2  0.2 ]
 [0.55 0.5  0.35 0.4 ]]

In log-space (default): False
[[0.25 0.3  0.45 0.4 ]
 [0.2  0.2  0.2  0.2 ]
 [0.55 0.5  0.35 0.4 ]]

In log-space (precise):
[[0.25 0.3  0.45 0.4 ]
 [0.2  0.2  0.2  0.2 ]
 [0.55 0.5  0.35 0.4 ]]

In log-space (inprecise):
[[0.25 0.3  0.45 0.4 ]
 [0.2  0.2  0.2  0.2 ]
 [0.55 0.5  0.35 0.4 ]]


In [7]:
## CODE_DEMO ##
# Column normalization
if RUN_DEMO:
    def demo_log_column_normalize():
        test_mtx = np.array(
            [[0, 1, 2],
            [100, 2, 2],
            [0, 3, 2],
            [0, 4, 0]])
    
        log_test_mtx = np.log(test_mtx)
        print("Input")
        print(test_mtx)
        print()
        print("In log-space (default):", USE_PRECISE_LOGSPACE)
        print(np.exp(log_column_normalize(log_test_mtx)))
        print()
        print("In log-space (precise):")
        print(np.exp(log_column_normalize(log_test_mtx, precise=True)))
        print()
        print("In log-space (inprecise):")
        print(np.exp(log_column_normalize(log_test_mtx, precise=False)))

    demo_log_column_normalize()

Input
[[  0   1   2]
 [100   2   2]
 [  0   3   2]
 [  0   4   0]]

In log-space (default): False
[[0.         0.1        0.33333333]
 [1.         0.2        0.33333333]
 [0.         0.3        0.33333333]
 [0.         0.4        0.        ]]

In log-space (precise):
[[0.         0.1        0.33333333]
 [1.         0.2        0.33333333]
 [0.         0.3        0.33333333]
 [0.         0.4        0.        ]]

In log-space (inprecise):
[[0.         0.1        0.33333333]
 [1.         0.2        0.33333333]
 [0.         0.3        0.33333333]
 [0.         0.4        0.        ]]


In [8]:
## CODE_DEMO ##
# Column Softmax
if RUN_DEMO:
    def demo_log_column_softmax():
        test_mtx = np.array([
            [0, 0.2, 20.],
            [1, 0.2, 20.],
            [0, 0.1, 10.]
        ])
        
        from scipy.special import softmax

        print("Input")
        print(test_mtx)
        print()

        print("In linear space:")
        print(softmax(10 * test_mtx, axis=0))
        print()

        print("In log-space (default): precise = ", USE_PRECISE_LOGSPACE)
        print(np.exp(log_column_softmax(X=test_mtx, alpha=10.0)))
        print()

        print("In log-space (precise):")
        print(np.exp(log_column_softmax(X=test_mtx, alpha=10.0, precise=True)))
        print()

        print("In log-space (inprecise):")
        print(np.exp(log_column_softmax(X=test_mtx, alpha=10.0, precise=False)))

    demo_log_column_softmax()

Input
[[ 0.   0.2 20. ]
 [ 1.   0.2 20. ]
 [ 0.   0.1 10. ]]

In linear space:
[[4.53958078e-05 4.22318798e-01 5.00000000e-01]
 [9.99909208e-01 4.22318798e-01 5.00000000e-01]
 [4.53958078e-05 1.55362403e-01 1.86003799e-44]]

In log-space (default): precise =  False
[[4.53958078e-05 4.22318798e-01 5.00000000e-01]
 [9.99909208e-01 4.22318798e-01 5.00000000e-01]
 [4.53958078e-05 1.55362403e-01 1.86003799e-44]]

In log-space (precise):
[[4.53958078e-05 4.22318798e-01 5.00000000e-01]
 [9.99909208e-01 4.22318798e-01 5.00000000e-01]
 [4.53958078e-05 1.55362403e-01 1.86003799e-44]]

In log-space (inprecise):
[[4.53958078e-05 4.22318798e-01 5.00000000e-01]
 [9.99909208e-01 4.22318798e-01 5.00000000e-01]
 [4.53958078e-05 1.55362403e-01 1.86003799e-44]]


## World

In [None]:
class World:
    """
    A class representing the world state in the pragmatic communication game.

    This class encapsulates the complete state space of possible observations,
    their likelihoods under different theta values, and semantic truth values of utterances.
    """

    # Class constants
    SEMANTIC_OPERATORS = {
        "all": lambda x, N: int(x == N),
        "most": lambda x, N: int(x > N / 2),
        "some": lambda x, N: int(x >= 1),
        "no": lambda x, N: int(x == 0)
    }
    QUANTIFIERS = ["all", "most", "some", "no"]
    PREDICATES = ["successful", "unsuccessful"]
    DEFAULT_THETA_VALUES: np.ndarray = np.round(np.linspace(0, 1, 11), 1)

    def __init__(
        self,
        n: int,
        m: int,
        theta_values: Optional[np.ndarray] = None
    ) -> None:
        """Initialize the world with given parameters and compute all necessary tables."""
        # Validate n and m parameters
        if not isinstance(n, int) or not isinstance(m, int):
            raise ValueError("n and m must be integers")
        if n < 1 or m < 1:
            raise ValueError("n and m must be positive")

        self.n = n
        self.m = m
        self.complex = n > 1

        # Validate and process theta values
        self.theta_values = self._validate_theta_values(theta_values)

        try:
            # Generate possible outcomes
            self.possible_outcomes = self._generate_possible_outcomes(self.n, self.m)

            # Compute success likelihoods
            self.suc_log_likelihood_theta = self._compute_successes_log_likelihoods(
                self.n, self.m, self.theta_values
            )

            # Compute observation likelihoods
            self.obs_log_likelihood_theta = self._compute_observation_log_likelihoods(
                self.n, self.m, self.theta_values, self.possible_outcomes
            )

            # Compute utterance truth values
            self.utterance_truth = self._compute_utterance_truth_values(
                self.n, self.m, self.possible_outcomes,
                self.QUANTIFIERS, self.PREDICATES, self.SEMANTIC_OPERATORS
            )
        except Exception as e:
            raise RuntimeError(f"Failed to initialize world state: {str(e)}")

    def _validate_theta_values(
        self,
        theta_values: Optional[np.ndarray]
    ) -> np.ndarray:
        """
        Validate and process theta values.

        Parameters
        ----------
        theta_values : Optional[np.ndarray]
            Array of possible theta values between 0 and 1. If None, uses DEFAULT_THETA_VALUES.

        Returns
        -------
        np.ndarray
            Validated array of theta values.
        """
        if theta_values is None:
            return self.DEFAULT_THETA_VALUES

        if not isinstance(theta_values, np.ndarray):
            raise ValueError("theta_values must be a numpy array")
        if not np.all((theta_values >= 0) & (theta_values <= 1)):
            raise ValueError("All theta values must be between 0 and 1")
        if not np.array_equal(theta_values, np.unique(theta_values)):
            raise ValueError("theta values must be arranged and not duplicating")
        if not np.array_equal(theta_values, np.round(theta_values, decimals=10)):
            warnings.warn("theta values above precision is rounded to 10 decimals",
                         UserWarning)
            if not np.array_equal(np.round(theta_values, decimals=10),
                                 np.unique(np.round(theta_values, decimals=10))):
                warnings.warn("Rounded theta values are duplicating, they will be collapsed",
                             UserWarning)

        return np.unique(np.round(theta_values, decimals=10))

    def _generate_possible_outcomes(self, n: int, m: int) -> List[Tuple[int, ...]]:
        """
        Generate all possible outcomes as frequency tuples.

        Parameters
        ----------
        n : int
            Number of independent Binomial experiments.
        m : int
            Number of Bernoulli trials per experiment.

        Returns
        -------
        List[Tuple[int, ...]]
            List of all possible frequency tuples.
        """
        try:
            # Convert the generator to a list and return
            return list(self._generate_outcome_tuples(n, m))
        except Exception as e:
            raise RuntimeError(f"Failed to generate possible outcomes: {str(e)}")

    def _generate_outcome_tuples(self, n: int, m: int) -> Iterator[Tuple[int, ...]]:
        """
        Generate all tuples (n_0, n_1, ..., n_m) of nonnegative integers
        such that sum(n_i) = n.
        Uses the stars and bars method.

        Parameters
        ----------
        n : int
            Number of independent Binomial experiments.
        m : int
            Number of Bernoulli trials per experiment.

        Yields
        ------
        Tuple[int, ...]
            Each possible outcome frequency tuple.
        """
        for dividers in itertools.combinations(range(n + m), m):
            counts = []
            prev = -1
            for d in dividers:
                counts.append(d - prev - 1)
                prev = d
            counts.append(n + m - prev - 1)
            yield tuple(counts)

    def _compute_successes_log_likelihoods(
        self,
        n: int,
        m: int,
        theta_values: np.ndarray
    ) -> pd.DataFrame:
        """
        Compute log P(S=s | theta) for S ~ Binomial(N, theta), where N = n*m.

        Parameters
        ----------
        n : int
            Number of independent Binomial experiments.
        m : int
            Number of Bernoulli trials per experiment.
        theta_values : np.ndarray
            Array of possible theta values.

        Returns
        -------
        pd.DataFrame
            DataFrame with rows indexed by s (total successes),
            columns by theta values, and values are log-probabilities.
        """
        try:
            N = n * m
            thetas = theta_values

            # Precompute log binomial coefficient ln[ C(N, s) ] for s=0..N
            s = np.arange(N+1)
            log_binom = (gammaln(N+1) - gammaln(s+1) - gammaln(N-s+1))

            # Prepare an array of shape (N+1, len(thetas))
            log_probs = np.empty((N+1, thetas.size))

            # Handle interior thetas (0 < theta < 1)
            mask = (thetas > 0) & (thetas < 1)
            theta_int = thetas[mask]
            if mask.any():
                log_theta = np.log(theta_int)[None, :]   # shape (1, k)
                log_one_minus = np.log(1 - theta_int)[None, :]
                # broadcast: log_binom[:,None] + s[:,None]*log_theta + (N-s)[:,None]*log_one_minus
                log_probs[:, mask] = (log_binom[:, None] + s[:, None] * log_theta +
                                      (N - s)[:, None] * log_one_minus)

            # theta = 0: only s=0 has log‐prob 0, others −inf
            mask_zero = thetas == 0
            log_probs[:, mask_zero] = -np.inf
            log_probs[0, mask_zero] = 0.0

            # theta = 1: only s=N has log‐prob 0, others −inf
            mask_one = thetas == 1
            log_probs[:, mask_one] = -np.inf
            log_probs[N, mask_one] = 0.0

            # Build DataFrame: rows indexed by s, columns by theta
            df = pd.DataFrame(log_probs,
                             index=range(N+1),
                             columns=thetas)
            return df
        except Exception as e:
            raise RuntimeError(f"Failed to compute success likelihoods: {str(e)}")

    def _compute_observation_log_likelihoods(
        self,
        n: int,
        m: int,
        theta_values: np.ndarray,
        possible_outcomes: List[Tuple[int, ...]]
    ) -> pd.DataFrame:
        """
        Compute a table of log-probabilities for each frequency tuple under each theta value.

        Parameters
        ----------
        n : int
            Number of independent Binomial experiments.
        m : int
            Number of Bernoulli trials per experiment.
        theta_values : np.ndarray
            Array of possible theta values.
        possible_outcomes : List[Tuple[int, ...]]
            List of all possible frequency tuples.

        Returns
        -------
        pd.DataFrame
            DataFrame with rows as frequency tuples and columns as theta values,
            containing log-probabilities.
        """
        try:
            mask_interior = (theta_values > 0) & (theta_values < 1)
            theta_interior = theta_values[mask_interior]

            j_vals = np.arange(m + 1)  # 0, 1, ..., m
            log_binom = gammaln(m + 1) - gammaln(j_vals + 1) - gammaln(m - j_vals + 1)
            base_const = gammaln(n + 1)

            results = []      # list to hold log-probability vectors (one per frequency tuple)
            index_labels = [] # frequency tuples

            for counts in possible_outcomes:
                counts_arr = np.array(counts)
                full_log_prob = np.empty(theta_values.size, dtype=float)

                # For theta = 0 and theta = 1, assign manually:
                for idx, theta in enumerate(theta_values):
                    if theta == 0:
                        full_log_prob[idx] = 0.0 if counts[0] == n else -np.inf
                    elif theta == 1:
                        full_log_prob[idx] = 0.0 if counts[m] == n else -np.inf

                if np.any(mask_interior):
                    interior_log_theta = np.log(theta_interior)
                    interior_log_one_minus_theta = np.log(1 - theta_interior)
                    base = base_const - np.sum(gammaln(counts_arr + 1))
                    terms = (log_binom[:, None] +
                            j_vals[:, None] * interior_log_theta +
                            (m - j_vals)[:, None] * interior_log_one_minus_theta)
                    log_term = np.sum(counts_arr[:, None] * terms, axis=0)
                    interior_result = base + log_term
                    full_log_prob[mask_interior] = interior_result

                results.append(full_log_prob)
                index_labels.append(counts)

            results_array = np.array(results)
            df = pd.DataFrame(results_array, index=index_labels, columns=theta_values)
            return df
        except Exception as e:
            raise RuntimeError(f"Failed to compute observation likelihoods: {str(e)}")

    def _compute_utterance_truth_values(
        self,
        n: int,
        m: int,
        counts_list: List[Tuple[int, ...]],
        quantifiers: List[str],
        predicates: List[str],
        semantic_operators: Dict[str, Callable]
    ) -> pd.DataFrame:
        """
        Generate a truth table where each row is a possible utterance (as a string)
        and each column is a possible frequency tuple (as a tuple of ints).

        Parameters
        ----------
        n : int
            Number of independent Binomial experiments.
        m : int
            Number of Bernoulli trials per experiment.
        counts_list : List[Tuple[int, ...]]
            List of all possible frequency tuples.
        quantifiers : List[str]
            List of quantifiers ("all", "most", "some", "no").
        predicates : List[str]
            List of predicates ("successful", "unsuccessful").
        semantic_operators : Dict[str, Callable]
            Dictionary mapping quantifiers to their semantic functions.

        Returns
        -------
        pd.DataFrame
            DataFrame with rows as utterance strings, columns as frequency tuples,
            and values as truth values (1 or 0).
        """
        try:
            # Inner helper to list all utterances
            def _generate_utterance(n: int, quantifiers: List[str], predicates: List[str]) -> List[Tuple[str, ...]]:
                if n > 1:
                    return list(itertools.product(quantifiers, quantifiers, predicates))
                else:
                    return list(itertools.product(quantifiers, predicates))

            utterances = _generate_utterance(n, quantifiers, predicates)
            counts_array = np.array(counts_list)             # shape (num_outcomes, n)
            truth_dict = {}

            if n == 1:
                # Single experiment: utterance = (quantifier, predicate)
                for utter in utterances:
                    q, p = utter
                    if p == "successful":
                        vec = np.array([semantic_operators[q](j, m)
                                      for j in range(m + 1)])
                    else:  # p == "unsuccessful"
                        vec = np.array([semantic_operators[q](m - j, m)
                                      for j in range(m + 1)])
                    truth_vals = counts_array.dot(vec)
                    utter_str = ",".join(utter)
                    truth_dict[utter_str] = truth_vals
            else:
                # Multiple experiments: utterance = (quantifier1, quantifier2, predicate)
                for utter in utterances:
                    q1, q2, p = utter
                    if p == "successful":
                        vec = np.array([semantic_operators[q2](j, m)
                                      for j in range(m + 1)])
                    else:  # p == "unsuccessful"
                        vec = np.array([semantic_operators[q2](m - j, m)
                                      for j in range(m + 1)])
                    inner_sum = counts_array.dot(vec)
                    truth_func = np.vectorize(lambda x: semantic_operators[q1](x, n))
                    truth_vals = truth_func(inner_sum)
                    utter_str = ",".join(utter)
                    truth_dict[utter_str] = truth_vals

            # Keep the actual tuples as column labels
            freq_labels = counts_list
            df = pd.DataFrame(
                data = np.array(list(truth_dict.values())).T,
                index = freq_labels,
                columns = list(truth_dict.keys())
            )
            # Transpose so rows are utterances, columns are outcome‑tuples
            df = df.T

            uncovered = [obs for obs in df.columns if df[obs].sum() == 0]
            if uncovered:
                raise ValueError(
                    f"No utterance covers the following observations: {uncovered}"
                )

            return df

        except Exception as e:
            raise RuntimeError(f"Failed to compute utterance truth values: {str(e)}")


    def sample(
        self, 
        theta: float, 
        seed: Optional[int] = None, 
        reuse: bool = False
    ) -> Tuple[int, ...]:
        """
        Sample an observation set according to its probability under given theta.
    
        Parameters
        ----------
        theta : float
            The theta value to use for sampling (must be one of the predefined values).
        seed : Optional[int], default=None
            Random seed for reproducible sampling.
        reuse : bool, default=False
            Whether to reuse cached RNG if seed matches. If False, always creates new RNG.
            If True, reuses cached RNG when seed matches the previously used seed.
    
        Returns
        -------
        Tuple[int, ...]
            The sampled observation tuple.
        """
        if not 0 <= theta <= 1:
            raise ValueError("theta must be between 0 and 1")
    
        closest_theta = self.theta_values[np.abs(self.theta_values - theta).argmin()]
        if not np.isclose(theta, closest_theta, rtol=1e-10, atol=1e-10):
            raise ValueError(
                f"theta {theta} not found in theta_values. "
                f"Closest available value is {closest_theta}. "
                f"Available values are {self.theta_values}"
            )
    
        probabilities = np.exp(self.obs_log_likelihood_theta[closest_theta])
        
        # Manage cached RNG based on reuse parameter
        if reuse:
            # Try to reuse cached RNG if seed matches
            if (hasattr(self, '_cached_rng') and 
                hasattr(self, '_cached_seed') and 
                self._cached_seed == seed):
                rng = self._cached_rng
            else:
                # Create new RNG and cache it, with appropriate warning
                if not hasattr(self, '_cached_rng'):
                    warnings.warn(
                        f"reuse=True but no cached RNG exists. Creating new RNG with seed={seed}",
                        UserWarning
                    )
                else:
                    warnings.warn(
                        f"reuse=True but seed mismatch (cached: {self._cached_seed}, requested: {seed}). "
                        f"Creating new RNG with seed={seed}",
                        UserWarning
                    )
                rng = np.random.default_rng(seed)
                self._cached_rng = rng
                self._cached_seed = seed
        else:
            # Always create new RNG (original behavior)
            rng = np.random.default_rng(seed)
        
        sampled_observation = rng.choice(
            a=probabilities.index,
            p=probabilities.values
        )
        return sampled_observation
    

    def sample_run(
        self, 
        theta: float, 
        n_round: int, 
        run_seed: int) -> pd.DataFrame:
        """
        Sample multiple observations for a single simulation run.
        
        Parameters
        ----------
        theta : float
            The theta value to sample from (will find closest available theta)
        n_round : int
            Number of observations to sample in this run
        run_seed : int
            Random seed for reproducible sampling
            
        Returns
        -------
        pd.DataFrame
            DataFrame with columns: ['observation', 'theta', 'run_seed', 'round_index']
            Each row represents one sampled observation with its position in the sequence.
        """
        # Validate inputs
        if not isinstance(n_round, int) or n_round < 1:
            raise ValueError("n_round must be a positive integer") 
        
        if not 0 <= theta <= 1:
            raise ValueError("theta must be between 0 and 1")
        
        # Find closest theta
        closest_theta = self.theta_values[np.abs(self.theta_values - theta).argmin()]
        if not np.isclose(theta, closest_theta, rtol=1e-10, atol=1e-10):
            warnings.warn(
                f"theta {theta} not exactly in theta_values. Using closest: {closest_theta}",
                UserWarning
            )
        
        # Get probabilities for this theta (computed once)
        log_probs = self.obs_log_likelihood_theta[closest_theta]
        probabilities = np.exp(log_probs)
        observations_list = list(probabilities.index)
        prob_values = probabilities.values
        
        # Validation checks
        if not np.isclose(np.sum(prob_values), 1.0, rtol=1e-10):
            raise ValueError(f"Probabilities don't sum to 1: {np.sum(prob_values)}")
        if np.any(prob_values < 0):
            raise ValueError("Found negative probabilities")
        
        # Sample using seeded RNG for reproducibility
        rng = np.random.default_rng(run_seed)
        sampled_indices = rng.choice(
            len(observations_list), 
            size=n_round, 
            p=prob_values
        )
        
        # Convert indices to actual observations
        sampled_observations = [observations_list[idx] for idx in sampled_indices]
        
        return pd.DataFrame({
            "observation": sampled_observations,
            "theta": closest_theta,
            "run_seed": run_seed,
            "round_index": range(n_round)
        })
    
    def sample_multiple_runs(
        self, 
        theta: float, 
        n_run: int, 
        n_round: int, 
        base_seed: int = None
    ) -> pd.DataFrame:
        """
        Sample observations reproducibly for multiple simulation runs.
        
        Parameters
        ----------
        theta : float  
            The theta value to sample from (will find closest available theta)
        n_run : int
            Number of independent simulation runs
        n_round : int
            Number of observations to sample per run
        base_seed : int, default=None
            Base random seed for reproducibility. Each run gets base_seed + run_id
            
        Returns
        -------
        pd.DataFrame
            DataFrame with columns: ['theta', 'run_id', 'round_index', 'observation', 'run_seed']
            Each row represents one sampled observation, with round_index indicating 
            the sequence position (0 to n_round-1) within each run.
        """
        # Validate inputs
        if not isinstance(n_run, int) or n_run < 1:
            raise ValueError("n_run must be a positive integer")
        
        # Collect results from each run
        run_dataframes = []
        for run_id in range(n_run):
            run_seed = None if base_seed is None else base_seed + run_id
            
            # Use existing sample_run method
            run_df = self.sample_run(theta=theta, n_round=n_round, run_seed=run_seed)
            
            # Add run_id to distinguish between runs
            run_df['run_id'] = run_id
            
            run_dataframes.append(run_df)
        
        # Combine all runs into single DataFrame
        combined_df = pd.concat(run_dataframes, ignore_index=True)
        
        # Reorder columns for consistency with the original specification
        return combined_df[['theta', 'run_id', 'round_index', 'observation', 'run_seed']]
    
    
    @property
    def utterances(self) -> List[str]:
        """Get list of all possible utterances (as strings)."""
        return list(self.utterance_truth.index)

    @property
    def observations(self) -> List[Tuple[int, ...]]:
        """Get list of all possible observations (frequency tuples)."""
        return list(self.obs_log_likelihood_theta.index)

    @property
    def suc_likelihood_theta(self) -> pd.DataFrame:
        """
        Return the success likelihood table (actual probabilities)
        by exponentiating the log-likelihood table.
        """
        return np.exp(self.suc_log_likelihood_theta)

    @property
    def obs_likelihood_theta(self) -> pd.DataFrame:
        """
        Return the observation likelihood table (actual probabilities)
        by exponentiating the log-likelihood table.
        """
        return np.exp(self.obs_log_likelihood_theta)

In [None]:
if RUN_DEMO:
    test_world = World(n=5, m=7)
    print(test_world.n, "experments each with", test_world.m, "trials")
    print(test_world.sample(0.0))
    print(test_world.sample(0.5))
    print(test_world.sample(1.0))
    print()

    test_seed = 12
    test_true_theta = 0.7
    test_num_obs = 3
    print(f"Sample each observation inidividually") 
    print(f"Call .sample() {test_num_obs} times with initial seed ={test_seed} under theta of {test_true_theta}")
    for _ in range(test_num_obs):
        print(test_world.sample(0.7, seed = 12, reuse=True))
    print()
    
    print(f"Sample one sequence of observations") 
    print(f"Call .sample_run() one time with n_round = {test_num_obs} and initial seed ={test_seed}")
    print(test_world.sample_run(
        theta = test_true_theta,
        n_round = test_num_obs,
        run_seed =12))
    print()

    print(f"Sample two sequences of observations") 
    print(f"Call .sample_multiple_runs() one time with n_run = 2, n_round = {test_num_obs} and initial seed ={test_seed}")
    print(test_world.sample_multiple_runs(
        theta= test_true_theta, 
        n_run = 2, 
        n_round = test_num_obs, 
        base_seed = 12
    ))

## Literal Speaker

In [None]:
class LiteralSpeaker:
    """
    A literal speaker in the RSA communication game.

    This speaker observes outcomes from the `World` at each round, updates
    its beliefs over theta using Bayes' rule on the total number of successes,
    and selects an utterance uniformly at random among those that are
    semantically (literally) true of the observed data.

    Attributes
    ----------
    world : World
        The shared World model containing likelihoods and truth tables.
    un_current_log_belief : np.ndarray
        Unnormalized log-probabilities over theta values reflecting the speaker's prior/posterior.
    utterance_log_prob_obs : pd.DataFrame
        Log-probabilities of each utterance given each observation (frequency tuple).
        Rows are utterance strings, columns are observations.
    """

    def __init__(
        self,
        world: 'World',
        initial_beliefs_theta: Optional[np.ndarray] = None
    ) -> None:
        """
        Initialize the literal speaker.

        Parameters
        ----------
        world : World
            Instance of the World class, containing theta grid and likelihood tables.
        initial_beliefs_theta : np.ndarray, optional
            1D array of prior probabilities over theta values (must sum to 1).
            If None, a uniform prior is assumed.
        """
        self.world = world

        # Process or initialize the log-prior over theta
        self.un_current_log_belief = self._process_initial_beliefs(
            initial_beliefs_theta, self.world
        )

        # Precompute log P(u | O) for all utterance-observation pairs
        self.utterance_log_prob_obs = self._compute_utterance_log_prob_obs(
            self.world.utterance_truth
        )

    def _process_initial_beliefs(
        self,
        initial_beliefs_theta: Optional[np.ndarray],
        world: 'World'
    ) -> np.ndarray:
        """
        Validate and convert initial belief vector to log-space.

        Parameters
        ----------
        initial_beliefs_theta : Optional[np.ndarray]
            Initial beliefs over theta values, or None for uniform prior.
        world : World
            The World object containing theta_values.

        Returns
        -------
        np.ndarray
            Array of log-probabilities over theta_values.
        """
        n_theta = len(world.theta_values)

        # Uniform prior if none provided
        if initial_beliefs_theta is None:
            return np.full(n_theta, -np.log(n_theta), dtype=float)

        # Validate shape and range
        if not isinstance(initial_beliefs_theta, np.ndarray):
            raise ValueError("initial_beliefs_theta must be a numpy array")
        if initial_beliefs_theta.shape != (n_theta,):
            raise ValueError(
                f"initial_beliefs_theta length {initial_beliefs_theta.size} must match "
                f"number of theta values {n_theta}."
            )
        if not np.all((0 <= initial_beliefs_theta) & (initial_beliefs_theta <= 1)):
            raise ValueError("All probabilities must be between 0 and 1.")
        if not np.isclose(initial_beliefs_theta.sum(), 1.0):
            raise ValueError("Probabilities must sum to 1.")

        return np.log(initial_beliefs_theta)

    def _compute_utterance_log_prob_obs(
        self,
        utterance_truth: pd.DataFrame
    ) -> pd.DataFrame:
        """
        Compute the log-probabilities P(u | O) for every utterance and observation.

        For each observation O, the literal speaker chooses uniformly among all u
        such that Truth(u; O) = 1.

        Parameters
        ----------
        utterance_truth : pd.DataFrame
            Truth values for utterance-observation pairs.

        Returns
        -------
        pd.DataFrame
            DataFrame with rows = utterances, columns = observations,
            and values = log P(u | O).
        """
        truth = utterance_truth.astype(bool)

        # Count how many utterances are true for each observation
        true_counts = truth.sum(axis=0)

        # Log-prob for each (u, O) is -log(num_true_utterances(O)) if true, else -inf
        base_logp = -np.log(true_counts.values)
        logp_matrix = np.tile(base_logp, (truth.shape[0], 1))

        df = pd.DataFrame(
            data=logp_matrix,
            index=truth.index,
            columns=truth.columns
        )

        # Mask out false utterance-observation pairs
        return df.where(truth, -np.inf)

    def update_and_speak(self, observation: Tuple[int, ...]) -> str:
        """
        Given a new frequency-tuple observation, update beliefs and sample an utterance.

        Parameters
        ----------
        observation : tuple of int
            A frequency tuple (n_0, n_1, ..., n_m) observed this round.

        Returns
        -------
        str
            The chosen utterance (as a comma-separated string).

        Raises
        ------
        ValueError
            If the observation is not in the world's possible outcomes.
        RuntimeError
            If belief update or utterance sampling fails.
        """
        # 1) Validate observation
        if observation not in self.world.observations:
            raise ValueError(f"Observation {observation} not supported by the world.")

        # 2) Compute total successes S = sum_j (j * n_j)
        counts = np.array(observation)
        successes = int(counts.dot(np.arange(self.world.m + 1)))

        # 3) Belief update in log-space: log P_new(theta) ∝ log P_old(theta) + log P(S | theta)
        try:
            log_lik = self.world.suc_log_likelihood_theta.loc[successes].values
            self.un_current_log_belief = self.un_current_log_belief + log_lik
        except Exception as e:
            raise RuntimeError(f"Belief update failed: {e}")

        # 4) Sample utterance: P(u|O) stored in utterance_log_prob_obs[observation]
        try:
            uttrs = self.world.utterance_truth.loc[:, [observation]]
            uttrs_true = uttrs.index[uttrs.iloc[:, 0] == 1].tolist()
            if not uttrs_true:
                raise RuntimeError(f"No valid utterances for observation {observation}")
            return np.random.choice(uttrs_true)
        except Exception as e:
            raise RuntimeError(f"Utterance sampling failed: {e}")

    @property
    def current_belief_theta(self) -> np.ndarray:
        """
        Return the current beliefs normalized so they form a valid distribution.

        Exponentiating and summing to 1.
        """
        return np.exp(log_column_normalize(self.un_current_log_belief[:, None],
                                           precise= USE_PRECISE_LOGSPACE).ravel())


In [None]:
if RUN_DEMO:
    
    test_LS = LiteralSpeaker(test_world)
    test_sample = test_world.sample(0.0)
    print("Sample a ")
    print(test_sample)
    test_uttr = test_LS.update_and_speak(test_sample)
    print(test_uttr)
    print(np.round(test_LS.current_belief_theta, 3))

    test_LS = LiteralSpeaker(test_world)

## Literal Listener

In [None]:
class LiteralListener:
    """
    A literal listener in the RSA communication game.

    The listener hears an utterance from a literal speaker S_0 and updates
    its beliefs over the hidden parameter theta using Bayes' rule.
    The update uses the speaker model P_S0(u | theta) which is computed
    by marginalizing over possible observations.

    Attributes
    ----------
    world : World
        The shared World model containing likelihoods and truth tables.
    un_current_log_belief : np.ndarray
        The listener's current unnormalized log-probabilities over theta values (log P(theta)).
    literal_speaker : LiteralSpeaker
        A helper speaker instance used to access P_S0(u | O) tables.
    utterance_log_likelihood_theta : pd.DataFrame
        Log-likelihoods log P_S0(u | theta) for all utterances u and theta values.
        Rows are utterance strings, columns are theta values.
    theta_log_post_utterance : pd.DataFrame
        Unnormalized log-posteriors for each utterance and theta.
        Rows are theta values, columns are utterances.
    """

    def __init__(
        self,
        world: 'World',
        initial_beliefs_theta: Optional[np.ndarray] = None
    ) -> None:
        """
        Initialize the literal listener.

        Parameters
        ----------
        world : World
            Instance of the World class that provides theta grid, likelihoods,
            and truth tables.
        initial_beliefs_theta : np.ndarray, optional
            1D array of prior probabilities over theta values (must sum to 1).
            If None, a uniform prior is assumed.
        """
        self.world = world

        try:
            # Set up the initial log-prior over theta
            self.un_current_log_belief = self._process_initial_beliefs(
                initial_beliefs_theta, self.world
            )

            # Instantiate a literal speaker to access P(u|O) = utterance_log_prob_obs
            self.literal_speaker = LiteralSpeaker(self.world, initial_beliefs_theta)

            # Precompute P(u|theta) = sum_O P(u|O) P(O|theta) in log-space
            self.utterance_log_likelihood_theta = self._compute_utterance_log_likelihood_theta(
                self.literal_speaker.utterance_log_prob_obs,
                self.world.obs_log_likelihood_theta
                )

            # Combine with prior to get unnormalized log-posteriors for each utterance
            self.theta_log_post_utterance = self._compute_theta_log_post_utterance(
                self.utterance_log_likelihood_theta,
                self.un_current_log_belief
                )

        except Exception as e:
            raise RuntimeError(f"Failed to initialize listener: {str(e)}")

    def _process_initial_beliefs(
        self,
        initial_beliefs_theta: Optional[np.ndarray],
        world: 'World'
    ) -> np.ndarray:
        """
        Validate and convert the initial belief vector into log-space.

        Parameters
        ----------
        initial_beliefs_theta : Optional[np.ndarray]
            Initial beliefs over theta values, or None for uniform prior.
        world : World
            The World object containing theta_values.

        Returns
        -------
        np.ndarray
            Array of log-probabilities over theta_values.
        """
        n_theta = len(world.theta_values)

        # Use a uniform prior if none provided: log(1 / n_theta)
        if initial_beliefs_theta is None:
            return np.full(n_theta, -np.log(n_theta), dtype=float)

        # Validate that the supplied prior is a proper probability distribution
        if not isinstance(initial_beliefs_theta, np.ndarray):
            raise ValueError("initial_beliefs_theta must be a numpy array")
        if initial_beliefs_theta.shape != (n_theta,):
            raise ValueError(
                f"initial_beliefs_theta length {initial_beliefs_theta.size} must match "
                f"number of theta values {n_theta}."
            )
        if not np.all((0 <= initial_beliefs_theta) & (initial_beliefs_theta <= 1)):
            raise ValueError("All probabilities must be between 0 and 1.")
        if not np.isclose(initial_beliefs_theta.sum(), 1.0):
            raise ValueError("Probabilities must sum to 1.")

        # Convert to log-space for numerical stability
        return np.log(initial_beliefs_theta)

    def _compute_utterance_log_likelihood_theta(
        self,
        utterance_log_prob_obs: pd.DataFrame,
        obs_log_likelihood_theta: pd.DataFrame) -> pd.DataFrame:
        """
        Compute the log-likelihood P(u | theta) for each utterance and theta.

        This uses the precomputed P(u | O) from the literal speaker and the
        observation likelihood P(O | theta) from the world, performing a
        log-space matrix product to marginalize over observations.

        Parameters
        ----------
        utterance_log_prob_obs: pd.DataFrame
            Log-probabilities of each utterance given each observation
            (frequency tuple) from self.literal_speaker.
            Rows are utterance strings, columns are observations.
        obs_log_likelihood_theta: pd.DataFrame
            Log-probabilities of each observation (frequency tuple) given each
            theta from self.world.
            Rows are frequency tuples, columns are thetas.

        Returns
        -------
        pd.DataFrame
            Rows are utterance strings, columns are theta values,
            entries are log P(u | theta).
        """

        try:
            # log_M_product performs a numerically stable log-sum-exp matrix multiply
            return pd.DataFrame(
                log_M_product(
                    utterance_log_prob_obs.values,
                    obs_log_likelihood_theta.values,
                    precise= USE_PRECISE_LOGSPACE
                ),
                index=utterance_log_prob_obs.index,
                columns=obs_log_likelihood_theta.columns
            )

        except Exception as e:
            raise RuntimeError(f"Failed to compute utterance likelihoods: {str(e)}")

    def _compute_theta_log_post_utterance(
        self,
        utterance_log_likelihood_theta: pd.DataFrame,
        un_current_log_belief: np.ndarray) -> pd.DataFrame:
        """
        Compute unnormalized log-posteriors for every utterance.

        Combines the listener's current log-prior with the
        log-likelihood log P(u | theta) to produce
        log P(theta, u) = log P(theta) + log P(u | theta).

        Parameters
        ----------
        utterance_log_likelihood_theta: pd.DataFrame
            Rows are utterance strings, columns are theta values,
            entries are log P(u | theta).
        un_current_log_belief: np.ndarray
            Array of unnormalized log-belief/probabilities over theta_values.

        Returns
        -------
        pd.DataFrame
            Rows are theta values, columns are utterances.
            entries are unnormalized log P(theta | u).
        """
        try:
            # Broadcasting adds the log-prior vector to each column of P(u|theta)
             return (utterance_log_likelihood_theta + un_current_log_belief).T
             # transpose so rows=theta, cols=utterance

        except Exception as e:
            raise RuntimeError(f"Failed to update posterior distributions: {str(e)}")

    def listen_and_update(self, utterance: str) -> None:
        """
        Incorporate a received utterance and update the listener's beliefs.

        Parameters
        ----------
        utterance : str
            The utterance string produced by the speaker in this round.

        Updates
        -------
        un_current_log_belief : np.ndarray
            Replaces with log P(theta | utterance) (unnormalized).
        theta_log_post_utterance : pd.DataFrame
            Recomputed table for the next round's potential updates.

        Raises
        ------
        ValueError
            If the utterance is not recognized.
        RuntimeError
            If the belief update fails.
        """
        # Ensure the utterance is valid in this world
        if utterance not in self.world.utterances:
            raise ValueError(
                f"Utterance '{utterance}' not found in possible utterances.\n"
                f"Valid utterances: {self.world.utterances}"
            )

        try:
            # 1) Fetch the log-posteriors for this utterance
            self.un_current_log_belief = self.theta_log_post_utterance[utterance].values

            # 2) Recompute posteriors for next round
            self.theta_log_post_utterance = self._compute_theta_log_post_utterance(
                self.utterance_log_likelihood_theta,
                self.un_current_log_belief
                )

        except Exception as e:
            raise RuntimeError(f"Failed to update beliefs: {str(e)}")

    @property
    def current_belief_theta(self) -> np.ndarray:
        """
        Return the listener's normalized belief over theta as probabilities.

        Exponentiates and normalizes the current log-belief to sum to 1.
        """
        # Use stable log-column normalize and exponentiate
        return np.exp(log_column_normalize(self.un_current_log_belief[:, None],
                                           precise= USE_PRECISE_LOGSPACE).ravel())

In [None]:
if RUN_DEMO:
    test_LL = LiteralListener(test_world)

    print("some,most,unsuccessful")
    test_LL.listen_and_update("some,most,unsuccessful")
    print(np.round(test_LL.current_belief_theta, 3))

    test_LL = LiteralListener(test_world)

## Pragmatic Speaker

In [None]:
class PragmaticSpeaker_obs:

    """
    A pragmatic speaker (S1) in an RSA-style communication game.

    This speaker balances literal truth, informativeness, and persuasiveness
    when choosing an utterance.  It relies on:
      - a LiteralSpeaker to provide P_S0(u | O)
      - a LiteralListener to track P_L0(theta) and related posteriors

    Attributes
    ----------
    world : World
        The shared World model with likelihoods and truth tables.
    omega : str
        Type of world:
        "coop" for cooperative world where speakers are all informative,
        "strat" for stratigic world where speakers can be also persuasive,
    psi : str
        Speaker goal: "inf" for purely informative,
        "pers+" to persuade the listener up,
        "pers-" to persuade the listener down.
    alpha : float or "determ"
        Softmax temperature (or "determ" for deterministic tie-split).
    beta : float
        Weight on informativeness (beta=1 for pure info, 0 for pure persuasion).
    update_internal : bool
        If True, update the literal listener's internal state after speaking.
    literal_speaker : LiteralSpeaker
        Helper to access P_S0(u | O).
    literal_listener : LiteralListener
        Helper to track listener beliefs P_L0(theta) and P_L0(O).
    utility : pd.DataFrame
        (U × O) matrix of log-utilities V(u; O).
    utterance_log_prob_obs : pd.DataFrame
        (U × O) matrix of log P_S1(u | O).
    """

    VALID_OMEGA_TYPES = {"coop", "strat"}
    VALID_PSI_TYPES = {"inf", "pers+", "pers-"}

    def __init__(
        self,
        world: 'World',
        omega: str,
        psi: str,
        update_internal: bool,
        alpha: Union[float, str],
        beta: float = 0.0,
        initial_beliefs_theta: Optional[np.ndarray] = None
    ) -> None:
        """Initialize the pragmatic speaker."""
        if omega not in self.VALID_OMEGA_TYPES:
            raise ValueError(
                f"omega must be one of {self.VALID_OMEGA_TYPES}, got '{omega}'"
            )
        if psi not in self.VALID_PSI_TYPES:
            raise ValueError(
                f"psi must be one of {self.VALID_PSI_TYPES}, got '{psi}'"
            )

        self.world = world
        self.omega = omega

        if self.omega == "coop" and psi != "inf":
            warnings.warn("when omega == coop, psi is forced to inf",
                            UserWarning)
            self.psi = "inf"
        else:
            self.psi = psi

        self.alpha = alpha
        self.beta = beta
        self.update_internal = update_internal

        try:
            # Initialize literal speaker keeping optimal Bayeisan belief in theta
            self.literal_speaker = LiteralSpeaker(world, initial_beliefs_theta)
            # Initialize literal listener as internal listner model
            self.literal_listener = LiteralListener(world, initial_beliefs_theta)

            # Compute utterance probabilities (this will cascade through all calculations)
            self.utterance_log_prob_obs = self._compute_utterance_log_prob_obs(self.alpha)

        except Exception as e:
            raise RuntimeError(f"Failed to initialize pragmatic speaker: {str(e)}")

    def _compute_log_informativeness(
        self,
        obs_log_likelihood_theta_values: np.ndarray,
        un_current_log_belief: np.ndarray,
        utterance_log_prob_obs: pd.DataFrame
    ) -> Tuple[np.ndarray, pd.DataFrame]:
        """
        Compute log-Inf(u; O) = log-P_L0(O|u).

        Parameters
        ----------
        obs_log_likelihood_theta_values : np.ndarray
            Observation log-likelihoods for each theta.
        un_current_log_belief : np.ndarray
            Unnormalized log-beliefs over theta values.
        utterance_log_prob_obs : pd.DataFrame
            Log P(u|O) for all utterances and observations.

        Returns
        -------
        Tuple[np.ndarray, pd.DataFrame]
            First: unnormalized_log_prob_O - Unnormalized log probability for each observation
            Second: log_informativeness - DataFrame with log-informativeness for each utterance and observation
        """
        # Compute unnormalized log-P(O)
        unnormalized_log_prob_O = log_M_product(
            obs_log_likelihood_theta_values,
            un_current_log_belief[:, np.newaxis],
            precise= USE_PRECISE_LOGSPACE
        ).flatten()

        # Compute log-P_L0(O|u)
        unnormalized_obs_log_post_utterance = (
            utterance_log_prob_obs +
            unnormalized_log_prob_O[np.newaxis, :]
        ).T

        obs_log_post_utterance = pd.DataFrame(
            log_column_normalize(unnormalized_obs_log_post_utterance.values,
                                 precise= USE_PRECISE_LOGSPACE),
            index=unnormalized_obs_log_post_utterance.index,
            columns=unnormalized_obs_log_post_utterance.columns
        )

        # Return unnormalized_log_prob_O and log-informativeness
        return unnormalized_log_prob_O, obs_log_post_utterance.T

    def _compute_log_persuasiveness(
        self,
        psi: str,
        theta_values: np.ndarray,
        theta_log_post_utterance_values: np.ndarray
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Compute log-PersStr(u; psi) based on speaker type.

        Parameters
        ----------
        psi : str
            Speaker goal: "inf", "pers+", or "pers-".
        theta_values : np.ndarray
            Array of possible theta values.
        theta_log_post_utterance_values : np.ndarray
            Unnormalized log-posteriors for each theta and utterance.

        Returns
        -------
        Tuple[pd.DataFrame, pd.DataFrame]
            First: theta_log_expectation_utterance - DataFrame with expectation of theta for each utterance
            Second: log_persuasiveness - DataFrame with log-persuasiveness for each utterance and observation
        """
        # Compute expectation log-E[theta|u]
        theta_values_zero = theta_values == 0
        log_theta_values = np.log(np.clip(theta_values, np.finfo(float).tiny, None))
        log_theta_values[theta_values_zero] = -np.inf

        theta_log_expectation_utterance = pd.DataFrame(
            log_M_product(
                log_theta_values[np.newaxis, :],
                log_column_normalize(theta_log_post_utterance_values,
                                     precise= USE_PRECISE_LOGSPACE),
                precise= USE_PRECISE_LOGSPACE
            ),
            columns=self.world.utterances
        )

        # Compute persuasiveness based on psi (speaker) type
        n_cols = len(self.world.possible_outcomes)

        # Persuade-up utility
        if psi == "pers+":
            values = theta_log_expectation_utterance.values.flatten()

        # Persuade-down utility: log(1 - E(theta)) = log(E(1 -theta))
        elif psi == "pers-":
            theta_values_one = theta_values == 1
            log_one_minus_theta_values = np.log(
                np.clip(1 - theta_values, np.finfo(float).tiny, None)
            )
            log_one_minus_theta_values[theta_values_one] = -np.inf

            values = log_M_product(
                log_one_minus_theta_values[np.newaxis, :],
                log_column_normalize(theta_log_post_utterance_values,
                                     precise= USE_PRECISE_LOGSPACE),
                precise= USE_PRECISE_LOGSPACE
            ).flatten()

        # Purely informative
        else:  # psi == "inf"
            values = np.zeros(len(self.world.utterances))

        # Create persuasiveness DataFrame
        log_persuasiveness = pd.DataFrame(
            np.tile(values, (n_cols, 1)).T,
            index=self.world.utterances,
            columns=self.world.possible_outcomes
        )

        return theta_log_expectation_utterance, log_persuasiveness

    def _compute_utility(self, psi: str, beta: float) -> pd.DataFrame:
        """
        Compute utility V(u; O, psi) combining truth, informativeness, and persuasiveness.

        Parameters
        ----------
        psi : str
            Speaker goal: "inf", "pers+", or "pers-".
        beta : float
            Weight on informativeness (beta=1 for pure info, 0 for pure persuasion).

        Returns
        -------
        pd.DataFrame
            DataFrame with rows as utterances, columns as observations, values as utility.
        """
        try:
            # Compute informativeness
            self.unnormalized_log_prob_O, self.log_informativeness = self._compute_log_informativeness(
                self.world.obs_log_likelihood_theta.values,
                self.literal_listener.un_current_log_belief,
                self.literal_listener.literal_speaker.utterance_log_prob_obs
            )

            # Compute persuasiveness
            self.theta_log_expectation_utterance, self.log_persuasiveness = self._compute_log_persuasiveness(
                psi,
                self.world.theta_values,
                self.literal_listener.theta_log_post_utterance.values
            )

            # mask of literally false (Truth=0)
            uttr_false = (self.world.utterance_truth == 0)

            if psi == "inf":
                # Purely informative
                util = self.log_informativeness.copy()
            elif beta == 0:
                # Purely persuasive
                util = self.log_persuasiveness.copy()
            else:
                # Mixed informative & persuasive
                # Mask out any utterance that was impossible in either term
                impossible = (
                    (self.log_informativeness == -np.inf) |
                    (self.log_persuasiveness == -np.inf)
                )
                # Weighted sum in log-space
                inf_term = self.log_informativeness.copy().clip(lower=-np.finfo(float).max)
                pers_term = self.log_persuasiveness.copy().clip(lower=-np.finfo(float).max)
                util = beta * inf_term + (1 - beta) * pers_term
                util[impossible] = -np.inf

            # Finally enforce literal truth
            util[uttr_false] = -np.inf

            # Store utility for reference
            #self.utility = util
            return util

        except Exception as e:
            raise RuntimeError(f"Failed to compute utility: {str(e)}")

    def _compute_utterance_log_prob_obs(self, alpha: Union[float, str]) -> pd.DataFrame:
        """
        Compute log probabilities of utterances given observations.
        This is the main computation pipeline that calls all other computations.

        Parameters
        ----------
        alpha : float or "determ"
            Softmax temperature parameter.

        Returns
        -------
        pd.DataFrame
            DataFrame with rows as utterances, columns as observations,
            values as log P(u|O).
        """
        try:
            # First compute utility (this will cascade to compute informativeness and persuasiveness)
            self.utility = self._compute_utility(self.psi, self.beta)

            # Then apply softmax to get utterance probabilities
            return pd.DataFrame(
                log_column_softmax(
                    self.utility.values,
                    alpha,
                    precise= USE_PRECISE_LOGSPACE),
                index=self.utility.index,
                columns=self.utility.columns
            )

        except Exception as e:
            raise RuntimeError(f"Failed to compute utterance probability table: {str(e)}")

    def update_and_speak(self, observation: Tuple[int, ...]) -> str:
        """Given an observation, update beliefs and sample an utterance."""
        try:
            # Bayesian optimal update using observation
            # i.e. update literal speaker's beliefs
            self.literal_speaker.update_and_speak(observation)

            # Sample utterance according to P(u|O)
            utterance_log_probs = self.utterance_log_prob_obs.loc[:, [observation]]
            selected_utterance = np.random.choice(
                utterance_log_probs.index,
                p=np.exp(utterance_log_probs.values.flatten())
            )

            # Only update internal state if update_internal is True
            if self.update_internal:
                self.literal_listener.listen_and_update(selected_utterance)

                # Recompute utterance probabilities (this will cascade through all calculations)
                self.utterance_log_prob_obs = self._compute_utterance_log_prob_obs(self.alpha)

            return selected_utterance

        except Exception as e:
            raise RuntimeError(f"Failed to update and select utterance: {str(e)}")

    @property
    def current_belief_theta(self) -> np.ndarray:
        """
        Return the speaker's current normalized belief over theta
        (in linear space, summing to 1).
        """
        # Grab the unnormalized log-belief from the embedded LiteralSpeaker
        log_bel = self.literal_speaker.un_current_log_belief
        # Normalize each theta-column in log-space and exponentiate
        normalized = np.exp(
            log_column_normalize(log_bel[:, None],
                                 precise= USE_PRECISE_LOGSPACE)  # shape: (theta,1)
        ).ravel()  # back to shape (theta,)
        return normalized

In [None]:
if RUN_DEMO:
    test_PS = PragmaticSpeaker_obs(
        world = test_world,
        omega = "strat",
        psi = "pers+",
        update_internal = False,
        alpha = 4.0)

    print(test_PS.current_belief_theta)
    test_sample = test_world.sample(0.0)
    print(test_sample)
    test_uttr = test_PS.update_and_speak(test_sample)
    print(test_uttr)
    print(np.round(test_PS.current_belief_theta, 3))

    test_PS = PragmaticSpeaker_obs(
        world = test_world,
        omega = "strat",
        psi = "pers+",
        update_internal = False,
        alpha = 4.0)

# Sampling utterances and observations

In [None]:
import numpy as np
import pandas as pd
import warnings
from typing import List, Tuple, Dict, Optional, Any, Union
from joblib import Parallel, delayed, cpu_count

from rsa_optimal_exp_core import (
    World, LiteralSpeaker, PragmaticSpeaker_obs, USE_PRECISE_LOGSPACE
)

np.seterr(divide='ignore', under='ignore')

In [None]:
TRUE_ALPHA = 4.0
TRUE_SPEAKER_CONFIGS = [
    {
        "speaker_type": "literal"
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "inf",
        "alpha": TRUE_ALPHA,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "inf",
        "alpha": TRUE_ALPHA,
        "update_internal": True
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers+",
        "alpha": TRUE_ALPHA,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers+",
        "alpha": TRUE_ALPHA,
        "update_internal": True
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers-",
        "alpha": TRUE_ALPHA,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers-",
        "alpha": TRUE_ALPHA,
        "update_internal": True
    },
]

## Demostration (all-in-one, not optimized)

In [None]:
def generate_simulation_data(
    n: int,
    m: int,
    thetas: Union[List[float], np.ndarray],
    T: int,
    speaker_config: Dict[str, Any],
    n_obs_seq: int,
    n_utt_seq: int,
    random_seed: Optional[int] = None,
    theta_values: Optional[np.ndarray] = None,
    return_format: str = "nested",
    compute_obs_likelihood: str = "none"
) -> Union[Dict[str, Any], pd.DataFrame]:
    """
    Generate simulated observation and utterance sequences for multiple theta values.

    This function creates a complete simulation dataset by:
    1. Creating a World instance
    2. For each theta: sampling multiple observation sequences
    3. For each observation sequence: generating multiple utterance sequences using the speaker model
    4. Optionally: computing observation sequence likelihoods

    Parameters
    ----------
    n : int
        Number of independent experiments in the World.
    m : int
        Number of Bernoulli trials per experiment in the World.
    thetas : List[float] or np.ndarray
        List of theta values to simulate. Each must be in the World's theta_values.
    T : int
        Length of each observation/utterance sequence (number of rounds).
    speaker_config : Dict[str, Any]
        Speaker configuration for generating utterances.
        Required keys:
        - speaker_type : str
            "literal" or "pragmatic"
        For pragmatic speaker (required):
        - omega : str
            "coop" or "strat"
        - psi : str
            "inf", "pers+", or "pers-"
        - alpha : float or "determ"
            Softmax temperature
        - update_internal : bool
            Whether to update internal listener model
        For pragmatic speaker (optional):
        - beta : float
            Weight on informativeness (default 0.0)
        - initial_beliefs_theta : np.ndarray or None
            Initial prior over theta (default None = uniform)
    n_obs_seq : int
        Number of observation sequences to generate per theta.
    n_utt_seq : int
        Number of utterance sequences to generate per observation sequence.
    random_seed : Optional[int], default None
        Base random seed for reproducibility.
        - Observation sampling uses: random_seed + theta_idx * 1000 + obs_idx
        - Utterance sampling uses: random_seed + theta_idx * 1000000 + obs_idx * 1000 + utt_idx
    theta_values : Optional[np.ndarray], default None
        Custom theta values for the World. If None, uses World's default.
    return_format : str, default "nested"
        Output format:
        - "nested": Nested dictionary structure
        - "dataframe": Flattened pandas DataFrame (one row per utterance sequence)
    compute_obs_likelihood : str, default "none"
        Whether to compute observation sequence log-likelihoods:
        - "none": Don't compute likelihoods
        - "true": Compute log P(obs_seq | true_theta) only
            Adds column: obs_seq_log_likelihood_true
        - "all": Compute log P(obs_seq | theta) for all thetas in the world
            Adds columns: obs_seq_log_likelihood_true, obs_seq_log_likelihood_0p0, 
            obs_seq_log_likelihood_0p1, ..., obs_seq_log_likelihood_1p0, obs_seq_mle_theta

    Returns
    -------
    If return_format == "nested":
        Dict[str, Any] with structure:
        {
            "world": World,
            "config": {
                "n": int, "m": int, "thetas": list, "T": int,
                "speaker_config": dict, "n_obs_seq": int, "n_utt_seq": int,
                "random_seed": int or None, "compute_obs_likelihood": str
            },
            "data": {
                theta_1: {
                    "obs_sequences": [
                        {
                            "obs_seq_idx": int,
                            "obs_seq": List[Tuple[int, ...]],
                            "run_seed": int,
                            "log_likelihood": float,  # if compute_obs_likelihood == "true"
                            "log_likelihood_all_theta": Dict[float, float],  # if compute_obs_likelihood == "all"
                            "mle_theta": float,  # if compute_obs_likelihood == "all"
                            "utt_sequences": [...]
                        },
                        ...
                    ]
                },
                ...
            }
        }

    If return_format == "dataframe":
        pd.DataFrame with one row per utterance sequence:
        - theta: float (true theta used for simulation)
        - obs_seq_idx: int
        - utt_seq_idx: int
        - obs_seq: List[Tuple[int, ...]] (full observation sequence)
        - utt_seq: List[str] (full utterance sequence)
        - run_seed: int
        - utt_seed: int
        
        If compute_obs_likelihood == "true":
        - obs_seq_log_likelihood_true: float
        
        If compute_obs_likelihood == "all":
        - obs_seq_log_likelihood_true: float
        - obs_seq_log_likelihood_0p0: float (log P(obs_seq | theta=0.0))
        - obs_seq_log_likelihood_0p1: float (log P(obs_seq | theta=0.1))
        - ... (one column per theta in world.theta_values)
        - obs_seq_log_likelihood_1p0: float (log P(obs_seq | theta=1.0))
        - obs_seq_mle_theta: float

    Raises
    ------
    ValueError
        If parameters are invalid or theta values not in World's theta_values.

    Examples
    --------
    >>> result = generate_simulation_data(
    ...     n=3, m=2,
    ...     thetas=[0.3, 0.5, 0.7],
    ...     T=10,
    ...     speaker_config={
    ...         "speaker_type": "pragmatic",
    ...         "omega": "coop",
    ...         "psi": "inf",
    ...         "alpha": 5.0,
    ...         "update_internal": True
    ...     },
    ...     n_obs_seq=5,
    ...     n_utt_seq=3,
    ...     random_seed=42,
    ...     compute_obs_likelihood="all",
    ...     return_format="dataframe"
    ... )
    >>> # Each row is one utterance sequence
    >>> print(result.columns.tolist())
    ['theta', 'obs_seq_idx', 'utt_seq_idx', 'obs_seq', 'utt_seq', 'run_seed', 'utt_seed',
     'obs_seq_log_likelihood_true', 'obs_seq_log_likelihood_0p0', ..., 'obs_seq_mle_theta']
    """

    # --- Input Validation ---
    if not isinstance(n, int) or n < 1:
        raise ValueError("n must be a positive integer")
    if not isinstance(m, int) or m < 1:
        raise ValueError("m must be a positive integer")
    if not isinstance(T, int) or T < 1:
        raise ValueError("T must be a positive integer")
    if not isinstance(n_obs_seq, int) or n_obs_seq < 1:
        raise ValueError("n_obs_seq must be a positive integer")
    if not isinstance(n_utt_seq, int) or n_utt_seq < 1:
        raise ValueError("n_utt_seq must be a positive integer")
    if return_format not in ["nested", "dataframe"]:
        raise ValueError("return_format must be 'nested' or 'dataframe'")
    if compute_obs_likelihood not in ["none", "true", "all"]:
        raise ValueError("compute_obs_likelihood must be 'none', 'true', or 'all'")

    thetas = list(thetas) if isinstance(thetas, np.ndarray) else thetas
    if len(thetas) == 0:
        raise ValueError("thetas cannot be empty")

    # --- Validate speaker config ---
    speaker_type = speaker_config.get("speaker_type")
    if speaker_type not in ["literal", "pragmatic"]:
        raise ValueError(
            f"speaker_type must be 'literal' or 'pragmatic', got '{speaker_type}'"
        )

    if speaker_type == "pragmatic":
        required_keys = ["omega", "psi", "alpha", "update_internal"]
        missing_keys = [k for k in required_keys if k not in speaker_config]
        if missing_keys:
            raise ValueError(
                f"Missing required keys for pragmatic speaker: {missing_keys}"
            )

    # --- Create World ---
    try:
        world = World(n=n, m=m, theta_values=theta_values)
    except Exception as e:
        raise RuntimeError(f"Failed to create World: {e}")

    # --- Validate thetas against World's theta_values ---
    for theta in thetas:
        closest = world.theta_values[np.abs(world.theta_values - theta).argmin()]
        if not np.isclose(theta, closest, rtol=1e-10, atol=1e-10):
            raise ValueError(
                f"theta {theta} not in World's theta_values. "
                f"Closest: {closest}. Available: {world.theta_values}"
            )

    # --- Prepare base config for speaker creation ---
    base_speaker_config = {
        "speaker_type": speaker_type,
        "initial_beliefs_theta": speaker_config.get("initial_beliefs_theta", None)
    }
    if speaker_type == "pragmatic":
        base_speaker_config.update({
            "omega": speaker_config["omega"],
            "psi": speaker_config["psi"],
            "alpha": speaker_config["alpha"],
            "update_internal": speaker_config["update_internal"],
            "beta": speaker_config.get("beta", 0.0)
        })

    # --- Helper function to convert theta to column name ---
    def theta_to_colname(theta_val: float) -> str:
        """Convert theta value to column name: 0.1 -> '0p1', 0.25 -> '0p25'"""
        return f"obs_seq_log_likelihood_{str(theta_val).replace('.', 'p')}"

    # --- Generate Data ---
    data = {}
    flat_records = []  # For dataframe format (one record per utterance sequence)

    for theta_idx, theta in enumerate(thetas):
        theta_data = {"obs_sequences": []}

        # Sample observation sequences for this theta
        obs_base_seed = (random_seed + theta_idx * 1000) if random_seed is not None else None

        try:
            obs_df = world.sample_multiple_runs(
                theta=theta,
                n_run=n_obs_seq,
                n_round=T,
                base_seed=obs_base_seed if obs_base_seed is not None else 123
            )
        except Exception as e:
            raise RuntimeError(
                f"Failed to sample observations for theta={theta}: {e}"
            )

        # --- Collect all observation sequences for this theta (for batch likelihood) ---
        all_obs_seqs_for_theta = []
        obs_seq_metadata = []
        
        for obs_idx in range(n_obs_seq):
            run_df = obs_df[obs_df["run_id"] == obs_idx].sort_values("round_index")
            obs_seq = list(run_df["observation"])
            run_seed = run_df["run_seed"].iloc[0]
            
            all_obs_seqs_for_theta.append(obs_seq)
            obs_seq_metadata.append({
                "obs_seq_idx": obs_idx,
                "obs_seq": obs_seq,
                "run_seed": run_seed
            })

        # --- Compute observation likelihoods (VECTORIZED) ---
        obs_log_likelihoods_true = None   # log P(obs_seq | true_theta)
        obs_log_likelihoods_all = None    # log P(obs_seq | all thetas) - shape (n_obs_seq, n_thetas)
        obs_mle_thetas = None             # argmax_theta P(obs_seq | theta)
        
        if compute_obs_likelihood != "none" and len(all_obs_seqs_for_theta) > 0:
            # Flatten all observations for batch lookup
            all_obs_flat = [obs for seq in all_obs_seqs_for_theta for obs in seq]
            all_obs_keys = [tuple(obs) if not isinstance(obs, tuple) else obs for obs in all_obs_flat]
            
            # Batch lookup: (n_obs_seq * T, n_thetas)
            log_probs_flat = world.obs_log_likelihood_theta.loc[all_obs_keys].values
            
            # Reshape to (n_obs_seq, T, n_thetas)
            n_thetas = len(world.theta_values)
            log_probs_3d = log_probs_flat.reshape(n_obs_seq, T, n_thetas)
            
            # Sum over T → (n_obs_seq, n_thetas)
            log_liks_all_theta = log_probs_3d.sum(axis=1)
            
            # Get theta index for true theta
            theta_col_idx = np.where(np.isclose(world.theta_values, theta))[0][0]
            
            # Extract log-likelihood at true theta
            obs_log_likelihoods_true = log_liks_all_theta[:, theta_col_idx]
            
            if compute_obs_likelihood == "all":
                obs_log_likelihoods_all = log_liks_all_theta  # (n_obs_seq, n_thetas)
                obs_mle_thetas = world.theta_values[np.argmax(log_liks_all_theta, axis=1)]

        # --- Process each observation sequence ---
        for obs_idx, meta in enumerate(obs_seq_metadata):
            obs_seq = meta["obs_seq"]
            run_seed = meta["run_seed"]
            
            obs_seq_data = {
                "obs_seq_idx": meta["obs_seq_idx"],
                "obs_seq": obs_seq,
                "run_seed": run_seed,
                "utt_sequences": []
            }
            
            # Add likelihood data to nested format
            if compute_obs_likelihood in ["true", "all"]:
                obs_seq_data["log_likelihood"] = float(obs_log_likelihoods_true[obs_idx])
            
            if compute_obs_likelihood == "all":
                obs_seq_data["log_likelihood_all_theta"] = dict(
                    zip(world.theta_values, obs_log_likelihoods_all[obs_idx])
                )
                obs_seq_data["mle_theta"] = float(obs_mle_thetas[obs_idx])

            # Generate multiple utterance sequences for this observation sequence
            for utt_idx in range(n_utt_seq):
                # Compute utterance seed for reproducibility
                if random_seed is not None:
                    utt_seed = random_seed + theta_idx * 1000000 + obs_idx * 1000 + utt_idx
                    np.random.seed(utt_seed)
                else:
                    utt_seed = None

                # Create fresh speaker instance
                try:
                    if speaker_type == "literal":
                        speaker = LiteralSpeaker(
                            world=world,
                            initial_beliefs_theta=base_speaker_config.get("initial_beliefs_theta")
                        )
                    else:  # pragmatic
                        speaker = PragmaticSpeaker_obs(
                            world=world,
                            omega=base_speaker_config["omega"],
                            psi=base_speaker_config["psi"],
                            update_internal=base_speaker_config["update_internal"],
                            alpha=base_speaker_config["alpha"],
                            beta=base_speaker_config.get("beta", 0.0),
                            initial_beliefs_theta=base_speaker_config.get("initial_beliefs_theta")
                        )
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to create speaker for theta={theta}, "
                        f"obs_idx={obs_idx}, utt_idx={utt_idx}: {e}"
                    )

                # Generate utterance sequence
                utt_seq = []
                try:
                    for obs in obs_seq:
                        utt = speaker.update_and_speak(obs)
                        utt_seq.append(utt)
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to generate utterance at theta={theta}, "
                        f"obs_idx={obs_idx}, utt_idx={utt_idx}: {e}"
                    )

                utt_seq_data = {
                    "utt_seq_idx": utt_idx,
                    "utt_seq": utt_seq,
                    "utt_seed": utt_seed
                }
                obs_seq_data["utt_sequences"].append(utt_seq_data)

                # --- Build flat record (one per utterance sequence) ---
                record = {
                    "theta": theta,
                    "obs_seq_idx": obs_idx,
                    "utt_seq_idx": utt_idx,
                    "obs_seq": obs_seq,      # Full observation sequence as list
                    "utt_seq": utt_seq,      # Full utterance sequence as list
                    "run_seed": run_seed,
                    "utt_seed": utt_seed
                }
                
                # Add likelihood at true theta
                if compute_obs_likelihood in ["true", "all"]:
                    record["obs_seq_log_likelihood_true"] = float(obs_log_likelihoods_true[obs_idx])
                
                # Add likelihoods for all thetas and MLE
                if compute_obs_likelihood == "all":
                    # Add column for each theta value
                    for theta_col_idx, theta_val in enumerate(world.theta_values):
                        col_name = theta_to_colname(theta_val)
                        record[col_name] = float(obs_log_likelihoods_all[obs_idx, theta_col_idx])
                    
                    # Add MLE theta
                    record["obs_seq_mle_theta"] = float(obs_mle_thetas[obs_idx])
                
                flat_records.append(record)

            theta_data["obs_sequences"].append(obs_seq_data)

        data[theta] = theta_data

    # --- Prepare output ---
    if return_format == "dataframe":
        df = pd.DataFrame(flat_records)
        
        # Ensure consistent column ordering
        base_cols = ["theta", "obs_seq_idx", "utt_seq_idx", "obs_seq", "utt_seq", "run_seed", "utt_seed"]
        
        if compute_obs_likelihood == "true":
            col_order = base_cols + ["obs_seq_log_likelihood_true"]
        elif compute_obs_likelihood == "all":
            # Build column order: base + true + all thetas + mle
            likelihood_cols = [theta_to_colname(tv) for tv in world.theta_values]
            col_order = base_cols + ["obs_seq_log_likelihood_true"] + likelihood_cols + ["obs_seq_mle_theta"]
        else:
            col_order = base_cols
        
        return df[col_order]

    else:  # nested
        return {
            "world": world,
            "config": {
                "n": n,
                "m": m,
                "thetas": thetas,
                "T": T,
                "speaker_config": speaker_config,
                "n_obs_seq": n_obs_seq,
                "n_utt_seq": n_utt_seq,
                "random_seed": random_seed,
                "compute_obs_likelihood": compute_obs_likelihood
            },
            "data": data
        }

## Single T

### Observations Sampling

In [None]:
def sample_observation_sequences(
    n: int,
    m: int,
    thetas: Union[List[float], np.ndarray],
    T: int,
    n_obs_seq: int,
    random_seed: Optional[int] = None,
    theta_values: Optional[np.ndarray] = None,
    compute_obs_likelihood: str = "none"
) -> Dict[str, Any]:
    """
    Sample observation sequences for given theta values.
    
    This function handles Stage 1 of the simulation pipeline:
    1. Create a World instance
    2. Sample observation sequences for each theta
    3. Optionally compute observation likelihoods
    
    The output can be passed to `generate_utterances_for_observations` 
    to generate utterances under different speaker configurations while
    keeping observations fixed.
    
    Parameters
    ----------
    n : int
        Number of independent experiments in the World.
    m : int
        Number of Bernoulli trials per experiment.
    thetas : List[float] or np.ndarray
        List of theta values to simulate. Each must be in the World's theta_values.
    T : int
        Length of each sequence (number of rounds).
    n_obs_seq : int
        Number of observation sequences per theta.
    random_seed : Optional[int], default None
        Base random seed for reproducibility.
    theta_values : Optional[np.ndarray], default None
        Custom theta grid for the World. If None, uses World's default [0, 0.1, ..., 1].
    compute_obs_likelihood : str, default "none"
        Whether to compute observation sequence log-likelihoods:
        - "none": Don't compute likelihoods
        - "true": Compute log P(obs_seq | true_theta) only
                  (optimized: only loads single column from likelihood table)
        - "all": Compute log P(obs_seq | theta) for all thetas in the world,
                 plus MLE theta for each observation sequence
    
    Returns
    -------
    Dict[str, Any]
        Dictionary containing:
        
        - "world": World
            The World object (reusable for utterance generation)
        
        - "config": Dict
            Configuration parameters used:
            {
                "n": int,
                "m": int, 
                "thetas": List[float],
                "T": int,
                "n_obs_seq": int,
                "random_seed": Optional[int],
                "compute_obs_likelihood": str
            }
        
        - "observations": Dict[float, List[Dict]]
            Mapping from theta -> list of observation data.
            Each observation data dict contains:
            {
                "obs_idx": int,
                "obs_seq": List[Tuple[int, ...]],
                "obs_run_seed": int,
                "theta": float,
                "log_lik_true_theta": Optional[float],    # if compute != "none"
                "log_lik_all_theta": Optional[np.ndarray], # if compute == "all"
                "mle_theta": Optional[float],              # if compute == "all"
                "utterances": None                         # Placeholder for Stage 2
            }
    
    Raises
    ------
    ValueError
        If parameters are invalid or theta values not in World's theta_values.
    RuntimeError
        If World creation or observation sampling fails.

    Examples
    --------
    >>> # Basic usage
    >>> obs_data = sample_observation_sequences(
    ...     n=3, m=2, thetas=[0.3, 0.5, 0.7], T=10, n_obs_seq=50,
    ...     random_seed=42
    ... )
    >>> print(f"Sampled {len(obs_data['observations'][0.3])} sequences for theta=0.3")
    
    >>> # With likelihood computation
    >>> obs_data = sample_observation_sequences(
    ...     n=3, m=2, thetas=[0.3, 0.5, 0.7], T=10, n_obs_seq=50,
    ...     random_seed=42, compute_obs_likelihood="all"
    ... )
    >>> # Access likelihood and MLE
    >>> log_lik = obs_data["observations"][0.3][0]["log_lik_true_theta"]
    >>> mle = obs_data["observations"][0.3][0]["mle_theta"]
    """
    
    # -----------------------------------------------------------------
    # INPUT VALIDATION
    # -----------------------------------------------------------------
    
    if not isinstance(n, int) or n < 1:
        raise ValueError("n must be a positive integer")
    if not isinstance(m, int) or m < 1:
        raise ValueError("m must be a positive integer")
    if not isinstance(T, int) or T < 1:
        raise ValueError("T must be a positive integer")
    if not isinstance(n_obs_seq, int) or n_obs_seq < 1:
        raise ValueError("n_obs_seq must be a positive integer")
    if compute_obs_likelihood not in ["none", "true", "all"]:
        raise ValueError("compute_obs_likelihood must be 'none', 'true', or 'all'")
    
    # Process thetas to list
    thetas = list(thetas) if isinstance(thetas, np.ndarray) else list(thetas)
    if len(thetas) == 0:
        raise ValueError("thetas cannot be empty")
    
    # -----------------------------------------------------------------
    # CREATE WORLD
    # -----------------------------------------------------------------
    
    try:
        world = World(n=n, m=m, theta_values=theta_values)
    except Exception as e:
        raise RuntimeError(f"Failed to create World: {e}")
    
    # Validate thetas against World's theta_values
    for theta in thetas:
        closest = world.theta_values[np.abs(world.theta_values - theta).argmin()]
        if not np.isclose(theta, closest, rtol=1e-10, atol=1e-10):
            raise ValueError(
                f"theta {theta} not in World's theta_values. "
                f"Closest: {closest}. Available: {list(world.theta_values)}"
            )
    
    # -----------------------------------------------------------------
    # SAMPLE OBSERVATIONS FOR EACH THETA
    # -----------------------------------------------------------------
    
    observations = {}
    n_theta_vals = len(world.theta_values)
    
    for theta in thetas:
        
        # Sample observation sequences using World's method
        # NOTE: Same seed for all thetas (aligned with multiT version)
        try:
            obs_df = world.sample_multiple_runs(
                theta=theta,
                n_run=n_obs_seq,
                n_round=T,
                base_seed=random_seed
            )
        except Exception as e:
            raise RuntimeError(f"Failed to sample observations for theta={theta}: {e}")
        
        # EXTRACT OBSERVATION SEQUENCES (VECTORIZED VIA GROUPBY)
        
        out = (
            obs_df
            .sort_values(["run_id", "round_index"])
            .groupby("run_id", sort=True)
            .agg(
                obs_seq=("observation", list),
                obs_run_seed=("run_seed", "first")  # Aligned field name
            )
        )
        
        obs_seqs = out["obs_seq"].tolist()
        obs_run_seeds = out["obs_run_seed"].tolist()
        
        # BUILD OBSERVATION RECORDS (aligned with multiT version)
        
        obs_list = [
            {
                "obs_idx": obs_idx,
                "obs_seq": obs_seqs[obs_idx],
                "obs_run_seed": obs_run_seeds[obs_idx], 
                "theta": theta,
                "log_lik_true_theta": None,
                "log_lik_all_theta": None,
                "mle_theta": None,
                "utterances": None 
            }
            for obs_idx in range(n_obs_seq)
        ]
        
        # COMPUTE OBSERVATION LIKELIHOODS
        
        if compute_obs_likelihood != "none":
            
            # Flatten all observations for batch lookup
            all_obs_flat = [obs for seq in obs_seqs for obs in seq]
            all_obs_keys = [tuple(obs) if not isinstance(obs, tuple) else obs 
                          for obs in all_obs_flat]
            
            # Find column index for true theta
            theta_col_idx = np.where(np.isclose(world.theta_values, theta))[0][0]
            true_theta_val = world.theta_values[theta_col_idx]
            
            if compute_obs_likelihood == "true":
                
                # Select ONLY the column for true theta
                log_probs_flat_true = world.obs_log_likelihood_theta.loc[
                    all_obs_keys, true_theta_val
                ].values
                
                # Reshape to (n_obs_seq, T) and sum
                log_probs_2d = log_probs_flat_true.reshape(n_obs_seq, T)
                log_liks_true = log_probs_2d.sum(axis=1)
                
                # Distribute results
                for i in range(n_obs_seq):
                    obs_list[i]["log_lik_true_theta"] = float(log_liks_true[i])
            
            else:  # compute_obs_likelihood == "all"
                
                # Load full matrix
                log_probs_flat = world.obs_log_likelihood_theta.loc[all_obs_keys].values
                
                # Reshape to (n_obs_seq, T, n_theta_vals) and sum over T
                log_probs_3d = log_probs_flat.reshape(n_obs_seq, T, n_theta_vals)
                log_liks_all = log_probs_3d.sum(axis=1)
                
                # Distribute results
                for i in range(n_obs_seq):
                    obs_list[i]["log_lik_true_theta"] = float(log_liks_all[i, theta_col_idx])
                    obs_list[i]["log_lik_all_theta"] = log_liks_all[i].copy()
                    obs_list[i]["mle_theta"] = float(
                        world.theta_values[np.argmax(log_liks_all[i])]
                    )
        
        # Store observations for this theta
        observations[theta] = obs_list
            
    # -----------------------------------------------------------------
    # RETURN RESULT
    # -----------------------------------------------------------------
    
    return {
        "world": world,
        "config": {
            "n": n,
            "m": m,
            "thetas": thetas,
            "T": T,
            "n_obs_seq": n_obs_seq,
            "random_seed": random_seed,
            "compute_obs_likelihood": compute_obs_likelihood
        },
        "observations": observations
    }

### Utterances Sampling

In [None]:
def generate_utterances_for_observations(
    obs_data: Dict[str, Any],
    speaker_config: Dict[str, Any],
    n_utt_seq: int,
    random_seed: Optional[int] = None,
    return_format: str = "dataframe",
    n_jobs: int = 1,
    backend: str = "loky",
    verbose: int = 0
) -> Union[Dict[str, Any], pd.DataFrame]:
    """
    Generate utterance sequences for pre-sampled observations.
    
    This function handles Stage 2 of the simulation pipeline:
    given pre-sampled observation sequences (from `sample_observation_sequences`),
    generate utterance sequences under a specified speaker configuration.
    
    This design allows comparing different speaker configurations on the
    exact same observation sequences, ensuring fair comparisons.

    Parameters
    ----------
    obs_data : Dict[str, Any]
        Output from sample_observation_sequences containing:
        - "world": World object
        - "config": Dict with sampling configuration
        - "observations": Dict mapping theta -> list of observation data
    speaker_config : Dict[str, Any]
        Speaker configuration. Required keys:
        - speaker_type : str ("literal" or "pragmatic")
        For pragmatic speaker:
        - omega, psi, alpha, update_internal (required)
        - beta, initial_beliefs_theta (optional)
    n_utt_seq : int
        Number of utterance sequences per observation sequence.
    random_seed : Optional[int], default None
        Base random seed for utterance generation.
        Seed hierarchy: random_seed + task_idx * 10_000 + utt_idx
    return_format : str, default "dataframe"
        "dataframe" or "nested"
    n_jobs : int, default 1
        Number of parallel jobs (-1 for all cores).
    backend : str, default "loky"
        Joblib backend.
    verbose : int, default 0
        Verbosity level.
    
    Returns
    -------
    pd.DataFrame (if return_format == "dataframe")
        Columns (aligned with multiT):
        - theta, obs_idx, utt_idx, obs_seq, utt_seq
        - obs_run_seed, utt_seed
        - log_lik_true_speaker
        - log_lik_true_theta (if computed)
        - log_lik_theta_0p0, log_lik_theta_0p1, ... (if compute == "all")
        - mle_theta (if compute == "all")
    
    Dict[str, Any] (if return_format == "nested")
        Hierarchical structure.
    """
    
    # ---------------------------------------------------------------------
    # INPUT VALIDATION
    # ---------------------------------------------------------------------
    
    if not isinstance(n_utt_seq, int) or n_utt_seq < 1:
        raise ValueError("n_utt_seq must be a positive integer")
    
    if return_format not in ["nested", "dataframe"]:
        raise ValueError("return_format must be 'nested' or 'dataframe'")
    
    if backend not in ["loky", "multiprocessing", "threading"]:
        raise ValueError("backend must be 'loky', 'multiprocessing', or 'threading'")
    
    if not isinstance(obs_data, dict):
        raise TypeError("obs_data must be a dictionary")
    
    required_obs_data_keys = ["world", "config", "observations"]
    missing_keys = [k for k in required_obs_data_keys if k not in obs_data]
    if missing_keys:
        raise ValueError(f"obs_data missing required keys: {missing_keys}")
    
    if not isinstance(speaker_config, dict):
        raise TypeError("speaker_config must be a dictionary")
    
    speaker_type = speaker_config.get("speaker_type")
    if speaker_type not in ["literal", "pragmatic"]:
        raise ValueError(f"speaker_type must be 'literal' or 'pragmatic', got '{speaker_type}'")
    
    if speaker_type == "pragmatic":
        required_pragmatic_keys = ["omega", "psi", "alpha", "update_internal"]
        missing_pragmatic_keys = [k for k in required_pragmatic_keys if k not in speaker_config]
        if missing_pragmatic_keys:
            raise ValueError(f"Missing required keys for pragmatic speaker: {missing_pragmatic_keys}")
        
        if speaker_config["omega"] not in ["coop", "strat"]:
            raise ValueError(f"omega must be 'coop' or 'strat', got '{speaker_config['omega']}'")
        if speaker_config["psi"] not in ["inf", "pers+", "pers-"]:
            raise ValueError(f"psi must be 'inf', 'pers+', or 'pers-', got '{speaker_config['psi']}'")
    
    # ---------------------------------------------------------------------
    # EXTRACT DATA FROM obs_data
    # ---------------------------------------------------------------------
    
    world = obs_data["world"]
    observations = obs_data["observations"]
    obs_config = obs_data["config"]
    
    thetas = obs_config["thetas"]
    compute_obs_likelihood = obs_config["compute_obs_likelihood"]
    
    # ---------------------------------------------------------------------
    # SETUP PARALLELIZATION
    # ---------------------------------------------------------------------
    
    if n_jobs == -1:
        n_workers = cpu_count()
    elif n_jobs < 0:
        n_workers = max(1, cpu_count() + 1 + n_jobs)
    else:
        n_workers = max(1, n_jobs)
    
    is_parallel = n_workers > 1
    
    # ---------------------------------------------------------------------
    # VERBOSE OUTPUT
    # ---------------------------------------------------------------------
    
    if verbose > 0:
        total_obs_seqs = sum(len(obs_list) for obs_list in observations.values())
        total_utt_seqs = total_obs_seqs * n_utt_seq
        
        print(f"Utterance generation configuration:")
        print(f"  Thetas: {thetas}")
        print(f"  {len(thetas)} thetas × {obs_config['n_obs_seq']} obs_seq × {n_utt_seq} utt_seq "
              f"= {total_utt_seqs} total utterance sequences")
        print(f"  Sequence length T = {obs_config['T']}")
        print(f"  Speaker: {speaker_type}", end="")
        if speaker_type == "pragmatic":
            print(f" (omega={speaker_config['omega']}, psi={speaker_config['psi']}, "
                  f"alpha={speaker_config['alpha']})")
        else:
            print()
        
        if is_parallel:
            print(f"  Parallel execution: {n_workers} workers, backend='{backend}'")
        else:
            print(f"  Sequential execution")
        print()
    
    # ---------------------------------------------------------------------
    # PREPARE FULL SPEAKER CONFIG
    # ---------------------------------------------------------------------
    
    full_speaker_config = {
        "speaker_type": speaker_type,
        "initial_beliefs_theta": speaker_config.get("initial_beliefs_theta", None)
    }
    
    if speaker_type == "pragmatic":
        full_speaker_config.update({
            "omega": speaker_config["omega"],
            "psi": speaker_config["psi"],
            "alpha": speaker_config["alpha"],
            "update_internal": speaker_config["update_internal"],
            "beta": speaker_config.get("beta", 0.0)
        })
    
    # ---------------------------------------------------------------------
    # BUILD FLAT TASK LIST (aligned with multiT approach)
    # ---------------------------------------------------------------------
    
    tasks = []
    task_idx = 0
    
    for theta in thetas:
        for obs_info in observations[theta]:
            tasks.append({
                "theta": theta,
                "task_idx": task_idx,
                "obs_idx": obs_info["obs_idx"],
                "obs_seq": obs_info["obs_seq"],
                "obs_run_seed": obs_info["obs_run_seed"],
                "obs_log_lik_true_theta": obs_info["log_lik_true_theta"],
                "obs_log_lik_all_theta": obs_info["log_lik_all_theta"], 
                "obs_mle_theta": obs_info["mle_theta"]
            })
            task_idx += 1
    
    # ---------------------------------------------------------------------
    # EXECUTE TASKS (PARALLEL OR SEQUENTIAL)
    # ---------------------------------------------------------------------
    
    def run_task(task):
        return _process_single_obs_seq(
            theta=task["theta"],
            task_idx=task["task_idx"],
            obs_idx=task["obs_idx"],
            world=world,
            obs_seq=task["obs_seq"],
            obs_run_seed=task["obs_run_seed"],
            n_utt_seq=n_utt_seq,
            speaker_config=full_speaker_config,
            speaker_type=speaker_type,
            random_seed=random_seed,
            obs_log_lik_true_theta=task["obs_log_lik_true_theta"],
            obs_log_lik_all_theta=task["obs_log_lik_all_theta"],
            obs_mle_theta=task["obs_mle_theta"],
            compute_obs_likelihood=compute_obs_likelihood
        )
    
    if is_parallel:
        results = Parallel(n_jobs=n_workers, backend=backend, verbose=verbose)(
            delayed(run_task)(task) for task in tasks
        )
    else:
        results = [run_task(task) for task in tasks]
    
    # ---------------------------------------------------------------------
    # AGGREGATE RESULTS
    # ---------------------------------------------------------------------
    
    data = {theta: {"obs_sequences": []} for theta in thetas}
    all_flat_records = []
    
    for task, (obs_seq_data, flat_records) in zip(tasks, results):
        theta = task["theta"]
        data[theta]["obs_sequences"].append(obs_seq_data)
        all_flat_records.extend(flat_records)
    
    # Sort obs_sequences by obs_idx for consistent ordering
    for theta in thetas:
        data[theta]["obs_sequences"].sort(key=lambda x: x["obs_idx"])
    
    # ---------------------------------------------------------------------
    # PREPARE OUTPUT
    # ---------------------------------------------------------------------
    
    if return_format == "dataframe":
        df = pd.DataFrame(all_flat_records)
        df = df.sort_values(["theta", "obs_idx", "utt_idx"]).reset_index(drop=True)
        
        # Define column order
        base_cols = [
            "theta",
            "obs_idx", 
            "utt_idx", 
            "obs_seq",
            "utt_seq",
            "obs_run_seed",
            "utt_seed",
            "log_lik_true_speaker" 
        ]
        
        if compute_obs_likelihood == "none":
            col_order = base_cols
        
        elif compute_obs_likelihood == "true":
            col_order = base_cols + ["log_lik_true_theta"] 
        
        elif compute_obs_likelihood == "all":
            likelihood_cols = [
                f"log_lik_theta_{str(tv).replace('.', 'p')}" 
                for tv in world.theta_values
            ]
            col_order = (
                base_cols +
                ["log_lik_true_theta"] +
                likelihood_cols +
                ["mle_theta"]
            )
        
        # Only include columns that exist
        col_order = [c for c in col_order if c in df.columns]
        
        return df[col_order]
    
    else:  # return_format == "nested"
        return {
            "obs_data_config": obs_config,
            "speaker_config": speaker_config,
            "n_utt_seq": n_utt_seq,
            "random_seed": random_seed,
            "data": data
        }



def _process_single_obs_seq(
    theta: float,
    task_idx: int,
    obs_idx: int,
    world: World,
    obs_seq: List[Tuple[int, ...]],
    obs_run_seed: int,
    n_utt_seq: int,
    speaker_config: Dict[str, Any],
    speaker_type: str,
    random_seed: Optional[int],
    obs_log_lik_true_theta: Optional[float],
    obs_log_lik_all_theta: Optional[np.ndarray],
    obs_mle_theta: Optional[float],
    compute_obs_likelihood: str
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    """
    Process a single observation sequence: generate all utterance sequences.
    
    This function is designed to be called in parallel via joblib.
    
    Parameters
    ----------
    theta : float
        True theta value for this observation sequence.
    task_idx : int
        Global task index (used for seed computation, replaces theta_idx).
    obs_idx : int
        Index of this observation sequence within its theta group.
    world : World
        The World object.
    obs_seq : List[Tuple[int, ...]]
        The observation sequence.
    obs_run_seed : int
        Seed used to sample this observation sequence (aligned name).
    n_utt_seq : int
        Number of utterance sequences to generate.
    speaker_config : Dict[str, Any]
        Speaker configuration dictionary.
    speaker_type : str
        "literal" or "pragmatic".
    random_seed : Optional[int]
        Base random seed for utterance generation.
        Actual seed: random_seed + task_idx * 10_000 (aligned with multiT)
    obs_log_lik_true_theta : Optional[float]
        Pre-computed log P(obs_seq | true theta) (aligned name).
    obs_log_lik_all_theta : Optional[np.ndarray]
        Pre-computed log P(obs_seq | theta) for all theta values (aligned name).
    obs_mle_theta : Optional[float]
        Pre-computed MLE theta.
    compute_obs_likelihood : str
        "none", "true", or "all".
    
    Returns
    -------
    Tuple[Dict[str, Any], List[Dict[str, Any]]]
        1. obs_seq_data: Nested data structure
        2. flat_records: Flat records for DataFrame format
    """
    
    # ---------------------------------------------------------------------
    # INITIALIZE NESTED STRUCTURE (aligned field names)
    # ---------------------------------------------------------------------
    
    obs_seq_data = {
        "obs_idx": obs_idx, 
        "obs_seq": obs_seq,
        "obs_run_seed": obs_run_seed, 
        "theta": theta,
        "utt_sequences": []
    }
    
    # ---------------------------------------------------------------------
    # ATTACH PRE-COMPUTED OBSERVATION LIKELIHOOD DATA (aligned field names)
    # ---------------------------------------------------------------------
    
    if compute_obs_likelihood in ["true", "all"] and obs_log_lik_true_theta is not None:
        obs_seq_data["log_lik_true_theta"] = obs_log_lik_true_theta 
    
    if compute_obs_likelihood == "all" and obs_log_lik_all_theta is not None:
        obs_seq_data["log_lik_all_theta"] = dict(  
            zip(world.theta_values, obs_log_lik_all_theta)
        )
        obs_seq_data["mle_theta"] = obs_mle_theta
    
    # ---------------------------------------------------------------------
    # COMPUTE BASE SEED 
    # ---------------------------------------------------------------------
    
    base_utt_seed = None
    if random_seed is not None:
        base_utt_seed = random_seed + task_idx * 10_000
    
    # ---------------------------------------------------------------------
    # GENERATE UTTERANCE SEQUENCES
    # ---------------------------------------------------------------------
    
    utt_records = _generate_utterances_for_obs_seq(
        world=world,
        obs_seq=obs_seq,
        n_utt_seq=n_utt_seq,
        speaker_config=speaker_config,
        base_utt_seed=base_utt_seed,
        speaker_type=speaker_type
    )
    
    obs_seq_data["utt_sequences"] = utt_records
    
    # ---------------------------------------------------------------------
    # BUILD FLAT RECORDS FOR DATAFRAME FORMAT
    # ---------------------------------------------------------------------
    
    flat_records = []
    
    for utt_data in utt_records:
        # Base record with aligned field names
        record = {
            "theta": theta,
            "obs_idx": obs_idx, 
            "utt_idx": utt_data["utt_idx"], 
            "obs_seq": obs_seq,
            "utt_seq": utt_data["utt_seq"],
            "obs_run_seed": obs_run_seed,  
            "utt_seed": utt_data["utt_seed"],
            "log_lik_true_speaker": utt_data["log_lik_true_speaker"] 
        }
        
        # Add observation log-likelihood at true theta 
        if compute_obs_likelihood in ["true", "all"] and obs_log_lik_true_theta is not None:
            record["log_lik_true_theta"] = obs_log_lik_true_theta 
        
        # Add observation log-likelihoods for all theta values + MLE
        if compute_obs_likelihood == "all" and obs_log_lik_all_theta is not None:
            for t_idx, t_val in enumerate(world.theta_values):
                col_name = f"log_lik_theta_{str(t_val).replace('.', 'p')}" 
                record[col_name] = float(obs_log_lik_all_theta[t_idx])
            record["mle_theta"] = obs_mle_theta
        
        flat_records.append(record)
    
    return obs_seq_data, flat_records



def _generate_utterances_for_obs_seq(
    world: World,
    obs_seq: List[Tuple[int, ...]],
    n_utt_seq: int,
    speaker_config: Dict[str, Any],
    base_utt_seed: Optional[int],
    speaker_type: str
) -> List[Dict[str, Any]]:
    """
    Generate multiple utterance sequences for a single observation sequence.
    
    This is the core worker function that handles the actual utterance generation.
    It creates a fresh speaker for each utterance sequence to ensure independence,
    and tracks the log-likelihood of each generated sequence.
    
    Parameters
    ----------
    world : World
        The World object containing likelihoods and truth tables.
    obs_seq : List[Tuple[int, ...]]
        The observation sequence to generate utterances for.
    n_utt_seq : int
        Number of utterance sequences to generate.
    speaker_config : Dict[str, Any]
        Speaker configuration dictionary.
    base_utt_seed : Optional[int]
        Base seed for reproducibility. Each utterance sequence gets seed
        (base_utt_seed + utt_idx). If None, no seeding is performed.
    speaker_type : str
        "literal" or "pragmatic".
    
    Returns
    -------
    List[Dict[str, Any]]
        List of n_utt_seq utterance sequence records, each containing:
        - utt_idx: int (aligned with multiT)
        - utt_seq: List[str]
        - utt_seed: Optional[int]
        - log_lik_true_speaker: float (aligned with multiT, single value)
        - log_lik_all_speaker: None (placeholder for consistency)
    """
    utt_records = []
    
    for utt_idx in range(n_utt_seq):
        
        # SEED MANAGEMENT
        utt_seed = (base_utt_seed + utt_idx) if base_utt_seed is not None else None
        
        if utt_seed is not None:
            np.random.seed(utt_seed)
        
        # CREATE FRESH SPEAKER INSTANCE
        if speaker_type == "literal":
            speaker = LiteralSpeaker(
                world=world,
                initial_beliefs_theta=speaker_config.get("initial_beliefs_theta")
            )
        else:  # pragmatic
            speaker = PragmaticSpeaker_obs(
                world=world,
                omega=speaker_config["omega"],
                psi=speaker_config["psi"],
                update_internal=speaker_config["update_internal"],
                alpha=speaker_config["alpha"],
                beta=speaker_config.get("beta", 0.0),
                initial_beliefs_theta=speaker_config.get("initial_beliefs_theta")
            )
        
        # GENERATE UTTERANCE SEQUENCE WITH LOG-LIKELIHOOD TRACKING
        utt_seq = []
        log_lik = 0.0
        
        for obs in obs_seq:
            obs_key = tuple(obs) if not isinstance(obs, tuple) else obs
            
            # IMPORTANT: Capture log probabilities BEFORE speaking
            log_probs_for_obs = speaker.utterance_log_prob_obs[obs_key].copy()
            
            # Generate utterance (may update speaker state)
            utt = speaker.update_and_speak(obs)
            utt_seq.append(utt)
            
            # Look up log probability
            log_p = float(log_probs_for_obs.loc[utt])
            
            # Accumulate
            if not np.isfinite(log_p):
                log_lik = -np.inf
            else:
                log_lik += log_p
        
        # STORE RECORD 
        utt_records.append({
            "utt_idx": utt_idx,
            "utt_seq": utt_seq,
            "utt_seed": utt_seed,
            "log_lik_true_speaker": log_lik, 
            "log_lik_all_speaker": None 
        })
    
    return utt_records

## Multiple T

### Observation Sampling

In [None]:
def sample_observation_sequences_multiT(
    n: int,
    m: int,
    thetas: Union[List[float], np.ndarray],
    Ts: Union[List[int], np.ndarray],
    n_obs_seq: int,
    random_seed: Optional[int] = None,
    theta_values: Optional[np.ndarray] = None,
    compute_obs_likelihood: str = "none"
) -> Dict[str, Any]:
    """
    Sample observation sequences and compute likelihoods for multiple sequence lengths.

    Parameters
    ----------
    n : int
        Number of independent experiments in the World.
    m : int
        Number of Bernoulli trials per experiment.
    thetas : List[float] or np.ndarray
        List of theta values to simulate. Each must be in the World's theta_values.
    Ts : List[int] or np.ndarray
        List of sequence lengths to compute likelihoods for.
        Observations are sampled with length max(Ts).
        Likelihoods are computed for each T using obs_seq[:T].
    n_obs_seq : int
        Number of observation sequences per theta.
    random_seed : Optional[int], default None
        Base random seed for reproducibility.
        The same seed is used for all thetas (sequences differ due to different
        theta parameters, not different seeds).
    theta_values : Optional[np.ndarray], default None
        Custom theta grid for the World. If None, uses World's default [0, 0.1, ..., 1].
    compute_obs_likelihood : str, default "none"
        Whether to compute observation sequence log-likelihoods:
        - "none": Don't compute likelihoods
        - "true": Compute log P(obs_seq[:T] | true_theta) for each T
                  (optimized: only loads single column from likelihood table)
        - "all": Compute log P(obs_seq[:T] | theta) for all thetas and each T,
                 plus MLE theta for each T
    
    Returns
    -------
    Dict[str, Any]
        Dictionary containing:
        
        - "world": World
            The World object (reusable for utterance generation)
        
        - "config": Dict
            Configuration parameters used:
            {
                "n": int,
                "m": int,
                "thetas": List[float],
                "Ts": List[int],
                "max_T": int,
                "n_obs_seq": int,
                "random_seed": Optional[int],
                "compute_obs_likelihood": str
            }
        
        - "observations": Dict[float, List[Dict]]
            Mapping from theta -> list of observation data.
            Each observation data dict contains:
            {
                "obs_idx": int,
                "obs_seq": List[Tuple[int, ...]],   # Full sequence (length max_T)
                "obs_run_seed": int, # run_seed output from world.sample_multiple_runs()
                "theta": float,
                "log_lik_true_theta": Optional[Dict[int, float]],
                    # {T: log P(obs_seq[:T] | true_theta)} for each T in Ts
                    # None if compute_obs_likelihood == "none"
                "log_lik_all_theta": Optional[Dict[int, np.ndarray]],
                    # {T: array of log P(obs_seq[:T] | theta) for all thetas}
                    # None if compute_obs_likelihood != "all"
                "mle_theta": Optional[Dict[int, float]],
                    # {T: argmax_theta P(obs_seq[:T] | theta)} for each T
                    # None if compute_obs_likelihood != "all"
                "utterances": None
                    # Placeholder for utterance generation (to be filled by Stage 2)
            }
    
    Raises
    ------
    ValueError
        If parameters are invalid or theta values not in World's theta_values.
    RuntimeError
        If World creation or observation sampling fails.
    
    Examples
    --------
    >>> # Sample observations and compute likelihoods for T=5, 10, 15, 20
    >>> obs_data = sample_observation_sequences_multiT(
    ...     n=3, m=2,
    ...     thetas=[0.3, 0.5, 0.7],
    ...     Ts=[5, 10, 15, 20],
    ...     n_obs_seq=50,
    ...     random_seed=42,
    ...     compute_obs_likelihood="all"
    ... )
    
    >>> # Access likelihood at T=10 for first observation of theta=0.5
    >>> obs_info = obs_data["observations"][0.5][0]
    >>> log_lik_T10 = obs_info["log_lik_true_theta"][10]
    >>> mle_T10 = obs_info["mle_theta"][10]
    
    >>> # Compare MLE accuracy across different T values
    >>> for T in obs_data["config"]["Ts"]:
    ...     mle_errors = [
    ...         abs(obs["mle_theta"][T] - obs["theta"])
    ...         for obs in obs_data["observations"][0.5]
    ...     ]
    ...     print(f"T={T}: mean MLE error = {np.mean(mle_errors):.3f}")
    """
    
    # =========================================================================
    # INPUT VALIDATION
    # =========================================================================
    
    if not isinstance(n, int) or n < 1:
        raise ValueError("n must be a positive integer")
    if not isinstance(m, int) or m < 1:
        raise ValueError("m must be a positive integer")
    if not isinstance(n_obs_seq, int) or n_obs_seq < 1:
        raise ValueError("n_obs_seq must be a positive integer")
    if compute_obs_likelihood not in ["none", "true", "all"]:
        raise ValueError("compute_obs_likelihood must be 'none', 'true', or 'all'")
    
    # Process thetas to list
    thetas = list(thetas) if isinstance(thetas, np.ndarray) else list(thetas)
    if len(thetas) == 0:
        raise ValueError("thetas cannot be empty")
    
    # Process Ts to sorted list
    Ts = list(Ts) if isinstance(Ts, np.ndarray) else list(Ts)
    if len(Ts) == 0:
        raise ValueError("Ts cannot be empty")
    
    # Validate all Ts are positive integers
    for T in Ts:
        if not isinstance(T, (int, np.integer)) or T < 1:
            raise ValueError(f"All values in Ts must be positive integers, got {T}")
    
    # Sort and deduplicate Ts
    Ts = sorted(set(int(T) for T in Ts))
    max_T = max(Ts)
    
    # =========================================================================
    # CREATE WORLD
    # =========================================================================
    
    try:
        world = World(n=n, m=m, theta_values=theta_values)
    except Exception as e:
        raise RuntimeError(f"Failed to create World: {e}")
    
    # Validate thetas against World's theta_values
    for theta in thetas:
        closest = world.theta_values[np.abs(world.theta_values - theta).argmin()]
        if not np.isclose(theta, closest, rtol=1e-10, atol=1e-10):
            raise ValueError(
                f"theta {theta} not in World's theta_values. "
                f"Closest: {closest}. Available: {list(world.theta_values)}"
            )
    
    # =========================================================================
    # SAMPLE OBSERVATIONS FOR EACH THETA
    # =========================================================================
    
    observations = {}
    n_theta_vals = len(world.theta_values)
    
    for theta in thetas:
        
        # Sample observation sequences of length max_T
        try:
            obs_df = world.sample_multiple_runs(
                theta=theta,
                n_run=n_obs_seq,
                n_round=max_T,
                base_seed=random_seed
            )
        except Exception as e:
            raise RuntimeError(f"Failed to sample observations for theta={theta}: {e}")
        
        # -----------------------------------------------------------------
        # EXTRACT OBSERVATION SEQUENCES (VECTORIZED VIA GROUPBY)
        # -----------------------------------------------------------------
        
        out = (
            obs_df
            .sort_values(["run_id", "round_index"])
            .groupby("run_id", sort=True)
            .agg(
                obs_seq=("observation", list), # Note: source column is "observation"
                obs_run_seed=("run_seed", "first")  # Note: source column is "run_seed"
            )
        )
        
        obs_seqs = out["obs_seq"].tolist()
        obs_run_seeds = out["obs_run_seed"].tolist()
        
        # -----------------------------------------------------------------
        # BUILD OBSERVATION RECORDS
        # -----------------------------------------------------------------
        
        obs_list = [
            {
                "obs_idx": obs_idx,
                "obs_seq": obs_seqs[obs_idx],
                "obs_run_seed": obs_run_seeds[obs_idx],
                "theta": theta,
                "log_lik_true_theta": None,
                "log_lik_all_theta": None,
                "mle_theta": None,
                "utterances": {}
            }
            for obs_idx in range(n_obs_seq)
        ]
        
        # -----------------------------------------------------------------
        # COMPUTE OBSERVATION LIKELIHOODS FOR ALL Ts
        # -----------------------------------------------------------------
        
        if compute_obs_likelihood != "none":
            
            # Flatten all observations for batch lookup
            # obs_seqs is List[List[Tuple]], flatten to List[Tuple]
            # Shape after flatten: (n_obs_seq * max_T,)
            all_obs_flat = [obs for seq in obs_seqs for obs in seq]
            all_obs_keys = [tuple(obs) if not isinstance(obs, tuple) else obs 
                          for obs in all_obs_flat]
            
            # Find column index for true theta (needed for both modes)
            theta_col_idx = np.where(np.isclose(world.theta_values, theta))[0][0]
            true_theta_val = world.theta_values[theta_col_idx]
            
            if compute_obs_likelihood == "true":
                
                # Select ONLY the column for true theta
                # Shape: (n_obs_seq * max_T,)
                log_probs_flat_true = world.obs_log_likelihood_theta.loc[
                    all_obs_keys, true_theta_val
                ].values
                
                # Reshape to (n_obs_seq, max_T)
                log_probs_2d = log_probs_flat_true.reshape(n_obs_seq, max_T)
                
                # Cumulative sum over T dimension (axis=1)
                # cumsum_log_probs[i, t] = log P(O_0, ..., O_t | true_theta)
                # Shape: (n_obs_seq, max_T)
                cumsum_log_probs_true = np.cumsum(log_probs_2d, axis=1)
                
                # Extract likelihoods for each T (vectorized)
                # T is 1-indexed, so T-1 gives 0-based array index
                log_lik_true_all = {
                    T: cumsum_log_probs_true[:, T - 1] for T in Ts
                }
                
                # Distribute results to observation records
                for i in range(n_obs_seq):
                    obs_list[i]["log_lik_true_theta"] = {
                        T: float(log_lik_true_all[T][i]) for T in Ts
                    }
                    # log_lik_all_theta and mle_theta remain None
            
            else:  # compute_obs_likelihood == "all"
                
                # Load full matrix
                # Shape: (n_obs_seq * max_T, n_theta_vals)
                log_probs_flat = world.obs_log_likelihood_theta.loc[all_obs_keys].values
                
                # Reshape to (n_obs_seq, max_T, n_theta_vals)
                log_probs_3d = log_probs_flat.reshape(n_obs_seq, max_T, n_theta_vals)
                
                # Cumulative sum over T dimension (axis=1)
                # cumsum_log_probs[i, t, :] = log P(O_0, ..., O_t | all thetas)
                # Shape: (n_obs_seq, max_T, n_theta_vals)
                cumsum_log_probs = np.cumsum(log_probs_3d, axis=1)
                
                # Storage for vectorized results
                log_lik_true_all = {}   # {T: shape (n_obs_seq,)}
                log_lik_all_all = {}    # {T: shape (n_obs_seq, n_theta_vals)}
                mle_all = {}            # {T: shape (n_obs_seq,)}
                
                for T in Ts:
                    # Extract likelihoods at position T-1 for all obs_seqs
                    # Shape: (n_obs_seq, n_theta_vals)
                    log_liks_at_T = cumsum_log_probs[:, T - 1, :]
                    
                    # True theta likelihood
                    log_lik_true_all[T] = log_liks_at_T[:, theta_col_idx]
                    
                    # Full likelihood array
                    log_lik_all_all[T] = log_liks_at_T
                    
                    # MLE theta: argmax across theta dimension
                    mle_indices = np.argmax(log_liks_at_T, axis=1)
                    mle_all[T] = world.theta_values[mle_indices]
                
                # Distribute results to observation records
                for i in range(n_obs_seq):
                    obs_list[i]["log_lik_true_theta"] = {
                        T: float(log_lik_true_all[T][i]) for T in Ts
                    }
                    obs_list[i]["log_lik_all_theta"] = {
                        T: log_lik_all_all[T][i].copy() for T in Ts
                    }
                    obs_list[i]["mle_theta"] = {
                        T: float(mle_all[T][i]) for T in Ts
                    }
        
        # Store observations for this theta
        observations[theta] = obs_list
    
    # =========================================================================
    # RETURN RESULT
    # =========================================================================
    
    return {
        "world": world,
        "config": {
            "n": n,
            "m": m,
            "thetas": thetas,
            "Ts": Ts,
            "max_T": max_T,
            "n_obs_seq": n_obs_seq,
            "random_seed": random_seed,
            "compute_obs_likelihood": compute_obs_likelihood
        },
        "observations": observations
    }

### Utterances Sampling

In [None]:
def generate_utterances_for_observations_multiT(
    obs_data: Dict[str, Any],
    speaker_config: Dict[str, Any],
    n_utt_seq: int,
    n_jobs: int = 1,
    backend: str = "loky",
    verbose: int = 0
) -> None:
    """
    Generate utterance sequences for pre-sampled observations (in-place).
    
    Mutates obs_data by filling in the 'utterances' field for each observation.
    
    Parameters
    ----------
    obs_data : Dict[str, Any]
        Output from sample_observation_sequences_multiT(). Will be mutated.
    speaker_config : Dict[str, Any]
        Speaker configuration (see _generate_utterances_for_single_obs_seq_multiT).
    n_utt_seq : int
        Number of utterance sequences per observation sequence.
        Must be <= 10000 to avoid seed collisions.
    n_jobs : int, default 1
        Number of parallel jobs (-1 for all cores).
    backend : str, default "loky"
        Joblib backend.
    verbose : int, default 0
        Verbosity level.
    
    Returns
    -------
    None
        Mutates obs_data IN PLACE.
    
    Notes
    -----
    Storage structure:
        obs_data["observations"][theta][obs_idx]["utterances"][speaker_key][alpha_key]
        
    Speaker keys: "literal", "inf_T", "inf_F", "persp_T", "persp_F", "persm_T", "persm_F"
    Alpha keys: 0.0 (literal), float (pragmatic), or "determ"


    Raises
    ------
    ValueError
        If utterances already exist for the given speaker/alpha combination.
    """
    
    # INPUT VALIDATION
    
    if not isinstance(n_utt_seq, int) or n_utt_seq < 1:
        raise ValueError("n_utt_seq must be a positive integer")
    
    # EXTRACT FROM obs_data
    
    world = obs_data["world"]
    config = obs_data["config"]
    Ts = config["Ts"]
    thetas = config["thetas"]
    random_seed = config["random_seed"]
    
    # DETERMINE SPEAKER KEY AND ALPHA KEY
    
    speaker_type = speaker_config["speaker_type"]
    
    if speaker_type == "literal":
        speaker_key = "literal"
        alpha_key = 0.0
    else:
        psi = speaker_config["psi"]
        update_internal = speaker_config["update_internal"]
        alpha = speaker_config["alpha"]
        
        psi_prefix = {"inf": "inf", "pers+": "persp", "pers-": "persm"}[psi]
        speaker_key = f"{psi_prefix}_{'T' if update_internal else 'F'}"
        alpha_key = alpha if alpha == "determ" else float(alpha)

    # CHECK FOR DUPLICATE GENERATION
    
    # Check first observation (all will have same structure)
    first_obs = obs_data["observations"][thetas[0]][0]
    if (first_obs["utterances"] is not None and
        speaker_key in first_obs["utterances"] and
        alpha_key in first_obs["utterances"][speaker_key]):
        
        delete_code = (
            f"for theta in obs_data['observations']:\n"
            f"    for obs in obs_data['observations'][theta]:\n"
            f"        if obs['utterances'] and '{speaker_key}' in obs['utterances']:\n"
            f"            obs['utterances']['{speaker_key}'].pop({alpha_key!r}, None)"
        )
        
        raise ValueError(
            f"Utterances already exist for speaker_key='{speaker_key}', alpha_key={alpha_key!r}.\n"
            f"To regenerate, first delete existing entries:\n\n{delete_code}"
        )
    
    # BUILD TASK LIST
    
    tasks = []
    for theta in thetas:
        for obs_info in obs_data["observations"][theta]:
            tasks.append({
                "theta": theta,
                "obs_idx": obs_info["obs_idx"],
                "obs_seq": obs_info["obs_seq"]
            })
    
    # Assign seeds based on task index
    for task_idx, task in enumerate(tasks):
        task["base_seed"] = (random_seed + task_idx * 10_000 
                            if random_seed is not None else None)
    
    # VERBOSE OUTPUT
    
    if verbose > 0:
        n_workers = (cpu_count() if n_jobs == -1 
                    else max(1, cpu_count() + 1 + n_jobs) if n_jobs < 0 
                    else max(1, n_jobs))
        print(f"Generating utterances: {len(tasks)} obs_seq × {n_utt_seq} utt_seq")
        print(f"  Speaker: {speaker_key}, alpha: {alpha_key}, Ts: {Ts}")
        print(f"  Workers: {n_workers}")
    
    # EXECUTE TASKS
    
    def run_task(task):
        return _generate_utterances_for_single_obs_seq_multiT(
            obs_seq=task["obs_seq"],
            world=world,
            speaker_config=speaker_config,
            n_utt_seq=n_utt_seq,
            Ts=Ts,
            base_seed=task["base_seed"]
        )
    
    if n_jobs == 1:
        results = [run_task(task) for task in tasks]
    else:
        results = Parallel(n_jobs=n_jobs, backend=backend, verbose=verbose)(
            delayed(run_task)(task) for task in tasks
        )
    
    # STORE RESULTS (in-place mutation of obs_data)
    
    for task, utt_records in zip(tasks, results):
        theta = task["theta"]
        obs_idx = task["obs_idx"]
        
        obs_dict = obs_data["observations"][theta][obs_idx]
        
        # Handle explicit None (from sample_observation_sequences_multiT)
        if obs_dict["utterances"] is None:
            obs_dict["utterances"] = {}
        
        obs_dict["utterances"].setdefault(speaker_key, {})[alpha_key] = utt_records



def _generate_utterances_for_single_obs_seq_multiT(
    obs_seq: List[Tuple[int, ...]],
    world: World,
    speaker_config: Dict[str, Any],
    n_utt_seq: int,
    Ts: List[int],
    base_seed: Optional[int]
) -> List[Dict[str, Any]]:
    """
    Generate utterance sequences for a single observation sequence.
    
    This is the core worker function for utterance generation. It creates a 
    fresh speaker for each utterance sequence and computes cumulative 
    log-likelihoods for each T in Ts.
    
    Parameters
    ----------
    obs_seq : List[Tuple[int, ...]]
        The observation sequence (length max_T).
    world : World
        The World object (used to create speaker instances).
    speaker_config : Dict[str, Any]
        Complete speaker configuration containing:
        - speaker_type: "literal" or "pragmatic"
        For pragmatic speakers:
        - omega: "coop" or "strat"
        - psi: "inf", "pers+", or "pers-"
        - alpha: float or "determ"
        - update_internal: bool
        - beta: float (default 0.0)
        - initial_beliefs_theta: Optional[np.ndarray]
    n_utt_seq : int
        Number of utterance sequences to generate.
    Ts : List[int]
        List of sequence lengths to compute log-likelihoods for.
        All values must satisfy 1 <= T <= len(obs_seq).
    base_seed : Optional[int]
        Base seed for reproducibility. Each utterance sequence uses
        seed = base_seed + utt_idx. If None, no seeding.
    
    Returns
    -------
    List[Dict[str, Any]]
        List of n_utt_seq records, each containing:
        - utt_idx: int
            Index of this utterance sequence (0 to n_utt_seq-1)
        - utt_seq: List[str]
            The generated utterance sequence (same length as obs_seq)
        - utt_seed: Optional[int]
            The random seed used (for reproducibility)
        - log_lik_true_speaker: Dict[int, float]
            {T: log P(utt[:T] | obs[:T], speaker)} for each T in Ts
        - log_lik_all_speaker: None
            Placeholder for later analysis
    
    Raises
    ------
    ValueError
        If any T in Ts is outside the valid range [1, len(obs_seq)].
    
    Notes
    -----
    - A fresh speaker is created for each utterance sequence to ensure
      independence (speakers have internal state that evolves).
    
    - For pragmatic speakers with update_internal=True, the probability
      table changes after each utterance. We capture log P(u_t | O_t)
      BEFORE the update to correctly compute the likelihood.
    
    - Log-likelihood computation:
      log P(utt[:T] | obs[:T]) = sum_{t=0}^{T-1} log P(u_t | O_t, state_t)
    
    Examples
    --------
    >>> world = World(n=3, m=2)
    >>> obs_seq = [(1, 1, 1), (0, 2, 1), (0, 0, 3), (1, 2, 0)]
    >>> config = {
    ...     "speaker_type": "pragmatic",
    ...     "omega": "coop",
    ...     "psi": "inf",
    ...     "alpha": 5.0,
    ...     "update_internal": True,
    ...     "beta": 0.0
    ... }
    >>> records = _generate_utterances_for_single_obs_seq_multiT(
    ...     obs_seq=obs_seq,
    ...     world=world,
    ...     speaker_config=config,
    ...     n_utt_seq=3,
    ...     Ts=[2, 4],
    ...     base_seed=42
    ... )
    >>> len(records)
    3
    >>> records[0]["log_lik_true_speaker"]
    {2: -1.85, 4: -3.72}
    """
    
    # VALIDATE Ts AGAINST SEQUENCE LENGTH
    
    max_T = len(obs_seq)
    
    # Check both bounds: T < 1 causes wrong indexing (T-1 = -1 → last element)
    #                    T > max_T causes IndexError
    invalid_Ts = [T for T in Ts if T < 1 or T > max_T]
    if invalid_Ts:
        raise ValueError(
            f"All T in Ts must satisfy 1 <= T <= len(obs_seq) ({max_T}). "
            f"Invalid values: {invalid_Ts}"
        )
    
    # GENERATE UTTERANCES
    
    utt_records = []
    is_literal = speaker_config["speaker_type"] == "literal"
    
    for utt_idx in range(n_utt_seq):
        
        # SEED MANAGEMENT
        utt_seed = (base_seed + utt_idx) if base_seed is not None else None
        
        if utt_seed is not None:
            np.random.seed(utt_seed)
        
        # CREATE FRESH SPEAKER
        if is_literal:
            speaker = LiteralSpeaker(
                world=world,
                initial_beliefs_theta=speaker_config.get("initial_beliefs_theta")
            )
        else:
            speaker = PragmaticSpeaker_obs(
                world=world,
                omega=speaker_config["omega"],
                psi=speaker_config["psi"],
                update_internal=speaker_config["update_internal"],
                alpha=speaker_config["alpha"],
                beta=speaker_config.get("beta", 0.0),
                initial_beliefs_theta=speaker_config.get("initial_beliefs_theta")
            )
        
        # GENERATE UTTERANCES WITH PER-STEP LOG PROBABILITIES
        utt_seq = []
        log_probs_per_step = []
        
        for obs in obs_seq:
            obs_key = tuple(obs) if not isinstance(obs, tuple) else obs
            
            # Capture log probs BEFORE speaking (for update_internal=True)
            log_probs_for_obs = speaker.utterance_log_prob_obs[obs_key].copy()
            
            # Generate utterance
            utt = speaker.update_and_speak(obs)
            utt_seq.append(utt)
            
            # Look up log probability of chosen utterance
            log_p = float(log_probs_for_obs.loc[utt])
            
            # Handle impossible utterances
            if not np.isfinite(log_p):
                log_p = -np.inf
            
            log_probs_per_step.append(log_p)
        
        # COMPUTE CUMULATIVE LOG-LIKELIHOODS FOR EACH T
        log_probs_array = np.array(log_probs_per_step)
        cumsum_log_probs = np.cumsum(log_probs_array)
        
        # Extract for each T (T is 1-indexed, array is 0-indexed)
        # Note: Ts validation above ensures T >= 1, so T-1 >= 0
        log_lik_true_speaker = {
            T: float(cumsum_log_probs[T - 1])
            for T in Ts
        }
        
        # STORE RECORD
        utt_records.append({
            "utt_idx": utt_idx,
            "utt_seq": utt_seq,
            "utt_seed": utt_seed,
            "log_lik_true_speaker": log_lik_true_speaker,
            "log_lik_all_speaker": None
        })
    
    return utt_records

# Fitting speakers

In [None]:
import warnings
import itertools
import numpy as np
from itertools import chain
from typing import List, Dict, Union, Optional, Tuple, TypeVar, Iterator, Callable, Any, Literal
from joblib import Parallel, delayed, cpu_count
from scipy.special import logsumexp
from scipy.optimize import minimize_scalar

from rsa_optimal_exp_core import (
    World, LiteralSpeaker, PragmaticSpeaker_obs, USE_PRECISE_LOGSPACE
)

np.seterr(divide='ignore', under='ignore')

In [None]:
from rsa_optimal_exp_sampling_fun import (
    sample_observation_sequences_multiT, generate_utterances_for_observations_multiT
)

TRUE_ALPHA = 4.0
TRUE_SPEAKER_CONFIGS = [
    {
        "speaker_type": "literal"
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "inf",
        "alpha": TRUE_ALPHA,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "inf",
        "alpha": TRUE_ALPHA,
        "update_internal": True
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers+",
        "alpha": TRUE_ALPHA,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers+",
        "alpha": TRUE_ALPHA,
        "update_internal": True
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers-",
        "alpha": TRUE_ALPHA,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers-",
        "alpha": TRUE_ALPHA,
        "update_internal": True
    },
]

## Single observation and utterance fitting

### Utterance Sequence Likelihood

In [None]:
def log_likelihood_utt_seq(
    world: World,
    obs_seq: List[Tuple[int, ...]],
    utt_seq: List[str],
    speaker_config: Dict[str, Any]
) -> float:
    """
    Compute log p(utt_seq | obs_seq, speaker_config).

    Parameters
    ----------
    world : World
        The World object defining the communication game.
    obs_seq : List[Tuple[int, ...]]
        Sequence of observations O_1, ..., O_T.
    utt_seq : List[str]
        Sequence of utterances u_1, ..., u_T.
    speaker_config : Dict[str, Any]
        Speaker configuration with keys:
        - speaker_type : str
            "literal" or "pragmatic" (required)
        - omega : str
            "coop" or "strat" (pragmatic only, required)
        - psi : str
            "inf", "pers+", or "pers-" (pragmatic only, required)
        - alpha : float or str
            Softmax temperature or "determ" (pragmatic only, required)
        - update_internal : bool
            Whether to update internal listener (pragmatic only, required)
        - beta : float or None
            Weight on informativeness vs persuasiveness (pragmatic only, default 0.0)
            beta=0: pure persuasion, beta=1: pure informativeness
            Only used when psi is "pers+" or "pers-"
        - initial_beliefs_theta : np.ndarray or None
            Initial prior over theta (default None = uniform)

    Returns
    -------
    float
        Log-likelihood of the utterance sequence under the speaker model.
    """
    if len(obs_seq) == 0:
        raise ValueError("obs_seq is empty.")
        
    if len(obs_seq) != len(utt_seq):
        raise ValueError("obs_seq and utt_seq must have the same length.")

    # Extract config
    speaker_type = speaker_config["speaker_type"]
    initial_beliefs = speaker_config.get("initial_beliefs_theta")  # default None

    # Initialize speaker based on type
    if speaker_type == "literal":
        speaker = LiteralSpeaker(world, initial_beliefs)
        update_internal = False  # Literal speaker P(u|O) is static
        
    elif speaker_type == "pragmatic":
        # Required parameters (no defaults in original class)
        omega = speaker_config["omega"]
        psi = speaker_config["psi"]
        alpha = speaker_config["alpha"]
        update_internal_cfg = speaker_config["update_internal"]
        
        # Optional parameter (has default in original class)
        beta = speaker_config.get("beta")
        
        speaker = PragmaticSpeaker_obs(
            world=world,
            omega=omega,
            psi=psi,
            update_internal=update_internal_cfg,
            alpha=alpha,
            beta=beta if beta is not None else 0.0,
            initial_beliefs_theta=initial_beliefs
        )
        update_internal = speaker.update_internal
    else:
        raise ValueError(f"speaker_type must be 'literal' or 'pragmatic', got '{speaker_type}'")

    # Compute log-likelihood
    log_lik = 0.0
    
    for obs, utt in zip(obs_seq, utt_seq):
        obs_key = tuple(obs) if not isinstance(obs, tuple) else obs
        
        # Validate
        if obs_key not in world.observations:
            raise ValueError(f"Observation {obs_key} not supported by the world.")
        if utt not in world.utterances:
            raise ValueError(f"Utterance '{utt}' not in world.utterances")

        # Get log P(u | O)
        log_p = speaker.utterance_log_prob_obs.at[utt, obs_key]
        
        if not np.isfinite(log_p):
            return -np.inf
        
        log_lik += float(log_p)

        # Update state for pragmatic speaker if configured
        if speaker_type == "pragmatic" and update_internal:
            speaker.literal_listener.listen_and_update(utt)
            speaker.utterance_log_prob_obs = speaker._compute_utterance_log_prob_obs(speaker.alpha)

    return log_lik

### Utterance Sequence Likelihood with optimized alpha

In [None]:
def log_likelihood_alpha_opt_utt_seq(
    world: 'World',
    obs_seq: List[Tuple[int, ...]],
    utt_seq: List[str],
    speaker_config: Dict[str, Any],
    alpha_bounds: Tuple[float, float] = (0.001, 50.0),
    method: str = "bounded",
    include_determ: bool = True,
    grid_search: bool = False,
    grid_points: int = 100,
    grid_spacing: str = "log"  
) -> Dict[str, Any]:
    """
    Find the optimal alpha that maximizes log-likelihood of an utterance sequence.
    
    This function optimizes over the softmax parameter alpha to find
    the value that makes the observed utterance sequence most probable under the
    specified speaker model. 
    
    Parameters
    ----------
    world : World
        The World object defining the communication game.
    obs_seq : List[Tuple[int, ...]]
        Sequence of observations O_1, ..., O_T.
    utt_seq : List[str]
        Sequence of utterances u_1, ..., u_T.
    speaker_config : Dict[str, Any]
        Speaker configuration. Must include:
        - speaker_type : str ("literal" or "pragmatic")
        For pragmatic speakers, must also include:
        - omega : str ("coop" or "strat")
        - psi : str ("inf", "pers+", or "pers-")
        - update_internal : bool
        Optional:
        - beta : float (default 0.0)
        - initial_beliefs_theta : np.ndarray or None (default None)
        Note: 'alpha' in speaker_config is ignored; it will be optimized.
    alpha_bounds : Tuple[float, float], default (0.001, 50.0)
        Lower and upper bounds for continuous alpha optimization.
        Lower bound should be > 0 to avoid numerical issues.
    method : str, default "bounded"
        Optimization method passed to scipy.optimize.minimize_scalar.
        "bounded" is recommended for bounded optimization.
    include_determ : bool, default True
        Whether to also evaluate alpha="determ" (hard argmax) and compare
        against the continuous optimum.
    grid_search : bool, default False
        If True, use grid search instead of scipy optimization.
        Grid search is more robust but slower for fine-grained search.
    grid_points : int, default 100
        Number of alpha values to evaluate when grid_search=True.
    grid_spacing : str, default "log"
        Spacing method for grid search. Options:
        - "log": Logarithmic spacing (recommended). Places more points at lower
          alpha values where the likelihood function typically changes more rapidly.
        - "linear": Linear spacing. Evenly spaced points across the range.
        
        Logarithmic spacing is generally preferred because the softmax function's
        behavior changes roughly logarithmically with alpha.
    
    Returns
    -------
    Dict[str, Any]
        Dictionary containing:
        - optimal_alpha : float or str or None
            The overall best alpha (could be float or "determ").
            None for literal speakers.
        - max_log_likelihood : float
            Log-likelihood at the optimal alpha.
        - continuous_optimal_alpha : float or None
            Best alpha from continuous optimization only.
        - continuous_max_log_likelihood : float or None
            Log-likelihood at the continuous optimum.
        - determ_log_likelihood : float or None
            Log-likelihood when alpha="determ" (if include_determ=True).
        - optimization_result : scipy.optimize.OptimizeResult or None
            Full scipy result object (if grid_search=False).
        - grid_results : Dict or None
            Grid search details (if grid_search=True), containing:
            - alphas: array of tested alpha values
            - log_likelihoods: array of corresponding log-likelihoods
            - best_idx: index of best alpha
            - grid_spacing: spacing method used
        - message : str (only for literal speakers)
            Explanation that alpha is not applicable.
    
    Raises
    ------
    ValueError
        If obs_seq is empty, lengths don't match, speaker_type is invalid,
        or required pragmatic speaker keys are missing.
    RuntimeError
        If optimization fails unexpectedly.
    """
    
    # --- Validation (add new parameter validation) ---
    if grid_spacing not in {"log", "linear"}:
        raise ValueError(
            f"grid_spacing must be 'log' or 'linear', got '{grid_spacing}'"
        )
    
    # --- [Previous validation code unchanged] ---
    if len(obs_seq) == 0:
        raise ValueError("obs_seq is empty.")
    if len(obs_seq) != len(utt_seq):
        raise ValueError("obs_seq and utt_seq must have the same length.")
    
    speaker_type = speaker_config.get("speaker_type")
    
    # --- Handle literal speaker ---
    if speaker_type == "literal":
        ll = log_likelihood_utt_seq(world, obs_seq, utt_seq, speaker_config)
        return {
            "optimal_alpha": None,
            "max_log_likelihood": ll,
            "continuous_optimal_alpha": None,
            "continuous_max_log_likelihood": None,
            "determ_log_likelihood": None,
            "optimization_result": None,
            "grid_results": None,
            "message": "Literal speaker does not use alpha parameter"
        }
    elif speaker_type != "pragmatic":
        raise ValueError(
            f"speaker_type must be 'literal' or 'pragmatic', got '{speaker_type}'"
        )
    
    # --- Validate pragmatic speaker config ---
    required_keys = ["omega", "psi", "update_internal"]
    missing_keys = [k for k in required_keys if k not in speaker_config]
    if missing_keys:
        raise ValueError(
            f"Missing required keys for pragmatic speaker: {missing_keys}"
        )
    
    # --- Create base config ---
    base_config = {
        "speaker_type": "pragmatic",
        "omega": speaker_config["omega"],
        "psi": speaker_config["psi"],
        "update_internal": speaker_config["update_internal"],
        "beta": speaker_config.get("beta", 0.0),
        "initial_beliefs_theta": speaker_config.get("initial_beliefs_theta", None)
    }
    
    # --- Define objective function ---
    def neg_log_likelihood(alpha: float) -> float:
        config = {**base_config, "alpha": float(alpha)}
        try:
            ll = log_likelihood_utt_seq(world, obs_seq, utt_seq, config)
        except Exception as e:
            warnings.warn(f"Error at alpha={alpha}: {e}")
            return np.inf
        if ll == -np.inf:
            return np.inf
        return -ll
    
    # --- Run optimization ---
    if grid_search:
        # IMPROVED: Create alpha grid with specified spacing
        a_min, a_max = alpha_bounds
        
        if grid_spacing == "log":
            # Logarithmic spacing: more points at lower alpha values
            alphas = list(np.exp(np.linspace(np.log(a_min), np.log(a_max), grid_points)))
        else:  # linear
            # Linear spacing: evenly distributed points
            alphas = list(np.linspace(a_min, a_max, grid_points))
        
        # Evaluate log-likelihood at each alpha
        log_likelihoods = []
        for alpha in alphas:
            ll = -neg_log_likelihood(alpha)
            log_likelihoods.append(ll)
        
        log_likelihoods = np.array(log_likelihoods)
        best_idx = np.argmax(log_likelihoods)
        optimal_alpha_continuous = alphas[best_idx]
        max_ll_continuous = log_likelihoods[best_idx]
        
        grid_results = {
            "alphas": np.array(alphas),
            "log_likelihoods": log_likelihoods,
            "best_idx": best_idx,
            "grid_spacing": grid_spacing  # Include spacing method in results
        }
        optimization_result = None
        
    else:
        # Scipy optimization approach 
        try:
            result = minimize_scalar(
                neg_log_likelihood,
                bounds=alpha_bounds,
                method=method
            )
            if not result.success:
                warnings.warn(
                    f"Optimization may not have converged: {result.message}"
                )
            optimal_alpha_continuous = result.x
            max_ll_continuous = -result.fun if np.isfinite(result.fun) else -np.inf
            optimization_result = result
        except Exception as e:
            raise RuntimeError(f"Optimization failed: {e}")
        
        grid_results = None
    
    # --- Evaluate deterministic alpha ---
    determ_ll = None
    if include_determ:
        config_determ = {**base_config, "alpha": "determ"}
        try:
            determ_ll = log_likelihood_utt_seq(world, obs_seq, utt_seq, config_determ)
        except Exception as e:
            warnings.warn(f"Failed to evaluate alpha='determ': {e}")
            determ_ll = -np.inf
    
    # --- Determine overall best alpha ---
    if include_determ and determ_ll is not None and determ_ll > max_ll_continuous:
        optimal_alpha = "determ"
        max_ll = determ_ll
    else:
        optimal_alpha = optimal_alpha_continuous
        max_ll = max_ll_continuous
    
    return {
        "optimal_alpha": optimal_alpha,
        "max_log_likelihood": max_ll,
        "continuous_optimal_alpha": optimal_alpha_continuous,
        "continuous_max_log_likelihood": max_ll_continuous,
        "determ_log_likelihood": determ_ll,
        "optimization_result": optimization_result,
        "grid_results": grid_results
    }

## SingleT Fitting

In [None]:
def compute_literal_log_likelihood_batched(
    world: 'World',
    df: pd.DataFrame,
    initial_beliefs_theta: Optional[np.ndarray] = None
) -> np.ndarray:
    """
    Batched computation of log P(utt_seq | obs_seq, literal_speaker) for all rows (more memory efficent).
    
    The literal speaker's P(u|O) table is static, so we create the speaker once
    and do batch lookups for all (observation, utterance) pairs.
    
    Parameters
    ----------
    world : World
        The World object defining the communication game.
    df : pd.DataFrame
        DataFrame with 'obs_seq' and 'utt_seq' columns.
    initial_beliefs_theta : Optional[np.ndarray], default None
        Initial beliefs over theta.
    
    Returns
    -------
    np.ndarray
        Array of log-likelihoods, shape (n_rows,).
    """
    if 'obs_seq' not in df.columns or 'utt_seq' not in df.columns:
        raise ValueError("DataFrame must have 'obs_seq' and 'utt_seq' columns")
    
    n_rows = len(df)
    if n_rows == 0:
        return np.array([])
    
    # Create literal speaker ONCE to get static P(u|O) table
    speaker = LiteralSpeaker(world, initial_beliefs_theta)
    utt_log_prob_obs = speaker.utterance_log_prob_obs
    
    # Build lookup dictionary: (obs_tuple, utt_str) -> log_prob
    lookup = {}
    for utt in utt_log_prob_obs.index:
        for obs in utt_log_prob_obs.columns:
            lookup[(obs, utt)] = utt_log_prob_obs.at[utt, obs]
    
    # Compute log-likelihood for each row
    log_likelihoods = []
    for _, row in df.iterrows():
        obs_seq = row['obs_seq']
        utt_seq = row['utt_seq']
        
        row_ll = 0.0
        for obs, utt in zip(obs_seq, utt_seq):
            obs_key = tuple(obs) if not isinstance(obs, tuple) else obs
            log_p = lookup.get((obs_key, utt), -np.inf)
            row_ll += log_p
            if np.isinf(row_ll):
                break
        
        log_likelihoods.append(row_ll)
    
    return np.array(log_likelihoods)

## MultiT Vectorized Fitting

### Fitting Literal Speaker

In [None]:
def compute_literal_log_likelihood_multiT(
    obs_data: Dict[str, Any],
    target_speaker_keys: Optional[List[str]] = None,
    target_alpha_keys: Optional[List[Any]] = None,
    initial_beliefs_theta: Optional[np.ndarray] = None,
    verbose: int = 0
) -> None:
    """
    Vectorized computation of log P(utt_seq | obs_seq, literal_speaker) for utterances.
    
    Evaluates how likely existing utterance sequences (generated by various speakers)
    are under a literal speaker model. 
    
    The computation is fully vectorized across ALL thetas, observation sequences,
    generating speaker configurations, and utterance sequences in a single pass.
    
    Mutates obs_data by filling in log_lik_all_speaker["literal_fitted"] for each
    utterance record.
    
    Parameters
    ----------
    obs_data : Dict[str, Any]
        Output from sample_observation_sequences_multiT with utterances generated.
        Must have utterances populated via generate_utterances_for_observations_multiT.
    target_speaker_keys : Optional[List[str]], default None
        Which generating speakers' utterances to evaluate.
        E.g., ["inf_T", "persp_T"] to evaluate only informative and persuasive speakers.
        If None, evaluate all speaker keys with existing utterances.
        Valid keys: "literal", "inf_T", "inf_F", "persp_T", "persp_F", "persm_T", "persm_F"
    target_alpha_keys : Optional[List[Any]], default None
        Which generating alphas' utterances to evaluate.
        E.g., [5.0, 10.0] to evaluate only utterances generated with those alphas.
        If None, evaluate all alpha keys with existing utterances.
    Note: Not all (speaker_key, alpha_key) combinations exist.
        - "literal" only has alpha_key 0.0
        - Pragmatic speakers have their generated alpha values    
        When both target_speaker_keys and target_alpha_keys are specified,
        only combinations that exist in the data are processed.
    initial_beliefs_theta : Optional[np.ndarray], default None
        Initial beliefs over theta for the literal speaker.
        Note: Does not affect literal speaker's P(u|O) table (which depends
        only on truth values), but accepted for API consistency.
    verbose : int, default 0
        Verbosity level. If > 0, print summary of what's being processed.
    
    Returns
    -------
    None
        Mutates obs_data in place.
    
    Notes
    -----
    Storage structure (per utterance record):
        utt_record["log_lik_all_speaker"]["literal_fitted"] = {
            "max_log_lik": {T: float for T in Ts},
            "optimal_alpha": {T: 0.0 for T in Ts}  # Always 0.0 for literal
        }

    Raises
    ------
    ValueError
        If any observation or utterance sequence has incorrect length (!= max_T).
        If any T in Ts is outside the valid range [1, max_T].
    
    Examples
    --------
    >>> # Evaluate ALL utterances under literal speaker
    >>> compute_literal_log_likelihood_multiT(obs_data, verbose=1)
    Processing 30000 utterance sequences (300 unique obs positions) across all configurations
    
    >>> # Check results
    >>> utt_rec = obs_data["observations"][0.5][0]["utterances"]["inf_T"][5.0][0]
    >>> utt_rec["log_lik_all_speaker"]["literal_fitted"]
    {'max_log_lik': {5: -4.23, 10: -8.91, 15: -13.45, 20: -18.02},
     'optimal_alpha': {5: 0.0, 10: 0.0, 15: 0.0, 20: 0.0}}
    """
    
    # EXTRACT CONFIGURATION
    
    world = obs_data["world"]
    Ts = obs_data["config"]["Ts"]
    max_T = obs_data["config"]["max_T"]
    thetas = obs_data["config"]["thetas"]
    
    # VALIDATE Ts
    
    invalid_Ts = [T for T in Ts if T < 1 or T > max_T]
    if invalid_Ts:
        raise ValueError(
            f"All T in Ts must satisfy 1 <= T <= max_T ({max_T}). "
            f"Invalid values: {invalid_Ts}"
        )
    
    # CREATE LITERAL SPEAKER AND EXTRACT PROBABILITY TABLE
    
    speaker = LiteralSpeaker(world, initial_beliefs_theta)
    prob_table = speaker.utterance_log_prob_obs.values  # shape: (n_utterances, n_observations)
    
    # Use existing pandas Index objects
    obs_index = speaker.utterance_log_prob_obs.columns
    utt_index = speaker.utterance_log_prob_obs.index
    
    # FLATTEN ALL UTTERANCE SEQUENCES WITH LOCATION TRACKING AND VALIDATION
    # Also collect observation sequences by position    
    flat_data = []
    skipped_combinations = set()
    

    unique_obs_positions = []  # List of obs_seq (one per unique position)
    obs_pos_to_unique_idx = {}  # (theta, obs_list_pos) -> index in unique_obs_positions
    
    for theta in thetas:
        for obs_list_pos, obs_info in enumerate(obs_data["observations"][theta]):
            obs_seq = obs_info["obs_seq"]
            obs_idx = obs_info["obs_idx"]
            
            if len(obs_seq) != max_T:
                raise ValueError(
                    f"Observation sequence length mismatch: "
                    f"expected {max_T}, got {len(obs_seq)} "
                    f"at theta={theta}, obs_idx={obs_idx}"
                )
            
            if obs_info["utterances"] is None:
                continue
            
            # Register this observation position if not already seen
            obs_pos_key = (theta, obs_list_pos)
            if obs_pos_key not in obs_pos_to_unique_idx:
                obs_pos_to_unique_idx[obs_pos_key] = len(unique_obs_positions)
                unique_obs_positions.append(obs_seq)
            
            for speaker_key, alpha_dict in obs_info["utterances"].items():
                if target_speaker_keys is not None and speaker_key not in target_speaker_keys:
                    skipped_combinations.add(f"speaker_key={speaker_key}")
                    continue
                
                for alpha_key, utt_records in alpha_dict.items():
                    if target_alpha_keys is not None and alpha_key not in target_alpha_keys:
                        skipped_combinations.add(f"alpha_key={alpha_key}")
                        continue
                    
                    for utt_list_idx, utt_rec in enumerate(utt_records):
                        utt_seq = utt_rec["utt_seq"]
                        
                        if len(utt_seq) != max_T:
                            raise ValueError(
                                f"Utterance sequence length mismatch: "
                                f"expected {max_T}, got {len(utt_seq)} "
                                f"at theta={theta}, obs_idx={obs_idx}, "
                                f"speaker_key={speaker_key}, alpha_key={alpha_key}, "
                                f"utt_idx={utt_list_idx}"
                            )
                        
                        flat_data.append({
                            "utt_seq": utt_seq,
                            "obs_unique_idx": obs_pos_to_unique_idx[obs_pos_key],
                            "location": (theta, obs_list_pos, speaker_key, alpha_key, utt_list_idx)
                        })
    
    # VERBOSE OUTPUT
    
    n_unique_obs = len(unique_obs_positions)
    n_total = len(flat_data)
    
    if verbose > 0:
        filter_desc = []
        if target_speaker_keys is not None:
            filter_desc.append(f"speakers: {target_speaker_keys}")
        if target_alpha_keys is not None:
            filter_desc.append(f"alphas: {target_alpha_keys}")
        
        if filter_desc:
            print(f"Processing {n_total} utterance sequences ({n_unique_obs} unique obs positions) "
                  f"for {', '.join(filter_desc)}")
        else:
            print(f"Processing {n_total} utterance sequences ({n_unique_obs} unique obs positions) "
                  f"across all configurations")
        
        if skipped_combinations and verbose > 1:
            print(f"  Skipped: {skipped_combinations}")
    
    # EARLY EXIT IF NOTHING TO PROCESS
    
    if n_total == 0:
        if verbose > 0:
            print("No utterance sequences to process (check filters or ensure utterances exist)")
        return
    
    # Defensive assertion
    # at least one observation position
    assert n_unique_obs > 0, "Internal error: n_total > 0 but no unique obs positions registered"
    
    # BUILD OBSERVATION INDEX ARRAY
    
    # Flatten unique observations only
    obs_flat_unique = list(chain.from_iterable(
        (tuple(obs) if not isinstance(obs, tuple) else obs for obs in obs_seq)
        for obs_seq in unique_obs_positions
    ))
    
    # Batch lookup for unique observations only
    obs_indices_flat_unique = obs_index.get_indexer(obs_flat_unique)
    
    # Validate
    if (obs_indices_flat_unique < 0).any():
        bad_idx = np.where(obs_indices_flat_unique < 0)[0][0]
        raise ValueError(f"Unknown observation encountered: {obs_flat_unique[bad_idx]}")
    
    # Reshape to (n_unique_obs, max_T)
    obs_indices_unique = obs_indices_flat_unique.reshape(n_unique_obs, max_T)
    
    # Build expansion array: which unique obs position each flat_data item uses
    unique_obs_idx_per_item = np.array(
        [item["obs_unique_idx"] for item in flat_data],
        dtype=np.int32
    )
    
    # Expand to full (n_total, max_T) via fancy indexing
    obs_indices = obs_indices_unique[unique_obs_idx_per_item]
    
    # BUILD UTTERANCE INDEX ARRAY
    
    # Flatten utterances directly from flat_data
    utt_flat = list(chain.from_iterable(item["utt_seq"] for item in flat_data))
    
    # Batch lookup
    utt_indices_flat = utt_index.get_indexer(utt_flat)
    
    # Validate
    if (utt_indices_flat < 0).any():
        bad_idx = np.where(utt_indices_flat < 0)[0][0]
        raise ValueError(f"Unknown utterance encountered: {utt_flat[bad_idx]}")
    
    # Reshape to (n_total, max_T)
    utt_indices = utt_indices_flat.reshape(n_total, max_T)
    
    # VECTORIZED LIKELIHOOD COMPUTATION
    
    # Single numpy advanced indexing: shape (n_total, max_T)
    log_probs = prob_table[utt_indices, obs_indices]
    
    # Cumulative sum across time dimension
    cumsum_log_probs = np.cumsum(log_probs, axis=1)
    
    # Extract all Ts at once (Optimization D)
    T_indices = np.array(Ts, dtype=np.int32) - 1  # Convert to 0-indexed
    log_liks_all_T = cumsum_log_probs[:, T_indices]  # shape: (n_total, len(Ts))
    
    # DISTRIBUTE RESULTS BACK TO NESTED STRUCTURE
    
    for i, item in enumerate(flat_data):
        theta, obs_list_pos, speaker_key, alpha_key, utt_list_idx = item["location"]
        
        utt_rec = obs_data["observations"][theta][obs_list_pos]["utterances"][speaker_key][alpha_key][utt_list_idx]
        
        if utt_rec["log_lik_all_speaker"] is None:
            utt_rec["log_lik_all_speaker"] = {}
        
        # Build result dicts
        utt_rec["log_lik_all_speaker"]["literal_fitted"] = {
            "max_log_lik": {T: float(ll) for T, ll in zip(Ts, log_liks_all_T[i])},
            "optimal_alpha": {T: 0.0 for T in Ts}
        }
    
    if verbose > 0:
        print(f"Completed: stored results in log_lik_all_speaker['literal_fitted']")

### Fitting Pragmatic Speaker with update_internal = False

In [None]:
def compute_pragmatic_static_log_likelihood_multiT(
    obs_data: Dict[str, Any],
    fitting_psi: Literal["inf", "pers+", "pers-"],
    target_speaker_keys: Optional[List[str]] = None,
    target_alpha_keys: Optional[List[Any]] = None,
    method: Literal["grid", "scipy"] = "grid",
    alpha_bounds: Tuple[float, float] = (0.1, 50.0),
    grid_spacing: Literal["log", "linear"] = "log",
    n_grid: int = 100,
    include_determ: bool = True,
    n_jobs: int = 1,
    backend: str = "loky",
    verbose: int = 0
) -> None:
    """
    Find optimal alpha and max log-likelihood under a pragmatic speaker 
    with update_internal=False.
    
    Mutates obs_data by filling in log_lik_all_speaker["{psi}_F_fitted"] 
    for each utterance record.
    
    Parameters
    ----------
    obs_data : Dict[str, Any]
        Output from sample_observation_sequences_multiT with utterances generated.
    fitting_psi : {"inf", "pers+", "pers-"}
        The psi parameter for the fitting speaker.
    target_speaker_keys : Optional[List[str]], default None
        Which generating speakers' utterances to evaluate. If None, all.
    target_alpha_keys : Optional[List[Any]], default None
        Which generating alphas' utterances to evaluate. If None, all.
    method : {"grid", "scipy"}, default "grid"
        Optimization method.
    alpha_bounds : Tuple[float, float], default (0.1, 50.0)
        Search range for alpha.
    grid_spacing : {"log", "linear"}, default "log"
        Grid spacing (only for method="grid").
    n_grid : int, default 100
        Number of grid points (only for method="grid").
    include_determ : bool, default True
        Whether to also evaluate alpha="determ".
    n_jobs : int, default 1
        Number of parallel jobs (only for method="scipy").
    backend : str, default "loky"
        Joblib backend (only for method="scipy").
    verbose : int, default 0
        Verbosity level.
    
    Returns
    -------
    None
        Mutates obs_data in place.
    
    Notes
    -----
    Storage key: "{psi}_F_fitted" (e.g., "inf_F_fitted", "persp_F_fitted")
    
    Storage structure:
        utt_record["log_lik_all_speaker"]["{psi}_F_fitted"] = {
            "max_log_lik": {T: float for T in Ts},
            "optimal_alpha": {T: float or "determ" for T in Ts}
        }
    """
    
    # Input validation
    if fitting_psi not in ["inf", "pers+", "pers-"]:
        raise ValueError(f"fitting_psi must be 'inf', 'pers+', or 'pers-', got '{fitting_psi}'")
    
    if method not in ["grid", "scipy"]:
        raise ValueError(f"method must be 'grid' or 'scipy', got '{method}'")
    
    if method == "grid" and grid_spacing not in ["log", "linear"]:
        raise ValueError(f"grid_spacing must be 'log' or 'linear', got '{grid_spacing}'")
    
    if backend not in ["loky", "multiprocessing", "threading"]:
        raise ValueError(f"backend must be 'loky', 'multiprocessing', or 'threading'")
    
    # Determine storage key
    psi_prefix = {"inf": "inf", "pers+": "persp", "pers-": "persm"}[fitting_psi]
    fitted_key = f"{psi_prefix}_F_fitted"
    
    # Extract configuration
    world = obs_data["world"]
    Ts = obs_data["config"]["Ts"]
    max_T = obs_data["config"]["max_T"]
    thetas = obs_data["config"]["thetas"]
    
    # Validate Ts
    invalid_Ts = [T for T in Ts if T < 1 or T > max_T]
    if invalid_Ts:
        raise ValueError(f"All T must satisfy 1 <= T <= {max_T}. Invalid: {invalid_Ts}")
    
    # Flatten all utterance sequences with location tracking
    flat_data = []
    skipped_combinations = set()
    
    unique_obs_positions = []
    obs_pos_to_unique_idx = {}
    
    for theta in thetas:
        for obs_list_pos, obs_info in enumerate(obs_data["observations"][theta]):
            obs_seq = obs_info["obs_seq"]
            obs_idx = obs_info["obs_idx"]
            
            if len(obs_seq) != max_T:
                raise ValueError(
                    f"Observation sequence length mismatch at theta={theta}, obs_idx={obs_idx}"
                )
            
            if obs_info["utterances"] is None:
                continue
            
            # Register unique observation position
            obs_pos_key = (theta, obs_list_pos)
            if obs_pos_key not in obs_pos_to_unique_idx:
                obs_pos_to_unique_idx[obs_pos_key] = len(unique_obs_positions)
                unique_obs_positions.append(obs_seq)
            
            for speaker_key, alpha_dict in obs_info["utterances"].items():
                if target_speaker_keys is not None and speaker_key not in target_speaker_keys:
                    skipped_combinations.add(f"speaker_key={speaker_key}")
                    continue
                
                for alpha_key, utt_records in alpha_dict.items():
                    if target_alpha_keys is not None and alpha_key not in target_alpha_keys:
                        skipped_combinations.add(f"alpha_key={alpha_key}")
                        continue
                    
                    for utt_list_idx, utt_rec in enumerate(utt_records):
                        utt_seq = utt_rec["utt_seq"]
                        
                        if len(utt_seq) != max_T:
                            raise ValueError(
                                f"Utterance sequence length mismatch at theta={theta}, "
                                f"obs_idx={obs_idx}, speaker_key={speaker_key}"
                            )
                        
                        flat_data.append({
                            "utt_seq": utt_seq,
                            "obs_unique_idx": obs_pos_to_unique_idx[obs_pos_key],
                            "location": (theta, obs_list_pos, speaker_key, alpha_key, utt_list_idx)
                        })
    
    # Verbose output
    n_unique_obs = len(unique_obs_positions)
    n_total = len(flat_data)
    
    if verbose > 0:
        filter_desc = []
        if target_speaker_keys is not None:
            filter_desc.append(f"speakers: {target_speaker_keys}")
        if target_alpha_keys is not None:
            filter_desc.append(f"alphas: {target_alpha_keys}")
        
        print(f"Static pragmatic speaker fitting (psi={fitting_psi}, method={method}):")
        print(f"  Storage key: '{fitted_key}'")
        if filter_desc:
            print(f"  Processing {n_total} utterance sequences ({n_unique_obs} unique obs) "
                  f"for {', '.join(filter_desc)}")
        else:
            print(f"  Processing {n_total} utterance sequences ({n_unique_obs} unique obs)")
        print(f"  Ts: {Ts}")
        
        if method == "grid":
            n_alphas = n_grid + (1 if include_determ else 0)
            print(f"  Grid: {n_grid} points, spacing={grid_spacing}, bounds={alpha_bounds}")
        else:
            print(f"  Scipy: {n_total * len(Ts)} optimizations")
    
    # Early exit
    if n_total == 0:
        if verbose > 0:
            print("No utterance sequences to process")
        return
    
    # Method-specific computation
    if method == "grid":
        results = _static_grid_search_multiT(
            flat_data=flat_data,
            unique_obs_positions=unique_obs_positions,
            world=world,
            psi=fitting_psi,
            Ts=Ts,
            max_T=max_T,
            alpha_bounds=alpha_bounds,
            grid_spacing=grid_spacing,
            n_grid=n_grid,
            include_determ=include_determ,
            verbose=verbose
        )
    else:
        results = _static_scipy_optimization_multiT(
            flat_data=flat_data,
            unique_obs_positions=unique_obs_positions,
            world=world,
            psi=fitting_psi,
            Ts=Ts,
            alpha_bounds=alpha_bounds,
            include_determ=include_determ,
            n_jobs=n_jobs,
            backend=backend,
            verbose=verbose
        )
    
    # Distribute results
    for i, item in enumerate(flat_data):
        theta, obs_list_pos, speaker_key, alpha_key, utt_list_idx = item["location"]
        
        utt_rec = obs_data["observations"][theta][obs_list_pos]["utterances"][speaker_key][alpha_key][utt_list_idx]
        
        if utt_rec["log_lik_all_speaker"] is None:
            utt_rec["log_lik_all_speaker"] = {}
        
        utt_rec["log_lik_all_speaker"][fitted_key] = {
            "max_log_lik": results["max_log_lik"][i],
            "optimal_alpha": results["optimal_alpha"][i]
        }
    
    if verbose > 0:
        print(f"Completed: stored results in log_lik_all_speaker['{fitted_key}']")


def _static_grid_search_multiT(
    flat_data: List[Dict[str, Any]],
    unique_obs_positions: List[List[Tuple[int, ...]]],
    world: 'World',
    psi: str,
    Ts: List[int],
    max_T: int,
    alpha_bounds: Tuple[float, float],
    grid_spacing: str,
    n_grid: int,
    include_determ: bool,
    verbose: int
) -> Dict[str, List[Dict[int, Any]]]:
    """
    Grid search for optimal alpha with static speaker.
    
    Optimization: Utilities are computed once (alpha-independent).
    Memory-efficient: Loops over alphas instead of 4D vectorization.
    
    Returns
    -------
    Dict with "max_log_lik" and "optimal_alpha" lists.
    """
    
    n_total = len(flat_data)
    n_unique_obs = len(unique_obs_positions)
    n_Ts = len(Ts)
    T_indices = np.array(Ts, dtype=np.int32) - 1
    
    # =========================================================================
    # PHASE 1: Create ONE speaker and extract utility table
    # =========================================================================
    
    ref_speaker = PragmaticSpeaker_obs(
        world=world,
        omega="strat",
        psi=psi,
        update_internal=False,
        alpha=1.0,
        beta=0.0,
        initial_beliefs_theta=None
    )
    
    # Utility table: shape (n_utterances, n_observations)
    utility_table = ref_speaker.utility.values
    n_utterances, n_obs_total = utility_table.shape
    
    # Index mappings
    obs_index = ref_speaker.utility.columns
    utt_index = ref_speaker.utility.index
    utterances = list(utt_index)
    utt_to_idx = {u: i for i, u in enumerate(utterances)}
    
    if verbose > 1:
        print(f"  utility_table: {utility_table.shape}, {utility_table.nbytes/1024/1024:.1f} MB")
    
    # =========================================================================
    # PHASE 2: Build observation index arrays
    # =========================================================================
    
    obs_flat_unique = list(chain.from_iterable(
        (tuple(obs) if not isinstance(obs, tuple) else obs for obs in obs_seq)
        for obs_seq in unique_obs_positions
    ))
    
    obs_indices_flat_unique = obs_index.get_indexer(obs_flat_unique)
    if (obs_indices_flat_unique < 0).any():
        bad_idx = np.where(obs_indices_flat_unique < 0)[0][0]
        raise ValueError(f"Unknown observation: {obs_flat_unique[bad_idx]}")
    
    obs_indices_unique = obs_indices_flat_unique.reshape(n_unique_obs, max_T)
    
    unique_obs_idx_per_item = np.array(
        [item["obs_unique_idx"] for item in flat_data], dtype=np.int32
    )
    obs_indices = obs_indices_unique[unique_obs_idx_per_item]  # (n_total, max_T)
    
    # =========================================================================
    # PHASE 3: Build utterance index array
    # =========================================================================
    
    utt_flat = list(chain.from_iterable(item["utt_seq"] for item in flat_data))
    utt_indices_flat = np.array([utt_to_idx[u] for u in utt_flat], dtype=np.int32)
    utt_indices = utt_indices_flat.reshape(n_total, max_T)  # (n_total, max_T)
    
    # =========================================================================
    # PHASE 4: Extract utilities for all (item, time_step) pairs
    # =========================================================================
    
    # Utilities for observed (utterance, observation) pairs: (n_total, max_T)
    observed_utilities = utility_table[utt_indices, obs_indices]
    
    # Utilities for ALL utterances at each position: (n_total, max_T, n_utterances)
    # Computed once, reused for all alphas
    all_utilities_per_step = utility_table[:, obs_indices].transpose(1, 2, 0)
    
    if verbose > 1:
        print(f"  all_utilities_per_step: {all_utilities_per_step.shape}, "
              f"{all_utilities_per_step.nbytes/1024/1024:.1f} MB")
    
    # =========================================================================
    # PHASE 5: Create alpha grid
    # =========================================================================
    
    a_min, a_max = alpha_bounds
    if grid_spacing == "log":
        alphas = list(np.exp(np.linspace(np.log(a_min), np.log(a_max), n_grid)))
    else:
        alphas = list(np.linspace(a_min, a_max, n_grid))
    
    if include_determ:
        alphas.append("determ")
    
    n_alphas = len(alphas)
    
    # =========================================================================
    # PHASE 6: Compute log P(u|O,α) for each alpha (loop for memory efficiency)
    # =========================================================================
    
    # Output array: (n_alphas, n_total, n_Ts)
    all_lls = np.zeros((n_alphas, n_total, n_Ts))
    
    for alpha_idx, alpha in enumerate(alphas):
        if verbose > 1 and alpha_idx % 20 == 0:
            print(f"  Processing alpha {alpha_idx+1}/{n_alphas}")
        
        if alpha == "determ":
            # Deterministic: uniform over max-utility utterances
            max_utilities = np.max(all_utilities_per_step, axis=2)  # (n_total, max_T)
            is_max = np.isclose(observed_utilities, max_utilities)
            is_max_all = np.isclose(all_utilities_per_step, max_utilities[:, :, np.newaxis])
            n_ties = np.sum(is_max_all, axis=2)
            log_probs = np.where(is_max, -np.log(n_ties), -np.inf)
        else:
            # Softmax: P(u|O,α) = exp(α·U(u)) / Σ exp(α·U(u'))
            scaled_observed = alpha * observed_utilities  # (n_total, max_T)
            scaled_all = alpha * all_utilities_per_step   # (n_total, max_T, n_utterances)
            log_normalizers = logsumexp(scaled_all, axis=2)  # (n_total, max_T)
            log_probs = scaled_observed - log_normalizers
        
        # Cumulative sum and extract for Ts
        cumsum_log_probs = np.cumsum(log_probs, axis=1)  # (n_total, max_T)
        log_liks = cumsum_log_probs[:, T_indices]  # (n_total, n_Ts)
        
        # Handle -inf propagation
        has_neginf = np.isneginf(log_probs)
        for t_idx, T in enumerate(Ts):
            has_neginf_up_to_T = np.any(has_neginf[:, :T], axis=1)
            log_liks[has_neginf_up_to_T, t_idx] = -np.inf
        
        all_lls[alpha_idx] = log_liks
    
    # =========================================================================
    # PHASE 7: Find optimal alpha for each (sequence, T)
    # =========================================================================
    
    best_alpha_indices = np.argmax(all_lls, axis=0)  # (n_total, n_Ts)
    
    row_indices = np.arange(n_total)[:, np.newaxis]
    T_col_indices = np.arange(n_Ts)[np.newaxis, :]
    max_lls = all_lls[best_alpha_indices, row_indices, T_col_indices]
    
    # =========================================================================
    # PHASE 8: Convert to list of dicts
    # =========================================================================
    
    alphas_array = np.array(alphas, dtype=object)
    
    max_log_lik_list = []
    optimal_alpha_list = []
    
    for i in range(n_total):
        max_log_lik_list.append({
            T: float(max_lls[i, t_idx]) for t_idx, T in enumerate(Ts)
        })
        optimal_alpha_list.append({
            T: alphas_array[best_alpha_indices[i, t_idx]] for t_idx, T in enumerate(Ts)
        })
    
    return {
        "max_log_lik": max_log_lik_list,
        "optimal_alpha": optimal_alpha_list
    }


def _static_scipy_optimization_multiT(
    flat_data: List[Dict[str, Any]],
    unique_obs_positions: List[List[Tuple[int, ...]]],
    world: 'World',
    psi: str,
    Ts: List[int],
    alpha_bounds: Tuple[float, float],
    include_determ: bool,
    n_jobs: int,
    backend: str,
    verbose: int
) -> Dict[str, List[Dict[int, Any]]]:
    """
    Scipy optimization for optimal alpha with static speaker.
    
    Returns
    -------
    Dict with "max_log_lik" and "optimal_alpha" lists.
    """
    
    n_total = len(flat_data)
    
    # Reconstruct obs_seq for each item
    obs_seqs = [unique_obs_positions[item["obs_unique_idx"]] for item in flat_data]
    
    # Build tasks: one per (sequence, T)
    tasks = []
    for item_idx, item in enumerate(flat_data):
        obs_seq = obs_seqs[item_idx]
        for T in Ts:
            tasks.append({
                "item_idx": item_idx,
                "T": T,
                "obs_seq": obs_seq[:T],
                "utt_seq": item["utt_seq"][:T]
            })
    
    if verbose > 0:
        print(f"  Running {len(tasks)} scipy optimizations...")
    
    # Worker function
    def optimize_single(task):
        base_config = {
            "speaker_type": "pragmatic",
            "omega": "strat",
            "psi": psi,
            "update_internal": False,
            "beta": 0.0,
            "initial_beliefs_theta": None
        }
        
        result = log_likelihood_alpha_opt_utt_seq(
            world=world,
            obs_seq=task["obs_seq"],
            utt_seq=task["utt_seq"],
            speaker_config=base_config,
            alpha_bounds=alpha_bounds,
            grid_search=False,
            include_determ=include_determ
        )
        
        return {
            "item_idx": task["item_idx"],
            "T": task["T"],
            "optimal_alpha": result["optimal_alpha"],
            "max_log_lik": result["max_log_likelihood"]
        }
    
    # Execute
    if n_jobs == 1:
        results = [optimize_single(task) for task in tasks]
    else:
        results = Parallel(n_jobs=n_jobs, backend=backend, verbose=verbose)(
            delayed(optimize_single)(task) for task in tasks
        )
    
    # Reorganize by item_idx
    max_log_lik_list = [{} for _ in range(n_total)]
    optimal_alpha_list = [{} for _ in range(n_total)]
    
    for res in results:
        item_idx = res["item_idx"]
        T = res["T"]
        max_log_lik_list[item_idx][T] = res["max_log_lik"]
        optimal_alpha_list[item_idx][T] = res["optimal_alpha"]
    
    return {
        "max_log_lik": max_log_lik_list,
        "optimal_alpha": optimal_alpha_list
    }

### Fitting Pragmatic Speaker with update_internal = True

In [None]:
def compute_pragmatic_dynamic_log_likelihood_multiT(
    obs_data: Dict[str, Any],
    fitting_psi: Literal["inf", "pers+", "pers-"],
    fitting_Ts: Optional[List[int]] = None,
    target_speaker_keys: Optional[List[str]] = None,
    target_alpha_keys: Optional[List[Any]] = None,
    method: Literal["grid", "scipy"] = "grid",
    alpha_bounds: Tuple[float, float] = (0.1, 50.0),
    grid_spacing: Literal["log", "linear"] = "log",
    n_grid: int = 100,
    include_determ: bool = True,
    n_jobs: int = 1,
    backend: str = "loky",
    verbose: int = 0
) -> None:
    """
    Find optimal alpha and max log-likelihood under a pragmatic speaker 
    with update_internal=True.
    
    Mutates obs_data by filling in log_lik_all_speaker["{psi}_T_fitted"] 
    for each utterance record.
    
    Parameters
    ----------
    obs_data : Dict[str, Any]
        Output from sample_observation_sequences_multiT with utterances generated.
    fitting_psi : {"inf", "pers+", "pers-"}
        The psi parameter for the fitting speaker.
    fitting_Ts : Optional[List[int]], default None
        Subset of Ts to compute likelihoods for. If None, use all Ts.
    target_speaker_keys : Optional[List[str]], default None
        Which generating speakers' utterances to evaluate. If None, all.
    target_alpha_keys : Optional[List[Any]], default None
        Which generating alphas' utterances to evaluate. If None, all.
    method : {"grid", "scipy"}, default "grid"
        Optimization method.
    alpha_bounds : Tuple[float, float], default (0.1, 50.0)
        Search range for alpha.
    grid_spacing : {"log", "linear"}, default "log"
        Grid spacing (only for method="grid").
    n_grid : int, default 100
        Number of grid points (only for method="grid").
    include_determ : bool, default True
        Whether to also evaluate alpha="determ".
    n_jobs : int, default 1
        Number of parallel jobs.
    backend : str, default "loky"
        Joblib backend.
    verbose : int, default 0
        Verbosity level.
    
    Returns
    -------
    None
        Mutates obs_data in place.
    
    Notes
    -----
    Storage key: "{psi}_T_fitted" (e.g., "inf_T_fitted", "persp_T_fitted")
    
    Storage structure:
        utt_record["log_lik_all_speaker"]["{psi}_T_fitted"] = {
            "max_log_lik": {T: float for T in fitting_Ts},
            "optimal_alpha": {T: float or "determ" for T in fitting_Ts}
        }
    """
    
    # =========================================================================
    # INPUT VALIDATION
    # =========================================================================
    
    if fitting_psi not in ["inf", "pers+", "pers-"]:
        raise ValueError(f"fitting_psi must be 'inf', 'pers+', or 'pers-', got '{fitting_psi}'")
    
    if method not in ["grid", "scipy"]:
        raise ValueError(f"method must be 'grid' or 'scipy', got '{method}'")
    
    if method == "grid" and grid_spacing not in ["log", "linear"]:
        raise ValueError(f"grid_spacing must be 'log' or 'linear', got '{grid_spacing}'")
    
    if backend not in ["loky", "multiprocessing", "threading"]:
        raise ValueError(f"backend must be 'loky', 'multiprocessing', or 'threading'")
    
    # =========================================================================
    # DETERMINE STORAGE KEY
    # =========================================================================
    
    psi_prefix = {"inf": "inf", "pers+": "persp", "pers-": "persm"}[fitting_psi]
    fitted_key = f"{psi_prefix}_T_fitted"
    
    # =========================================================================
    # EXTRACT CONFIGURATION
    # =========================================================================
    
    world = obs_data["world"]
    config_Ts = obs_data["config"]["Ts"]
    max_T = obs_data["config"]["max_T"]
    thetas = obs_data["config"]["thetas"]
    
    # =========================================================================
    # VALIDATE AND PROCESS fitting_Ts
    # =========================================================================
    
    if fitting_Ts is None:
        Ts = config_Ts
    else:
        if not isinstance(fitting_Ts, (list, np.ndarray)):
            raise TypeError("fitting_Ts must be a list of integers or None")
        
        fitting_Ts = list(fitting_Ts)
        
        if len(fitting_Ts) == 0:
            raise ValueError("fitting_Ts cannot be empty")
        
        for T in fitting_Ts:
            if not isinstance(T, (int, np.integer)):
                raise ValueError(f"All values in fitting_Ts must be integers")
        
        fitting_Ts = sorted(set(int(T) for T in fitting_Ts))
        
        invalid_Ts = [T for T in fitting_Ts if T not in config_Ts]
        if invalid_Ts:
            raise ValueError(
                f"fitting_Ts contains values not in config Ts. "
                f"Invalid: {invalid_Ts}. Available: {config_Ts}"
            )
        
        Ts = fitting_Ts
    
    invalid_Ts = [T for T in Ts if T < 1 or T > max_T]
    if invalid_Ts:
        raise ValueError(f"All T must satisfy 1 <= T <= {max_T}. Invalid: {invalid_Ts}")
    
    # =========================================================================
    # FLATTEN ALL UTTERANCE SEQUENCES WITH LOCATION TRACKING
    # =========================================================================
    
    flat_data = []
    skipped_combinations = set()
    
    for theta in thetas:
        for obs_list_pos, obs_info in enumerate(obs_data["observations"][theta]):
            obs_seq = obs_info["obs_seq"]
            obs_idx = obs_info["obs_idx"]
            
            if len(obs_seq) != max_T:
                raise ValueError(
                    f"Observation sequence length mismatch at theta={theta}, obs_idx={obs_idx}"
                )
            
            if obs_info["utterances"] is None:
                continue
            
            for speaker_key, alpha_dict in obs_info["utterances"].items():
                if target_speaker_keys is not None and speaker_key not in target_speaker_keys:
                    skipped_combinations.add(f"speaker_key={speaker_key}")
                    continue
                
                for alpha_key, utt_records in alpha_dict.items():
                    if target_alpha_keys is not None and alpha_key not in target_alpha_keys:
                        skipped_combinations.add(f"alpha_key={alpha_key}")
                        continue
                    
                    for utt_list_idx, utt_rec in enumerate(utt_records):
                        utt_seq = utt_rec["utt_seq"]
                        
                        if len(utt_seq) != max_T:
                            raise ValueError(
                                f"Utterance sequence length mismatch at theta={theta}, "
                                f"obs_idx={obs_idx}, speaker_key={speaker_key}"
                            )
                        
                        flat_data.append({
                            "obs_seq": obs_seq,
                            "utt_seq": utt_seq,
                            "location": (theta, obs_list_pos, speaker_key, alpha_key, utt_list_idx)
                        })
    
    # =========================================================================
    # VERBOSE OUTPUT
    # =========================================================================
    
    n_total = len(flat_data)
    
    if verbose > 0:
        print(f"Dynamic pragmatic speaker fitting (psi={fitting_psi}, method={method}):")
        print(f"  Storage key: '{fitted_key}'")
        print(f"  Processing {n_total} utterance sequences")
        print(f"  Ts: {Ts}")
        
        if method == "grid":
            n_alphas = n_grid + (1 if include_determ else 0)
            print(f"  Grid: {n_grid} points, spacing={grid_spacing}, bounds={alpha_bounds}")
        else:
            print(f"  Scipy: {n_total * len(Ts)} optimizations")
        
        if n_jobs != 1:

            n_workers = (cpu_count() if n_jobs == -1 
                        else max(1, cpu_count() + 1 + n_jobs) if n_jobs < 0 
                        else max(1, n_jobs))
            print(f"  Parallel: {n_workers} workers")
    
    # =========================================================================
    # EARLY EXIT
    # =========================================================================
    
    if n_total == 0:
        if verbose > 0:
            print("No utterance sequences to process")
        return
    
    # =========================================================================
    # METHOD-SPECIFIC COMPUTATION
    # =========================================================================
    
    if method == "grid":
        results = _dynamic_grid_search_multiT(
            flat_data=flat_data,
            world=world,
            psi=fitting_psi,
            Ts=Ts,
            alpha_bounds=alpha_bounds,
            grid_spacing=grid_spacing,
            n_grid=n_grid,
            include_determ=include_determ,
            n_jobs=n_jobs,
            backend=backend,
            verbose=verbose
        )
    else:
        results = _dynamic_scipy_optimization_multiT(
            flat_data=flat_data,
            world=world,
            psi=fitting_psi,
            Ts=Ts,
            alpha_bounds=alpha_bounds,
            include_determ=include_determ,
            n_jobs=n_jobs,
            backend=backend,
            verbose=verbose
        )
    
    # =========================================================================
    # DISTRIBUTE RESULTS BACK TO NESTED STRUCTURE
    # =========================================================================
    
    for i, item in enumerate(flat_data):
        theta, obs_list_pos, speaker_key, alpha_key, utt_list_idx = item["location"]
        
        utt_rec = obs_data["observations"][theta][obs_list_pos]["utterances"][speaker_key][alpha_key][utt_list_idx]
        
        if utt_rec["log_lik_all_speaker"] is None:
            utt_rec["log_lik_all_speaker"] = {}
        
        if fitted_key in utt_rec["log_lik_all_speaker"]:
            existing = utt_rec["log_lik_all_speaker"][fitted_key]
            existing["max_log_lik"].update(results["max_log_lik"][i])
            existing["optimal_alpha"].update(results["optimal_alpha"][i])
        else:
            utt_rec["log_lik_all_speaker"][fitted_key] = {
                "max_log_lik": results["max_log_lik"][i],
                "optimal_alpha": results["optimal_alpha"][i]
            }
    
    if verbose > 0:
        print(f"Completed: stored results in log_lik_all_speaker['{fitted_key}']")


def _dynamic_evaluate_sequence_all_alphas(
    obs_seq: List[Tuple[int, ...]],
    utt_seq: List[str],
    world: 'World',
    psi: str,
    alphas: List[Any],
    Ts: List[int]
) -> np.ndarray:
    """
    Evaluate log-likelihoods for one sequence across all alphas and Ts.
    
    For update_internal=True, the listener's beliefs evolve based on observed
    utterances. This evolution is alpha-independent, so we:
    1. Walk through the sequence collecting utilities at each step
    2. Apply softmax for all alphas at once
    
    Parameters
    ----------
    obs_seq : List[Tuple[int, ...]]
        Observation sequence.
    utt_seq : List[str]
        Utterance sequence.
    world : World
        The World object.
    psi : str
        Speaker goal: "inf", "pers+", or "pers-".
    alphas : List[Any]
        List of alpha values (floats and/or "determ").
    Ts : List[int]
        Sequence lengths to compute likelihoods for.
    
    Returns
    -------
    np.ndarray
        Shape (n_alphas, n_Ts) log-likelihoods.
    """
    
    steps_needed = max(Ts)
    n_alphas = len(alphas)
    n_Ts = len(Ts)
    T_indices = np.array(Ts) - 1
    
    # =========================================================================
    # SEPARATE NUMERIC ALPHAS FROM "determ"
    # =========================================================================
    
    numeric_indices = []
    numeric_alphas = []
    determ_idx = None
    
    for i, alpha in enumerate(alphas):
        if alpha == "determ":
            determ_idx = i
        else:
            numeric_indices.append(i)
            numeric_alphas.append(float(alpha))
    
    has_numeric = len(numeric_alphas) > 0
    has_determ = determ_idx is not None
    
    # =========================================================================
    # CREATE SPEAKER (alpha value doesn't matter for utility extraction)
    # =========================================================================
    
    speaker = PragmaticSpeaker_obs(
        world=world,
        omega="strat",
        psi=psi,
        update_internal=True,
        alpha=1.0,
        beta=0.0,
        initial_beliefs_theta=None
    )
    
    utterances = list(speaker.utility.index)
    utt_to_idx = {u: i for i, u in enumerate(utterances)}
    
    # =========================================================================
    # COLLECT UTILITIES AT EACH TIME STEP
    # =========================================================================
    # Loop is unavoidable: listener state at t depends on utterances u_0...u_{t-1}
    
    utilities_list = []
    
    for t in range(steps_needed):
        obs = obs_seq[t]
        obs_key = tuple(obs) if not isinstance(obs, tuple) else obs
        
        # Extract utilities directly from speaker.utility DataFrame
        utilities = speaker.utility[obs_key].values.copy()
        utilities_list.append(utilities)
        
        # Update listener with observed utterance
        utt = utt_seq[t]
        speaker.literal_listener.listen_and_update(utt)
        
        # Recompute utility table for new listener state
        speaker.utterance_log_prob_obs = speaker._compute_utterance_log_prob_obs(speaker.alpha)
    
    # Stack: shape (steps_needed, n_utterances)
    utilities_matrix = np.stack(utilities_list)
    
    # Observed utterance indices: shape (steps_needed,)
    utt_indices = np.array([utt_to_idx[utt_seq[t]] for t in range(steps_needed)])
    
    # =========================================================================
    # COMPUTE LOG PROBABILITIES FOR ALL ALPHAS AND TIME STEPS
    # =========================================================================
    
    all_log_probs = np.zeros((n_alphas, steps_needed))
    
    if has_numeric:
        alphas_arr = np.array(numeric_alphas)
        
        # Scale utilities: α * U(u, O)
        # Shape: (n_numeric, steps_needed, n_utterances)
        scaled_utilities = (alphas_arr[:, np.newaxis, np.newaxis] * 
                          utilities_matrix[np.newaxis, :, :])
        
        # Log normalizers: logsumexp over utterances
        # Shape: (n_numeric, steps_needed)
        log_normalizers = logsumexp(scaled_utilities, axis=2)
        
        # Scaled utilities for observed utterances
        # Shape: (n_numeric, steps_needed)
        observed_scaled = scaled_utilities[:, np.arange(steps_needed), utt_indices]
        
        # Log P(u|O, α) = α*U(u) - logsumexp(α*U)
        log_probs_numeric = observed_scaled - log_normalizers
        
        all_log_probs[np.array(numeric_indices), :] = log_probs_numeric
    
    # =========================================================================
    # HANDLE DETERMINISTIC ALPHA
    # =========================================================================
    
    if has_determ:
        # Max utility at each step
        max_utilities = np.max(utilities_matrix, axis=1)
        
        # Utility of observed utterance at each step
        observed_utilities = utilities_matrix[np.arange(steps_needed), utt_indices]
        
        # Check if observed is among maxima
        is_max = np.isclose(observed_utilities, max_utilities)
        
        # Count ties
        is_max_all = np.isclose(utilities_matrix, max_utilities[:, np.newaxis])
        n_ties = np.sum(is_max_all, axis=1)
        
        # Log prob: -log(n_ties) if max, -inf otherwise
        determ_log_probs = np.where(is_max, -np.log(n_ties), -np.inf)
        all_log_probs[determ_idx, :] = determ_log_probs
    
    # =========================================================================
    # CUMULATIVE SUM AND EXTRACT FOR EACH T
    # =========================================================================
    
    cumsum_log_probs = np.cumsum(all_log_probs, axis=1)
    all_lls = cumsum_log_probs[:, T_indices]
    
    # Handle -inf propagation
    has_neginf = np.isneginf(all_log_probs)
    for t_idx, T in enumerate(Ts):
        has_neginf_up_to_T = np.any(has_neginf[:, :T], axis=1)
        all_lls[has_neginf_up_to_T, t_idx] = -np.inf
    
    return all_lls


def _dynamic_grid_search_multiT(
    flat_data: List[Dict[str, Any]],
    world: 'World',
    psi: str,
    Ts: List[int],
    alpha_bounds: Tuple[float, float],
    grid_spacing: str,
    n_grid: int,
    include_determ: bool,
    n_jobs: int,
    backend: str,
    verbose: int
) -> Dict[str, List[Dict[int, Any]]]:
    """
    Grid search for optimal alpha with dynamic speaker.
    
    Parallelizes over sequences.
    
    Returns
    -------
    Dict with "max_log_lik" and "optimal_alpha" lists.
    """
    
    n_total = len(flat_data)
    n_Ts = len(Ts)
    
    # Create alpha grid
    a_min, a_max = alpha_bounds
    if grid_spacing == "log":
        alphas = list(np.exp(np.linspace(np.log(a_min), np.log(a_max), n_grid)))
    else:
        alphas = list(np.linspace(a_min, a_max, n_grid))
    
    if include_determ:
        alphas.append("determ")
    
    alphas_array = np.array(alphas, dtype=object)
    
    # Worker function
    def evaluate_single(item):
        return _dynamic_evaluate_sequence_all_alphas(
            obs_seq=item["obs_seq"],
            utt_seq=item["utt_seq"],
            world=world,
            psi=psi,
            alphas=alphas,
            Ts=Ts
        )
    
    # Execute
    if n_jobs == 1:
        all_results = [evaluate_single(item) for item in flat_data]
    else:
        all_results = Parallel(n_jobs=n_jobs, backend=backend, verbose=verbose)(
            delayed(evaluate_single)(item) for item in flat_data
        )
    
    # Stack: shape (n_total, n_alphas, n_Ts)
    all_lls = np.array(all_results)
    
    # Find optimal alpha for each (sequence, T)
    # Transpose to (n_alphas, n_total, n_Ts) for argmax
    all_lls_T = all_lls.transpose(1, 0, 2)
    best_alpha_indices = np.argmax(all_lls_T, axis=0)
    
    # Extract max log-likelihoods
    row_idx = np.arange(n_total)[:, np.newaxis]
    T_idx = np.arange(n_Ts)[np.newaxis, :]
    max_lls = all_lls_T[best_alpha_indices, row_idx, T_idx]
    
    # Convert to list of dicts
    max_log_lik_list = []
    optimal_alpha_list = []
    
    for i in range(n_total):
        max_log_lik_list.append({
            T: float(max_lls[i, t_idx]) for t_idx, T in enumerate(Ts)
        })
        optimal_alpha_list.append({
            T: alphas_array[best_alpha_indices[i, t_idx]] for t_idx, T in enumerate(Ts)
        })
    
    return {
        "max_log_lik": max_log_lik_list,
        "optimal_alpha": optimal_alpha_list
    }


def _dynamic_scipy_optimization_multiT(
    flat_data: List[Dict[str, Any]],
    world: 'World',
    psi: str,
    Ts: List[int],
    alpha_bounds: Tuple[float, float],
    include_determ: bool,
    n_jobs: int,
    backend: str,
    verbose: int
) -> Dict[str, List[Dict[int, Any]]]:
    """
    Scipy optimization for optimal alpha with dynamic speaker.
    
    Runs separate optimization for each (sequence, T) pair.
    
    Returns
    -------
    Dict with "max_log_lik" and "optimal_alpha" lists.
    """
    
    n_total = len(flat_data)
    
    # Build tasks: one per (sequence, T)
    tasks = []
    for item_idx, item in enumerate(flat_data):
        for T in Ts:
            tasks.append({
                "item_idx": item_idx,
                "T": T,
                "obs_seq": item["obs_seq"][:T],
                "utt_seq": item["utt_seq"][:T]
            })
    
    if verbose > 0:
        print(f"  Running {len(tasks)} scipy optimizations...")
    
    # Worker function
    def optimize_single(task):
        base_config = {
            "speaker_type": "pragmatic",
            "omega": "strat",
            "psi": psi,
            "update_internal": True,
            "beta": 0.0,
            "initial_beliefs_theta": None
        }
        
        result = log_likelihood_alpha_opt_utt_seq(
            world=world,
            obs_seq=task["obs_seq"],
            utt_seq=task["utt_seq"],
            speaker_config=base_config,
            alpha_bounds=alpha_bounds,
            grid_search=False,
            include_determ=include_determ
        )
        
        return {
            "item_idx": task["item_idx"],
            "T": task["T"],
            "optimal_alpha": result["optimal_alpha"],
            "max_log_lik": result["max_log_likelihood"]
        }
    
    # Execute
    if n_jobs == 1:
        results = [optimize_single(task) for task in tasks]
    else:
        results = Parallel(n_jobs=n_jobs, backend=backend, verbose=verbose)(
            delayed(optimize_single)(task) for task in tasks
        )
    
    # Reorganize by item_idx
    max_log_lik_list = [{} for _ in range(n_total)]
    optimal_alpha_list = [{} for _ in range(n_total)]
    
    for res in results:
        item_idx = res["item_idx"]
        T = res["T"]
        max_log_lik_list[item_idx][T] = res["max_log_lik"]
        optimal_alpha_list[item_idx][T] = res["optimal_alpha"]
    
    return {
        "max_log_lik": max_log_lik_list,
        "optimal_alpha": optimal_alpha_list
    }

In [None]:
RUN_DEMO =True
if RUN_DEMO:
    test_1_multiT = sample_observation_sequences_multiT(
        n=5, m=5,
        thetas=[0.3, 0.5],#[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
        Ts=[10, 11, 12, 13, 14, 15],
        n_obs_seq=5,
        random_seed=21,
        compute_obs_likelihood="all"
    )
    
    
    for true_speaker_config in TRUE_SPEAKER_CONFIGS:
        
        generate_utterances_for_observations_multiT(
            obs_data=test_1_multiT,
            speaker_config=true_speaker_config,
            n_utt_seq=5,
            n_jobs=-1,
            verbose=2
        )
    
    fitting_psis = ["inf", "pers+", "pers-"]
    
    for fitting_psi in fitting_psis:
        compute_pragmatic_dynamic_log_likelihood_multiT(
            obs_data=test_1_multiT,
            fitting_psi=fitting_psi,
            method="scipy",
            n_jobs=-1,
            verbose=2
        )

# Test coarse screening

In [1]:
from rsa_optimal_exp_sampling_fun import (
    sample_observation_sequences_multiT, generate_utterances_for_observations_multiT
)

from rsa_optimal_exp_fitting import (
    compute_literal_log_likelihood_multiT, 
    compute_pragmatic_static_log_likelihood_multiT,
    compute_pragmatic_dynamic_log_likelihood_multiT
)


N=5
M=5
Ts = [5, 7, 10, 13, 15]
THETAs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
SEED = 21
N_OBS_SEQ = 1
N_UTT_SEQ = 45
VERBOSE = 3

TRUE_SPEAKER_CONFIGS = [
    {
        "speaker_type": "literal"
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "inf",
        "alpha": 4.0,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "inf",
        "alpha": 4.0,
        "update_internal": True
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers+",
        "alpha": 4.0,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers+",
        "alpha": 4.0,
        "update_internal": True
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers-",
        "alpha": 4.0,
        "update_internal": False
    },
    {
        "speaker_type": "pragmatic",
        "omega": "strat",
        "psi": "pers-",
        "alpha": 4.0,
        "update_internal": True
    },
]

FITTING_SPEAKER_PSIS = ["inf", "pers+", "pers-"]


raw_data = sample_observation_sequences_multiT(
    n=N, m=M,
    thetas=THETAs,
    Ts=Ts,
    n_obs_seq=N_OBS_SEQ,
    random_seed=SEED,
    compute_obs_likelihood="all"
)


for true_speaker_config in TRUE_SPEAKER_CONFIGS:
    generate_utterances_for_observations_multiT(
        obs_data=raw_data,
        speaker_config=true_speaker_config,
        n_utt_seq=N_UTT_SEQ,
        n_jobs=-1,
        verbose=VERBOSE
    )


compute_literal_log_likelihood_multiT(
    obs_data=raw_data,
    verbose=VERBOSE
)


for fitting_speaker_psi in FITTING_SPEAKER_PSIS:
    
    compute_pragmatic_static_log_likelihood_multiT(
        obs_data=raw_data,
        fitting_psi=fitting_speaker_psi,
        n_jobs=1,
        verbose=VERBOSE
    )

    compute_pragmatic_dynamic_log_likelihood_multiT(
        obs_data=raw_data,
        fitting_psi=fitting_speaker_psi,
        n_jobs=-1,
        verbose=VERBOSE
    )


Generating utterances: 2 obs_seq × 45 utt_seq
  Speaker: literal, alpha: 0.0, Ts: [5, 7, 10, 13, 15]
  Workers: 1
Generating utterances: 2 obs_seq × 45 utt_seq
  Speaker: inf_F, alpha: 4.0, Ts: [5, 7, 10, 13, 15]
  Workers: 1
Generating utterances: 2 obs_seq × 45 utt_seq
  Speaker: inf_T, alpha: 4.0, Ts: [5, 7, 10, 13, 15]
  Workers: 1
Generating utterances: 2 obs_seq × 45 utt_seq
  Speaker: persp_F, alpha: 4.0, Ts: [5, 7, 10, 13, 15]
  Workers: 1
Generating utterances: 2 obs_seq × 45 utt_seq
  Speaker: persp_T, alpha: 4.0, Ts: [5, 7, 10, 13, 15]
  Workers: 1
Generating utterances: 2 obs_seq × 45 utt_seq
  Speaker: persm_F, alpha: 4.0, Ts: [5, 7, 10, 13, 15]
  Workers: 1
Generating utterances: 2 obs_seq × 45 utt_seq
  Speaker: persm_T, alpha: 4.0, Ts: [5, 7, 10, 13, 15]
  Workers: 1
Processing 630 utterance sequences (2 unique obs positions) across all configurations
Completed: stored results in log_lik_all_speaker['literal_fitted']
Static pragmatic speaker fitting (psi=inf, method=gri

# END

In [8]:
import os
import pickle

# Define the output directory and file path
output_dir = "/home/users/fangke/prag_net/optimal_design"
os.makedirs(output_dir, exist_ok=True)

# Save the dictionary using pickle
pickle_file = os.path.join(output_dir, "N5M5T15.pkl")

with open(pickle_file, 'wb') as f:
    pickle.dump(raw_data, f)

print(f"Saved obs_data to {pickle_file}")

Saved obs_data to /home/users/fangke/prag_net/optimal_design/obs_seqs_N5M6T15.pkl


In [9]:
raw_data["observations"][0.2][0]["utterances"]["persp_T"][4.0][1]["log_lik_all_speaker"]["persm_T_fitted"]

dict_keys(['utt_idx', 'utt_seq', 'utt_seed', 'log_lik_true_speaker', 'log_lik_all_speaker'])

In [5]:
raw_data["observations"].keys()

dict_keys([0.1, 0.2])

In [2]:
"""
Stimuli Generator for RSA Human Experiment (N=5, M=1 Version)
UPDATED: Generates ALL possible arrangements for each effectiveness level.
Uses Twemoji images from CDN for consistent, high-quality emoji display.

For 5 patients with k effective:
- k=0: 1 arrangement  (C(5,0) = 1)
- k=1: 5 arrangements (C(5,1) = 5)
- k=2: 10 arrangements (C(5,2) = 10)
- k=3: 10 arrangements (C(5,3) = 10)
- k=4: 5 arrangements (C(5,4) = 5)
- k=5: 1 arrangement  (C(5,5) = 1)
Total: 32 images

Output folder: stimuli_emoji_n5m1/
Naming: effective_{k}_v{variant}.png (e.g., effective_3_v0.png through effective_3_v9.png)
"""
from PIL import Image
from pathlib import Path
from typing import Dict, List, Tuple
from itertools import combinations
import urllib.request

# Configuration
N_PATIENTS = 5

# Emoji Unicode code points for Twemoji URLs
EMOJI_CONFIG = {
    "effective": {
        "char": "😃",
        "codepoint": "1f603",  # Smiling face with open mouth
    },
    "ineffective": {
        "char": "🤒",
        "codepoint": "1f912",  # Face with thermometer
    }
}

TWEMOJI_URL = "https://cdn.jsdelivr.net/gh/twitter/twemoji@latest/assets/72x72/{codepoint}.png"


def get_cache_dir() -> Path:
    """Get or create the emoji cache directory."""
    cache_dir = Path.home() / ".cache" / "rsa_stimuli" / "emoji"
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


def download_emoji(codepoint: str, cache_dir: Path) -> Path:
    """Download an emoji PNG from Twemoji CDN."""
    cache_path = cache_dir / f"{codepoint}.png"
    
    if cache_path.exists():
        return cache_path
    
    url = TWEMOJI_URL.format(codepoint=codepoint)
    try:
        urllib.request.urlretrieve(url, cache_path)
        return cache_path
    except Exception as e:
        raise RuntimeError(f"Failed to download emoji {codepoint}: {e}\nURL: {url}")


def load_emoji_images(size: int = 72) -> Dict[str, Image.Image]:
    """Download and load all required emoji images."""
    cache_dir = get_cache_dir()
    emojis = {}
    
    print("Loading emoji images from Twemoji CDN...")
    for name, config in EMOJI_CONFIG.items():
        print(f"  {config['char']} ({name})...", end=" ")
        png_path = download_emoji(config["codepoint"], cache_dir)
        img = Image.open(png_path).convert("RGBA")
        img = img.resize((size, size), Image.Resampling.LANCZOS)
        emojis[name] = img
        print("✓")
    
    return emojis


def get_all_arrangements(num_effective: int) -> List[Tuple[int, ...]]:
    """
    Get all possible arrangements of effective/ineffective patients.
    
    Returns a list of tuples, where each tuple contains the positions (0-4)
    of effective patients.
    
    Example for num_effective=2:
    [(0,1), (0,2), (0,3), (0,4), (1,2), (1,3), (1,4), (2,3), (2,4), (3,4)]
    """
    return list(combinations(range(N_PATIENTS), num_effective))


def create_stimuli_image(
    effective_positions: Tuple[int, ...],
    emoji_images: Dict[str, Image.Image],
    emoji_size: int = 72,
    padding: int = 10,
    output_path: str = None
) -> Image.Image:
    """
    Create a row of 5 emojis with specified positions being effective.
    
    Parameters
    ----------
    effective_positions : Tuple[int, ...]
        Tuple of positions (0-4) that should show effective (happy) faces.
        All other positions show ineffective (sick) faces.
    emoji_images : Dict[str, Image.Image]
        Pre-loaded emoji images
    emoji_size : int
        Size of each emoji
    padding : int
        Padding between emojis
    output_path : str
        If provided, save the image to this path
    
    Returns
    -------
    Image.Image
        The PIL Image object
    """
    # Calculate dimensions
    width = N_PATIENTS * emoji_size + (N_PATIENTS + 1) * padding
    height = emoji_size + 2 * padding
    
    # Create image with white background
    img = Image.new('RGBA', (width, height), color='#FFFFFF')
    
    # Place emojis based on effective_positions
    effective_set = set(effective_positions)
    for i in range(N_PATIENTS):
        emoji_key = "effective" if i in effective_set else "ineffective"
        emoji_img = emoji_images[emoji_key]
        
        x = padding + i * (emoji_size + padding)
        y = padding
        
        img.paste(emoji_img, (x, y), emoji_img)
    
    # Convert to RGB
    img_rgb = Image.new('RGB', img.size, '#FFFFFF')
    img_rgb.paste(img, mask=img.split()[3] if img.mode == 'RGBA' else None)
    
    if output_path:
        img_rgb.save(output_path, quality=95)
    
    return img_rgb


def positions_to_visual(effective_positions: Tuple[int, ...]) -> str:
    """Convert positions to emoji visual representation."""
    result = []
    effective_set = set(effective_positions)
    for i in range(N_PATIENTS):
        result.append("😃" if i in effective_set else "🤒")
    return "".join(result)


def generate_all_stimuli(output_dir: str = "stimuli_emoji_n5m1", emoji_size: int = 72) -> None:
    """Generate stimuli images for all possible arrangements."""
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)
    
    emoji_images = load_emoji_images(size=emoji_size)
    
    print(f"\nGenerating all arrangement stimuli in '{output_dir}/'...")
    print("-" * 60)
    
    total_images = 0
    arrangement_counts = {}
    
    for num_effective in range(N_PATIENTS + 1):
        arrangements = get_all_arrangements(num_effective)
        arrangement_counts[num_effective] = len(arrangements)
        
        print(f"\n{num_effective}/5 effective: {len(arrangements)} arrangement(s)")
        
        for variant_idx, positions in enumerate(arrangements):
            filename = output_path / f"effective_{num_effective}_v{variant_idx}.png"
            
            create_stimuli_image(
                effective_positions=positions,
                emoji_images=emoji_images,
                emoji_size=emoji_size,
                output_path=str(filename)
            )
            
            visual = positions_to_visual(positions)
            positions_str = ",".join(map(str, positions)) if positions else "none"
            print(f"  v{variant_idx}: {visual} (effective at positions: {positions_str}) → {filename.name}")
            total_images += 1
    
    print("\n" + "-" * 60)
    print(f"✓ Done! Generated {total_images} images.")
    print("\nSummary of arrangements per effectiveness level:")
    for k, count in arrangement_counts.items():
        print(f"  {k} effective: {count} variants")


def generate_arrangement_map() -> Dict[int, List[Tuple[int, ...]]]:
    """
    Generate a mapping from num_effective to all possible arrangements.
    Useful for verification and JavaScript code generation.
    """
    return {k: get_all_arrangements(k) for k in range(N_PATIENTS + 1)}


def print_javascript_config():
    """Print JavaScript configuration for stimuli.js"""
    print("\n" + "=" * 60)
    print("JavaScript configuration for stimuli.js:")
    print("=" * 60)
    print("""
// Arrangement data: maps numEffective to list of variant indices
const ARRANGEMENT_COUNTS = {
    0: 1,   // C(5,0) = 1
    1: 5,   // C(5,1) = 5
    2: 10,  // C(5,2) = 10
    3: 10,  // C(5,3) = 10
    4: 5,   // C(5,4) = 5
    5: 1    // C(5,5) = 1
};

// Total: 32 images
""")
    
    print("\n// Detailed arrangement mappings (positions of effective patients):")
    print("const ARRANGEMENTS = {")
    for k in range(N_PATIENTS + 1):
        arrangements = get_all_arrangements(k)
        arr_str = ", ".join([str(list(a)) for a in arrangements])
        print(f"    {k}: [{arr_str}],")
    print("};")


def main():
    """Main entry point."""
    print("=" * 60)
    print("  RSA Experiment Stimuli Generator (Randomized Arrangements)")
    print(f"  Configuration: {N_PATIENTS} patients, 1 session")
    print("  Using Twemoji images from CDN")
    print("=" * 60)
    
    print(f"\n  Emojis:")
    print(f"    Effective:   😃 (Twemoji: {EMOJI_CONFIG['effective']['codepoint']})")
    print(f"    Ineffective: 🤒 (Twemoji: {EMOJI_CONFIG['ineffective']['codepoint']})")
    
    print(f"\n  For {N_PATIENTS} patients, generating ALL possible arrangements:")
    print(f"    This creates C(5,k) images for each k effective patients")
    
    generate_all_stimuli(output_dir="stimuli_emoji_n5m1")
    
    print_javascript_config()
    
    print("\n" + "=" * 60)
    print("  Cache location: " + str(get_cache_dir()))
    print("  Done! Images created with Twemoji and randomized arrangements.")
    print("=" * 60)


if __name__ == "__main__":
    main()


  RSA Experiment Stimuli Generator (Randomized Arrangements)
  Configuration: 5 patients, 1 session
  Using Twemoji images from CDN

  Emojis:
    Effective:   😃 (Twemoji: 1f603)
    Ineffective: 🤒 (Twemoji: 1f912)

  For 5 patients, generating ALL possible arrangements:
    This creates C(5,k) images for each k effective patients
Loading emoji images from Twemoji CDN...
  😃 (effective)... ✓
  🤒 (ineffective)... ✓

Generating all arrangement stimuli in 'stimuli_emoji_n5m1/'...
------------------------------------------------------------

0/5 effective: 1 arrangement(s)
  v0: 🤒🤒🤒🤒🤒 (effective at positions: none) → effective_0_v0.png

1/5 effective: 5 arrangement(s)
  v0: 😃🤒🤒🤒🤒 (effective at positions: 0) → effective_1_v0.png
  v1: 🤒😃🤒🤒🤒 (effective at positions: 1) → effective_1_v1.png
  v2: 🤒🤒😃🤒🤒 (effective at positions: 2) → effective_1_v2.png
  v3: 🤒🤒🤒😃🤒 (effective at positions: 3) → effective_1_v3.png
  v4: 🤒🤒🤒🤒😃 (effective at positions: 4) → effective_1_v4.png

2/5 effective: 10 arr