In [1]:
import logging
import random

import chirho
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import torch
import torch.nn.functional as F
from plotly.subplots import make_subplots
from pyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from pyro.infer.autoguide import (AutoDiagonalNormal, AutoMultivariateNormal,
                                  AutoNormal, init_to_mean, init_to_value)
from pyro.nn import PyroModule
from scipy.stats import lognorm

In [2]:
def add_jittered_com(data, jitter=0.00):
    if "c_trust_jittered" not in data.columns:
        data["c_trust_jittered"] = data["c_trust"] + np.random.uniform(-jitter, jitter, len(data))
    return data

impactDF_grid45 = pd.read_csv("../data/communicators/communicators_impact/resultsDF_grid45.csv", index_col=0)
impactDF_grid45["reward_patch_dim"] = impactDF_grid45["reward_patch_dim"].astype("category")
impactDF_grid45 = add_jittered_com(impactDF_grid45)


In [19]:
color_scale = {
    "1": "red",
    "2": "green",
    "4": "blue",
}

fig_grid45 = px.scatter(
    impactDF_grid45,
    x="c_trust_jittered",
    y="time to first food",
    color="reward_patch_dim",
    template="presentation",
    title="Trust vs time to first food (grid 45)",
    opacity=0.5,
    labels={"c_trust_jittered": "coefficient of trust", "reward_patch_dim": "patch size"}
)


fig_grid45.update_xaxes(showgrid=False)
fig_grid45.update_yaxes(showgrid=False)
fig_grid45.update_traces(marker={"size": 7})
fig_grid45.update_layout(
    width=800,
    height=800,
)

In [4]:
# prior on standard deviation is 
# meaningfull but not too restrictive
mu = 3
sigma = 0.7
sd = np.linspace(0.1, 100, 1000)

pdf = np.exp(-(np.log(sd) - mu)**2 / (2 * sigma**2)) / (sd * sigma * np.sqrt(2 * np.pi))
fig = px.line(x=sd, y=pdf, labels={"x": "sd", "y": "PDF"}, title=f"Log-Normal Distribution (mu={mu}, sigma={sigma})")
fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False)
fig.update_layout(template="presentation", width = 600)
fig.show()


In [5]:
# define model and training

def get_samples(data, base_m, base_s, slope_m=0, slope_s=30, n_iters=2000, n_samples=1000):
    trust = torch.tensor(data["c_trust"].values)
    patch = torch.tensor(data["reward_patch_dim"].astype("category").cat.codes.values).long()
    time = torch.tensor(data["time to first food"].values)

    def model(trust, time, patch):
        with pyro.plate("coefs", 3):
            base = pyro.sample("base", dist.Normal(base_m, base_s))
            slope = pyro.sample("slope", dist.Normal(slope_m, slope_s))

        sig = pyro.sample("sig", dist.LogNormal(3, 0.7))

        with pyro.plate("obs", len(time)):
            pyro.sample("time", dist.Normal(base[patch] + slope[patch] * trust, sig), obs=time)

    guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)
    svi = SVI(model, guide, optim.Adam({"lr": 0.01}), loss=Trace_ELBO())

    iterations = []
    losses = []

    pyro.clear_param_store()
    num_iters = n_iters
    for i in range(num_iters):
        elbo = svi.step(trust, time, patch)
        iterations.append(i)
        losses.append(elbo)
        if i % 200 == 0:
            logging.info("Elbo loss: {}".format(elbo))

    fig = px.line(x=iterations, y=losses, title="ELBO loss", template="presentation")
    labels={"iterations": "iteration", "losses": "loss"}
    fig.update_xaxes(showgrid=False, title_text=labels["iterations"])
    fig.update_yaxes(showgrid=False, title_text=labels["losses"])
    
    fig.show()

    predictive = Predictive(
        model,
        guide=guide,
        num_samples=n_samples,
    )

    sample = {
        k: v.flatten().reshape(n_samples, -1).detach().cpu().numpy()
        for k, v in predictive(trust, time, patch).items()
        if k != "obs"
    }

    return sample

In [6]:
def plot_coefs(sample):

    df = pd.DataFrame(sample['slope'])
    df.columns = ["1", "2", "4"]
    
    prob_sub_0 = df.applymap(lambda x: 1 if x < 0 else 0).sum()

    print(f"posterior probability of negative slope by patch size: \n {prob_sub_0 / len(df)}")
    
    df_long = df.melt(var_name="patch size", value_name="coefficient of trust")

    fig_svi = px.histogram(df_long["coefficient of trust"],
                           color = df_long["patch size"],
                           template="presentation",
                           nbins=100,
                           title = f"Posterior coefficients of trust by patch size (grid 45)",
                           opacity=0.5,
                           labels={"value": "coefficient of trust", "color": "patch_size"},
                           width = 700, height = 700  
                        ) 
    
    fig.update_yaxes(showgrid=False)
    fig.update_xaxes(showgrid=False)
    
    fig_svi.update_layout(
        xaxis_title="coefficient of trust",
        yaxis_title="",
        xaxis_range = [-25,5],
        legend=dict(yanchor="top", y=0.9, xanchor="left", x=0.9),
        yaxis=dict(showgrid=False, showticklabels=False),
    )

    fig_svi.show()
    
    return(fig_svi)

In [7]:
sample45 = get_samples(impactDF_grid45, base_m= 30, base_s= 25, slope_m= 0, slope_s= 30, n_iters=2000, n_samples=3000)

In [8]:
fig = plot_coefs(sample45)

posterior probability of negative slope by patch size: 
 1    1.000000
2    0.998667
4    0.976333
dtype: float64


In [12]:
base_means = sample45['base'].mean(axis=0)
slope_means = sample45['slope'].mean(axis=0)

trust_range = np.linspace(0, 1, 15)


print(trust_range)
preds = []
for i in range(3):
    local_pred = base_means[i] + slope_means[i] * trust_range
    preds.append(local_pred)
    
    fig_grid45.add_shape(type="line",
    x0=trust_range[0], y0=local_pred[0], x1=trust_range[14], y1=local_pred[14],
    line=dict(
        color="MediumPurple",
        width=1,
        dash="dot",
    )
    )
    
fig_grid45.update_shapes(dict(xref='x', yref='y'))
fig_grid45.show()




[0.         0.07142857 0.14285714 0.21428571 0.28571429 0.35714286
 0.42857143 0.5        0.57142857 0.64285714 0.71428571 0.78571429
 0.85714286 0.92857143 1.        ]


In [None]:
fig_grid45 = px.scatter(
    impactDF_grid45,
    x="c_trust_jittered",
    y="time to first food",
    color="reward_patch_dim",
    template="presentation",
    title="Trust vs time to first food (grid 45)",
    opacity=0.5,
    labels={"c_trust_jittered": "coefficient of trust", "reward_patch_dim": "patch size"}
)


fig_grid45.update_xaxes(showgrid=False)
fig_grid45.update_yaxes(showgrid=False)
fig_grid45.update_traces(marker={"size": 7})
fig_grid45.update_layout(
    width=800,
    height=800,
)