In [1]:
import logging
import os
import time

import dill
import numpy as np
import pandas as pd
import plotly.io as pio
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import torch
import torch.nn.functional as F
from plotly import express as px
from pyro.infer import SVI, Predictive, Trace_ELBO
from pyro.infer.autoguide import AutoMultivariateNormal, init_to_mean

from collab.foraging import toolkit as ft
from collab.utils import find_repo_root

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
root = find_repo_root()

smoke_test = "CI" in os.environ
num_svi_iters = 50 if smoke_test else 1000
num_samples = 50 if smoke_test else 1000
keys = [50] if smoke_test else [10, 20, 30, 40, 50, 60, 70, 80]
sampling_rate = 0.01 if smoke_test else 0.01

notebook_starts = time.time()

In [2]:
# this file is generated using `centrap_park_birds_predictors.ipynb`
path = os.path.join(
    root,
    f"data/foraging/central_park_birds_cleaned_2022/central_park_objects_sampling_rate_{sampling_rate}.pkl",
)

assert os.path.exists(path), "Please run `central_park_birds_predictors.ipynb` to prep the data first."

with open(path, "rb") as file:
    central_park_objects = dill.load(file)

In [3]:
def cp_prep_data_for_iference(obj):
    df = obj.how_farDF.copy()
    print("Initial dataset size:", len(df))
    df.dropna(inplace=True)
    print("After dropping NAs:", len(df))

    columns_to_normalize = [
        "distance",
        "proximity_standardized",
    ]

    for column in columns_to_normalize:
        df[column] = ft.normalize(df[column])

    return (
        torch.tensor(df["distance"].values),
        torch.tensor(df["proximity_standardized"].values),
        torch.tensor(df["how_far_squared_scaled"].values),
    )

In [4]:
def model_sigmavar_proximity(distance, proximity, how_far):
    d = pyro.sample("d", dist.Normal(0, 0.6))
    p = pyro.sample("p", dist.Normal(0, 0.6))
    b = pyro.sample("b", dist.Normal(0.5, 0.6))

    ds = pyro.sample("ds", dist.Normal(0, 0.6))
    ps = pyro.sample("ps", dist.Normal(0, 0.6))
    bs = pyro.sample("bs", dist.Normal(0.2, 0.6))

    sigmaRaw = bs + ds * distance + ps * proximity
    sigma = pyro.deterministic("sigma", F.softplus(sigmaRaw))
    mean = b + d * distance + p * proximity

    with pyro.plate("data", len(how_far)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=how_far)

In [5]:
def get_samples(
    distance,
    proximity,
    how_far,
    model=model_sigmavar_proximity,
    num_svi_iters=num_svi_iters,
    num_samples=num_samples,
):
    guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)
    svi = SVI(
        model_sigmavar_proximity, guide, optim.Adam({"lr": 0.01}), loss=Trace_ELBO()
    )

    iterations = []
    losses = []

    logging.info(f"Starting SVI inference with {num_svi_iters} iterations.")
    start_time = time.time()
    pyro.clear_param_store()
    for i in range(num_svi_iters):
        elbo = svi.step(distance, proximity, how_far)
        iterations.append(i)
        losses.append(elbo)
        if i % 50 == 0:
            logging.info("Elbo loss: {}".format(elbo))
    end_time = time.time()
    elapsed_time = end_time - start_time
    logging.info("SVI inference completed in %.2f seconds.", elapsed_time)

    fig = px.line(x=iterations, y=losses, title="ELBO loss", template="presentation")
    labels = {"iterations": "iteration", "losses": "loss"}
    fig.update_xaxes(showgrid=False, title_text=labels["iterations"])
    fig.update_yaxes(showgrid=False, title_text=labels["losses"])
    fig.update_layout(width=700)
    fig.show()

    predictive = Predictive(model, guide=guide, num_samples=num_samples)

    proximity_svi = {
        k: v.flatten().reshape(num_samples, -1).detach().cpu().numpy()
        for k, v in predictive(distance, proximity, how_far).items()
        if k != "obs"
    }

    print("SVI-based coefficient marginals:")
    for site, values in ft.summary(proximity_svi, ["d", "p"]).items():
        print("Site: {}".format(site))
        print(values, "\n")

    return {
        "svi_samples": proximity_svi,
        "svi_guide": guide,
        "svi_predictive": predictive,
    }

