# Fit model for earlier-later experiment

In [1]:
import numpy as np
import pandas as pd

from memo import memo
import jax
import jax.numpy as jnp
from jax import lax
from enum import IntEnum

Make an array from the transparency means, with columns corresponding to referential informativity and length, and rows corresponding to the index of the label (indexed by `tangram_set` and then `target`, `earlier`, and `length`) (see `transparency-analysis.Rmd`)

In [2]:
df_transparency = pd.read_csv("../data/3pp/transparency/means_by_label.csv").rename(columns={"target": "tangram"})
df_transparency.iloc[:, :5].to_csv("label_indices.csv")

df_transparency.head()

Unnamed: 0,tangram_set,tangram,earlier,length,label,n,empirical_stat,ci_lower,mean,ci_upper
0,0,A,earlier,27,it looks like someone stood up with their feet...,42,0.952381,0.878788,0.951934,1.0
1,0,A,earlier,41,a guy facing left with his head to the right o...,45,0.6,0.462875,0.598966,0.734779
2,0,A,earlier,43,it looks like a person kicking one foot in fro...,43,0.930233,0.847826,0.928347,1.0
3,0,A,later,1,thriller,35,0.571429,0.416616,0.573758,0.741977
4,0,A,later,1,zombie,36,0.638889,0.473648,0.639046,0.8


In [3]:
# Save 'length' and 'empirical_stat' columns into a 90x2 numpy array
label_info = df_transparency[["length", "empirical_stat"]].values
label_info.shape

(90, 2)

Load label info from earlier-later experiment

In [4]:
labels = pd.read_csv("../data/3pp/earlier-later/labels.csv").rename(
    columns={"item_id": "tangram_set"}
).drop(columns=["n"])
labels.head()

Unnamed: 0,tangram_set,counterbalance,tangram,shared,audience_group,earlier_label,later_label
0,0,a,A,unique,blue,it looks like someone stood up with their feet...,zombie
1,0,a,A,unique,red,a guy facing left with his head to the right o...,thriller
2,0,a,B,shared,blue,"two big ears, pointing nose, facing left. A sm...",mouse
3,0,a,B,shared,red,the head is on a flat surface with a triangle ...,mouse
4,0,a,C,shared,blue,it looks like a person is falling backwards wi...,falling guy


Add length and informativity

In [5]:
for label_type in ["earlier", "later"]:
    labels = (
        labels.merge(
            df_transparency[df_transparency["earlier"] == label_type][
                ["tangram_set", "tangram", "label", "empirical_stat", "length"]
            ],
            left_on=["tangram_set", "tangram", f"{label_type}_label"],
            right_on=["tangram_set", "tangram", "label"],
        )
        .rename(
            columns={
                "empirical_stat": f"{label_type}_info",
                "length": f"{label_type}_length",
            }
        )
        .drop(columns=["label"])
    )

In [6]:
labels

Unnamed: 0,tangram_set,counterbalance,tangram,shared,audience_group,earlier_label,later_label,earlier_info,earlier_length,later_info,later_length
0,0,a,A,unique,blue,it looks like someone stood up with their feet...,zombie,0.952381,27,0.638889,1
1,0,a,A,unique,red,a guy facing left with his head to the right o...,thriller,0.600000,41,0.571429,1
2,0,a,B,shared,blue,"two big ears, pointing nose, facing left. A sm...",mouse,0.827586,28,0.568182,1
3,0,a,B,shared,red,the head is on a flat surface with a triangle ...,mouse,0.615385,28,0.568182,1
4,0,a,C,shared,blue,it looks like a person is falling backwards wi...,falling guy,0.956522,32,0.948718,2
...,...,...,...,...,...,...,...,...,...,...,...
67,2,b,G,shared,red,It looks like a person looking up and to the r...,leg,0.650000,79,0.181818,1
68,2,b,I,shared,blue,picture looks like it has 2 hands going forwar...,ice skater,0.942857,20,0.911111,2
69,2,b,I,shared,red,"ice skating one, square head balancing, two tr...",ice skater,1.000000,13,0.911111,2
70,2,b,K,unique,blue,it's a diamond on a diamond with a triangle ou...,duck,0.875000,25,0.348837,1


Extract corresponding 

Condition x tangram x group for each item, counterbalance pair


Functions for getting referential informativity, social informativity, and utterance length, given a label

- Assign an index to each label

