# Fit model for earlier-later experiment

In [1]:
import numpy as np
import pandas as pd

from memo import memo
import jax
import jax.numpy as jnp
from jax import lax
from enum import IntEnum

Make an array from the transparency means, with columns corresponding to referential informativity and length, and rows corresponding to the index of the label (indexed by `tangram_set` and then `target`, `earlier`, and `length`) (see `transparency-analysis.Rmd`)

In [2]:
df_transparency = pd.read_csv("../data/3pp/transparency/means_by_label.csv").rename(columns={"target": "tangram"})
df_transparency.iloc[:, :5].to_csv("label_indices.csv")

df_transparency.head()

Unnamed: 0,tangram_set,tangram,earlier,length,label,n,empirical_stat,ci_lower,mean,ci_upper
0,0,A,earlier,27,it looks like someone stood up with their feet...,42,0.952381,0.878788,0.951934,1.0
1,0,A,earlier,41,a guy facing left with his head to the right o...,45,0.6,0.462875,0.598966,0.734779
2,0,A,earlier,43,it looks like a person kicking one foot in fro...,43,0.930233,0.847826,0.928347,1.0
3,0,A,later,1,thriller,35,0.571429,0.416616,0.573758,0.741977
4,0,A,later,1,zombie,36,0.638889,0.473648,0.639046,0.8


In [3]:
# Save 'length' and 'empirical_stat' columns into a 90x2 numpy array
label_info = df_transparency[["length", "empirical_stat"]].values
label_info.shape

(90, 2)

Load label info from earlier-later experiment

In [4]:
labels = pd.read_csv("../data/3pp/earlier-later/labels.csv").rename(
    columns={"item_id": "tangram_set"}
).drop(columns=["n"])
labels.head()

Unnamed: 0,tangram_set,counterbalance,tangram,shared,audience_group,earlier_label,later_label
0,0,a,A,unique,blue,it looks like someone stood up with their feet...,zombie
1,0,a,A,unique,red,a guy facing left with his head to the right o...,thriller
2,0,a,B,shared,blue,"two big ears, pointing nose, facing left. A sm...",mouse
3,0,a,B,shared,red,the head is on a flat surface with a triangle ...,mouse
4,0,a,C,shared,blue,it looks like a person is falling backwards wi...,falling guy


Add length and informativity

In [5]:
for label_type in ["earlier", "later"]:
    labels = (
        labels.merge(
            df_transparency[df_transparency["earlier"] == label_type][
                ["tangram_set", "tangram", "label", "empirical_stat", "length"]
            ],
            left_on=["tangram_set", "tangram", f"{label_type}_label"],
            right_on=["tangram_set", "tangram", "label"],
        )
        .rename(
            columns={
                "empirical_stat": f"{label_type}_info",
                "length": f"{label_type}_length",
            }
        )
        .drop(columns=["label"])
    )

In [6]:
labels

Unnamed: 0,tangram_set,counterbalance,tangram,shared,audience_group,earlier_label,later_label,earlier_info,earlier_length,later_info,later_length
0,0,a,A,unique,blue,it looks like someone stood up with their feet...,zombie,0.952381,27,0.638889,1
1,0,a,A,unique,red,a guy facing left with his head to the right o...,thriller,0.600000,41,0.571429,1
2,0,a,B,shared,blue,"two big ears, pointing nose, facing left. A sm...",mouse,0.827586,28,0.568182,1
3,0,a,B,shared,red,the head is on a flat surface with a triangle ...,mouse,0.615385,28,0.568182,1
4,0,a,C,shared,blue,it looks like a person is falling backwards wi...,falling guy,0.956522,32,0.948718,2
...,...,...,...,...,...,...,...,...,...,...,...
67,2,b,G,shared,red,It looks like a person looking up and to the r...,leg,0.650000,79,0.181818,1
68,2,b,I,shared,blue,picture looks like it has 2 hands going forwar...,ice skater,0.942857,20,0.911111,2
69,2,b,I,shared,red,"ice skating one, square head balancing, two tr...",ice skater,1.000000,13,0.911111,2
70,2,b,K,unique,blue,it's a diamond on a diamond with a triangle ou...,duck,0.875000,25,0.348837,1


Extract corresponding 

Condition x tangram x group for each item, counterbalance pair


Functions for getting referential informativity, social informativity, and utterance length, given a label

- Assign an index to each label

In [7]:
# tangram_set x counterbalance x tangram x audience_group x earlier/later
info_mtx = np.zeros((3, 2, 12, 2, 2))
length_mtx = np.zeros((3, 2, 12, 2, 2))

counterbalance_to_idx = {"a": 0, "b": 1}
audience_group_to_idx = {"red": 0, "blue": 1}
tangram_to_idx = {
    "A": 0,
    "B": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "J": 9,
    "K": 10,
    "L": 11,
}

for tangram_set in [0, 1, 2]:
    for counterbalance in ["a", "b"]:
        filtered_labels = labels[
            (labels["tangram_set"] == tangram_set)
            & (labels["counterbalance"] == counterbalance)
        ]

        for _, row in filtered_labels.iterrows():
            info_mtx[
                tangram_set,
                counterbalance_to_idx[counterbalance],
                tangram_to_idx[row["tangram"]],
                audience_group_to_idx[row["audience_group"]],
                0,
            ] = row["earlier_info"]
            info_mtx[
                tangram_set,
                counterbalance_to_idx[counterbalance],
                tangram_to_idx[row["tangram"]],
                audience_group_to_idx[row["audience_group"]],
                1,
            ] = row["later_info"]

            length_mtx[
                tangram_set,
                counterbalance_to_idx[counterbalance],
                tangram_to_idx[row["tangram"]],
                audience_group_to_idx[row["audience_group"]],
                0,
            ] = row["earlier_length"]
            length_mtx[
                tangram_set,
                counterbalance_to_idx[counterbalance],
                tangram_to_idx[row["tangram"]],
                audience_group_to_idx[row["audience_group"]],
                1,
            ] = row["later_length"]

