In [1]:
import os
import sys

current_dir =  os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)

import math
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import torch
import pyro
from collab import foraging_toolkit as ft
import torch.nn.functional as F
import logging
import time
import dill
import copy

from scipy.signal import find_peaks

import plotly.io as pio
from plotly import express as px, graph_objects as go, figure_factory as ff
from pyro.nn import PyroModule
import pyro.distributions as dist
from pyro.infer.autoguide import (
    AutoNormal,
    AutoDiagonalNormal,
    AutoMultivariateNormal,
    init_to_mean,
    init_to_value,
)


from pyro.contrib.autoguide import AutoLaplaceApproximation
from pyro.infer import SVI, Trace_ELBO, MCMC, NUTS
from pyro.optim import Adam
import pyro.optim as optim
from pyro.infer import Predictive
from pyro.infer import MCMC, NUTS


In [2]:
path = "../data/central_park_birds_cleaned_2022/central_park_objects.pkl"
with open(path, "rb") as file:
    central_park_objects = dill.load(file)


In [3]:
def cp_prep_data_for_iference(obj):
    df = obj.how_farDF.copy()
    print("Initial dataset size:", len(df))
    df.dropna(inplace=True)
    print("After dropping NAs:", len(df))
    
    columns_to_normalize = [
        "distance",
        "proximity_standardized",
    ]

    for column in columns_to_normalize:
        df[column] = ft.normalize(df[column])
    
    return torch.tensor(df['distance'].values), torch.tensor(df['proximity_standardized'].values), torch.tensor(df['how_far_squared_scaled'].values), 
    

In [4]:
def model_sigmavar_proximity(distance, proximity, how_far):
    d = pyro.sample("d", dist.Normal(0, .6))
    p = pyro.sample("p", dist.Normal(0, .6))
    b = pyro.sample("b", dist.Normal(.5, .6))

    ds = pyro.sample("ds", dist.Normal(0, .6))
    ps = pyro.sample("ps", dist.Normal(0, .6))
    bs = pyro.sample("bs", dist.Normal(.2, .6))

    sigmaRaw = bs + ds * distance +  ps * proximity
    sigma = pyro.deterministic("sigma", F.softplus(sigmaRaw))
    mean = b + d * distance +  p * proximity

    with pyro.plate("data", len(how_far)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=how_far)

In [5]:
def get_samples(distance, proximity, how_far, model = model_sigmavar_proximity,
                num_svi_iters = 1000,
                num_mcmc_samples = 200, 
                num_samples = 1000):
    
    guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)
    svi = SVI(model_sigmavar_proximity,
        guide,
        optim.Adam({"lr": .01}),
        loss=Trace_ELBO())

    iterations = []
    losses = []
    
    logging.info(f"Starting SVI inference with {num_svi_iters} iterations.")
    start_time = time.time()
    pyro.clear_param_store()
    for i in range(num_svi_iters):
        elbo = svi.step(distance, proximity, how_far)
        iterations.append(i)
        losses.append(elbo)
        if i % 200 == 0:
            logging.info("Elbo loss: {}".format(elbo))
    end_time = time.time()
    elapsed_time = end_time - start_time
    logging.info("SVI inference completed in %.2f seconds.", elapsed_time)

    fig = px.line(x=iterations, y=losses, title="ELBO loss", template="presentation")
    labels={"iterations": "iteration", "losses": "loss"}
    fig.update_xaxes(showgrid=False, title_text=labels["iterations"])
    fig.update_yaxes(showgrid=False, title_text=labels["losses"])
    fig.update_layout(width=700)
    fig.show()
        
    predictive = Predictive(model, guide=guide, 
                        num_samples=num_samples)
    
    proximity_svi = {k: v.flatten().reshape(num_samples, -1).detach().cpu().numpy()
            for k, v in predictive(distance, proximity, how_far).items()
            if k != "obs"}


    print ("SVI-based coefficient marginals:")
    for site, values in ft.summary(proximity_svi, ["d", "p"]).items():
            print("Site: {}".format(site))
            print(values, "\n")
            
    return {"svi_samples": proximity_svi, "svi_guide": guide, "svi_predictive": predictive}

In [6]:
ducks_objects = central_park_objects[0]
keys = [19, 46, 85]

duck_outcomes = {}

for key in keys:
    obj = ducks_objects[key]
    print (f"Working on ducks with optimal={key}")
    distance, proximity, how_far= cp_prep_data_for_iference(obj)
    ft.visualise_bird_predictors(distance, proximity, how_far, vis_sampling_rate=.05, titles = [f"Distance (ducks)", f"Proximity (ducks, optimal={key})"],
                                 x_axis_labels = ["distance", "proximity"])

    duck_outcomes[key] = get_samples(distance, proximity, how_far)
    

Working on ducks with optimal=19
Initial dataset size: 202152
After dropping NAs: 199001


2023-09-28 13:25:29,996:  Starting SVI inference with 1000 iterations.
2023-09-28 13:25:30,066:  Elbo loss: 113559.8530292369
2023-09-28 13:25:35,063:  Elbo loss: -196496.45820252062
2023-09-28 13:25:39,597:  Elbo loss: -208449.36847726052
2023-09-28 13:25:43,966:  Elbo loss: -208679.27770173643
2023-09-28 13:25:48,272:  Elbo loss: -210581.48319280826
2023-09-28 13:25:52,996:  SVI inference completed in 23.00 seconds.


SVI-based coefficient marginals:
Site: d
       mean       std      5%       25%       50%       75%       95%
