In [1]:
import pandas as pd
import numpy as np
import plotly.express as px
from plotly.subplots import make_subplots

import torch
import pyro
import pyro.distributions as dist

from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.infer.autoguide import AutoMultivariateNormal

from jax import random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import Predictive, SVI, Trace_ELBO, MCMC, NUTS
from numpyro.infer.autoguide import AutoLaplaceApproximation

from pyro.optim import Adam


In [4]:
threshold = 0.25
impactDFgrid30 = pd.read_csv(
    "../data/communicators/communicators_impact/resultsDFgrid30.csv", index_col=0
)
impactDFgrid30_2 = pd.read_csv(
    "../data/communicators/communicators_impact/resultsDFgrid30_2.csv", index_col=0
)

impactDFgrid30 = pd.concat([impactDFgrid30, impactDFgrid30_2], axis=0)


impactDFgrid30.keys()
impactDFgrid30 = impactDFgrid30[impactDFgrid30["reward_patch_dim"].isin([1, 2, 4])]
impactDFgrid30["reward_patch_dim"] = impactDFgrid30["reward_patch_dim"].astype(
    "category"
)

restriction30_low = (impactDFgrid30["c_trust"] >= 0.0) & (
    impactDFgrid30["c_trust"] <= threshold
)
restriction30_high = impactDFgrid30["c_trust"] > threshold

impactDFgrid30_low = impactDFgrid30[restriction30_low].copy()
impactDFgrid30_high = impactDFgrid30[restriction30_high].copy()

display(impactDFgrid30_low.tail())
display(impactDFgrid30_high.head())


Unnamed: 0,c_trust,sight_radius,reward_patch_dim,sim index,run index,time to first food,num birds failed
303,0.25,5,1,150,1,2.222222,0
304,0.25,5,2,151,0,3.777778,0
305,0.25,5,2,151,1,7.777778,0
306,0.25,5,4,152,0,5.666667,0
307,0.25,5,4,152,1,6.111111,0


Unnamed: 0,c_trust,sight_radius,reward_patch_dim,sim index,run index,time to first food,num birds failed
266,0.26,5,1,130,0,2.333333,0
267,0.26,5,1,130,1,1.666667,0
268,0.26,5,2,131,0,5.666667,0
269,0.26,5,2,131,1,5.777778,0
272,0.26,5,4,133,0,3.111111,0


In [5]:
# this is not going to be a useful metric, if the number of frames is sufficient
# for birds to have enough time to succeed
px.histogram(
    impactDFgrid30,
    x="num birds failed",
    nbins=20,
    title="Overall number of birds failed",
    template="plotly_dark",
)

In [6]:
# this migth be more useful
px.histogram(
    impactDFgrid30,
    x="time to first food",
    nbins=30,
    title="Overall times to first food",
    template="plotly_dark",
)

In [7]:
# the success metrics are rather correlated anyway
impactDFgrid30.iloc[:, -2:].corr(method="spearman")

Unnamed: 0,time to first food,num birds failed
time to first food,1.0,0.463305
num birds failed,0.463305,1.0


In [8]:
px.scatter(
    impactDFgrid30,
    x="num birds failed",
    y="time to first food",
    title="Number of birds failed vs time to first food",
    template="plotly_dark",
)

In [12]:
jitter = 0.004
impactDFgrid30_low["c_trust_jittered"] = impactDFgrid30_low["c_trust"] + np.random.uniform(
    -jitter, jitter, len(impactDFgrid30_low)
)
impactDFgrid30_high["c_trust_jittered"] = impactDFgrid30_high["c_trust"] + np.random.uniform(
    -jitter, jitter, len(impactDFgrid30_high)
)

fig_grid30_low = px.scatter(
    impactDFgrid30_low,
    x="c_trust_jittered",
    y="time to first food",
    color="reward_patch_dim",
    template="plotly_dark",
    title="Trust vs time to first food (low trust)",
    trendline="lowess",
    opacity=0.5,
)

fig_grid30_low.update_traces({"marker": {"size": 5, "opacity": 0.5}})


