# Choosing between earlier and later expressions, given two groups' observed interactions

December 2024

In [1]:
import numpy as np
import pandas as pd

from memo import memo
import jax
import jax.numpy as jnp
from jax import lax
from enum import IntEnum

## 1. Generate model predictions with binary utilities based on design

What are the utilities involved? 

- Referential informativity
    - 1 for all earlier utterances regardless of group; 1 for later utterances to in-group; 0 for later utterances to out-group if 'group-specific label' tangram, 1 if 'shared label' tangram
    - so that means, for later utterances, 0.5 to 'either' group for the 'group-specific' labels but 1 to 'either' group for the 'shared' labels
- Social informativity: 0 for all earlier utterances, 1 for all later utterances
- Utterance cost: 1 for all earlier utterances, 0 for all later utterances

In [2]:
# type: ignore
class AudienceConditions(IntEnum):
    EitherGroup = 0
    OneGroup = 1


class Audiences(IntEnum):
    Ingroup = 1
    Outgroup = 0


class TangramTypes(IntEnum):
    Shared = 0
    GroupSpecific = 1


class Utterances(IntEnum):
    Earlier = 0
    Later = 1


@jax.jit
def audience_wpp(audience_condition, audience):
    # for the "either group" condition, return 1 regardless of audience
    # for the "one group" condition, return 1 for the ingroup and 0 for the outgroup
    return jnp.array([1, audience])[audience_condition]


@jax.jit
def ref_info(audience, tangram_type, utterance):
    ingroup_info = jnp.array([1, 1])  # [earlier, later]
    outgroup_info = lax.cond(
        tangram_type == TangramTypes.Shared,
        lambda _: jnp.array(
            [1, 1]
        ),  # if it's a shared-label tangram, the later utterance is informative regardless of group
        lambda _: jnp.array([1, 0]),
        operand=None,
    )

    info = lax.cond(
        audience == Audiences.Ingroup,
        lambda _: ingroup_info,
        lambda _: outgroup_info,
        operand=None,
    )

    return info[utterance]


@jax.jit
def social_info(utterance):
    return jnp.array([0, 1])[utterance]


@jax.jit
def cost(utterance):
    return jnp.array([1, 0])[utterance]


@memo
def speaker[
    utterance: Utterances, audience: Audiences
](
    audience_condition: AudienceConditions,
    tangram_type: TangramTypes,
    alpha,
    w_r,
    w_s,
    w_c,
):
    cast: [speaker]
    speaker: chooses(
        audience in Audiences, wpp=audience_wpp(audience_condition, audience)
    )
    speaker: chooses(
        utterance in Utterances,
        wpp=exp(
            alpha
            * (
                w_r * ref_info(audience, tangram_type, utterance)
                + w_s * social_info(utterance)
                - w_c * cost(utterance)
            )
        ),
    )
    return Pr[speaker.utterance == utterance]

## Generate model predictions and fit weights

In [3]:
class Conditions(IntEnum):
    ReferEither = 0
    ReferOne = 1
    SocialOne = 2

