In [20]:
from scipy.stats import poisson
from memo import memo
import jax.numpy as jnp
import jax
import numpy as np
from jax.scipy.special import logsumexp
from typing import Union
from tqdm import tqdm
from scipy.special import gammaln, logsumexp, softmax
from enum import IntEnum
from memo.lib import * 
import matplotlib.pyplot as plt
from scipy.stats import poisson

from world_jax import WorldJAX

from rsa_memo import *
from test_utils import *

### 1. Generate Cognitive Hierarchy with Possion Distribution

In [21]:
def sample_level_from_poisson(max_level, tau=1.5, random_seed=42):
    #np.random.seed(random_seed)
    levels = np.arange(max_level + 1)
    pmf = poisson.pmf(levels, mu=tau)
    pmf = pmf / pmf.sum()
    sampled_level = np.random.choice(levels, p=pmf)
    return sampled_level

def get_level_distribution(max_level, tau=1.5):
    levels = np.arange(max_level + 1)
    pmf = poisson.pmf(levels, mu=tau)
    pmf = pmf / pmf.sum()  # Normalize
    return {level: prob for level, prob in zip(levels, pmf)}

print(sample_level_from_poisson(10, tau=2.0))
print(get_level_distribution(10, tau=2.0))

1
{np.int64(0): np.float64(0.13533640764185262), np.int64(1): np.float64(0.27067281528370524), np.int64(2): np.float64(0.27067281528370524), np.int64(3): np.float64(0.1804485435224701), np.int64(4): np.float64(0.09022427176123506), np.int64(5): np.float64(0.03608970870449403), np.int64(6): np.float64(0.012029902901498002), np.int64(7): np.float64(0.003437115114713715), np.int64(8): np.float64(0.0008592787786784285), np.int64(9): np.float64(0.00019095083970631786), np.int64(10): np.float64(3.819016794126347e-05)}


### 2. Auto-generate Memo High Level Code

**Manual copy-paste**: Run `create_high_level_agents(max_level=N)` to generate code, then copy and paste it into the codebase.

Memo doesn't allow using exec() to actually create the functions: 

*memo.core.MemoError: Python couldn't find your memo source code*

*hint: You cannot define a new @memo in the Python interactive REPL.*
    *Try writing your memo code to a file and running via `python*
    *filename.py`. If you really want an interactive experience,*
    *memo also works inside Jupyter notebooks.*
]

**Recursion Pattern:**
- **Sn**: Speaker at level n models listener L(n-1)
- **Ln_obs**: Listener at level n infers observation, modeling speaker Sn  
- **S(n+1)**: Speaker at level n+1 models listener Ln

