In [1]:
import logging
import os
import time

import dill
import numpy as np
import pandas as pd
import plotly.io as pio
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import torch
from plotly import express as px
from pyro.infer import SVI, Predictive, Trace_ELBO
from pyro.infer.autoguide import AutoMultivariateNormal, init_to_mean

from collab.foraging import toolkit as ft
from collab.utils import find_repo_root

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
root = find_repo_root()

smoke_test = "CI" in os.environ
num_svi_iters = 50 if smoke_test else 1000
num_samples = 50 if smoke_test else 1000
keys = [50] if smoke_test else [10, 20, 30, 40, 50, 60, 70, 80]
sampling_rate = 0.01 if smoke_test else 0.01

notebook_starts = time.time()

In [2]:
# this file is generated using `central_park_birds_predictors.ipynb`
path = os.path.join(
    root,
    f"data/foraging/central_park_birds_cleaned_2022/central_park_objects_sampling_rate_{sampling_rate}.pkl",
)

if not smoke_test:
    assert os.path.exists(
        path
    ), "Please run `central_park_birds_predictors.ipynb` to prep the data first."



    with open(path, "rb") as file:
        central_park_objects = dill.load(file)

In [3]:
def cp_prep_data_for_inference(obj):
    df = obj.how_farDF.copy()
    print("Initial dataset size:", len(df))
    df.dropna(inplace=True)
    print("After dropping NAs:", len(df))

    columns_to_normalize = [
        "distance",
        "proximity_standardized",
    ]

    for column in columns_to_normalize:
        df[column] = ft.normalize(df[column])

    return (
        torch.tensor(df["distance"].values),
        torch.tensor(df["proximity_standardized"].values),
        torch.tensor(df["how_far_squared_scaled"].values),
    )

In [4]:
def model_sigmavar_proximity(distance, proximity, how_far):
    d = pyro.sample("d", dist.Normal(0, 0.2))
    p = pyro.sample("p", dist.Normal(0, 0.2))
    b = pyro.sample("b", dist.Normal(0.5, 0.3))

    ds = pyro.sample("ds", dist.Exponential(7))
    ps = pyro.sample("ps", dist.Exponential(7))
    bs = pyro.sample("bs", dist.Exponential(7))

    sigma = bs + ds * distance + ps * proximity
    mean = b + d * distance + p * proximity

    with pyro.plate("data", len(how_far)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=how_far)

In [5]:
def get_samples(
    distance,
    proximity,
    how_far,
    model=model_sigmavar_proximity,
    num_svi_iters=num_svi_iters,
    num_samples=num_samples,
):
    guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)
    svi = SVI(
        model_sigmavar_proximity, guide, optim.Adam({"lr": 0.01}), loss=Trace_ELBO()
    )

    iterations = []
    losses = []

    logging.info(f"Starting SVI inference with {num_svi_iters} iterations.")
    start_time = time.time()
    pyro.clear_param_store()
    for i in range(num_svi_iters):
        elbo = svi.step(distance, proximity, how_far)
        iterations.append(i)
        losses.append(elbo)
        if i % 50 == 0:
            logging.info("Elbo loss: {}".format(elbo))
    end_time = time.time()
    elapsed_time = end_time - start_time
    logging.info("SVI inference completed in %.2f seconds.", elapsed_time)

    # uncomment if you want to see the ELBO loss plots
    # fig = px.line(x=iterations, y=losses, title="ELBO loss", template="presentation")
    # labels = {"iterations": "iteration", "losses": "loss"}
    # fig.update_xaxes(showgrid=False, title_text=labels["iterations"])
    # fig.update_yaxes(showgrid=False, title_text=labels["losses"])
    # fig.update_layout(width=700)
    # fig.show()

    predictive = Predictive(model, guide=guide, num_samples=num_samples)

    proximity_svi = {
        k: v.flatten().reshape(num_samples, -1).detach().cpu().numpy()
        for k, v in predictive(distance, proximity, how_far).items()
        if k != "obs"
    }

    print("SVI-based coefficient marginals:")
    for site, values in ft.summary(proximity_svi, ["d", "p"]).items():
        print("Site: {}".format(site))
        print(values, "\n")

    return {
        "svi_samples": proximity_svi,
        "svi_guide": guide,
        "svi_predictive": predictive,
    }

In [6]:
path = os.path.join(
    root, "data/foraging/central_park_birds_cleaned_2022/duck_outcomes.pkl"
)

if os.path.exists(path):
    print("The duck samples exist, skipping inference, will load later on.")

if not smoke_test and not os.path.exists(path):
    ducks_objects = central_park_objects[0]
    # for ducks starting that low might not make sense
    # [19, 46, 85]
    duck_outcomes = {}

    for key in keys:
        obj = ducks_objects[key]
        print(f"Working on ducks with optimal={key}")
        distance, proximity, how_far = cp_prep_data_for_inference(obj)
        ft.visualise_forager_predictors(
            distance,
            proximity,
            how_far,
            vis_sampling_rate=0.05,
            titles=[f"Distance (ducks)", f"Proximity (ducks, optimal={key})"],
            x_axis_labels=["distance", "proximity"],
        )
        duck_outcomes[key] = get_samples(distance, proximity, how_far)

        # helps prevent crashing the kernel on slower machines
        time.sleep(1)

    with open(path, "wb") as file:
        dill.dump(duck_outcomes, file)