0 -0.336466  0.022088 -0.3709 -0.351089 -0.337681 -0.322131 -0.299219 

Site: p
       mean       std        5%       25%       50%       75%       95%
0  0.119372  0.026092  0.075518  0.103214  0.120353  0.136745  0.160861 

Working on ducks with optimal=46
Initial dataset size: 202152
After dropping NAs: 199001


2023-09-28 13:26:00,567:  Starting SVI inference with 1000 iterations.
2023-09-28 13:26:00,627:  Elbo loss: 176165.17556287535
2023-09-28 13:26:05,320:  Elbo loss: -173561.05496494126
2023-09-28 13:26:09,484:  Elbo loss: -193993.98959043174
2023-09-28 13:26:13,674:  Elbo loss: -214916.54616541072
2023-09-28 13:26:18,777:  Elbo loss: -217416.17006127597
2023-09-28 13:26:24,075:  SVI inference completed in 23.51 seconds.


SVI-based coefficient marginals:
Site: d
       mean      std        5%       25%       50%       75%       95%
0 -0.345011  0.02807 -0.390915 -0.364077 -0.345318 -0.326834 -0.299198 

Site: p
       mean       std        5%       25%       50%       75%      95%
0  0.153465  0.023919  0.114452  0.137379  0.152931  0.169849  0.19289 

Working on ducks with optimal=85
Initial dataset size: 202152
After dropping NAs: 199001


2023-09-28 13:26:30,961:  Starting SVI inference with 1000 iterations.
2023-09-28 13:26:31,004:  Elbo loss: 203562.73938632198
2023-09-28 13:26:35,568:  Elbo loss: -153437.29023248173
2023-09-28 13:26:39,861:  Elbo loss: -190109.6665854209
2023-09-28 13:26:44,293:  Elbo loss: -203585.4842394688
2023-09-28 13:26:48,498:  Elbo loss: -208535.03122090243
2023-09-28 13:26:52,815:  SVI inference completed in 21.85 seconds.


SVI-based coefficient marginals:
Site: d
       mean       std        5%       25%       50%       75%      95%
0 -0.332493  0.029181 -0.381055 -0.352789 -0.331816 -0.312552 -0.28477 

Site: p
       mean      std      5%       25%       50%       75%       95%
0 -0.154583  0.02772 -0.1986 -0.173077 -0.155146 -0.135964 -0.106686 



In [35]:
def plot_coefs(outcomes, title, ann_start_y = 100, ann_break_y = 50, generate_object = False):

    keys = [19, 46, 85]
    samples = {}

    for key in keys:
        samples[key] =  outcomes[key]["svi_samples"]["p"].flatten()
    
    samples_df = pd.DataFrame(duck_samples)
    samples_df_medians = samples_df.median(axis = 0 ).tolist()

    fig_coefs = px.histogram(samples_df, template = "presentation", 
                opacity = .4,
                labels={"variable": "optimal proximity", "value": "proximity coefficient"},
                        width  = 700,
                        title  = title,
                        marginal="rug"
                        )


    for i, color in enumerate(['#1f77b4', '#ff7f0e', '#2ca02c']):
            fig_coefs.add_vline(x=samples_df_medians[i], line_dash="dash", line_color=color, name=f"Median ({samples_df_medians[i]})")
        
        
            fig_coefs.add_annotation(
            x=samples_df_medians[i],
            y= ann_start_y + ann_break_y * i,  # Adjust the vertical position of the label
            text=f"{samples_df_medians[i]:.2f}",
            bgcolor="white",
            showarrow=False,
            opacity=0.8,
            )

    fig_coefs.update_layout(barmode='overlay', yaxis=dict(showticklabels=False, title=None, showgrid=False)) 

    if generate_object:
        return fig_coefs
    else:
        fig_coefs.show()


In [38]:
ducks_coefs_plot = plot_coefs(duck_outcomes, "Ducks: proximity coefficients", ann_start_y = 200, ann_break_y = 80, generate_object = True)

ducks_coefs_plot.show()

pio.write_image(ducks_coefs_plot, 'exported_figures/duck_coefs_plot.png', 
               engine = "kaleido", width=600, height=600, scale=5)


[0.12035326659679413, 0.1529311239719391, -0.15514570474624634]


In [1]:


def calculate_R_squared_prox(distance, proximity,
                how_far, guide, subsample_size = 1000) :
    predictive = pyro.infer.Predictive(model_sigmavar_proximity, guide=guide, num_samples=1000)
    
    random_indices = np.random.choice(len(distance), size=subsample_size, replace=False)
    distance_sub = distance[random_indices]
    proximity_sub = proximity[random_indices]
    how_far_sub = how_far[random_indices]
    
    predictions = predictive(distance_sub, proximity_sub, how_far_sub)

    simulated_outcome = ( predictions['b'] + predictions['p'] * proximity +
                      predictions['d'] * distance )

    mean_sim_outcome = simulated_outcome.mean(0).detach().cpu().numpy()

    observed_mean = torch.mean(how_far)

    tss = torch.sum((how_far - observed_mean) ** 2)
    rss = torch.sum((how_far - mean_sim_outcome) ** 2)

    r_squared = 1 - (rss / tss)

    return r_squared.float().item()

    


In [2]:
for key in keys:
    guide = duck_outcomes[key]["svi_guide"]
    print (f"R^2 for ducks with optimal={key}:", calculate_R_squared_prox(distance, proximity, how_far, guide))

NameError: name 'keys' is not defined