In [22]:
def create_high_level_agents(max_level=5, agent_type='inf'):
    """
    Automatically generate high-level Speaker and Listener agents with correct parameter passing.

    Recursion Pattern:
    - Sn models L(n-1)_obs: "What observation will L(n-1) infer from my utterance?"
    - Ln_obs models Sn: "What observation led to this utterance from Sn?"
    - S(n+1) models Ln_obs: "What observation will Ln infer from my utterance?"

    Parameter Passing Pattern:
    - Level 0-1: Simple parameters (prior, alpha)
    - Level 1+: Complex parameters (prior, l0_prior, alpha_dist, alpha_vals, [psi_dist, theta_values])

    Args:
        max_level: Maximum speaker level to generate (default=5)
        agent_type: Type of agents to generate ('inf' for informed, 'vig' for vigilant, 'all' for both)

    Returns:
        str: Generated code strings
    """

    code_parts = []

    # Starting from level 2, since we already have L1_cred_obs and S2_inf manually defined
    for level in range(2, max_level):

        # ==================== INFORMED AGENTS ====================
        if agent_type in ['inf', 'all']:
            # Generate Ln_cred_obs that models Sn_inf
            listener_name = f"L{level}_cred_obs"
            prev_listener = f"L{level-1}_cred_obs"

            # All levels from L2+ need full parameters because they model speakers that use L1_cred_obs
            listener_params = "prior: ..., l0_prior: ..., alpha_dist: ..., alpha_vals: ..."
            prev_listener_call = f"{prev_listener}[u, world.obs](prior, l0_prior, alpha_dist, alpha_vals)"
            speaker_choice = f"is_utt_true_for_obs(u, world.obs) * exp(array_index(alpha_vals, alp) * log({prev_listener_call} + 1e-10))"

            listener_code = f'''@memo
def {listener_name}[_u: U, _obs: Obs]({listener_params}):
    """Level {level} listener inferring observation, modeling S{level}_inf"""
    listener: knows(_u, _obs)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        speaker: knows(world.obs),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp={speaker_choice}
        )
    ]
    listener: observes[speaker.u] is _u
    return listener[Pr[world.obs == _obs]]'''

            # Generate S(n+1)_inf that models Ln_cred_obs
            next_speaker_name = f"S{level+1}_inf"
            speaker_params = "prior: ..., l0_prior: ..., alpha, alpha_dist: ..., alpha_values: ..."
            listener_call = f"{listener_name}[u, _obs](prior, l0_prior, alpha_dist, alpha_values)"

            speaker_code = f'''@memo
def {next_speaker_name}[_obs: Obs, _u: U]({speaker_params}):
    """Level {level+1} informed speaker, modeling L{level}_cred_obs"""
    speaker: knows(_obs, _u)
    speaker: chooses(
        u in U,
        wpp=is_utt_true_for_obs(u, _obs) * exp(alpha * log({listener_call} + 1e-10))
    )
    return Pr[speaker.u == _u]'''

            code_parts.append(listener_code)
            code_parts.append(speaker_code)

        # ==================== VIGILANT AGENTS (for persuasive modeling) ====================
        if agent_type in ['vig', 'all']:
            # Generate Ln_vig_obs (models speakers with uncertain type psi)
            vig_listener_name = f"L{level}_vig_obs"
            prev_vig_listener = f"L{level-1}_vig_obs"
            prev_vig_expected = f"L{level-1}_vig_expected_theta"

            vig_listener_params = "prior: ..., l0_prior: ..., alpha_dist: ..., alpha_values: ..., psi_dist: ..., theta_values: ..."

            vig_listener_code = f'''@memo
def {vig_listener_name}[_u: U, _obs: Obs, _alpha: Alpha, _psi: Psi]({vig_listener_params}):
    """Level {level} vigilant listener inferring observation, modeling S{level}_pers"""
    listener: knows(_u, _obs, _alpha, _psi)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        world: given(stype in Psi, wpp=array_index(psi_dist, stype)),
        speaker: knows(world.obs, world.stype),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp=is_utt_true_for_obs(u, world.obs) * exp(
                array_index(alpha_values, alp) * log(
                    (world.stype == 0) * {prev_vig_listener}[u, world.obs, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) +
                    (world.stype == 1) * {prev_vig_expected}[u, _theta, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) +
                    (world.stype == 2) * (1.0 - {prev_vig_expected}[u, _theta, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values))
                    + 1e-10
                )
            )
        )
    ]
    listener: observes[speaker.u] is _u
    return listener[Pr[world.obs == _obs]]'''

            code_parts.append(vig_listener_code)

            # Also generate Ln_vig_expected_theta for persuasive speakers at next level
            vig_expected_name = f"L{level}_vig_expected_theta"
            vig_expected_code = f'''@memo
def {vig_expected_name}[_u: U, _theta: Theta, _alpha: Alpha, _psi: Psi]({vig_listener_params}):
    """Level {level} vigilant listener computing expected theta, for S{level+1}_pers"""
    listener: knows(_u, _theta, _alpha, _psi)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        world: given(stype in Psi, wpp=array_index(psi_dist, stype)),
        speaker: knows(world.obs, world.stype),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp=is_utt_true_for_obs(u, world.obs) * exp(
                array_index(alpha_values, alp) * log(
                    (world.stype == 0) * {prev_vig_listener}[u, world.obs, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) +
                    (world.stype == 1) * {prev_vig_expected}[u, _theta, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) +
                    (world.stype == 2) * (1.0 - {prev_vig_expected}[u, _theta, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values))
                    + 1e-10
                )
            )
        )
    ]
    listener: observes[speaker.u] is _u
    return listener[E[array_index(theta_values, world.theta)]]'''

            code_parts.append(vig_expected_code)

            # Generate S(n+1)_pers that models Ln_vig
            next_pers_speaker = f"S{level+1}_pers"
            pers_speaker_params = "prior: ..., l0_prior: ..., alpha, alpha_dist: ..., alpha_values: ..., theta_values: ..., psi, psi_dist: ..."

            pers_speaker_code = f'''@memo
def {next_pers_speaker}[_obs: Obs, _u: U, _alpha: Alpha, _psi: Psi, _theta: Theta]({pers_speaker_params}):
    """Level {level+1} persuasive speaker, modeling L{level}_vig"""
    speaker: knows(_obs, _u, _alpha, _psi, _theta)
    speaker: chooses(
        u in U,
        wpp=(
            # psi == 2.0: pers_down - minimize expected theta
            is_utt_true_for_obs(u, _obs) * exp(
                alpha * log(1.0 - {vig_expected_name}[u, _theta, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) + 1e-10)
            )
            if psi == 2.0
            else (
                # psi == 1.0: pers_up - maximize expected theta
                is_utt_true_for_obs(u, _obs) * exp(
                    alpha * log({vig_expected_name}[u, _theta, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) + 1e-10)
                )
                if psi == 1.0
                # psi == 0.0: inf - maximize informativeness about obs
                else is_utt_true_for_obs(u, _obs) * exp(
                    alpha * log({vig_listener_name}[u, _obs, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) + 1e-10)
                )
            )
        )
    )
    return Pr[speaker.u == _u]'''

            code_parts.append(pers_speaker_code)

    return "\n\n".join(code_parts)