In [7]:
# tangram_set x counterbalance x tangram x audience_group x earlier/later
info_mtx = np.zeros((3, 2, 12, 2, 2))
length_mtx = np.zeros((3, 2, 12, 2, 2))

# tangram availability mtx: given a tangram set and counterbalance, what tangrams are available? 
tangram_mtx = np.zeros((3, 2, 12))
# shared-unique mtx: given a tangram set, counterbalance, and tangram, is it shared or unique?
shared_unique_mtx = np.zeros((3, 2, 12))

counterbalance_to_idx = {"a": 0, "b": 1}
audience_group_to_idx = {"red": 0, "blue": 1}
tangram_to_idx = {
    "A": 0,
    "B": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "J": 9,
    "K": 10,
    "L": 11,
}

for tangram_set in [0, 1, 2]:
    for counterbalance in ["a", "b"]:
        filtered_labels = labels[
            (labels["tangram_set"] == tangram_set)
            & (labels["counterbalance"] == counterbalance)
        ]

        for _, row in filtered_labels.iterrows():
            info_mtx[
                tangram_set,
                counterbalance_to_idx[counterbalance],
                tangram_to_idx[row["tangram"]],
                audience_group_to_idx[row["audience_group"]],
                0,
            ] = row["earlier_info"]
            info_mtx[
                tangram_set,
                counterbalance_to_idx[counterbalance],
                tangram_to_idx[row["tangram"]],
                audience_group_to_idx[row["audience_group"]],
                1,
            ] = row["later_info"]

            length_mtx[
                tangram_set,
                counterbalance_to_idx[counterbalance],
                tangram_to_idx[row["tangram"]],
                audience_group_to_idx[row["audience_group"]],
                0,
            ] = row["earlier_length"]
            length_mtx[
                tangram_set,
                counterbalance_to_idx[counterbalance],
                tangram_to_idx[row["tangram"]],
                audience_group_to_idx[row["audience_group"]],
                1,
            ] = row["later_length"]

            tangram_mtx[tangram_set, counterbalance_to_idx[counterbalance], tangram_to_idx[row["tangram"]]] = 1
            shared_unique_mtx[tangram_set, counterbalance_to_idx[counterbalance], tangram_to_idx[row["tangram"]]] = 0 if row["shared"] == "shared" else 1

In [8]:
# turn both into jax arrays
info_mtx = jnp.array(info_mtx)
length_mtx = jnp.array(length_mtx)
tangram_mtx = jnp.array(tangram_mtx)
shared_unique_mtx = jnp.array(shared_unique_mtx)

## Model

In [54]:
class TangramSet(IntEnum):
    Set1 = 0
    Set2 = 1
    Set3 = 2


class Counterbalance(IntEnum):
    a = 0
    b = 1


class Tangram(IntEnum):
    A = 0
    B = 1
    C = 2
    D = 3
    E = 4
    F = 5
    G = 6
    H = 7
    I = 8
    J = 9
    K = 10
    L = 11


class AudienceConditions(IntEnum):
    EitherGroup = 0
    OneGroup = 1


class TangramTypes(IntEnum):
    Shared = 0
    Unique = 1


class AudienceGroup(IntEnum):
    Red = 0
    Blue = 1


class Audiences(IntEnum):
    Outgroup = 0
    Ingroup = 1

class Utterances(IntEnum):
    Earlier = 0
    Later = 1


@jax.jit
def audience_wpp(audience_condition, audience):
    # for the "either group" condition, return 1 regardless of audience
    # for the "one group" condition, return 1 for the ingroup and 0 for the outgroup
    # TODO: check this logic
    return jnp.array([1, audience])[audience_condition]

@jax.jit
def ref_info(
    tangram_set, counterbalance, tangram, audience_group, 
    tangram_type, audience, utt
):
    """the values that change are tangram_type, audience, and utt"""
    # if the tangram is unique-label, outgroup info is the same as naive listener
    naive_info_unique = info_mtx[
        tangram_set, counterbalance, tangram, audience_group, :
    ]
    # print(naive_info_unique)

    # for the ingroup, increase earlier_utt info by 0.1 and later_utt info by 0.3; cap at 1
    # this is the assumption that the later utterances are probably seen more times by the in group
    ingroup_info = jnp.minimum(naive_info_unique + jnp.array([0.1, 0.3]), 1)

    # if the tangram is shared-label, for the outgroup, earlier utt info is same as naive listener but later utt info is same as ingroup
    naive_info_shared = jnp.array([naive_info_unique[0], ingroup_info[1]])

    naive_info = lax.cond(
        tangram_type == TangramTypes.Shared,
        lambda _: naive_info_shared,
        lambda _: naive_info_unique,
        operand=None,
    )

    info = lax.cond(
        audience == Audiences.Outgroup,
        lambda _: naive_info,
        lambda _: ingroup_info,
        operand=None,
    )

    return info[utt]