Working on ducks with optimal=10
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-02 17:55:21,778 - Starting SVI inference with 1000 iterations.
2024-03-02 17:55:22,130 - Elbo loss: 133490.86319372943
2024-03-02 17:55:22,803 - Elbo loss: -26164.31939461905
2024-03-02 17:55:23,143 - Elbo loss: -56576.16523980917
2024-03-02 17:55:23,446 - Elbo loss: -78540.73526366733
2024-03-02 17:55:23,743 - Elbo loss: -52167.37213240118
2024-03-02 17:55:24,038 - Elbo loss: -93498.99779707985
2024-03-02 17:55:24,334 - Elbo loss: -90750.88422295536
2024-03-02 17:55:24,627 - Elbo loss: -86337.54944152941
2024-03-02 17:55:24,923 - Elbo loss: -93650.46318031754
2024-03-02 17:55:25,217 - Elbo loss: -100677.0655762277
2024-03-02 17:55:25,509 - Elbo loss: -97436.85435878002
2024-03-02 17:55:25,805 - Elbo loss: -103042.11429985569
2024-03-02 17:55:26,092 - Elbo loss: -98043.53206940871
2024-03-02 17:55:26,386 - Elbo loss: -100988.76472354996
2024-03-02 17:55:26,712 - Elbo loss: -104006.47182635155
2024-03-02 17:55:27,039 - Elbo loss: -103709.40543040098
2024-03-02 17:55:27,363 - Elb

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.274453  0.024272 -0.314393 -0.291668 -0.274737 -0.258001 -0.235279 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.089536  0.030342  0.038957  0.068806  0.088836  0.111575  0.137295 

Working on ducks with optimal=20
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-02 17:55:32,410 - Starting SVI inference with 1000 iterations.
2024-03-02 17:55:32,427 - Elbo loss: 11913.207079876127
2024-03-02 17:55:33,009 - Elbo loss: -19494.41043642687
2024-03-02 17:55:33,571 - Elbo loss: -46099.557520616334
2024-03-02 17:55:34,133 - Elbo loss: -87650.61084231189
2024-03-02 17:55:34,692 - Elbo loss: -96543.5566607577
2024-03-02 17:55:35,253 - Elbo loss: -88917.29681703288
2024-03-02 17:55:35,781 - Elbo loss: -100554.08642273257
2024-03-02 17:55:36,346 - Elbo loss: -90115.33671262354
2024-03-02 17:55:36,910 - Elbo loss: -101253.69818016567
2024-03-02 17:55:37,458 - Elbo loss: -101865.32230858057
2024-03-02 17:55:37,999 - Elbo loss: -103037.31280116262
2024-03-02 17:55:38,290 - Elbo loss: -103454.5317131088
2024-03-02 17:55:38,577 - Elbo loss: -103159.62857710267
2024-03-02 17:55:38,875 - Elbo loss: -104496.98689545995
2024-03-02 17:55:39,173 - Elbo loss: -99628.28143326513
2024-03-02 17:55:39,466 - Elbo loss: -102359.872372344
2024-03-02 17:55:39,808 - El

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.261491  0.022639 -0.299872 -0.276203 -0.261162 -0.245603 -0.225349 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.143899  0.029116  0.096885  0.123617  0.143643  0.163716  0.192201 

Working on ducks with optimal=30
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-02 17:55:43,603 - Starting SVI inference with 1000 iterations.
2024-03-02 17:55:43,612 - Elbo loss: 68259.70497867386
2024-03-02 17:55:43,990 - Elbo loss: -26557.450502735963
2024-03-02 17:55:44,290 - Elbo loss: -71607.30019721105
2024-03-02 17:55:44,594 - Elbo loss: -72712.65568013953
2024-03-02 17:55:44,900 - Elbo loss: -88752.65889921573
2024-03-02 17:55:45,194 - Elbo loss: -101405.50940075456
2024-03-02 17:55:45,494 - Elbo loss: -90904.80817412164
2024-03-02 17:55:45,777 - Elbo loss: -78703.83532638862
2024-03-02 17:55:46,060 - Elbo loss: -102939.67780576019
2024-03-02 17:55:46,355 - Elbo loss: -99591.89069041207
2024-03-02 17:55:46,635 - Elbo loss: -99031.29670545987
2024-03-02 17:55:46,925 - Elbo loss: -104553.4656634425
2024-03-02 17:55:47,208 - Elbo loss: -102110.2780599423
2024-03-02 17:55:47,528 - Elbo loss: -100762.57219778396
2024-03-02 17:55:47,993 - Elbo loss: -103346.60689168415
2024-03-02 17:55:48,520 - Elbo loss: -102855.29439821315
2024-03-02 17:55:49,048 - El

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.247109  0.020906 -0.281122 -0.261125 -0.246656 -0.232687 -0.213996 

