In [1]:
import sys

sys.path.insert(0, "..")

import random
import numpy as np
from sklearn.preprocessing import LabelEncoder

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import random
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import torch
import pyro
import foraging_toolkit as ft


import torch.nn.functional as F
import pyro.distributions as dist
import pyro.optim as optim
from pyro.nn import PyroModule
from pyro.infer.autoguide import (
    AutoNormal,
    AutoDiagonalNormal,
    AutoMultivariateNormal,
    init_to_mean,
    init_to_value,
)
from pyro.contrib.autoguide import AutoLaplaceApproximation
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer import Predictive
from pyro.infer import MCMC, NUTS

import os
import logging

logging.basicConfig(format="%(message)s", level=logging.INFO)
smoke_test = "CI" in os.environ

In [2]:

with open('locust_data/loc_derived.pkl', 'rb') as file:
    loc_derived = pickle.load(file)

#proximity, trace, visibility, communicate, how_far_score = loc_derived

FileNotFoundError: [Errno 2] No such file or directory: 'locust_data/loc_derived.pkl'

In [None]:
ft.visualise_bird_predictors(trace, proximity, how_far_score, com = communicate)

In [None]:
def model_sigmavar_com(proximity, trace, visibility, communicate, how_far_score):
    p = pyro.sample("p", dist.Normal(0, .3))
    t = pyro.sample("t", dist.Normal(0, .3))
    v = pyro.sample("v", dist.Normal(0, .3))
    c = pyro.sample("c", dist.Normal(0, .3))
    b = pyro.sample("b", dist.Normal(.5, .3))

    ps = pyro.sample("ps", dist.Normal(0, .3))
    ts = pyro.sample("ts", dist.Normal(0, .3))
    vs = pyro.sample("vs", dist.Normal(0, .3))
    cs = pyro.sample("cs", dist.Normal(0, .3))
    bs = pyro.sample("bs", dist.Normal(.2, .3))


    sigmaRaw = bs + ps * proximity + ts * trace + vs * visibility + cs * communicate
    sigma = pyro.deterministic("sigma", F.softplus(sigmaRaw))
    mean = b + p * proximity + t * trace + v * visibility + c * communicate

    with pyro.plate("data", len(how_far_score)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=how_far_score)

pyro.render_model(model_sigmavar_com, model_args=(proximity, trace, visibility,
                                               communicate, how_far_score),
                        render_distributions=True)

In [None]:

guide = AutoMultivariateNormal(model_sigmavar_com, init_loc_fn=init_to_mean)
svi = SVI(model_sigmavar_com,
        guide,
        optim.Adam({"lr": .01}),
        loss=Trace_ELBO())

iterations = []
losses = []

pyro.clear_param_store()
num_iters = 1000
for i in range(num_iters):
    elbo = svi.step(proximity, trace, visibility, communicate, how_far_score)
    iterations.append(i)
    losses.append(elbo)
    if i % 200 == 0:
        logging.info("Elbo loss: {}".format(elbo))


In [None]:
num_samples = 1000
predictive = Predictive(model_sigmavar_com, guide=guide, 
                        num_samples=num_samples,
                        return_sites = ["t", "p", "c"])
communicate_sigmavar = {k: v.flatten().reshape(num_samples, -1).detach().cpu().numpy()
            for k, v in predictive(proximity, trace, visibility, communicate, how_far_score).items()
            if k != "obs"}

for site, values in ft.summary(communicate_sigmavar, ["t", "p", "c"]).items():
            print("Site: {}".format(site))
            print(values, "\n")

In [None]:
def calculate_R_squared_com(guide):
    predictive = pyro.infer.Predictive(model_sigmavar_com, guide=guide, num_samples=1000)
    predictions = predictive(proximity, trace, visibility, communicate, how_far_score)

    simulated_outcome = ( predictions['b'] + predictions['p'] * proximity +
                      predictions['t'] * trace + predictions['v'] * visibility +
                      predictions['c'] * communicate)

    mean_sim_outcome = simulated_outcome.mean(0).detach().cpu().numpy()

    observed_mean = torch.mean(how_far_score)

    tss = torch.sum((how_far_score - observed_mean) ** 2)
    rss = torch.sum((how_far_score - mean_sim_outcome) ** 2)

    r_squared = 1 - (rss / tss)

    return r_squared.float().item()

print(calculate_R_squared_com(guide))

In [None]:

site_names = ['p', 't', 'c']

svi_p = communicate_sigmavar['p'].flatten().tolist()
#mcmc_p = communicate_sigmavar_mcmc['p'].flatten().tolist()

svi_t = communicate_sigmavar['t'].flatten().tolist()
#mcmc_t = communicate_sigmavar_mcmc['t'].flatten().tolist()


svi_c = communicate_sigmavar['c'].flatten().tolist()
#mcmc_c = communicate_sigmavar_mcmc['c'].flatten().tolist()


site_data = {
#    'Site': site_names * len(svi_p),
    'p': svi_p, # + mcmc_p,
    't': svi_t, # + mcmc_t,
    'c': svi_c, # + mcmc_c,
    'Method': ['SVI'] * len(svi_p) #+ ['MCMC'] * len(mcmc_p)
}

df = pd.DataFrame(site_data)

fig1 = px.histogram(df, x='p',# color='Method',
                     title="Proximity", template='plotly_dark')

fig1.update_layout(autosize=False, width=800, height=400)
fig1.update_traces(opacity=0.4)


fig2 = px.histogram(df, x='t',# color='Method',
                     title="Trace", template='plotly_dark')

fig2.update_layout(autosize=False, width=800, height=400)
fig2.update_traces(opacity=0.4)


fig3 = px.histogram(df, x='c', #color='Method',
                     title="Communicate", template='plotly_dark')

fig3.update_layout(autosize=False, width=800, height=400)
fig3.update_traces(opacity=0.4)


fig1.show()
fig2.show()
fig3.show()

#note: we only had two birds and one food patch, likely will improve with better data