@jax.jit
def social_info(utt):
    # later utterance has more social informativity
    return jnp.array([0.1, 0.6])[utt]


@jax.jit
def utt_length(tangram_set, counterbalance, tangram, audience_group, utt):
    return length_mtx[tangram_set, counterbalance, tangram, audience_group, utt]

In [56]:
audience_wpp(AudienceConditions.EitherGroup, Audiences.Ingroup)

Array(1, dtype=int32)

In [57]:
# type: ignore

# the idea is that the goal is in the weights so each goal is fit separately? 
# TODO: figure out a way to organize the params, and document

@memo
def speaker[
    u: Utterances, audience: Audiences
](
    #### experimental design stuff
    tangram_set,
    counterbalance,
    tangram,
    audience_group,
    #### variables
    tangram_type,
    audience_condition,
    #### parameters
    alpha,
    w_r,
    w_s,
    w_c,
):
    cast: [speaker]
    speaker: chooses(
        audience in Audiences, wpp=audience_wpp(audience_condition, audience)
    )
    speaker: chooses(
        u in Utterances,
        wpp=exp(
            alpha
            * (
                w_r * ref_info(tangram_set, counterbalance, tangram, audience_group, tangram_type, audience, u)
                + w_s * social_info(u)
                - w_c * utt_length(tangram_set, counterbalance, tangram, audience_group, u)
            )
        ),
    )
    return Pr[speaker.u == u]

In [59]:
speaker(
    tangram_set=TangramSet.Set1,
    counterbalance=Counterbalance.a,
    tangram=Tangram.A,
    audience_group=AudienceGroup.Red,
    tangram_type=TangramTypes.Unique,
    audience_condition=AudienceConditions.OneGroup,
    alpha=2,
    w_r=3,
    w_s=0,
    w_c=0.01,
)

Array([[0.13840853],
       [0.86159146]], dtype=float32)

In [60]:
speaker(
    tangram_set=TangramSet.Set1,
    counterbalance=Counterbalance.a,
    tangram=Tangram.A,
    audience_group=AudienceGroup.Red,
    tangram_type=TangramTypes.Unique,
    audience_condition=AudienceConditions.EitherGroup,
    alpha=2,
    w_r=3,
    w_s=0,
    w_c=0.01,
)[0, 0]

Array(0.24312153, dtype=float32)

In [61]:
info_mtx[TangramSet.Set1, Counterbalance.a, Tangram.A, AudienceGroup.Red, :]

Array([0.6      , 0.5714286], dtype=float32)

## Generate model predictions and fit weights
A different set of weights per condition

Make matrix of model predictions: tangram_set x counterbalance x tangram x tangram_type x condition x audience_group 

Model predicts: likelihood of choosing earlier tangram

In [129]:
def generate_model_preds(alpha, w_r, w_s, w_c, audience_condition):
    """
    Generate model predictions
    Output: 3 x 2 x 12 x 2  array
    tangram_set x counterbalance x tangram x tangram_type x audience_group

    Audience condition: either group or one group
    Goals are encoded in the weights
    """
    model_preds = np.zeros((3, 2, 12, 2))

    for tangram_set in TangramSet:
        for counterbalance in Counterbalance:
            available_tangrams = tangram_mtx[tangram_set, counterbalance, :]
            for tangram in Tangram:
                if not available_tangrams[tangram]:
                    continue
                for tangram_type in TangramTypes:
                    if (
                        not shared_unique_mtx[tangram_set, counterbalance, tangram]
                        == tangram_type
                    ):
                        continue
                    for audience_group in AudienceGroup:
                        model_preds[
                            tangram_set,
                            counterbalance,
                            tangram,
                            # tangram_type, # Later on: tangram type can be a separate mask
                            audience_group
                        ] = speaker(
                            tangram_set=tangram_set,
                            counterbalance=counterbalance,
                            tangram=tangram,
                            audience_group=audience_group,
                            tangram_type=tangram_type,
                            audience_condition=audience_condition,
                            alpha=alpha,
                            w_r=w_r,
                            w_s=w_s,
                            w_c=w_c,
                        )[
                            0, 0
                        ]  # index into the probability of the earlier utterance

    return model_preds