Site: p
       mean      std        5%       25%       50%       75%       95%
0  0.194628  0.02605  0.153208  0.177844  0.193202  0.211179  0.237082 

Working on ducks with optimal=40
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-02 17:55:54,763 - Starting SVI inference with 1000 iterations.
2024-03-02 17:55:54,781 - Elbo loss: 155436.51431990962
2024-03-02 17:55:55,419 - Elbo loss: -16163.066232614932
2024-03-02 17:55:55,818 - Elbo loss: -41312.844625504666
2024-03-02 17:55:56,373 - Elbo loss: -85147.00140122064
2024-03-02 17:55:56,922 - Elbo loss: -90721.54907248878
2024-03-02 17:55:57,471 - Elbo loss: -99398.84973808781
2024-03-02 17:55:58,017 - Elbo loss: -85340.21786739552
2024-03-02 17:55:58,563 - Elbo loss: -99851.36289506011
2024-03-02 17:55:59,038 - Elbo loss: -100348.1108582188
2024-03-02 17:55:59,527 - Elbo loss: -99961.14154966085
2024-03-02 17:56:00,077 - Elbo loss: -99780.95936911399
2024-03-02 17:56:00,624 - Elbo loss: -104941.03229023072
2024-03-02 17:56:01,172 - Elbo loss: -105136.3030598964
2024-03-02 17:56:01,721 - Elbo loss: -104207.00111640329
2024-03-02 17:56:02,022 - Elbo loss: -105056.88477717085
2024-03-02 17:56:02,307 - Elbo loss: -105121.67509904924
2024-03-02 17:56:02,600 - E

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.282695  0.024332 -0.321949 -0.299076 -0.282389 -0.266826 -0.243136 

Site: p
       mean       std        5%       25%      50%     75%       95%
0  0.141697  0.021476  0.106827  0.126387  0.14161  0.1561  0.177535 

Working on ducks with optimal=50
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-02 17:56:06,622 - Starting SVI inference with 1000 iterations.
2024-03-02 17:56:06,640 - Elbo loss: 10537.812178424158
2024-03-02 17:56:07,234 - Elbo loss: 7542.019707144247
2024-03-02 17:56:07,753 - Elbo loss: -46362.61162898034
2024-03-02 17:56:08,272 - Elbo loss: -76784.6073377997
2024-03-02 17:56:08,788 - Elbo loss: -84981.92399054913
2024-03-02 17:56:09,303 - Elbo loss: -74175.77093763184
2024-03-02 17:56:09,817 - Elbo loss: -81725.87472428552
2024-03-02 17:56:10,332 - Elbo loss: -100720.99208028457
2024-03-02 17:56:10,830 - Elbo loss: -97149.0434907578
2024-03-02 17:56:11,260 - Elbo loss: -94874.7390764861
2024-03-02 17:56:11,779 - Elbo loss: -97128.48086854679
2024-03-02 17:56:12,295 - Elbo loss: -98471.65293151513
2024-03-02 17:56:12,814 - Elbo loss: -101153.98207717862
2024-03-02 17:56:13,332 - Elbo loss: -103799.75387014206
2024-03-02 17:56:13,795 - Elbo loss: -100335.23856560183
2024-03-02 17:56:14,314 - Elbo loss: -103164.03634451039
2024-03-02 17:56:14,832 - Elbo l

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.300343  0.026586 -0.344098 -0.317907 -0.300578 -0.282419 -0.258913 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.070156  0.024905  0.030116  0.052038  0.069732  0.087911  0.109841 

Working on ducks with optimal=60
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-02 17:56:20,598 - Starting SVI inference with 1000 iterations.
2024-03-02 17:56:20,614 - Elbo loss: 30683.501656880195
2024-03-02 17:56:21,195 - Elbo loss: -16571.303665945547
2024-03-02 17:56:21,715 - Elbo loss: -49235.304395119674
2024-03-02 17:56:22,237 - Elbo loss: -32676.63380480015
2024-03-02 17:56:22,760 - Elbo loss: -79939.22964761002
2024-03-02 17:56:23,281 - Elbo loss: -89836.32092915759
2024-03-02 17:56:23,803 - Elbo loss: -93812.72723743065
2024-03-02 17:56:24,327 - Elbo loss: -101289.01009890393
2024-03-02 17:56:24,849 - Elbo loss: -101558.86424710017
2024-03-02 17:56:25,371 - Elbo loss: -102680.87174976114
2024-03-02 17:56:25,894 - Elbo loss: -91495.83583760413
2024-03-02 17:56:26,415 - Elbo loss: -103383.53914844654
2024-03-02 17:56:26,937 - Elbo loss: -103752.81737657738
2024-03-02 17:56:27,460 - Elbo loss: -99032.78694037221
2024-03-02 17:56:27,984 - Elbo loss: -102942.38639769185
2024-03-02 17:56:28,505 - Elbo loss: -104058.08237331889
2024-03-02 17:56:29,049 