# Generate and display code for INFORMED agents
print("=== Generated INFORMED High-Level Agents (Copy and paste below) ===\n")
generated_inf_code = create_high_level_agents(max_level=5, agent_type='inf')
print(generated_inf_code)

print("\n\n" + "="*80 + "\n")
print("=== Generated VIGILANT High-Level Agents (Copy and paste below) ===\n")
generated_vig_code = create_high_level_agents(max_level=3, agent_type='vig')  # Smaller max for demo
print(generated_vig_code)

=== Generated INFORMED High-Level Agents (Copy and paste below) ===

@memo
def L2_cred_obs[_u: U, _obs: Obs](prior: ..., l0_prior: ..., alpha_dist: ..., alpha_vals: ...):
    """Level 2 listener inferring observation, modeling S2_inf"""
    listener: knows(_u, _obs)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        speaker: knows(world.obs),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp=is_utt_true_for_obs(u, world.obs) * exp(array_index(alpha_vals, alp) * log(L1_cred_obs[u, world.obs](prior, l0_prior, alpha_dist, alpha_vals) + 1e-10))
        )
    ]
    listener: observes[speaker.u] is _u
    return listener[Pr[world.obs == _obs]]

@memo
def S3_inf[_obs: Obs, _u: U](prior: ..., l0_prior: ..., alpha, alpha_dist: ..., alpha_values: ...):
    """Level 3 informed speaker, modeling L2_

### 3. Testing the Validity of High-Level Agent Code

#### world setup

In [23]:
worldJAX = WorldJAX(
    n=1,  # N independent binomial experiment
    m=7,  # each experiment has M Bernoulli trials
    theta_values=jnp.linspace(0, 1, 11)  # theta from 0 to 1
)
Theta = jnp.arange(len(worldJAX.theta_values))
U = jnp.arange(len(worldJAX.utterances))
Obs = jnp.arange(len(worldJAX.observations))
utterance_log_prob_obs = literal_semantics_uniform( jnp.array(worldJAX.utterance_truth.values))
obs_log_likelihood_theta = jnp.array(worldJAX.obs_log_likelihood_theta.values)
utterance_truth_matrix = jnp.array(worldJAX.utterance_truth.values)

#--------------- Set the Word Types for Memo ----------------#
Psi_dist = jnp.array([1/3, 1/3, 1/3])  # possible Psi values
Psi = jnp.arange(len(Psi_dist)) 
uniform_prior = jnp.ones(len(Theta)) / len(Theta)
alpha_values = jnp.array([1.0, 5.0, 10.0])  # Actual alpha values
Alpha = jnp.arange(len(alpha_values))  # Indices: 0, 1, 2

def is_utt_true_for_obs(u, obs):
    return utterance_truth_matrix[u, obs] > 0
def get_obs_prob(obs, theta):
    return jnp.exp(obs_log_likelihood_theta[obs, theta])
def get_utt_prob(u, obs):
    return jnp.exp(utterance_log_prob_obs[u, obs])
def marginalize_L1_vig(L1_vig_memo):
    p_theta_given_u = jnp.sum(L1_vig_memo, axis=(-1, -2))
    return p_theta_given_u

#### low levels

In [24]:
# Level 0
#。*
@memo
def S0[_obs: Obs, _u: U]():
    speaker: knows(_obs)
    speaker: chooses(u in U, wpp=is_utt_true_for_obs(u, _obs))
    return Pr[speaker.u == _u]

#。*
# this prior here means the listener's belief over theta distribution
@memo
def L0[_u: U, _theta: Theta](prior: ...):
    listener: knows(_u, _theta)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        world: chooses(u in U, wpp=get_utt_prob(u, obs))
    ]
    listener: observes [world.u] is _u
    return listener[Pr[world.theta == _theta]] 

# Level 1
#seems not used... maybe still keep it for now
@memo
def S0_belief[_obs: Obs, _u: U, _theta: Theta](prior: ...):
    speaker: knows(_obs, _theta, _u)
    speaker: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta))
    ]
    speaker: observes [world.obs] is _obs
    speaker: chooses(u in U, wpp=is_utt_true_for_obs(u, _obs)) 
    return E[speaker[Pr[world.theta == _theta]]]
    return Pr[speaker.u == _u]

#used for S1 inf. the prior here means the listener's belief over theta distribution
@memo
def L0_obs[_u: U, _obs: Obs](prior: ...):
    listener: knows(_u, _obs)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        world: chooses(u in U, wpp=get_utt_prob(u, obs))
    ]
    listener: observes [world.u] is _u
    return listener[Pr[world.obs == _obs]]

#。*
# this 'prior' here means the imagined listener's belief over theta distribution
@memo
def S1_inf[_obs: Obs, _u: U](prior: ..., alpha):
    speaker: knows(_obs, _u)
    speaker: chooses(
        u in U, 
        wpp=is_utt_true_for_obs(u, _obs) * exp(alpha * log(L0_obs[u, _obs](prior) + 1e-10))
    )
    return Pr[speaker.u == _u]

#used for S1 pers. the 'prior' here means the imagined listener's belief over theta distribution
    #we need to include theta_values here to compute expected theta, here just pass worldJAX.theta_values
@memo
def L0_expected_theta[_u: U](prior: ..., theta_values: ...):
    listener: knows(_u)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        world: chooses(u in U, wpp=get_utt_prob(u, obs))
    ]
    listener: observes [world.u] is _u
    return listener[E[array_index(theta_values, world.theta)]]


# this 'prior' here means the imagined listener's belief over theta distribution
@memo
def S1_pers_up[_obs: Obs, _u: U](prior: ..., alpha, theta_values: ...):
    speaker: knows(_obs, _u)
    speaker: chooses(
        u in U,
        wpp=is_utt_true_for_obs(u, _obs) * exp(
            alpha * log(L0_expected_theta[u](prior, theta_values) + 1e-10)
        )
    )
    return Pr[speaker.u == _u]

# this 'prior' here means the imagined listener's belief over theta distribution
@memo  
def S1_pers_down[_obs: Obs, _u: U](prior: ..., alpha, theta_values: ...):
    speaker: knows(_obs, _u)
    speaker: chooses(
        u in U,
        wpp=is_utt_true_for_obs(u, _obs) * exp(
            alpha * log(1.0 - L0_expected_theta[u](prior, theta_values) + 1e-10)
        ) 
    )
    return Pr[speaker.u == _u]

#。*
#this function combines pers_up, pers_down, inf into one function with psi parameter
#this prior here means the imagined listener's belief over theta distribution
@memo  
def S1_pers[_obs: Obs, _u: U](prior: ..., alpha, theta_values: ..., psi):
    speaker: knows(_obs, _u)
    speaker: chooses(
        u in U,
        wpp=(
            # psi == 2.0: pers_down
            is_utt_true_for_obs(u, _obs) * exp(alpha * log(1.0 - L0_expected_theta[u](prior, theta_values) + 1e-10))
            if psi == 2.0
            else (
                # psi == 1.0: pers_up
                is_utt_true_for_obs(u, _obs) * exp(alpha * log(L0_expected_theta[u](prior, theta_values) + 1e-10))
                if psi == 1.0
                # psi == 0.0: inf
                else is_utt_true_for_obs(u, _obs) * exp(alpha * log(L0_obs[u, _obs](prior) + 1e-10))
            )
        )
    )
    return Pr[speaker.u == _u]

#。*
#prior here means this listener's belief over theta distribution & imagined L0 within S1's prior 
#!!ASSUMPTION: this listener believes that when speaker imaging L0, it's also knowing L1's prior?
@memo
def L1_cred[_u: U, _theta: Theta, _alpha: Alpha](prior: ..., l0_prior: ..., alpha_dist: ..., alpha_vals: ...):
    """
    L1 listener who infers theta from utterance u,
    assuming speaker is S1_inf (informed speaker)
    """
    listener: knows(_u, _theta, _alpha)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        speaker: knows(world.obs),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp=is_utt_true_for_obs(u, world.obs) * exp(array_index(alpha_vals, alp) * log(L0_obs[u, world.obs](l0_prior) + 1e-10))
        )
    ]
    listener: observes[speaker.u] is _u
    return listener[Pr[world.theta == _theta, speaker.alp == _alpha]]

#。*
#prior here means this listener's belief over theta distribution & imagined L0 within S1's prior 
#!!ASSUMPTION: this listener believes that when speaker imaging L0, it's also knowing L1's prior?
@memo
def L1_vig[_u: U, _theta: Theta, _alpha: Alpha, _psi: Psi](prior: ..., l0_prior: ..., alpha_dist:..., alpha_values:..., psi_dist: ..., theta_values: ...):
    listener: knows(_u, _theta, _alpha, _psi)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        world: given(stype in Psi, wpp=array_index(psi_dist, stype)),
        speaker: knows(world.obs, world.stype),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp=is_utt_true_for_obs(u, world.obs) * exp(
                array_index(alpha_values, alp) * log(
                    (world.stype == 0) * L0_obs[u, world.obs](l0_prior) +
                    (world.stype == 1) * L0_expected_theta[u](l0_prior, theta_values) +
                    (world.stype == 2) * (1.0 - L0_expected_theta[u](l0_prior, theta_values))
                    + 1e-10
                )
            )
        )
    ]
    listener: observes [speaker.u] is _u
    return listener[Pr[world.theta == _theta, speaker.alp == _alpha, world.stype == _psi]]

# Level 2
#used for S2 inf.the prior here means BOTH THIS L1 & the imagined L0 within S1's prior
@memo
def L1_cred_obs[_u: U, _obs: Obs](prior: ..., l0_prior:..., alpha_dist: ..., alpha_vals: ...):
    listener: knows(_u, _obs)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        speaker: knows(world.obs),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp=is_utt_true_for_obs(u, world.obs) * exp(array_index(alpha_vals, alp) * log(L0_obs[u, world.obs](l0_prior) + 1e-10))
        )
    ]
    listener: observes[speaker.u] is _u
    return listener[Pr[world.obs == _obs]]

#。*
#prior here means the prior for the imagined L1 within S2's prior
@memo
def S2_inf[_obs: Obs, _u: U](prior: ..., l0_prior:..., alpha,alpha_dist:..., alpha_values:..., ):
    speaker: knows(_obs, _u)
    speaker: chooses(
        u in U, 
        wpp=is_utt_true_for_obs(u, _obs) * exp(alpha * log(L1_cred_obs[u, _obs](prior, l0_prior, alpha_dist, alpha_values) + 1e-10))
    )
    return Pr[speaker.u == _u]

#used for S2 pers. 
    #the prior here means BOTH THIS L1 & the imagined L0 within S1's prior
@memo
def L1_vig_obs[_u: U, _obs: Obs, _alpha: Alpha, _psi: Psi](prior: ..., l0_prior:..., alpha_dist:..., alpha_values:..., psi_dist: ..., theta_values: ...):
    listener: knows(_u, _obs, _alpha, _psi)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        world: given(stype in Psi, wpp=array_index(psi_dist, stype)),
        speaker: knows(world.obs, world.stype),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp=is_utt_true_for_obs(u, world.obs) * exp(
                array_index(alpha_values, alp) * log(
                    (world.stype == 0) * L0_obs[u, world.obs](l0_prior) +
                    (world.stype == 1) * L0_expected_theta[u](l0_prior, theta_values) +
                    (world.stype == 2) * (1.0 - L0_expected_theta[u](l0_prior, theta_values))
                    + 1e-10
                )
            )
        )
    ]
    listener: observes [speaker.u] is _u
    return listener[Pr[world.obs == _obs]]

#used for S2 pers. 
    #the prior here means BOTH THIS L1 & the imagined L0 within S1's prior
@memo
def L1_vig_expected_theta[_u: U, _theta: Theta, _alpha: Alpha, _psi: Psi](prior: ..., l0_prior:..., alpha_dist:..., alpha_values:..., psi_dist: ..., theta_values: ...):
    listener: knows(_u, _theta, _alpha, _psi)
    listener: thinks[
        world: chooses(theta in Theta, wpp=array_index(prior, theta)),
        world: chooses(obs in Obs, wpp=get_obs_prob(obs, theta)),
        world: given(stype in Psi, wpp=array_index(psi_dist, stype)),
        speaker: knows(world.obs, world.stype),
        speaker: chooses(alp in Alpha, wpp=array_index(alpha_dist, alp)),
        speaker: chooses(
            u in U,
            wpp=is_utt_true_for_obs(u, world.obs) * exp(
                array_index(alpha_values, alp) * log(
                    (world.stype == 0) * L0_obs[u, world.obs](l0_prior) +
                    (world.stype == 1) * L0_expected_theta[u](l0_prior, theta_values) +
                    (world.stype == 2) * (1.0 - L0_expected_theta[u](l0_prior, theta_values))
                    + 1e-10
                )
            )
        )
    ]
    listener: observes [speaker.u] is _u
    return listener[E[array_index(theta_values, world.theta)]]

#used for S2 pers. 
    #the prior here means BOTH THIS L1 & the imagined L0 within S1's prior
@memo  
def S2_pers[_obs: Obs, _u: U, _alpha: Alpha, _psi: Psi, _theta: Theta](prior: ..., l0_prior:..., alpha, alpha_dist:..., alpha_values:..., theta_values: ..., psi, psi_dist: ...):
    speaker: knows(_obs, _u, _alpha, _psi, _theta)
    speaker: chooses(
        u in U,
        wpp=(
            # psi == 2.0: pers_down - minimize L1_vig's expected theta
            is_utt_true_for_obs(u, _obs) * exp(
                alpha * log(1.0 - L1_vig_expected_theta[u,_theta, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) + 1e-10)
            )
            if psi == 2.0
            else (
                # psi == 1.0: pers_up - maximize L1_vig's expected theta
                is_utt_true_for_obs(u, _obs) * exp(
                    alpha * log(L1_vig_expected_theta[u, _theta, _alpha, _psi](prior, l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) + 1e-10)
                )
                if psi == 1.0
                # psi == 0.0: inf - maximize informativeness about obs for L1_vig
                else is_utt_true_for_obs(u, _obs) * exp(
                    alpha * log(L1_vig_obs[u, _obs, _alpha, _psi](prior,l0_prior, alpha_dist, alpha_values, psi_dist, theta_values) + 1e-10)
                )
            )
        )
    )
    return Pr[speaker.u == _u]

#### updated high levels