Sanity check: Generate and plot model predictions with toy params

In [63]:
conditions_params = {
    "refer+either": {"alpha": 2, "w_r": 3, "w_s": 0, "w_c": 0.01, "audience_condition": AudienceConditions.EitherGroup},
    "refer+one": {"alpha": 2, "w_r": 3, "w_s": 1, "w_c": 0.01, "audience_condition": AudienceConditions.OneGroup},
    "social+one": {"alpha": 2, "w_r": 0, "w_s": 1, "w_c": 0.01, "audience_condition": AudienceConditions.OneGroup},
}

model_preds = {
    key: generate_model_preds(**params) for key, params in conditions_params.items()}

For each condition, take the mean of all values across tangram type (drop the zeros)

In [64]:
means = {}
for condition, pred in model_preds.items():
   preds_shared = pred[:, :, :, TangramTypes.Shared, :]
   preds_unique = pred[:, :, :, TangramTypes.Unique, :]

   # compute the mean for preds_shared and preds_unique; omit zeros
   means_shared = preds_shared[preds_shared != 0].mean()
   means_unique = preds_unique[preds_unique != 0].mean()

   means[condition] = (means_shared, means_unique)

In [65]:
means

{'refer+either': (0.3307983160743283, 0.4409951085431708),
 'refer+one': (0.19126081210561097, 0.19381432020519343),
 'social+one': (0.174693637010124, 0.17244326074918112)}

### Fit weights


In [80]:
# Load in and add the human data
# Into a matrix: tangram_set x counterbalance x tangram x tangram_type x condition x participant

# Then, for each condition, compute model predictions vs. human data

# duplicate and stack model prediction matrix? 

# for each condition, loop through participants, make a model prediction matrix

data = pd.read_csv("../data/3pp/earlier-later/selection_trials_clean.csv").rename(
    columns={"item_id": "tangram_set", 
             "shared": "tangram_type"}
)
data.loc[data["condition"] == "either refer", "condition"] = "EitherRefer"
data.loc[data["condition"] == "one refer", "condition"] = "OneRefer"
data.loc[data["condition"] == "one social", "condition"] = "OneSocial"


data.loc[data["response.earlier"] == "earlier", "response.earlier"] = 1
data.loc[data["response.earlier"] == "later", "response.earlier"] = 0

data.loc[data["tangram_type"] == "shared", "tangram_type"] = "Shared"
data.loc[data["tangram_type"] == "unique", "tangram_type"] = "Unique"

data.loc[data["audience_group"] == "red", "audience_group"] = "Red"
data.loc[data["audience_group"] == "blue", "audience_group"] = "Blue"

data.head()

Unnamed: 0,subject_id,tangram_set,counterbalance,condition,audience_group,tangram,tangram_type,response.earlier
0,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,OneSocial,Red,C,Shared,0
1,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,EitherRefer,Blue,B,Shared,0
2,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,OneSocial,Blue,L,Shared,0
3,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,OneRefer,Blue,L,Shared,0
4,dd2c8d9a-acab-46b0-a2d6-58e7c914ac4e,0,a,OneRefer,Red,H,Unique,1


Make data into a matrix

tangram_set x counterbalance x tangram x tangram_type x audience_group x condition x participant

In [81]:
class Conditions(IntEnum):
    EitherRefer = 0
    OneRefer = 1
    OneSocial = 2

Participants = IntEnum("Participants", {str(i): idx for idx, i in enumerate(data["subject_id"].unique())})

In [120]:
data_organized = {}
for tangram_set in [0, 1, 2]:
    for counterbalance in ["a", "b"]:
        filtered_data = data[
            (data["tangram_set"] == tangram_set)
            & (data["counterbalance"] == counterbalance)
        ]
        filtered_data_participants = filtered_data["subject_id"].unique()

        this_set_mtx = np.zeros(
            (12, 2, len(Conditions), len(filtered_data_participants))
        )
        for i, participant in enumerate(filtered_data_participants):
            this_data = filtered_data[filtered_data["subject_id"] == participant]
            this_participant_mtx = np.zeros((12, 2, len(Conditions)))
            for _, row in this_data.iterrows():
                this_participant_mtx[
                    Tangram[row["tangram"]],
                    AudienceGroup[row["audience_group"]],
                    Conditions[row["condition"]],
                ] = row["response.earlier"]
                this_set_mtx[:, :, :, i] = this_participant_mtx

        data_organized[(tangram_set, counterbalance)] = this_set_mtx