SVI-based coefficient marginals:
Site: d
       mean       std        5%      25%       50%       75%       95%
0 -0.283254  0.025859 -0.326535 -0.30046 -0.282619 -0.265967 -0.240679 

Site: p
       mean      std        5%       25%       50%       75%       95%
0 -0.034539  0.02344 -0.072627 -0.050783 -0.034652 -0.018531  0.003471 

Working on ducks with optimal=70
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-02 17:56:34,819 - Starting SVI inference with 1000 iterations.
2024-03-02 17:56:34,839 - Elbo loss: -7619.061950553333
2024-03-02 17:56:35,371 - Elbo loss: -42028.34167633892
2024-03-02 17:56:35,915 - Elbo loss: -56686.744169085214
2024-03-02 17:56:36,461 - Elbo loss: -75890.08059230057
2024-03-02 17:56:37,010 - Elbo loss: -86839.77482066484
2024-03-02 17:56:37,556 - Elbo loss: -97413.97380876746
2024-03-02 17:56:38,081 - Elbo loss: -100023.64672931991
2024-03-02 17:56:38,627 - Elbo loss: -98358.05868996234
2024-03-02 17:56:39,174 - Elbo loss: -102057.65910559153
2024-03-02 17:56:39,719 - Elbo loss: -104120.01370755974
2024-03-02 17:56:40,265 - Elbo loss: -98305.93470856461
2024-03-02 17:56:40,810 - Elbo loss: -103667.8504743967
2024-03-02 17:56:41,358 - Elbo loss: -104520.06708533858
2024-03-02 17:56:41,907 - Elbo loss: -101433.72045248775
2024-03-02 17:56:42,454 - Elbo loss: -104606.58623647073
2024-03-02 17:56:43,022 - Elbo loss: -104253.76563872669
2024-03-02 17:56:43,567 -

SVI-based coefficient marginals:
Site: d
       mean     std        5%       25%       50%       75%       95%
0 -0.267039  0.0297 -0.314602 -0.286969 -0.266247 -0.247367 -0.218864 

Site: p
       mean       std        5%       25%       50%       75%      95%
0 -0.077231  0.026799 -0.122086 -0.094745 -0.076211 -0.059435 -0.03415 

Working on ducks with optimal=80
Initial dataset size: 101213
After dropping NAs: 99637


2024-03-02 17:56:49,462 - Starting SVI inference with 1000 iterations.
2024-03-02 17:56:49,482 - Elbo loss: 8724.336093654352
2024-03-02 17:56:50,278 - Elbo loss: 6201.577967477129
2024-03-02 17:56:50,734 - Elbo loss: -53437.219097977635
2024-03-02 17:56:51,247 - Elbo loss: -40246.60975046248
2024-03-02 17:56:51,779 - Elbo loss: -79036.47176721656
2024-03-02 17:56:52,328 - Elbo loss: -92789.73502037574
2024-03-02 17:56:52,879 - Elbo loss: -100968.56738047962
2024-03-02 17:56:53,431 - Elbo loss: -100700.01341536055
2024-03-02 17:56:53,944 - Elbo loss: -101191.1181011845
2024-03-02 17:56:54,494 - Elbo loss: -98546.28777627109
2024-03-02 17:56:55,043 - Elbo loss: -100408.69263296462
2024-03-02 17:56:55,593 - Elbo loss: -99173.85503858373
2024-03-02 17:56:56,144 - Elbo loss: -103246.56535847115
2024-03-02 17:56:56,658 - Elbo loss: -102154.80606453274
2024-03-02 17:56:57,126 - Elbo loss: -102507.16183179023
2024-03-02 17:56:57,675 - Elbo loss: -100623.962049275
2024-03-02 17:56:58,227 - Elb

SVI-based coefficient marginals:
Site: d
      mean       std        5%      25%       50%       75%       95%
0 -0.25963  0.029541 -0.309029 -0.28042 -0.258902 -0.238874 -0.210387 

Site: p
       mean       std       5%       25%       50%       75%       95%
0 -0.107191  0.023952 -0.14688 -0.122985 -0.106479 -0.090773 -0.068853 



In [7]:
def plot_coefs(outcomes, title, ann_start_y=100, ann_break_y=50, generate_object=False):
    keys = [10, 20, 30, 40, 50, 60, 70, 80]
    # [19, 46, 85]
    samples = {}

    for key in keys:
        samples[key] = outcomes[key]["svi_samples"]["p"].flatten()

    samples_df = pd.DataFrame(samples)
    # samples_df_medians = samples_df.median(axis=0).tolist()

    fig_coefs = px.histogram(
        samples_df,
        template="presentation",
        opacity=0.4,
        labels={"variable": "preferred proximity", "value": "proximity coefficient"},
        width=700,
        title=title,
    )

    # for i, color in enumerate(['#1f77b4', '#ff7f0e', '#2ca02c']):
    #         fig_coefs.add_vline(x=samples_df_medians[i], line_dash="dash", line_color=color, name=f"Median ({samples_df_medians[i]})")

    #         fig_coefs.add_annotation(
    #         x=samples_df_medians[i],
    #         y= ann_start_y + ann_break_y * i,  # Adjust the vertical position of the label
    #         text=f"{samples_df_medians[i]:.2f}",
    #         bgcolor="white",
    #         showarrow=False,
    #         opacity=0.8,
    #         )

    fig_coefs.update_layout(
        barmode="overlay"
    )  # , yaxis=dict(showticklabels=False, title=None, showgrid=False))

    if generate_object:
        return fig_coefs
    else:
        fig_coefs.show()