In [6]:
path = os.path.join(
    root, "data/foraging/central_park_birds_cleaned_2022/duck_outcomes.pkl"
)


if not smoke_test and not os.path.exists(path):
    ducks_objects = central_park_objects[0]
    # for ducks starting that low might not make sense
    # [19, 46, 85]
    duck_outcomes = {}

    for key in keys:
        obj = ducks_objects[key]
        print(f"Working on ducks with optimal={key}")
        distance, proximity, how_far = cp_prep_data_for_iference(obj)
        # ft.visualise_forager_predictors(
        #     distance,
        #     proximity,
        #     how_far,
        #     vis_sampling_rate=0.05,
        #     titles=[f"Distance (ducks)", f"Proximity (ducks, optimal={key})"],
        #     x_axis_labels=["distance", "proximity"],
        # )
        duck_outcomes[key] = get_samples(distance, proximity, how_far)

        with open(path, "wb") as file:
            dill.dump(duck_outcomes, file)

2024-03-01 08:44:38,921 - Starting SVI inference with 1000 iterations.


Working on ducks with optimal=10
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-01 08:44:39,266 - Elbo loss: 81646.38998601849
2024-03-01 08:44:39,878 - Elbo loss: -11900.564172505417
2024-03-01 08:44:40,565 - Elbo loss: -56298.52634936318
2024-03-01 08:44:41,196 - Elbo loss: -46784.44613538318
2024-03-01 08:44:41,818 - Elbo loss: -89888.94227598196
2024-03-01 08:44:42,410 - Elbo loss: -92607.54228040657
2024-03-01 08:44:43,054 - Elbo loss: -100541.8506750777
2024-03-01 08:44:43,669 - Elbo loss: -96934.5709931607
2024-03-01 08:44:44,297 - Elbo loss: -100814.81208239541
2024-03-01 08:44:44,885 - Elbo loss: -99686.88126711154
2024-03-01 08:44:45,611 - Elbo loss: -102359.84684890427
2024-03-01 08:44:46,458 - Elbo loss: -103520.35331777114
2024-03-01 08:44:47,183 - Elbo loss: -105958.5233176655
2024-03-01 08:44:48,163 - Elbo loss: -96901.64077530555
2024-03-01 08:44:49,157 - Elbo loss: -103500.35752704318
2024-03-01 08:44:49,976 - Elbo loss: -105306.62811953961
2024-03-01 08:44:50,831 - Elbo loss: -104398.85793432241
2024-03-01 08:44:51,666 - Elbo loss: -10647

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%      75%       95%
0 -0.355279  0.022706 -0.392536 -0.370816 -0.355624 -0.33943 -0.317951 

Site: p
       mean       std        5%       25%      50%       75%     95%
0  0.032601  0.027941 -0.012548  0.014374  0.03271  0.052263  0.0777 



2024-03-01 08:45:01,158 - Starting SVI inference with 1000 iterations.
2024-03-01 08:45:01,188 - Elbo loss: 77655.25156474454


Working on ducks with optimal=20
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-01 08:45:01,918 - Elbo loss: 5176.464522594793
2024-03-01 08:45:02,851 - Elbo loss: -54726.643512194365
2024-03-01 08:45:03,485 - Elbo loss: -47054.54738616293
2024-03-01 08:45:04,088 - Elbo loss: -100988.24488905289
2024-03-01 08:45:04,717 - Elbo loss: -101458.04773419056
2024-03-01 08:45:05,641 - Elbo loss: -101988.29416202288
2024-03-01 08:45:06,466 - Elbo loss: -103428.3701273314
2024-03-01 08:45:07,148 - Elbo loss: -94508.582053966
2024-03-01 08:45:07,862 - Elbo loss: -106588.45339529426
2024-03-01 08:45:08,537 - Elbo loss: -102909.58798841079
2024-03-01 08:45:09,294 - Elbo loss: -103247.71877375501
2024-03-01 08:45:10,058 - Elbo loss: -107096.3582353378
2024-03-01 08:45:10,752 - Elbo loss: -97702.36744773196
2024-03-01 08:45:11,442 - Elbo loss: -105460.76832444486
2024-03-01 08:45:12,126 - Elbo loss: -106189.6167549754
2024-03-01 08:45:12,813 - Elbo loss: -107231.63777493377
2024-03-01 08:45:13,505 - Elbo loss: -107984.0011321905
2024-03-01 08:45:14,212 - Elbo loss: -1078