fig_grid30_high = px.scatter(
    impactDFgrid30_high,
    x="c_trust_jittered",
    y="time to first food",
    color="reward_patch_dim",
    template="plotly_dark",
    title="Trust vs time to first food (high trust)",
    trendline="lowess",
    opacity=0.5,
)

fig_grid30_high.update_traces({"marker": {"size": 5, "opacity": 0.5}})


fig_grid30_low.show()

fig_grid30_high.show()

In [29]:
def model(patch_id, c_trust, time):
    baseline = numpyro.sample("baseline", dist.Normal(10, 5).expand([len(set(patch_id))]))
    tr = numpyro.sample("tr", dist.Normal(0, 10).expand([len(set(patch_id))]))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = baseline[patch_id] + tr[patch_id] * c_trust
    numpyro.sample("time", dist.Normal(mu, sigma), obs=time)


def get_samples(data):
    data["patch_id"] = data.reward_patch_dim.astype("category").cat.codes
    data["time"] = data["time to first food"]

    guide = AutoLaplaceApproximation(model)
    svi = SVI(
        model,
        guide,
        optim.Adam(1),
        Trace_ELBO(),
        patch_id=data.patch_id.values,
        c_trust=data.c_trust.values,
        time=data.time.values,
    )
    svi_result = svi.run(random.PRNGKey(0), 2000)
    params = svi_result.params
    post_svi = guide.sample_posterior(random.PRNGKey(1), params, (1000,))

    nuts_kernel = NUTS(model)

    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=2000)

    rng_key = random.PRNGKey(0)
    mcmc.run(
        rng_key,
        patch_id=data.patch_id.values,
        c_trust=data.c_trust.values,
        time=data.time.values,
    )

    post_mcmc = mcmc.get_samples()

    return {"post_svi": post_svi, "post_mcmc": post_mcmc}

In [30]:
post30high = get_samples(impactDFgrid30_high)

TypeError: numpyro.handlers.trace.get_trace() got multiple values for keyword argument 'patch_id'

In [25]:
post30low = get_samples(impactDFgrid30_low)

100%|██████████| 1000/1000 [00:00<00:00, 2092.98it/s, init loss: 374753.6250, avg. loss [951-1000]: 1516.7517]
sample: 100%|██████████| 2500/2500 [00:03<00:00, 655.77it/s, 3 steps of size 3.39e-01. acc. prob=0.92]  


In [26]:
def plot_coefs(post, group):
    df_mcmc = pd.DataFrame(
        post["post_mcmc"]["tr"], columns=["patch 1", "patch 2", "patch 4"]
    )
    df_svi = pd.DataFrame(
        post["post_svi"]["tr"], columns=["patch 1", "patch 2", "patch 4"]
    )

    fig_mcmc = px.histogram(
        df_mcmc,
        x=["patch 1", "patch 2", "patch 4"],
        nbins=100,
        template="plotly_dark",
        title=f"Posterior coefficients of trust (MCMC, {group})",
    )

    fig_svi = px.histogram(
        df_svi,
        x=["patch 1", "patch 2", "patch 4"],
        nbins=100,
        template="plotly_dark",
        title=f"Posterior coefficients of trust (SVI, {group})",
    )

    fig_mcmc.show()
    fig_svi.show()


In [27]:
plot_coefs(post30high, ">0.3")

In [28]:
plot_coefs(post30low, "<0.3")


In [None]:
import pyro
import pyro.distributions as dist
import torch


def pyro_model(data):
    num_patches = 3
    num_samples = len(data)

    # Prior for patch categories
    patch_probs = torch.tensor([0.3, 0.4, 0.3])  # Adjust these probabilities as needed

    # Sample the patch category for each data point
    patch = pyro.sample("patch", dist.Categorical(patch_probs), obs=data["patch"].to(torch.int64))

    # Parameters for other variables
    alpha = pyro.param("alpha", torch.tensor(1.0))
    beta = pyro.param("beta", torch.tensor(1.0))

    # Model for other variables based on patch category
    with pyro.plate("data", num_samples):
        trust = pyro.sample("trust", dist.Beta(alpha, beta))
        food_latency = pyro.sample("food_latency", dist.Normal(0, 1))

    return trust, food_latency, patch