In [8]:
duck_outcomes_path = os.path.join(
    root, "data/foraging/central_park_birds_cleaned_2022/duck_outcomes.pkl"
)

if not smoke_test:
    duck_outcomes = dill.load(open(duck_outcomes_path, "rb"))

    ducks_coefs_plot = plot_coefs(
        duck_outcomes, "Ducks", ann_start_y=350, ann_break_y=50, generate_object=True
    )

    ducks_coefs_plot.show()

    pio.write_image(
        ducks_coefs_plot,
        os.path.join(root, "docs/figures/duck_coefs_plot.png"),
        engine="kaleido",
        width=700,
        scale=5,
    )

In [9]:
def calculate_R_squared_prox(distance, proximity, how_far, guide, subsample_size=1000):
    predictive = pyro.infer.Predictive(
        model_sigmavar_proximity, guide=guide, num_samples=1000
    )

    random_indices = np.random.choice(len(distance), size=subsample_size, replace=False)
    distance_sub = distance[random_indices]
    proximity_sub = proximity[random_indices]
    how_far_sub = how_far[random_indices]

    predictions = predictive(distance_sub, proximity_sub, how_far_sub)

    simulated_outcome = (
        predictions["b"] + predictions["p"] * proximity + predictions["d"] * distance
    )

    mean_sim_outcome = simulated_outcome.mean(0).detach().cpu().numpy()

    observed_mean = torch.mean(how_far)

    tss = torch.sum((how_far - observed_mean) ** 2)
    rss = torch.sum((how_far - mean_sim_outcome) ** 2)

    r_squared = 1 - (rss / tss)

    return r_squared.float().item()

In [10]:
if not smoke_test:
    ducks_objects = central_park_objects[0]

    # all objects share the key columns

    for key in keys:

        distance, proximity, how_far = cp_prep_data_for_inference(ducks_objects[key])

        guide = duck_outcomes[key]["svi_guide"]
        print(
            f"R^2 for ducks with optimal={key}:",
            calculate_R_squared_prox(distance, proximity, how_far, guide),
        )

# interestingly, knowing where they won't go is useful

Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=10: 0.3432106673717499
Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=20: 0.36489877104759216
Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=30: 0.36773115396499634
Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=40: 0.38994890451431274
Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=50: 0.36634671688079834
Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=60: 0.36187848448753357
Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=70: 0.37045010924339294
Initial dataset size: 101213
After dropping NAs: 99637
R^2 for ducks with optimal=80: 0.3725975453853607


In [11]:
path = os.path.join(
    root, "data/foraging/central_park_birds_cleaned_2022/sps_outcomes.pkl"
)

if not smoke_test and not os.path.exists(path):
    sps_objects = central_park_objects[1]  # [19, 46, 85]

    sps_outcomes = {}

    for key in keys:
        obj = sps_objects[key]
        print(f"Working on sparrows et al. with optimal={key}")
        distance, proximity, how_far = cp_prep_data_for_inference(obj)
        ft.visualise_forager_predictors(
            distance,
            proximity,
            how_far,
            vis_sampling_rate=0.05,
            titles=[
                f"Distance (sparrows et al.)",
                f"Proximity (sparrows et al., optimal={key})",
            ],
            x_axis_labels=["distance", "proximity"],
        )

        sps_outcomes[key] = get_samples(distance, proximity, how_far)

        time.sleep(1)

    with open(path, "wb") as file:
        dill.dump(sps_outcomes, file)

Working on sparrows et al. with optimal=10
Initial dataset size: 61115
After dropping NAs: 60594


2024-03-02 17:58:30,932 - Starting SVI inference with 1000 iterations.
2024-03-02 17:58:30,948 - Elbo loss: 37168.72259862744
2024-03-02 17:58:31,537 - Elbo loss: -14688.169519539555
2024-03-02 17:58:32,094 - Elbo loss: -44402.91731632241
2024-03-02 17:58:32,653 - Elbo loss: -59418.543272119496
2024-03-02 17:58:33,215 - Elbo loss: -66471.99456089284
2024-03-02 17:58:33,769 - Elbo loss: -73308.75777715638
2024-03-02 17:58:34,321 - Elbo loss: -81240.0093166657
2024-03-02 17:58:34,844 - Elbo loss: -74022.26486824991
2024-03-02 17:58:35,369 - Elbo loss: -76837.73232141903
2024-03-02 17:58:35,892 - Elbo loss: -84248.78443545394
2024-03-02 17:58:36,416 - Elbo loss: -85800.3992684854
2024-03-02 17:58:36,938 - Elbo loss: -82468.01161090162
2024-03-02 17:58:37,459 - Elbo loss: -81078.45533470015
2024-03-02 17:58:37,983 - Elbo loss: -83972.65456467535
2024-03-02 17:58:38,507 - Elbo loss: -83761.75070282667
2024-03-02 17:58:39,066 - Elbo loss: -86556.37544377489
2024-03-02 17:58:39,621 - Elbo los

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.129959  0.022941 -0.169437 -0.144399 -0.129776 -0.114323 -0.095283 

