In [1]:
import sys
sys.path.insert(0, "..")

import random
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import random
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import pyro
import foraging_toolkit as ft
import torch.nn.functional as F
import pyro.distributions as dist
import pyro.optim as optim
from pyro.nn import PyroModule
from pyro.infer.autoguide import AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, init_to_mean, init_to_value
from pyro.contrib.autoguide import AutoLaplaceApproximation
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer import Predictive
from pyro.infer import MCMC, NUTS

import os
import logging
logging.basicConfig(format='%(message)s', level=logging.INFO)
smoke_test = ('CI' in os.environ)


import foraging_toolkit as ft

## Random birds

### Running a simulation

In [2]:
random.seed(2)
np.random.seed(2)
# generate a simulation object with fixed 
# simulation parameters
random_birds = ft.RandomBirds(grid_size= 40,
                        probabilities= [1, 2, 3, 2, 1, 2, 3, 2, 1],
                        num_birds= 3,
                        num_frames= 10,
                        num_rewards= 15,
                        grab_range= 3)

#run a particular simulation with these parameters
random_birds()
print(random_birds.birdsDF)

      x     y bird  time
0  19.0  21.0    1     1
1  15.0  21.0    1     2
2  16.0  18.0    1     3
3  15.0  18.0    1     4
4  14.0  16.0    1     5
5  12.0  18.0    1     6
6  10.0  21.0    1     7
7  11.0  21.0    1     8
8   9.0  24.0    1     9
9   7.0  21.0    1    10
0  20.0  19.0    2     1
1  17.0  17.0    2     2
2  16.0  18.0    2     3
3  13.0  18.0    2     4
4  10.0  18.0    2     5
5  11.0  17.0    2     6
6   9.0  19.0    2     7
7   6.0  20.0    2     8
8   4.0  17.0    2     9
9   2.0  19.0    2    10
0  24.0  21.0    3     1
1  24.0  18.0    3     2
2  27.0  17.0    3     3
3  25.0  20.0    3     4
4  26.0  19.0    3     5
5  25.0  15.0    3     6
6  24.0  13.0    3     7
7  26.0  10.0    3     8
8  27.0  14.0    3     9
9  31.0  18.0    3    10


In [None]:
# you have created birds and rewards in a space-time grid

#each row contains the x and y coordinates of a bird at a particular time
random_birds.birdsDF.head()

In [None]:
# each row contains the x and y coordinates of a reward at a particular time
random_birds.rewardsDF.head()

In [None]:
# look at just the birds

ft.animate_birds(random_birds, plot_rewards=True,
                  width = 600, height = 600, point_size = 10)


### Transformed data

In [None]:
# let's add transformed data

rbirds_transformed = ft.transform_birds(random_birds)

In [None]:
# now we can plot food traces in time:

ft.animate_birds(rbirds_transformed, plot_rewards=True, width = 600,
                  height = 600, point_size = 10,plot_traces=True)

In [None]:
# we can also plot visibility, for one bird at a time
# as multiple birds' visibility is hard to see 
# let's say, bird 2

ft.animate_birds(rbirds_transformed, plot_rewards=True, width = 600,
                  height = 600, point_size = 10, 
                  plot_visibility=2, plot_traces=True)


In [None]:
# we can plot proximity score
# bird 2 again

ft.animate_birds(rbirds_transformed, plot_rewards=True, width = 600,
                  height = 600, point_size = 10, plot_proximity=2)

### Inference


In [None]:
# prepare data

df = rbirds_transformed.transformationsDF.dropna()

data = torch.tensor(df[["proximity_standardized",
                         "trace_standardized", 
                         "visibility", "how_far_squared_scaled"]].values, dtype=torch.float32)

prox, tr, vis, hf = data[:, 0], data[:, 1], data[:, 2], data[:, 3]

In [None]:
# visualise key predictors:

ft.visualise_bird_predictors(tr, prox, hf)

In [None]:
#setting up the MCMC inference method

def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

def mcmc_training(model, *args):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=400)
    mcmc.run(*args)

    hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

    for site, values in summary(hmc_samples).items():
        if site in ["t","p"]:
            print("Site: {}".format(site))
            print(values, "\n")

    return hmc_samples

In [None]:
# define the model

def model_sigmavar(prox, tr, vis, hf):
    p = pyro.sample("p", dist.Normal(0, .3))
    t = pyro.sample("t", dist.Normal(0, .3))
    v = pyro.sample("v", dist.Normal(0, .3))
    b = pyro.sample("b", dist.Normal(.5, .3))

    ps = pyro.sample("ps", dist.Normal(0, .3))
    ts = pyro.sample("ts", dist.Normal(0, .3))
    vs = pyro.sample("vs", dist.Normal(0, .3))
    bs = pyro.sample("bs", dist.Normal(.2, .3))


    sigmaRaw = bs + ps * prox + ts * tr + vs * vis
    sigma = pyro.deterministic("sigma", F.softplus(sigmaRaw))
    mean = b + p * prox + t * tr + v * vis

    with pyro.plate("data", len(hf)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=hf)

pyro.render_model(model_sigmavar, model_args=(prox, tr, vis, hf),
                        render_distributions=True)

In [None]:
# Inference with SVI
# note how long this takes
# and compare with MCM 

guide = AutoMultivariateNormal(model_sigmavar, init_loc_fn=init_to_mean)
svi = SVI(model_sigmavar,
        guide,
        optim.Adam({"lr": .01}),
        loss=Trace_ELBO())

iterations = []
losses = []

pyro.clear_param_store()
num_iters = 1000
for i in range(num_iters):
    elbo = svi.step(prox, tr, vis, hf)
    iterations.append(i)
    losses.append(elbo)
    if i % 200 == 0:
        logging.info("Elbo loss: {}".format(elbo))


In [None]:
df = pd.DataFrame({'iterations': iterations, 'ELBO Loss': losses})

fig = px.line(df, x='iterations', y='ELBO Loss')

fig.update_layout(
    title="ELBO Loss over iterations",
    xaxis_title="iterations",
    yaxis_title="ELBO Loss",
    template="plotly_dark", 
    showlegend=False,  
    xaxis_showgrid=False,
    yaxis_showgrid=False,
)

fig.show()

In [None]:
# inspect the summary of the SVI posterior
# of key interest: t and p

num_samples = 1000
predictive = Predictive(model_sigmavar, guide=guide, 
                        num_samples=num_samples,
                        return_sites = ["t", "p"])
random_sigmavar = {k: v.flatten().reshape(num_samples, -1).detach().cpu().numpy()
            for k, v in predictive(prox, tr, vis, hf).items()
            if k != "obs"}

for site, values in summary(random_sigmavar).items():
            print("Site: {}".format(site))
            print(values, "\n")

In [None]:
# as a sanity check
# compare with the results 
# from the MCMC method

# warning: MCMC training might take a bit of time to run
#random_sigmavar_mcmc = mcmc_training(model_sigmavar, prox, tr, vis, hf)

In [None]:
# plot the marginal posterior distributions
# for the coefficients of interest

sites = ["p", "t"]
plt.style.use("dark_background")
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 10))
fig.suptitle("Random birds (posterior densities)", fontsize=16)
for i, ax in enumerate(axs.reshape(-1)):
    site = sites[i]
    sns.histplot(random_sigmavar[site], ax=ax, label="svi")
    sns.histplot(random_sigmavar_mcmc [site], ax=ax, label="mcmc")
    ax.set_title(site)
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right');