In [122]:
for key, mtx in data_organized.items():
    print(key, mtx.shape)

(0, 'a') (12, 2, 3, 8)
(0, 'b') (12, 2, 3, 9)
(1, 'a') (12, 2, 3, 11)
(1, 'b') (12, 2, 3, 13)
(2, 'a') (12, 2, 3, 11)
(2, 'b') (12, 2, 3, 9)


In [130]:
# Start with (0, "a")
this_condition = Conditions.EitherRefer
refer_either_preds = generate_model_preds(**conditions_params["refer+either"])
print(refer_either_preds.shape) # tangram x counterbalance x tangram x audience_group
test_preds = refer_either_preds[TangramSet.Set1, Counterbalance.a, :, :]
print(test_preds.shape) # tangram x audience_group

# Get (0, "a") model preds


(3, 2, 12, 2)
(12, 2)


In [135]:
# get matrices to compare for the either + refer condition
refer_either_preds = generate_model_preds(**conditions_params["refer+either"])
data_all = []
model_all = []
for key, mtx in data_organized.items():
    data_mtx = mtx[:, :, Conditions.EitherRefer, :]
    model_mtx = refer_either_preds[key[0], Counterbalance[key[1]], :, :]
    # duplicate model mtx for each participant along the last axis of the data mtx
    model_mtx_stacked = np.stack([model_mtx for _ in range(data_mtx.shape[-1])], axis=-1)
    assert data_mtx.shape == model_mtx_stacked.shape
    data_all.append(data_mtx)
    model_all.append(model_mtx_stacked)

data_all = np.concatenate(data_all, axis=-1)
model_all = np.concatenate(model_all, axis=-1)

data_all.shape, model_all.shape

((12, 2, 61), (12, 2, 61))

In [143]:
# This is currently bad. Improve later
def generate_data_model_matrices(condition, params): 
    model_preds = generate_model_preds(**params)
    data_all = []
    model_all = []
    for key, mtx in data_organized.items():
        data_mtx = mtx[:, :, condition, :]
        model_mtx = model_preds[key[0], Counterbalance[key[1]], :, :]
        # duplicate model mtx for each participant along the last axis of the data mtx
        model_mtx_stacked = np.stack([model_mtx for _ in range(data_mtx.shape[-1])], axis=-1)
        assert data_mtx.shape == model_mtx_stacked.shape
        data_all.append(data_mtx)
        model_all.append(model_mtx_stacked)

    data_all = np.concatenate(data_all, axis=-1)
    model_all = np.concatenate(model_all, axis=-1)
    return data_all, model_all

In [173]:
def compute_nll(data, model):
    mask = model != 0
    return -np.sum(data[mask] * np.log(model[mask]) + (1 - data[mask]) * np.log(1 - model[mask]))

In [174]:
compute_nll(data_all, model_all)

1083.370770486826

In [178]:
params_list = [
    {
        "alpha": 2,
        "w_r": w_r,
        "w_s": w_s,
        "w_c": 0.01,
        "audience_condition": AudienceConditions.EitherGroup,
    }
    for w_r in np.linspace(0, 5, 50)
    for w_s in np.linspace(0, 5, 50)
]

nlls = []

for params in params_list:
    data_all, model_all = generate_data_model_matrices(Conditions.EitherRefer, params)
    nll = compute_nll(data_all, model_all)
    nlls.append(nll)

nlls = np.array(nlls)
best_params = params_list[np.argmin(nlls)]
best_params
    

{'alpha': 2,
 'w_r': 1.4285714285714286,
 'w_s': 0.0,
 'w_c': 0.01,
 'audience_condition': <AudienceConditions.EitherGroup: 0>}

In [179]:
# Do this for the other conditions
params_list = [
    {
        "alpha": 2,
        "w_r": w_r,
        "w_s": w_s,
        "w_c": 0.01,
        "audience_condition": AudienceConditions.OneGroup,
    }
    for w_r in np.linspace(0, 5, 50)
    for w_s in np.linspace(0, 5, 50)
]

nlls = []