Site: p
       mean      std        5%       25%       50%       75%       95%
0  0.117867  0.02672  0.074945  0.098959  0.117477  0.135474  0.160463 

Working on sparrows et al. with optimal=20
Initial dataset size: 61115
After dropping NAs: 60594


2024-03-02 17:58:45,511 - Starting SVI inference with 1000 iterations.
2024-03-02 17:58:45,527 - Elbo loss: 70811.48207955793
2024-03-02 17:58:45,999 - Elbo loss: -20174.49360375203
2024-03-02 17:58:46,286 - Elbo loss: -42608.26704184759
2024-03-02 17:58:46,588 - Elbo loss: -57611.043596811054
2024-03-02 17:58:46,888 - Elbo loss: -64529.39037845608
2024-03-02 17:58:47,189 - Elbo loss: -80341.02344366303
2024-03-02 17:58:47,460 - Elbo loss: -84795.78709028786
2024-03-02 17:58:47,738 - Elbo loss: -68503.20209446135
2024-03-02 17:58:48,008 - Elbo loss: -70582.24305246837
2024-03-02 17:58:48,310 - Elbo loss: -81501.34046232433
2024-03-02 17:58:48,588 - Elbo loss: -73237.69220528138
2024-03-02 17:58:48,860 - Elbo loss: -86933.37875167158
2024-03-02 17:58:49,135 - Elbo loss: -84213.22114806311
2024-03-02 17:58:49,432 - Elbo loss: -85153.39847810796
2024-03-02 17:58:49,733 - Elbo loss: -85225.10303021228
2024-03-02 17:58:50,035 - Elbo loss: -86556.63924155595
2024-03-02 17:58:50,319 - Elbo lo

SVI-based coefficient marginals:
Site: d
       mean       std        5%      25%       50%      75%       95%
0 -0.119766  0.023267 -0.157841 -0.13371 -0.120232 -0.10438 -0.080936 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.104441  0.023733  0.065439  0.089044  0.104739  0.120628  0.142733 

Working on sparrows et al. with optimal=30
Initial dataset size: 61115
After dropping NAs: 60594


2024-03-02 17:58:53,988 - Starting SVI inference with 1000 iterations.
2024-03-02 17:58:54,020 - Elbo loss: 6107.562199116548
2024-03-02 17:58:54,702 - Elbo loss: -11062.610521993458
2024-03-02 17:58:55,250 - Elbo loss: -11478.13852015456
2024-03-02 17:58:55,801 - Elbo loss: -56964.09804149469
2024-03-02 17:58:56,351 - Elbo loss: -61366.14416105091
2024-03-02 17:58:56,638 - Elbo loss: -74793.7675710361
2024-03-02 17:58:56,920 - Elbo loss: -76164.41121042622
2024-03-02 17:58:57,190 - Elbo loss: -81541.30760100434
2024-03-02 17:58:57,464 - Elbo loss: -80202.45372948449
2024-03-02 17:58:57,734 - Elbo loss: -82842.10596468532
2024-03-02 17:58:58,011 - Elbo loss: -84483.69769975565
2024-03-02 17:58:58,278 - Elbo loss: -64685.91983894343
2024-03-02 17:58:58,560 - Elbo loss: -81231.6327508171
2024-03-02 17:58:58,865 - Elbo loss: -84354.65275106736
2024-03-02 17:58:59,166 - Elbo loss: -84743.77555892835
2024-03-02 17:58:59,469 - Elbo loss: -78251.04654364876
2024-03-02 17:58:59,757 - Elbo loss

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.145994  0.023296 -0.182417 -0.161376 -0.146202 -0.131275 -0.105689 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.092748  0.024598  0.052578  0.076217  0.092826  0.109911  0.132743 

Working on sparrows et al. with optimal=40
Initial dataset size: 61115
After dropping NAs: 60594


