In [59]:
import logging
import random

import chirho
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import torch
import torch.nn.functional as F
from plotly.subplots import make_subplots
from pyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from pyro.infer.autoguide import (AutoDiagonalNormal, AutoMultivariateNormal,
                                  AutoNormal, init_to_mean, init_to_value)
from pyro.nn import PyroModule
from scipy.stats import lognorm

import plotly.io as pio

In [60]:
def add_jittered_com(data, jitter=0.00):
    if "c_trust_jittered" not in data.columns:
        data["c_trust_jittered"] = data["c_trust"] + np.random.uniform(-jitter, jitter, len(data))
    return data

impactDF_grid45 = pd.read_csv("../data/communicators/communicators_impact/resultsDF_grid45_150frames.csv", index_col=0)
impactDF_grid45["reward_patch_dim"] = impactDF_grid45["reward_patch_dim"].astype("category")
impactDF_grid45 = add_jittered_com(impactDF_grid45)

In [61]:
color_scale = {
    "1": "red",
    "2": "green",
    "4": "blue",
}

fig_grid45 = px.scatter(
    impactDF_grid45,
    x="c_trust_jittered",
    y="time to first food",
    color="reward_patch_dim",
    template="presentation",
    title="",
    opacity=0.5,
    labels={"c_trust_jittered": "communication parameter", "reward_patch_dim": "patch size"}
)


fig_grid45.update_xaxes(showgrid=False, range = [0, 0.7])
fig_grid45.update_yaxes(showgrid=False)
fig_grid45.update_traces(marker={"size": 7})
fig_grid45.update_layout(
    width=800,
    height=800,
)





In [62]:
# prior on standard deviation is 
# meaningfull but not too restrictive
mu = 4
sigma = 0.7
sd = np.linspace(0.1, 100, 1000)

pdf = np.exp(-(np.log(sd) - mu)**2 / (2 * sigma**2)) / (sd * sigma * np.sqrt(2 * np.pi))
fig = px.line(x=sd, y=pdf, labels={"x": "sd", "y": "PDF"}, title=f"Log-Normal Distribution (mu={mu}, sigma={sigma})")
fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False)
fig.update_layout(template="presentation", width = 600)
fig.show()

In [63]:
# define model and training

def get_samples(data, base_m, base_s, slope_m=0, slope_s=30, n_iters=2000, n_samples=1000):
    trust = torch.tensor(data["c_trust"].values)
    patch = torch.tensor(data["reward_patch_dim"].astype("category").cat.codes.values).long()
    time = torch.tensor(data["time to first food"].values)

    def model(trust, time, patch):
        with pyro.plate("coefs", 3):
            base = pyro.sample("base", dist.Normal(base_m, base_s))
            slope = pyro.sample("slope", dist.Normal(slope_m, slope_s))

        sig = pyro.sample("sig", dist.LogNormal(4, 0.7))

        with pyro.plate("obs", len(time)):
            pyro.sample("time", dist.Normal(base[patch] + slope[patch] * trust, sig), obs=time)

    guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)
    svi = SVI(model, guide, optim.Adam({"lr": 0.01}), loss=Trace_ELBO())

    iterations = []
    losses = []

    pyro.clear_param_store()
    num_iters = n_iters
    for i in range(num_iters):
        elbo = svi.step(trust, time, patch)
        iterations.append(i)
        losses.append(elbo)
        if i % 200 == 0:
            logging.info("Elbo loss: {}".format(elbo))

    fig = px.line(x=iterations, y=losses, title="ELBO loss", template="presentation")
    labels={"iterations": "iteration", "losses": "loss"}
    fig.update_xaxes(showgrid=False, title_text=labels["iterations"])
    fig.update_yaxes(showgrid=False, title_text=labels["losses"])
    fig.update_layout(width=700)
    fig.show()

    predictive = Predictive(
        model,
        guide=guide,
        num_samples=n_samples,
    )

    sample = {
        k: v.flatten().reshape(n_samples, -1).detach().cpu().numpy()
        for k, v in predictive(trust, time, patch).items()
        if k != "obs"
    }


    return sample

In [64]:
def plot_coefs(sample):

    df = pd.DataFrame(sample['slope'])
    df.columns = ["1", "2", "4"]
    
    df_medians = df.median(axis = 0 )
    
    prob_sub_0 = df.applymap(lambda x: 1 if x < 0 else 0).sum()

    print(f"posterior probability of negative slope by patch size: \n {prob_sub_0 / len(df)}")
    
    df_long = df.melt(var_name="patch size", value_name="coefficient of trust")
    

    
    fig_svi = px.histogram(df_long["coefficient of trust"],
                           color = df_long["patch size"],
                           template="presentation",
                           nbins=100,
                           title = f"",
                           opacity=0.5,
                           labels={"value": "communication parameter", "color": "patch_size"},
                           color_discrete_sequence=['#1f77b4', '#ff7f0e', '#2ca02c'], 
                           barmode='overlay'
                        ) 
    
    for i, color in enumerate(['#1f77b4', '#ff7f0e', '#2ca02c']):
        fig_svi.add_vline(x=df_medians[i], line_dash="dash", line_color=color, name=f"Median ({df_medians[i]})")
    
        fig_svi.add_annotation(
        x=df_medians[i],
        y=200 + 40 * i,  # Adjust the vertical position of the label
        text=f"{df_medians[i]:.2f}",
        showarrow=False,
        bordercolor="black",
        borderwidth=2,
        bgcolor="white",
        opacity=0.7
        )

    fig.update_yaxes(showgrid=False)
    fig.update_xaxes(showgrid=False)
    
    fig_svi.update_layout(
        xaxis_title="communication slope coefficient",
        yaxis_title="count",
        xaxis_range = [-30,10],
        legend=dict(yanchor="top", y=0.9, xanchor="left", x=0.9),
        width = 700
    )

    fig_svi.show()
    
    return(fig_svi)

In [65]:

sample45 = get_samples(impactDF_grid45, base_m= 30, base_s= 25, slope_m= 0, slope_s= 30, n_iters=2000, n_samples=3000)

In [66]:
fig_coefs45 = plot_coefs(sample45)


pio.write_image(fig_coefs45, 'exported_figures/fig4_coefs45.png', 
               engine = "kaleido", width=600, height=600, scale=5)

posterior probability of negative slope by patch size: 
 1    0.997333
2    0.995667
4    0.920333
dtype: float64



DataFrame.applymap has been deprecated. Use DataFrame.map instead.


Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`


Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`


Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`



In [67]:
base_means = sample45['base'].mean(axis=0)
slope_means = sample45['slope'].mean(axis=0)

trust_range = np.linspace(0, .7, 15)

line_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]

preds = []
for i in range(3):
    local_pred = base_means[i] + slope_means[i] * trust_range
    preds.append(local_pred)
    
    fig_grid45.add_shape(type="line",
    x0=trust_range[0], y0=local_pred[0], x1=trust_range[14], y1=local_pred[14],
    line=dict(
        color=line_colors[i],
        width=3,
    )
    )
    
fig_grid45.update_shapes(dict(xref='x', yref='y'))
fig_grid45.show()


pio.write_image(fig_grid45, 'exported_figures/fig4_trust_linear.png', 
               engine = "kaleido", width=600, height=600, scale=5)