for params in params_list:
    data_all, model_all = generate_data_model_matrices(Conditions.OneRefer, params)
    nll = compute_nll(data_all, model_all)
    nlls.append(nll)

nlls = np.array(nlls)
best_params = params_list[np.argmin(nlls)]
best_params

{'alpha': 2,
 'w_r': 1.0204081632653061,
 'w_s': 0.10204081632653061,
 'w_c': 0.01,
 'audience_condition': <AudienceConditions.OneGroup: 1>}

In [180]:
# Do this for the other conditions
params_list = [
    {
        "alpha": 2,
        "w_r": w_r,
        "w_s": w_s,
        "w_c": 0.01,
        "audience_condition": AudienceConditions.OneGroup,
    }
    for w_r in np.linspace(0, 5, 50)
    for w_s in np.linspace(0, 5, 50)
]

nlls = []

for params in params_list:
    data_all, model_all = generate_data_model_matrices(Conditions.OneSocial, params)
    nll = compute_nll(data_all, model_all)
    nlls.append(nll)

nlls = np.array(nlls)
best_params = params_list[np.argmin(nlls)]
best_params

{'alpha': 2,
 'w_r': 0.7142857142857143,
 'w_s': 0.30612244897959184,
 'w_c': 0.01,
 'audience_condition': <AudienceConditions.OneGroup: 1>}

In [182]:
refer_either_params = {'alpha': 2,
 'w_r': 1.4285714285714286,
 'w_s': 0.0,
 'w_c': 0.01}

refer_one_params = {'alpha': 2,
 'w_r': 1.0204081632653061,
 'w_s': 0.10204081632653061,
 'w_c': 0.01} 

social_one_params = {'alpha': 2,
 'w_r': 0.7142857142857143,
 'w_s': 0.30612244897959184,
 'w_c': 0.01}

# put params all in one dataframe
params_df = pd.DataFrame([refer_either_params, refer_one_params, social_one_params], index=["refer+either", "refer+one", "social+one"])

In [183]:
params_df

Unnamed: 0,alpha,w_r,w_s,w_c
refer+either,2,1.428571,0.0,0.01
refer+one,2,1.020408,0.102041,0.01
social+one,2,0.714286,0.306122,0.01


In [163]:
nlls[nlls < 250]
np.where(nlls < 230)

(array([  0,  50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600,
        650, 700, 750]),)

In [169]:
params_list[400]

{'alpha': 2,
 'w_r': 0.8163265306122449,
 'w_s': 0.0,
 'w_c': 0.01,
 'audience_condition': <AudienceConditions.OneGroup: 1>}

TODOS

- Make heatmaps and figure out whats happening with `w_s`
- try plugging in these weights in toy model and see what happens, also just w model predictions. Can change behavior just by changing `w_r`...
- see if data loaded in is correct... plot data maybe? 
- use the weights to plot the trial-level data
- refactor code and organize and make it FASTER 
- is the data right...
- move code to script stop using jupyter

- Can you get the results using just weights on rever adn COST?? Do you need weight on social


Get rid fo cost weight

Fit data

Start with refer+either condition

Model predictions matrix:     Output: 3 x 2 x 12 x 2 x 2  array
    tangram_set x counterbalance x tangram x tangram_type x audience_group

Data matrix: 
tangram_set x counterbalance x tangram x tangram_type x audience_group x condition x participant

In [None]:
this_condition = Conditions.EitherRefer
test_model_preds = generate_model_preds(**conditions_params["refer+either"])
test_data = data_mtx[:, :, :, :, :, this_condition, 0]
assert test_model_preds.shape == test_data.shape

In [84]:
# Calculate MSE
def mse(pred, data):
    return ((pred - data) ** 2).mean()

((3, 2, 12, 2, 2), (3, 2, 12, 2, 2))

In [98]:
test_model_preds[0, 0, :, :, 1]

array([[0.        , 0.62882805],
       [0.3838926 , 0.        ],
       [0.32575074, 0.        ],
       [0.        , 0.52077317],
       [0.        , 0.        ],
       [0.        , 0.        ],
       [0.        , 0.        ],
       [0.        , 0.27000737],
       [0.        , 0.        ],
       [0.        , 0.        ],
       [0.        , 0.        ],
       [0.27071151, 0.        ]])

In [95]:
test_data[0, 0, :, :, 0]

array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 1.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]])

FIT BY GOAL (refer vs. social) NOT condition
Or maybe fit by condition? 

Instead of weird dicts use intenum from the beginning