Generate model predictions for all conditions (FYI for best fit we might need a `w_s` (i.e. don't fix it to 0) for the 'refer one' condition)

In [4]:
@jax.jit
def get_model_preds(alpha, w_r, w_s, w_c):
    """
    Get the model predictions for all conditions and tangram types
    Output: 3x2 (condition, tangram type) array of probability of choosing the later utterance
    """
    conditions_vals = jnp.array(
        [Conditions.ReferEither, Conditions.ReferOne, Conditions.SocialOne]
    )
    tangrams_vals = jnp.array([TangramTypes.Shared, TangramTypes.GroupSpecific])

    def single_pred(cond, ttype):
        return lax.cond(
            cond == Conditions.ReferEither,
            lambda _: speaker(
                AudienceConditions.EitherGroup, ttype, alpha, w_r, 0, w_c
            ),
            lambda _: lax.cond(
                cond == Conditions.ReferOne,
                lambda __: speaker(
                    AudienceConditions.OneGroup, ttype, alpha, w_r, 0, w_c
                ),
                lambda __: speaker(
                    AudienceConditions.OneGroup, ttype, alpha, w_r, w_s, w_c
                ),
                operand=None,
            ),
            operand=None,
        )[Utterances.Later, 0]

    # Vectorize over tangrams first, then over conditions
    vmap_tangrams = jax.vmap(
        lambda c: jax.vmap(lambda t: single_pred(c, t))(tangrams_vals)
    )
    return vmap_tangrams(conditions_vals)

In [5]:
get_model_preds(1, 1, 1, 0.1)

Array([[0.5249792 , 0.40701485],
       [0.5249792 , 0.5249792 ],
       [0.7502601 , 0.7502601 ]], dtype=float32)

## Load data and get it into a matrix format

Data array, for each participant: tangram (A through L) x audience group (red vs. blue) x condition ('refer either' vs. 'refer one' vs. 'social one'). 

For each participant, only half of the tangrams have values because each participant only sees 6 tangrams. There are 6 * 2 * 3 = 36 critical trials per participant. 

In [6]:
class Tangram(IntEnum):
    A = 0
    B = 1
    C = 2
    D = 3
    E = 4
    F = 5
    G = 6
    H = 7
    I = 8
    J = 9
    K = 10
    L = 11

class AudienceGroup(IntEnum):
    Red = 0
    Blue = 1

class Counterbalance(IntEnum):
    a = 0
    b = 1

In [7]:
data = pd.read_csv("../data/3pp/earlier-later/selection_trials_clean.csv").rename(
    columns={"item_id": "tangram_set", 
             "shared": "tangram_type"}
)
data.loc[data["condition"] == "either refer", "condition"] = "ReferEither"
data.loc[data["condition"] == "one refer", "condition"] = "ReferOne"
data.loc[data["condition"] == "one social", "condition"] = "SocialOne"

data.loc[data["response.earlier"] == "earlier", "response.earlier"] = 1
data.loc[data["response.earlier"] == "later", "response.earlier"] = 0

data["response.later"] = 1 - data["response.earlier"]

data.loc[data["tangram_type"] == "shared", "tangram_type"] = "Shared"
data.loc[data["tangram_type"] == "unique", "tangram_type"] = "GroupSpecific"

data.loc[data["audience_group"] == "red", "audience_group"] = "Red"
data.loc[data["audience_group"] == "blue", "audience_group"] = "Blue"

data.head()

Unnamed: 0,subject_id,tangram_set,counterbalance,condition,audience_group,tangram,tangram_type,response.earlier,response.later
0,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,SocialOne,Red,C,Shared,0,1
1,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,ReferEither,Blue,B,Shared,0,1
2,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,SocialOne,Blue,L,Shared,0,1
3,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,ReferOne,Blue,L,Shared,0,1
4,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,ReferOne,Red,H,GroupSpecific,1,0


Save what tangrams are shared and what tangrams are group-specific

tangram set x counterbalance x tangram 

In [8]:
is_group_specific = np.zeros((3, 2, 12)) # tangram_set, counterbalance, tangram
# group by tangram_set, counterbalance, tangram
tangram_info = data.groupby(["tangram_set", "counterbalance", "tangram_type", "tangram"]).size()
tangram_info = tangram_info.reset_index(name="count")

for _, row in tangram_info.iterrows():
    tangram_set = row["tangram_set"]
    counterbalance = row["counterbalance"]
    tangram = row["tangram"]
    tangram_type = row["tangram_type"]
    is_group_specific[tangram_set, Counterbalance[counterbalance], Tangram[tangram]] = tangram_type == "GroupSpecific"

In [9]:
tangram_info

Unnamed: 0,tangram_set,counterbalance,tangram_type,tangram,count
0,0,a,GroupSpecific,A,48
1,0,a,GroupSpecific,D,48
2,0,a,GroupSpecific,H,48
3,0,a,Shared,B,48
4,0,a,Shared,C,48
5,0,a,Shared,L,48
6,0,b,GroupSpecific,B,54
7,0,b,GroupSpecific,C,54
8,0,b,GroupSpecific,L,54
9,0,b,Shared,A,54


In [10]:
data_organized = {}
for tangram_set in [0, 1, 2]:
    for counterbalance in ["a", "b"]:
        filtered_data = data[
            (data["tangram_set"] == tangram_set)
            & (data["counterbalance"] == counterbalance)
        ]
        filtered_data_participants = filtered_data["subject_id"].unique()

        this_set_mtx = np.zeros(
            (12, 2, len(Conditions), len(filtered_data_participants))
        )
        for i, participant in enumerate(filtered_data_participants):
            this_data = filtered_data[filtered_data["subject_id"] == participant]
            this_participant_mtx = np.zeros((12, 2, len(Conditions)))
            for _, row in this_data.iterrows():
                this_participant_mtx[
                    Tangram[row["tangram"]],
                    AudienceGroup[row["audience_group"]],
                    Conditions[row["condition"]],
                ] = row["response.earlier"]
                this_set_mtx[:, :, :, i] = this_participant_mtx

        data_organized[(tangram_set, counterbalance)] = this_set_mtx

In [11]:
for key, mtx in data_organized.items():
    print(key, mtx.shape)
# tangram x audience x condition x participant

(0, 'a') (12, 2, 3, 8)
(0, 'b') (12, 2, 3, 9)
(1, 'a') (12, 2, 3, 11)
(1, 'b') (12, 2, 3, 13)
(2, 'a') (12, 2, 3, 11)
(2, 'b') (12, 2, 3, 9)


Iterate through sets and counterbalance, and create model predictions for each

In [43]:
def format_model_preds(model_preds, data_organized):
    """model_preds and data_organized are in the same format
    needs tangram_info, which is information about what tangrams belong to which set and counterbalance etc. (rename later)
    this can probably be sped up
    """
    all_preds = {}
    for key, mtx in data_organized.items():
        tangram_set, counterbalance = key
        preds = np.zeros((12, 2, len(Conditions)))
        available = tangram_info[(tangram_info["tangram_set"] == tangram_set) & (tangram_info["counterbalance"] == counterbalance)]
        for _, row in available.iterrows():
            for condition in Conditions: 
                preds[Tangram[row["tangram"]], :, condition] = model_preds[condition, TangramTypes[row["tangram_type"]]]

        # repeat by number of participants
        all_preds[key] = np.repeat(preds[:, :, :, np.newaxis], mtx.shape[-1], axis=-1)
    return all_preds


In [44]:
model_organized = format_model_preds(get_model_preds(1, 1, 1, 0.1), data_organized)

Stack all the matrices

In [45]:
data_all = np.concatenate([mtx for mtx in data_organized.values()], axis=-1)
model_all = np.concatenate([mtx for mtx in model_organized.values()], axis=-1)
assert data_all.shape == model_all.shape

Compute NLL and optimize

In [46]:
def compute_nll(data, model):
    mask = model != 0
    return -np.sum(data[mask] * np.log(model[mask]) + (1 - data[mask]) * np.log(1 - model[mask]))

In [47]:
compute_nll(data_all, model_all)

1820.6600212938047

### Optimization

In [18]:
# Make grid of parameter values
alphas = jnp.arange(0, 5, 0.1)
w_rs = jnp.arange(0, 5, 0.1)
w_ss = jnp.arange(0, 5, 0.1)
w_cs = jnp.arange(0, 1, 0.01)

# make grid of all values
param_grid = jnp.meshgrid(alphas, w_rs, w_ss, w_cs)

In [22]:
def grid_search_nll():
    # Flatten the parameter grid
    all_alphas = param_grid[0].ravel()
    all_wrs    = param_grid[1].ravel()
    all_wss    = param_grid[2].ravel()
    all_wcs    = param_grid[3].ravel()

    def single_nll(params):
        alpha, w_r, w_s, w_c = params
        preds = get_model_preds(alpha, w_r, w_s, w_c)
        print(preds)
        mo    = format_model_preds(preds, data_organized)
        mo_all = jnp.concatenate(list(mo.values()), axis=-1)
        return compute_nll(data_all, mo_all)

    # Vectorize NLL computation
    params_list = jnp.stack([all_alphas, all_wrs, all_wss, all_wcs], axis=1)
    nll_values  = jax.vmap(single_nll)(params_list)

    # Find best NLL
    best_idx   = jnp.argmin(nll_values)
    best_nll   = nll_values[best_idx]
    best_params = params_list[best_idx]
    return best_params, best_nll

best_params, best_nll_val = grid_search_nll()
print("Best NLL:", best_nll_val)
print("Best params (alpha, w_r, w_s, w_c):", best_params)

Traced<ShapedArray(float32[3,2])>with<BatchTrace(level=1/0)> with
  val = Array([[[0.5       , 0.5       ],
        [0.5       , 0.5       ],
        [0.5       , 0.5       ]],

       [[0.5       , 0.5       ],
        [0.5       , 0.5       ],
        [0.5       , 0.5       ]],

       [[0.5       , 0.5       ],
        [0.5       , 0.5       ],
        [0.5       , 0.5       ]],

       ...,

       [[0.991448  , 0.495724  ],
        [0.991448  , 0.991448  ],
        [1.        , 1.        ]],

       [[0.9918536 , 0.4959268 ],
        [0.9918536 , 0.9918536 ],
        [1.        , 1.        ]],

       [[0.99224013, 0.49612007],
        [0.99224013, 0.99224013],
        [1.        , 1.        ]]], dtype=float32)
  batch_dim = 0


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[]
This BatchTracer with object id 13800616400 was created on line:
  /var/folders/fc/814yph2s11jgjvdqqz2x9w0c0000gn/T/ipykernel_32003/1366208535.py:13:63 (model_mtx)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [19]:
param_grid

[Array([[[[0.       , 0.       , 0.       , ..., 0.       , 0.       ,
           0.       ],
          [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
           0.       ],
          [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
           0.       ],
          ...,
          [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
           0.       ],
          [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
           0.       ],
          [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
           0.       ]],
 
         [[0.1      , 0.1      , 0.1      , ..., 0.1      , 0.1      ,
           0.1      ],
          [0.1      , 0.1      , 0.1      , ..., 0.1      , 0.1      ,
           0.1      ],
          [0.1      , 0.1      , 0.1      , ..., 0.1      , 0.1      ,
           0.1      ],
          ...,
          [0.1      , 0.1      , 0.1      , ..., 0.1      , 0.1      ,
           0.1      ],
          [0.1      , 0.1  