SVI-based coefficient marginals:
Site: d
      mean       std      5%      25%       50%       75%       95%
0 -0.33419  0.025539 -0.3748 -0.35092 -0.334262 -0.317377 -0.291556 

Site: p
       mean      std        5%       25%       50%      75%       95%
0  0.113668  0.02956  0.066266  0.093405  0.114334  0.13279  0.161836 



2024-03-01 08:45:24,283 - Starting SVI inference with 1000 iterations.
2024-03-01 08:45:24,338 - Elbo loss: 88637.23403533624


Working on ducks with optimal=30
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-01 08:45:25,287 - Elbo loss: 820.4475702842486
2024-03-01 08:45:26,065 - Elbo loss: -80509.6285680953
2024-03-01 08:45:26,837 - Elbo loss: -100161.0847592077
2024-03-01 08:45:27,622 - Elbo loss: -93762.99791284648
2024-03-01 08:45:28,347 - Elbo loss: -98091.58898346232
2024-03-01 08:45:29,063 - Elbo loss: -108783.79785406691
2024-03-01 08:45:29,820 - Elbo loss: -104852.10221984607
2024-03-01 08:45:30,608 - Elbo loss: -103301.21656089206
2024-03-01 08:45:31,454 - Elbo loss: -104531.79497087542
2024-03-01 08:45:32,285 - Elbo loss: -109234.06229761332
2024-03-01 08:45:33,015 - Elbo loss: -109933.90451399298
2024-03-01 08:45:33,736 - Elbo loss: -106627.3974800553
2024-03-01 08:45:34,426 - Elbo loss: -110580.80568490237
2024-03-01 08:45:35,102 - Elbo loss: -106906.16408265868
2024-03-01 08:45:35,769 - Elbo loss: -109061.23301516636
2024-03-01 08:45:36,430 - Elbo loss: -109264.53751965312
2024-03-01 08:45:37,115 - Elbo loss: -108634.30784208255
2024-03-01 08:45:37,746 - Elbo loss: -1

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.315092  0.024565 -0.354822 -0.331659 -0.315322 -0.298442 -0.275763 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.160659  0.028743  0.112873  0.141489  0.160352  0.181104  0.207424 

Working on ducks with optimal=40
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-01 08:45:54,297 - Starting SVI inference with 1000 iterations.
2024-03-01 08:45:54,778 - Elbo loss: 105001.32972486239
2024-03-01 08:45:55,617 - Elbo loss: -17827.676808991324
2024-03-01 08:45:56,402 - Elbo loss: -67221.74321656415
2024-03-01 08:45:57,045 - Elbo loss: -92669.24519698197
2024-03-01 08:45:57,661 - Elbo loss: -80304.9648342426
2024-03-01 08:45:58,254 - Elbo loss: -107506.49523139425
2024-03-01 08:45:58,882 - Elbo loss: -109361.64654670251
2024-03-01 08:45:59,552 - Elbo loss: -100204.26940651477
2024-03-01 08:46:00,135 - Elbo loss: -106601.2718291183
2024-03-01 08:46:00,704 - Elbo loss: -95007.26648165111
2024-03-01 08:46:01,300 - Elbo loss: -106910.92348970272
2024-03-01 08:46:01,895 - Elbo loss: -108433.26689043132
2024-03-01 08:46:02,483 - Elbo loss: -103504.13256675455
2024-03-01 08:46:03,171 - Elbo loss: -106375.80039309245
2024-03-01 08:46:03,770 - Elbo loss: -109930.64303030018
2024-03-01 08:46:04,358 - Elbo loss: -97375.12216441998
2024-03-01 08:46:04,961 -

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.332434  0.024912 -0.372798 -0.349289 -0.332533 -0.315552 -0.290783 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.134983  0.022149  0.098445  0.120246  0.135254  0.149452  0.172003 