# Sample data (replace this with your actual data)
data = {
    "patch": torch.tensor([0, 1, 2, 0, 1, 2]),  # Assuming "patch" is a categorical variable with 3 categories
}

# Run the model
pyro.clear_param_store()
trust, food_latency, patch = pyro_model(data)

# Print the sampled values
print("Sampled Trust:", trust)
print("Sampled Food Latency:", food_latency)
print("Sampled Patch:", patch)

In [29]:
def model(trust, patch, time):
    w_trust = pyro.sample("weight_trust", dist.Normal(0.0, 1.0))
    w_patch = pyro.sample("weight_patch", dist.Normal(0.0, 1.0))
    bias = pyro.sample("bias", dist.Normal(5.0, 4.0))

    sd = pyro.sample("sd", dist.Exponential(0.5))
    mu = w_trust * trust + w_patch * patch + bias

    with pyro.plate("data", len(time)):
        pyro.sample("obs", dist.Normal(mu, sd), obs=time)

In [None]:
def svi_training(impactDF):
    trust = torch.tensor(impactDF["c_trust"].values.astype(float))
    patch = torch.tensor(impactDF["reward_patch_dim"].values.astype(float))
    time = torch.tensor(impactDF["time to first food"].values.astype(float))

    guide = AutoMultivariateNormal(model)

    pyro.clear_param_store()
    svi = SVI(model, guide, Adam({"lr": 0.03}), loss=Trace_ELBO())

    num_iterations = 2000
    for i in range(num_iterations):
        loss = svi.step(trust, patch, time)
        if (i + 1) % 100 == 0:
            print(f"Iteration {i + 1}/{num_iterations}, Loss: {loss:.2f}")

    predictive = Predictive(model, guide=guide, num_samples=1000)
    samples = predictive(trust, patch, time)

    return samples

In [30]:
samples_grid30 = svi_training(impactDFgrid30)


Iteration 100/2000, Loss: 722.13
Iteration 200/2000, Loss: 697.92
Iteration 300/2000, Loss: 686.32
Iteration 400/2000, Loss: 689.47
Iteration 500/2000, Loss: 686.46
Iteration 600/2000, Loss: 672.40
Iteration 700/2000, Loss: 671.55
Iteration 800/2000, Loss: 674.95
Iteration 900/2000, Loss: 675.40
Iteration 1000/2000, Loss: 672.93
Iteration 1100/2000, Loss: 671.95
Iteration 1200/2000, Loss: 671.14
Iteration 1300/2000, Loss: 671.75
Iteration 1400/2000, Loss: 670.95
Iteration 1500/2000, Loss: 670.80
Iteration 1600/2000, Loss: 671.22
Iteration 1700/2000, Loss: 671.42
Iteration 1800/2000, Loss: 670.82
Iteration 1900/2000, Loss: 669.81
Iteration 2000/2000, Loss: 671.01


In [31]:
fig_trust_grid30 = px.histogram(
    samples_grid30["weight_trust"],
    title="Posterior marginal distribution of trust weights (grid 30)",
    template="plotly_dark",
)

mean_trust = samples["weight_trust"].mean()


fig_trust_grid30.update_traces(showlegend=False)
fig_trust_grid30.add_vline(
    x=mean_trust,
    line_dash="dash",
    line_color="red",
    annotation_text=f"Mean = {mean_trust:.2f}",
    annotation_position="top right",
    annotation_xshift=20,
)

In [25]:
fig2 = px.histogram(
    samples["weight_patch"],
    title="Posterior marginal distribution of patch weights",
    template="plotly_dark",
)

mean_patch = samples["weight_patch"].mean()


fig2.update_traces(showlegend=False)
fig2.add_vline(
    x=mean_patch,
    line_dash="dash",
    line_color="red",
    annotation_text=f"Mean = {mean_patch:.2f}",
    annotation_position="top right",
    annotation_xshift=20,
)