In [1]:
import pandas as pd
import numpy as np
import plotly.express as px

import torch
import pyro
import pyro.distributions as dist

from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.infer.autoguide import AutoMultivariateNormal

from pyro.optim import Adam

In [68]:
impactDF = pd.read_csv("communicators_impact_data/resultsDF.csv", index_col=0)

display(impactDF)

Unnamed: 0,c_trust,sight_radius,reward_patch_dim,sim index,run index,time to first food,num birds failed
0,0.0,5,1,0,0,14.666667,0
1,0.0,5,1,0,1,14.000000,0
2,0.0,5,1,0,2,4.888889,0
3,0.0,5,1,0,3,2.555556,0
4,0.0,5,2,1,0,7.888889,0
...,...,...,...,...,...,...,...
415,1.0,5,4,103,3,8.444444,0
416,1.0,5,5,104,0,3.555556,0
417,1.0,5,5,104,1,6.111111,0
418,1.0,5,5,104,2,7.333333,0


In [7]:
# this is not going to be a useful metric, if the number of frames is sufficient
# for birds to have enough time to succeed
px.histogram(
    impactDF,
    x="num birds failed",
    nbins=20,
    title="Overall number of birds failed",
    template="plotly_dark",
)


In [8]:
# this migth be more useful
px.histogram(
    impactDF,
    x="time to first food",
    nbins=30,
    title="Overall times to first food",
    template="plotly_dark",
)


In [9]:
# the success metrics are rather correlated anyway
impactDF.iloc[:, -2:].corr(method="spearman")


Unnamed: 0,time to first food,num birds failed
time to first food,1.0,0.545523
num birds failed,0.545523,1.0


In [10]:
px.scatter(
    impactDF,
    x="num birds failed",
    y="time to first food",
    title="Number of birds failed vs time to first food",
    template="plotly_dark",
)


In [12]:
# note that initially the utility of communicating grows faster
jitter = 0.01
impactDF["c_trust_jittered"] = impactDF["c_trust"] + np.random.uniform(
    -jitter, jitter, len(impactDF)
)


fig = px.scatter(
    impactDF,
    x="c_trust_jittered",
    y="time to first food",
    color="reward_patch_dim",
    template="plotly_dark",
    title="Trust vs time to first food",
    trendline="lowess",
)

fig.update_traces({"marker": {"size": 5, "opacity": 0.5}})

fig.show()


In [19]:
# let's take a closer look

fig.update_yaxes(range=[0, 25])
fig.show()

In [22]:
# regular_impactDF = impactDF[(impactDF["c_trust"] != 0) & (impactDF["c_trust"] != 1)]
# regular_impactDF.shape


(380, 8)

In [69]:
# restriction = (impactDF["c_trust"] > 0.05) & (impactDF["c_trust"] < 0.95)
# impactDF = impactDF[restriction]
print(impactDF.shape)

trust = torch.tensor(impactDF["c_trust"].values.astype(float))
patch = torch.tensor(impactDF["reward_patch_dim"].values.astype(float))
time = torch.tensor(impactDF["time to first food"].values.astype(float))


(420, 7)


In [70]:
def model(trust, patch, time):
    w_trust = pyro.sample("weight_trust", dist.Normal(0.0, 1.0))
    w_patch = pyro.sample("weight_patch", dist.Normal(0.0, 1.0))
    bias = pyro.sample("bias", dist.Normal(5.0, 4.0))

    sd = pyro.sample("sd", dist.Exponential(0.5))
    mu = w_trust * trust + w_patch * patch + bias

    #    considered this interaction
    #    but there are no good reasons to think the interaction coefficient is non-null
    #    w_tp = pyro.sample("interaction_trust_patch", dist.Normal(0.0, 1.0))
    #    mu = w_trust * trust + w_patch * patch + w_tp * trust * patch + bias

    with pyro.plate("data", len(time)):
        pyro.sample("obs", dist.Normal(mu, sd), obs=time)


In [71]:
guide = AutoMultivariateNormal(model)

pyro.clear_param_store()
svi = SVI(model, guide, Adam({"lr": 0.03}), loss=Trace_ELBO())

num_iterations = 2000
for i in range(num_iterations):
    loss = svi.step(trust, patch, time)
    if (i + 1) % 100 == 0:
        print(f"Iteration {i + 1}/{num_iterations}, Loss: {loss:.2f}")


Iteration 100/2000, Loss: 2050.32
Iteration 200/2000, Loss: 1806.25
Iteration 300/2000, Loss: 1585.79
Iteration 400/2000, Loss: 1480.49
Iteration 500/2000, Loss: 1466.94
Iteration 600/2000, Loss: 1452.06
Iteration 700/2000, Loss: 1452.21
Iteration 800/2000, Loss: 1453.42
Iteration 900/2000, Loss: 1451.57
Iteration 1000/2000, Loss: 1453.08
Iteration 1100/2000, Loss: 1451.02
Iteration 1200/2000, Loss: 1455.26
Iteration 1300/2000, Loss: 1452.86
Iteration 1400/2000, Loss: 1453.81
Iteration 1500/2000, Loss: 1450.87
Iteration 1600/2000, Loss: 1452.04
Iteration 1700/2000, Loss: 1451.71
Iteration 1800/2000, Loss: 1452.88
Iteration 1900/2000, Loss: 1452.85
Iteration 2000/2000, Loss: 1453.16


In [72]:
predictive = Predictive(model, guide=guide, num_samples=1000)
samples = predictive(trust, patch, time)


In [94]:
def calculate_R_squared_com(guide):
    predictive = pyro.infer.Predictive(model, guide=guide, num_samples=1000)
    predictions = predictive(trust, patch, time)

    print(predictions.keys())
    simulated_outcome = (
        predictions["weight_trust"] * trust
        + predictions["weight_patch"] * patch
        + predictions["bias"]
    )
    print(simulated_outcome.shape)
    mean_sim_outcome = simulated_outcome.mean(dim=0)
    print(mean_sim_outcome.shape)

    fig = px.histogram(time - mean_sim_outcome)
    fig.show()

    observed_mean = torch.mean(time)

    tss = torch.sum((time - observed_mean) ** 2)
    print(tss)

    rss = torch.sum((time - mean_sim_outcome) ** 2)

    r_squared = 1 - (rss / tss)

    return r_squared


print(calculate_R_squared_com(guide))


dict_keys(['weight_trust', 'weight_patch', 'bias', 'sd', 'obs'])
torch.Size([1000, 420])
torch.Size([420])


tensor(25338.1308, dtype=torch.float64)
rss tensor(23579.7555, dtype=torch.float64)
Div tensor(0.9306, dtype=torch.float64)
tensor(0.0694, dtype=torch.float64)


In [73]:
fig1 = px.histogram(
    samples["weight_trust"],
    title="Posterior marginal distribution of trust weights",
    template="plotly_dark",
)

mean_trust = samples["weight_trust"].mean()


fig1.update_traces(showlegend=False)
fig1.add_vline(
    x=mean_trust,
    line_dash="dash",
    line_color="red",
    annotation_text=f"Mean = {mean_trust:.2f}",
    annotation_position="top right",
    annotation_xshift=20,
)


In [74]:
fig2 = px.histogram(
    samples["weight_patch"],
    title="Posterior marginal distribution of patch weights",
    template="plotly_dark",
)

mean_patch = samples["weight_patch"].mean()


fig2.update_traces(showlegend=False)
fig2.add_vline(
    x=mean_patch,
    line_dash="dash",
    line_color="red",
    annotation_text=f"Mean = {mean_patch:.2f}",
    annotation_position="top right",
    annotation_xshift=20,
)