Working on ducks with optimal=50
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-01 08:47:08,150 - Starting SVI inference with 1000 iterations.
2024-03-01 08:47:08,433 - Elbo loss: 95841.53199982145
2024-03-01 08:47:09,187 - Elbo loss: 717.9404173274332
2024-03-01 08:47:09,864 - Elbo loss: 17177.64738418947
2024-03-01 08:47:10,496 - Elbo loss: -90466.00825910454
2024-03-01 08:47:11,119 - Elbo loss: -97209.48493899623
2024-03-01 08:47:11,735 - Elbo loss: -90676.89590796473
2024-03-01 08:47:12,379 - Elbo loss: -98115.6463327602
2024-03-01 08:47:12,979 - Elbo loss: -96059.33860733005
2024-03-01 08:47:13,562 - Elbo loss: -92512.9678615845
2024-03-01 08:47:14,168 - Elbo loss: -104894.8159554147
2024-03-01 08:47:14,757 - Elbo loss: -97920.45400384761
2024-03-01 08:47:15,363 - Elbo loss: -105913.12448426257
2024-03-01 08:47:15,954 - Elbo loss: -102796.48556854884
2024-03-01 08:47:16,547 - Elbo loss: -106332.42963613431
2024-03-01 08:47:17,139 - Elbo loss: -107922.26747489141
2024-03-01 08:47:17,729 - Elbo loss: -107221.65467257223
2024-03-01 08:47:18,367 - Elbo lo

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.352656  0.027456 -0.399927 -0.370307 -0.352462 -0.334956 -0.306546 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.099846  0.024814  0.058545  0.084012  0.099289  0.116809  0.140109 



: 

In [None]:
def plot_coefs(outcomes, title, ann_start_y=100, ann_break_y=50, generate_object=False):
    keys = [10, 20, 30, 40, 50, 60, 70, 80]
    # [19, 46, 85]
    samples = {}

    for key in keys:
        samples[key] = outcomes[key]["svi_samples"]["p"].flatten()

    samples_df = pd.DataFrame(samples)
    # samples_df_medians = samples_df.median(axis=0).tolist()

    fig_coefs = px.histogram(
        samples_df,
        template="presentation",
        opacity=0.4,
        labels={"variable": "preferred proximity", "value": "proximity coefficient"},
        width=700,
        title=title,
    )

    # for i, color in enumerate(['#1f77b4', '#ff7f0e', '#2ca02c']):
    #         fig_coefs.add_vline(x=samples_df_medians[i], line_dash="dash", line_color=color, name=f"Median ({samples_df_medians[i]})")

    #         fig_coefs.add_annotation(
    #         x=samples_df_medians[i],
    #         y= ann_start_y + ann_break_y * i,  # Adjust the vertical position of the label
    #         text=f"{samples_df_medians[i]:.2f}",
    #         bgcolor="white",
    #         showarrow=False,
    #         opacity=0.8,
    #         )

    fig_coefs.update_layout(
        barmode="overlay"
    )  # , yaxis=dict(showticklabels=False, title=None, showgrid=False))

    if generate_object:
        return fig_coefs
    else:
        fig_coefs.show()

In [None]:
duck_outcomes_path = os.path.join(
    root, "data/foraging/central_park_birds_cleaned_2022/duck_outcomes.pkl"
)

if not smoke_test:
    duck_outcomes = dill.load(open(duck_outcomes_path, "rb"))

    ducks_coefs_plot = plot_coefs(
        duck_outcomes, "Ducks", ann_start_y=350, ann_break_y=50, generate_object=True
    )

    ducks_coefs_plot.show()

    pio.write_image(
        ducks_coefs_plot,
        os.path.join(root, "docs/figures/duck_coefs_plot.png"),
        engine="kaleido",
        width=700,
        scale=5,
    )

In [None]:
def calculate_R_squared_prox(distance, proximity, how_far, guide, subsample_size=1000):
    predictive = pyro.infer.Predictive(
        model_sigmavar_proximity, guide=guide, num_samples=1000
    )

    random_indices = np.random.choice(len(distance), size=subsample_size, replace=False)
    distance_sub = distance[random_indices]
    proximity_sub = proximity[random_indices]
    how_far_sub = how_far[random_indices]

    predictions = predictive(distance_sub, proximity_sub, how_far_sub)

    simulated_outcome = (
        predictions["b"] + predictions["p"] * proximity + predictions["d"] * distance
    )

    mean_sim_outcome = simulated_outcome.mean(0).detach().cpu().numpy()

    observed_mean = torch.mean(how_far)

    tss = torch.sum((how_far - observed_mean) ** 2)
    rss = torch.sum((how_far - mean_sim_outcome) ** 2)

    r_squared = 1 - (rss / tss)

    return r_squared.float().item()