2024-03-02 17:59:05,448 - Starting SVI inference with 1000 iterations.
2024-03-02 17:59:05,464 - Elbo loss: 87024.74926925931
2024-03-02 17:59:06,094 - Elbo loss: -19035.831077718187
2024-03-02 17:59:06,561 - Elbo loss: -44524.95518707573
2024-03-02 17:59:07,031 - Elbo loss: -50317.13177445869
2024-03-02 17:59:07,581 - Elbo loss: -60948.18946109664
2024-03-02 17:59:08,017 - Elbo loss: -69154.52701751128
2024-03-02 17:59:08,568 - Elbo loss: -78980.03062460052
2024-03-02 17:59:09,118 - Elbo loss: -82092.78202203508
2024-03-02 17:59:09,633 - Elbo loss: -77727.84641004808
2024-03-02 17:59:10,150 - Elbo loss: -80647.99453421323
2024-03-02 17:59:10,667 - Elbo loss: -84124.91983550924
2024-03-02 17:59:11,165 - Elbo loss: -83033.96205305401
2024-03-02 17:59:11,682 - Elbo loss: -79708.18881191105
2024-03-02 17:59:12,201 - Elbo loss: -82288.32185945247
2024-03-02 17:59:12,719 - Elbo loss: -84191.28707526012
2024-03-02 17:59:13,236 - Elbo loss: -85343.59094527528
2024-03-02 17:59:13,507 - Elbo lo

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%      50%       75%       95%
0 -0.145112  0.026527 -0.188694 -0.163749 -0.14522 -0.125896 -0.103455 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.033372  0.020174  0.000999  0.019679  0.033374  0.046642  0.066933 

Working on sparrows et al. with optimal=50
Initial dataset size: 61115
After dropping NAs: 60594


2024-03-02 17:59:17,111 - Starting SVI inference with 1000 iterations.
2024-03-02 17:59:17,121 - Elbo loss: 10582.216309616837
2024-03-02 17:59:17,617 - Elbo loss: -5261.733718798461
2024-03-02 17:59:17,957 - Elbo loss: -42731.185084507604
2024-03-02 17:59:18,474 - Elbo loss: -42169.34298834983
2024-03-02 17:59:18,788 - Elbo loss: -67208.79258653062
2024-03-02 17:59:19,074 - Elbo loss: -72355.22066079355
2024-03-02 17:59:19,343 - Elbo loss: -74365.97178364072
2024-03-02 17:59:19,634 - Elbo loss: -81214.4097018256
2024-03-02 17:59:19,905 - Elbo loss: -71831.89060965789
2024-03-02 17:59:20,184 - Elbo loss: -78291.42099598481
2024-03-02 17:59:20,466 - Elbo loss: -83377.77462674984
2024-03-02 17:59:20,736 - Elbo loss: -78580.32046533187
2024-03-02 17:59:21,016 - Elbo loss: -78459.8360028395
2024-03-02 17:59:21,286 - Elbo loss: -85136.8607505072
2024-03-02 17:59:21,567 - Elbo loss: -83564.23030676265
2024-03-02 17:59:21,841 - Elbo loss: -78041.27472417589
2024-03-02 17:59:22,115 - Elbo loss

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.139776  0.025766 -0.181095 -0.156701 -0.140346 -0.123509 -0.097613 

Site: p
       mean       std        5%       25%       50%       75%       95%
0 -0.015544  0.022155 -0.050626 -0.030083 -0.015286 -0.001223  0.021765 

Working on sparrows et al. with optimal=60
Initial dataset size: 61115
After dropping NAs: 60594


2024-03-02 17:59:25,700 - Starting SVI inference with 1000 iterations.
2024-03-02 17:59:25,709 - Elbo loss: 86581.33179153988
2024-03-02 17:59:26,114 - Elbo loss: -17258.58739628821
2024-03-02 17:59:26,403 - Elbo loss: -11558.083453266516
2024-03-02 17:59:26,667 - Elbo loss: -46299.420028899665
2024-03-02 17:59:26,943 - Elbo loss: -52055.216101428385
2024-03-02 17:59:27,231 - Elbo loss: -73025.62717348128
2024-03-02 17:59:27,509 - Elbo loss: -77890.64722599555
2024-03-02 17:59:27,784 - Elbo loss: -68177.87007771098
2024-03-02 17:59:28,051 - Elbo loss: -80577.09370308046
2024-03-02 17:59:28,324 - Elbo loss: -84519.8962385424
2024-03-02 17:59:28,635 - Elbo loss: -80673.43667068591
2024-03-02 17:59:28,908 - Elbo loss: -84313.26780441895
2024-03-02 17:59:29,182 - Elbo loss: -83000.65322991842
2024-03-02 17:59:29,452 - Elbo loss: -81801.92449993471
2024-03-02 17:59:29,728 - Elbo loss: -85853.72741446644
2024-03-02 17:59:29,994 - Elbo loss: -81209.5866413112
2024-03-02 17:59:30,265 - Elbo lo

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%      50%       75%       95%
0 -0.140205  0.028517 -0.188327 -0.159563 -0.14053 -0.121577 -0.091796 

Site: p
       mean      std        5%       25%       50%       75%       95%
0 -0.063394  0.02035 -0.097444 -0.077137 -0.062526 -0.049538 -0.029974 

Working on sparrows et al. with optimal=70
Initial dataset size: 61115
After dropping NAs: 60594