In [8]:
# turn both into jax arrays
info_mtx = jnp.array(info_mtx)
length_mtx = jnp.array(length_mtx)

## Model

In [9]:
class TangramSet(IntEnum):
    Set1 = 0
    Set2 = 1
    Set3 = 2


class Counterbalance(IntEnum):
    A = 0
    B = 1


class Tangram(IntEnum):
    A = 0
    B = 1
    C = 2
    D = 3
    E = 4
    F = 5
    G = 6
    H = 7
    I = 8
    J = 9
    K = 10
    L = 11


class AudienceConditions(IntEnum):
    EitherGroup = 0
    OneGroup = 1


class TangramTypes(IntEnum):
    Shared = 0
    Unique = 1


class AudienceGroup(IntEnum):
    Red = 0
    Blue = 1


class Audiences(IntEnum):
    Outgroup = 0
    Ingroup = 1

class Utterances(IntEnum):
    Earlier = 0
    Later = 1


@jax.jit
def audience_wpp(audience_condition, audience):
    # for the "either group" condition, return 1 regardless of audience
    # for the "one group" condition, return 1 for the ingroup and 0 for the outgroup
    # TODO: check this logic
    return jnp.array([1, audience])[audience_condition]

@jax.jit
def ref_info(
    tangram_set, counterbalance, tangram, audience_group, 
    tangram_type, audience, utt
):
    """the values that change are tangram_type, audience, and utt"""
    # if the tangram is unique-label, outgroup info is the same as naive listener
    naive_info_unique = info_mtx[
        tangram_set, counterbalance, tangram, audience_group, :
    ]
    # print(naive_info_unique)

    # for the ingroup, increase earlier_utt info by 0.1 and later_utt info by 0.3; cap at 1
    # this is the assumption that the later utterances are probably seen more times by the in group
    ingroup_info = jnp.minimum(naive_info_unique + jnp.array([0.1, 0.3]), 1)

    # if the tangram is shared-label, for the outgroup, earlier utt info is same as naive listener but later utt info is same as ingroup
    naive_info_shared = jnp.array([naive_info_unique[0], ingroup_info[1]])

    naive_info = lax.cond(
        tangram_type == TangramTypes.Shared,
        lambda _: naive_info_shared,
        lambda _: naive_info_unique,
        operand=None,
    )

    info = lax.cond(
        audience == Audiences.Outgroup,
        lambda _: naive_info,
        lambda _: ingroup_info,
        operand=None,
    )

    return info[utt]


@jax.jit
def social_info(utt):
    # later utterance has more social informativity
    return jnp.array([0.1, 0.6])[utt]


@jax.jit
def utt_length(tangram_set, counterbalance, tangram, audience_group, utt):
    return length_mtx[tangram_set, counterbalance, tangram, audience_group, utt]

In [10]:
ref_info(TangramSet.Set1, Counterbalance.A, Tangram.A, AudienceGroup.Red, TangramTypes.Shared, Audiences.Ingroup, Utterances.Earlier)
social_info(Utterances.Earlier)
utt_length(TangramSet.Set1, Counterbalance.A, Tangram.A, AudienceGroup.Red, Utterances.Earlier)

Array(41., dtype=float32)

In [11]:
audience_wpp(AudienceConditions.EitherGroup, Audiences.Ingroup)

Array(1, dtype=int32)

In [12]:
# type: ignore

# the idea is that the goal is in the weights so each goal is fit separately? 
# TODO: figure out a way to organize the params, and document

@memo
def speaker[
    u: Utterances, audience: Audiences
](
    #### experimental design stuff
    tangram_set,
    counterbalance,
    tangram,
    audience_group,
    #### variables
    tangram_type,
    audience_condition,
    #### parameters
    alpha,
    w_r,
    w_s,
    w_c,
):
    cast: [speaker]
    speaker: chooses(
        audience in Audiences, wpp=audience_wpp(audience_condition, audience)
    )
    speaker: chooses(
        u in Utterances,
        wpp=exp(
            alpha
            * (
                w_r * ref_info(tangram_set, counterbalance, tangram, audience_group, tangram_type, audience, u)
                + w_s * social_info(u)
                - w_c * utt_length(tangram_set, counterbalance, tangram, audience_group, u)
            )
        ),
    )
    return Pr[speaker.u == u]

In [13]:
speaker(
    tangram_set=TangramSet.Set1,
    counterbalance=Counterbalance.A,
    tangram=Tangram.A,
    audience_group=AudienceGroup.Red,
    tangram_type=TangramTypes.Unique,
    audience_condition=AudienceConditions.OneGroup,
    alpha=2,
    w_r=5,
    w_s=0,
    w_c=0.01,
)

Array([[0.07486279],
       [0.9251372 ]], dtype=float32)

In [14]:
speaker(
    tangram_set=TangramSet.Set1,
    counterbalance=Counterbalance.A,
    tangram=Tangram.A,
    audience_group=AudienceGroup.Red,
    tangram_type=TangramTypes.Unique,
    audience_condition=AudienceConditions.EitherGroup,
    alpha=2,
    w_r=5,
    w_s=0,
    w_c=0.01,
)

Array([[0.22452605],
       [0.77547395]], dtype=float32)

In [15]:
info_mtx[TangramSet.Set1, Counterbalance.A, Tangram.A, AudienceGroup.Red, :]

Array([0.6      , 0.5714286], dtype=float32)

FIT BY GOAL (refer vs. social) NOT condition

Instead of weird dicts use intenum from the beginning