In [None]:
if not smoke_test:
    ducks_objects = central_park_objects[0]

    distance, proximity, how_far = cp_prep_data_for_iference(
        ducks_objects[50]
    )  # all objects share the key columns

    for key in keys:
        guide = duck_outcomes[key]["svi_guide"]
        print(
            f"R^2 for ducks with optimal={key}:",
            calculate_R_squared_prox(distance, proximity, how_far, guide),
        )

# interestingly, knowing where they won't go is useful

dict_keys([10, 20, 30, 40, 50, 60, 70, 80])
Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=10: 0.34993937611579895
R^2 for ducks with optimal=20: 0.18329627811908722
R^2 for ducks with optimal=30: 0.15729406476020813
R^2 for ducks with optimal=40: 0.3640820384025574
R^2 for ducks with optimal=50: 0.32379448413848877
R^2 for ducks with optimal=60: 0.3512199819087982
R^2 for ducks with optimal=70: 0.3209821879863739
R^2 for ducks with optimal=80: 0.3128003180027008


In [None]:
path = os.path.join(
    root, "data/foraging/central_park_birds_cleaned_2022/sps_outcomes.pkl"
)

if not smoke_test and not os.path.exists(path):
    sps_objects = central_park_objects[1]  # [19, 46, 85]

    sps_outcomes = {}

    for key in keys:
        obj = sps_objects[key]
        print(f"Working on sparrows et al. with optimal={key}")
        distance, proximity, how_far = cp_prep_data_for_iference(obj)
        # ft.visualise_forager_predictors(
        #     distance,
        #     proximity,
        #     how_far,
        #     vis_sampling_rate=0.05,
        #     titles=[
        #         f"Distance (sparrows et al.)",
        #         f"Proximity (sparrows et al., optimal={key})",
        #     ],
        #    x_axis_labels=["distance", "proximity"],
        # )

        sps_outcomes[key] = get_samples(distance, proximity, how_far)

    with open(path, "wb") as file:
        dill.dump(sps_outcomes, file)

In [None]:
sps_outcomes_path = os.path.join(
    root, "data/foraging/central_park_birds_cleaned_2022/sps_outcomes.pkl"
)

if not smoke_test:
    sps_outcomes = dill.load(open(sps_outcomes_path, "rb"))

    sps_coefs_plot = plot_coefs(
        sps_outcomes,
        "Sparrows et al.",
        ann_start_y=200,
        ann_break_y=30,
        generate_object=True,
    )

    sps_coefs_plot.show()
    # add title to figure

    pio.write_image(
        sps_coefs_plot,
        os.path.join(root, "docs/figures/sps_coefs_plot.png"),
        engine="kaleido",
        width=700,
        scale=5,
    )

In [None]:
if not smoke_test:
    sps_objects = central_park_objects[1]

    distance, proximity, how_far = cp_prep_data_for_iference(sps_objects[50])

    for key in keys:
        guide = sps_outcomes[key]["svi_guide"]
        print(
            f"R^2 for sparrows et al. with optimal={key}:",
            calculate_R_squared_prox(distance, proximity, how_far, guide),
        )

notebook_ends = time.time()

print(
    f"notebook took {notebook_ends - notebook_starts} seconds, {(notebook_ends - notebook_starts)/60} minutes to run"
)

Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=10: 0.07661649584770203
R^2 for sparrows et al. with optimal=20: -0.004683657083660364
R^2 for sparrows et al. with optimal=30: 0.0047910213470458984
R^2 for sparrows et al. with optimal=40: 0.14601436257362366
R^2 for sparrows et al. with optimal=50: 0.17356309294700623
R^2 for sparrows et al. with optimal=60: 0.15320388972759247
R^2 for sparrows et al. with optimal=70: 0.1526634246110916
R^2 for sparrows et al. with optimal=80: 0.10207065939903259
notebook took 45.60255551338196 seconds, 0.7600425918896992 minutes to run