2024-03-02 17:59:33,859 - Starting SVI inference with 1000 iterations.
2024-03-02 17:59:33,870 - Elbo loss: 7348.85600399659
2024-03-02 17:59:34,231 - Elbo loss: -27189.89927797802
2024-03-02 17:59:34,728 - Elbo loss: -20920.9255702897
2024-03-02 17:59:35,229 - Elbo loss: -37845.654036198255
2024-03-02 17:59:35,735 - Elbo loss: -72768.76538775247
2024-03-02 17:59:36,235 - Elbo loss: -73642.56250230332
2024-03-02 17:59:36,750 - Elbo loss: -83197.7233163468
2024-03-02 17:59:37,040 - Elbo loss: -83458.65493486336
2024-03-02 17:59:37,328 - Elbo loss: -86302.26534177578
2024-03-02 17:59:37,608 - Elbo loss: -83636.85878845493
2024-03-02 17:59:37,881 - Elbo loss: -84383.86980882558
2024-03-02 17:59:38,152 - Elbo loss: -86502.64162107487
2024-03-02 17:59:38,455 - Elbo loss: -86926.48050767745
2024-03-02 17:59:38,757 - Elbo loss: -84885.09700635797
2024-03-02 17:59:39,059 - Elbo loss: -84661.11997276207
2024-03-02 17:59:39,365 - Elbo loss: -85338.48837565529
2024-03-02 17:59:39,668 - Elbo loss:

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%       95%
0 -0.128891  0.023635 -0.166968 -0.145109 -0.129458 -0.113171 -0.088541 

Site: p
      mean       std        5%      25%       50%       75%       95%
0 -0.08666  0.019891 -0.117968 -0.10045 -0.086955 -0.073539 -0.053357 

Working on sparrows et al. with optimal=80
Initial dataset size: 61115
After dropping NAs: 60594


2024-03-02 17:59:43,601 - Starting SVI inference with 1000 iterations.
2024-03-02 17:59:43,658 - Elbo loss: 13594.511842821003
2024-03-02 17:59:44,211 - Elbo loss: -18832.45274705038
2024-03-02 17:59:44,713 - Elbo loss: -40829.30925687797
2024-03-02 17:59:45,218 - Elbo loss: -59325.162676551015
2024-03-02 17:59:45,494 - Elbo loss: -60431.01393239891
2024-03-02 17:59:45,777 - Elbo loss: -76601.4766231435
2024-03-02 17:59:46,058 - Elbo loss: -67328.48102191486
2024-03-02 17:59:46,340 - Elbo loss: -79910.91328872163
2024-03-02 17:59:46,609 - Elbo loss: -84724.34271909272
2024-03-02 17:59:46,882 - Elbo loss: -79939.3839736892
2024-03-02 17:59:47,151 - Elbo loss: -77412.43286252295
2024-03-02 17:59:47,417 - Elbo loss: -82983.38024300631
2024-03-02 17:59:47,683 - Elbo loss: -87617.78236067832
2024-03-02 17:59:47,951 - Elbo loss: -85510.9758145652
2024-03-02 17:59:48,219 - Elbo loss: -86341.90513738376
2024-03-02 17:59:48,488 - Elbo loss: -87204.4280067479
2024-03-02 17:59:48,756 - Elbo loss:

SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%      50%       75%       95%
0 -0.130009  0.026516 -0.174464 -0.148581 -0.12883 -0.111254 -0.086984 

Site: p
       mean       std        5%       25%       50%      75%       95%
0 -0.099221  0.020486 -0.132564 -0.112938 -0.099098 -0.08507 -0.067802 



In [13]:
sps_outcomes_path = os.path.join(
    root, "data/foraging/central_park_birds_cleaned_2022/sps_outcomes.pkl"
)

if not smoke_test:
    sps_outcomes = dill.load(open(sps_outcomes_path, "rb"))

    sps_coefs_plot = plot_coefs(
        sps_outcomes,
        "Sparrows et al.",
        ann_start_y=200,
        ann_break_y=30,
        generate_object=True,
    )

    sps_coefs_plot.show()
    # add title to figure

    pio.write_image(
        sps_coefs_plot,
        os.path.join(root, "docs/figures/sps_coefs_plot.png"),
        engine="kaleido",
        width=700,
        scale=5,
    )

In [14]:
# note sparrows' movements are harder to predict

if not smoke_test:
    sps_objects = central_park_objects[1]

    for key in keys:
        distance, proximity, how_far = cp_prep_data_for_inference(sps_objects[key])
        guide = sps_outcomes[key]["svi_guide"]
        print(
            f"R^2 for sparrows et al. with optimal={key}:",
            calculate_R_squared_prox(distance, proximity, how_far, guide),
        )

notebook_ends = time.time()

print(
    f"notebook took {notebook_ends - notebook_starts} seconds, {(notebook_ends - notebook_starts)/60} minutes to run"
)

Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=10: 0.18643923103809357
Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=20: 0.19863438606262207
Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=30: 0.20319433510303497
Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=40: 0.17665237188339233
Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=50: 0.16481170058250427
Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=60: 0.15161633491516113
Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=70: 0.200607568025589
Initial dataset size: 61115
After dropping NAs: 60594
R^2 for sparrows et al. with optimal=80: 0.21199537813663483
notebook took 331.26502561569214 seconds, 5.5210837602615355 minutes to run
