In [1]:
import os
import time

import dill
import matplotlib.pyplot as plt
import pyro
import seaborn as sns
import torch

import pandas as pd
import pyro.distributions as dist
from chirho.dynamical.handlers import LogTrajectory
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import simulate
from pyro.infer import Predictive
from chirho.observational.handlers import condition
from chirho.dynamical.handlers import LogTrajectory, StaticBatchObservation
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import Dynamics, State, simulate

pyro.settings.set(module_local_params=True)

sns.set_style("white")

# Set seed for reproducibility
seed = 123
pyro.clear_param_store()
pyro.set_rng_seed(seed)

import matplotlib.pyplot as plt
import seaborn as sns
import torch

from collab.foraging import locust as lc
from collab.foraging import toolkit as ft
from collab.utils import find_repo_root, progress_saver


smoke_test = "CI" in os.environ
num_iterations = 50 if smoke_test else 100
num_samples = 20 if smoke_test else 100

In [None]:
data_code = "15EQ20191202"
validation_data_code = "15EQ20191205"

starts = []
ends = []
null_mses = []
model_mses = []
rsquareds = []




for start in [0, 20, 40, 60, 80, 100 ]: 
    for end in [start + end for end in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 20, 30, 40, 50, 60, 70, 80 ]]:
        starts.append(start)
        ends.append(end)
        print(start, end)

        locds = lc.LocustDS(data_code = data_code, 
                    start = start,
                    end = end,)
        locds.validate(validation_data_code = validation_data_code)

    

        null_mses.append(locds.validation['null_mse'].numpy())   
        model_mses.append(locds.validation['mse_mean'])
        rsquareds.append(locds.validation['rsquared'])
        


In [None]:
v_results = pd.DataFrame(
    {"start": [_ * 10 for _ in starts],
    "end":  [_ *10 for _ in ends],
    "null_mse":  null_mses,
    "model_mse": model_mses,
    "rsquared": rsquareds}
)

v_results['duration'] = [_ * 10 for _ in (v_results['end'] - v_results['start'])]

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1) 
scatter = plt.scatter(v_results['start'] , v_results['rsquared'], marker='o', c=v_results['duration'])
plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)


cbar = plt.colorbar(scatter, label='duration')
plt.xlabel('start time (seconds)')
plt.ylabel('$R^2$')
plt.title('$R^2$ vs start time (validation)')
sns.despine()
plt.legend()


plt.subplot(1, 2, 2)  
scatter2 = plt.scatter(v_results['duration'], v_results['rsquared'], marker='o', c=results['start'])
plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)

cbar = plt.colorbar(scatter2, label='start')
plt.xlabel('duration (seconds)')
plt.ylabel('$R^2$')
plt.title('$R^2$ vs duration (validation)')
sns.despine()
plt.legend()

plt.tight_layout()  
plt.show()


root = find_repo_root()
results_path =  os.path.join(
            root, f"data/foraging/locust/ds/length_experiment_v_results_{data_code}_v{validation_data_code}.pkl"
        )
with open(results_path, "wb") as f:
    dill.dump(v_results, f)
