In [63]:
import pandas as pd
import numpy as np
import plotly.express as px

import torch
import pyro
import pyro.distributions as dist

from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.infer.autoguide import AutoMultivariateNormal

from pyro.optim import Adam

In [7]:
impactDF = pd.read_csv("communicators_impact_data/resultsDF.csv", index_col=0)

display(impactDF)

Unnamed: 0,c_trust,sight_radius,reward_patch_dim,sim index,run index,time to first food,num birds failed
0,0.0,5,1,0,0,1.555556,0
1,0.0,5,1,0,1,1.333333,0
2,0.0,5,1,0,2,4.222222,0
3,0.0,5,1,0,3,13.222222,1
4,0.0,5,1,0,4,2.555556,0
...,...,...,...,...,...,...,...
445,0.9,5,5,44,5,9.777778,0
446,0.9,5,5,44,6,10.222222,1
447,0.9,5,5,44,7,14.111111,0
448,0.9,5,5,44,8,4.333333,0


In [13]:
# this is not going to be a useful metric, if the number of frames is sufficient
# for birds to have enough time to succeed
px.histogram(impactDF, x="num birds failed", nbins=20, title="Overall number of birds failed",
             template="plotly_dark")


In [14]:
#this migth be more useful
px.histogram(impactDF, x="time to first food", nbins=20, title="Overall times to first food",
             template="plotly_dark")

In [17]:
# the success metrics are rather correlated anyway
impactDF.iloc[:,-2:].corr(method="spearman")

Unnamed: 0,time to first food,num birds failed
time to first food,1.0,0.591975
num birds failed,0.591975,1.0


In [18]:
px.scatter(impactDF, x="num birds failed", y="time to first food",
            title="Number of birds failed vs time to first food",
            template="plotly_dark")

In [43]:
# note that initially the utility of communicating grows faster  
jitter = 0.02 
impactDF["c_trust_jittered"] = impactDF["c_trust"] + np.random.uniform(-jitter, jitter, len(impactDF))



fig = px.scatter(impactDF, x="c_trust_jittered", y="time to first food", color="reward_patch_dim",
           template="plotly_dark", title="Trust vs time to first food",
           trendline="lowess")

fig.update_traces({'marker':{'size': 5, 'opacity':0.5}})

fig.show()

In [44]:
# let's take a closer look

fig.update_yaxes(range=[0, 15])
fig.show()

In [52]:
trust = torch.tensor(impactDF["c_trust"].values.astype(float))
patch = torch.tensor(impactDF["reward_patch_dim"].values.astype(float))
time = torch.tensor(impactDF["time to first food"].values.astype(float))



In [54]:
def model(trust, patch, time):
    w_trust = pyro.sample("weight_trust", dist.Normal(0., 1.))
    w_patch = pyro.sample("weight_patch", dist.Normal(0., 1.))
    bias = pyro.sample("bias", dist.Normal(5., 4.))

    sd = pyro.sample("sd", dist.Exponential(.5))

    mu = w_trust * trust + w_patch * patch + bias

    with pyro.plate("data", len(time)):
        pyro.sample("obs", dist.Normal(mu, sd), obs=time)


In [60]:
guide = AutoMultivariateNormal(model)

pyro.clear_param_store()
svi = SVI(model, guide, Adam({"lr": 0.03}), loss=Trace_ELBO())

num_iterations = 2000
for i in range(num_iterations):
    loss = svi.step(trust, patch, time)
    if (i + 1) % 100 == 0:
        print(f"Iteration {i + 1}/{num_iterations}, Loss: {loss:.2f}")


Iteration 100/2000, Loss: 1575.86
Iteration 200/2000, Loss: 1572.45
Iteration 300/2000, Loss: 1562.77
Iteration 400/2000, Loss: 1561.19
Iteration 500/2000, Loss: 1560.23
Iteration 600/2000, Loss: 1561.13
Iteration 700/2000, Loss: 1556.95
Iteration 800/2000, Loss: 1559.20
Iteration 900/2000, Loss: 1558.42
Iteration 1000/2000, Loss: 1554.13
Iteration 1100/2000, Loss: 1559.52
Iteration 1200/2000, Loss: 1556.59
Iteration 1300/2000, Loss: 1556.48
Iteration 1400/2000, Loss: 1558.47
Iteration 1500/2000, Loss: 1557.21
Iteration 1600/2000, Loss: 1556.44
Iteration 1700/2000, Loss: 1557.21
Iteration 1800/2000, Loss: 1553.91
Iteration 1900/2000, Loss: 1556.69
Iteration 2000/2000, Loss: 1557.57


In [68]:
predictive = Predictive(model, guide=guide, num_samples=1000)
samples = predictive(trust, patch, time)

In [80]:
fig1 = px.histogram(samples["weight_trust"], title="Posterior marginal distribution of trust weights",
             template="plotly_dark")

mean_trust = samples["weight_trust"].mean()


fig1.update_traces(showlegend=False)
fig1.add_vline(x= mean_trust, line_dash="dash", line_color="red", annotation_text=f"Mean = {mean_trust:.2f}", 
               annotation_position="top right", 
              annotation_xshift=20)


In [83]:
fig2 = px.histogram(samples["weight_patch"], title="Posterior marginal distribution of patch weights",
             template="plotly_dark")

mean_patch = samples["weight_patch"].mean()


fig2.update_traces(showlegend=False)
fig2.add_vline(x= mean_patch, line_dash="dash", line_color="red", annotation_text=f"Mean = {mean_patch:.2f}", 
               annotation_position="top right", 
              annotation_xshift=20)