In [1]:
# importing packages. See https://github.com/BasisResearch/collab-creatures for repo setup
import logging
import os
import time
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

from collab.utils import find_repo_root

root = find_repo_root()
from collab.foraging import random_hungry_followers as rhf
from collab.foraging import toolkit as ft
from collab.foraging.toolkit import (
    add_velocities_to_data_object,
    construct_visibility,
    filter_by_visibility,
    generate_grid,
    generate_velocity_scores,
)

logging.basicConfig(format="%(message)s", level=logging.INFO)

# users can ignore smoke_test -- it's for automatic testing on GitHub, to make sure the notebook runs on future updates to the repository
smoke_test = "CI" in os.environ
num_frames = 5 if smoke_test else 50
num_svi_iters = 10 if smoke_test else 1000
num_samples = 10 if smoke_test else 1000


notebook_starts = time.time()

# from .trace import rewards_trace
# from .utils import generate_grid

In [2]:
random_foragers_sim = rhf.RandomForagers(
    grid_size=40,
    probabilities=[1, 2, 3, 2, 1, 2, 3, 2, 1],
    num_foragers=3,
    num_frames=num_frames,
    num_rewards=15,
    grab_range=3,
)

random_foragers_sim()

sim = random_foragers_sim



In [10]:
random_foragers_derived = ft.derive_predictors(
    random_foragers_sim,
    dropna=False,
    generate_velocity_indicator=True,
    visibility_range=20,
    velocity_time_decay=1,
)

2024-05-28 11:38:51,226 - traces done
2024-05-28 11:38:51,477 - visibility done
2024-05-28 11:38:51,742 - proximity done
2024-05-28 11:38:51,838 - how_far done
2024-05-28 11:38:51,912 - derivedDF done
2024-05-28 11:38:51,912 - starting to generate velocity
2024-05-28 11:38:52,425 - velocity done


In [11]:
ft.animate_foragers(
    random_foragers_derived,
    plot_rewards=True,
    width=600,
    height=600,
    point_size=10,
    plot_visibility=0,
    plot_traces=False,
    plot_velocity=1,
)

In [5]:
sim.velocity_scoresDF.head()

Unnamed: 0,x,y,forager,time,velocity_score,velocity_score_standardized
526,14,7,1,1,0.111675,-0.60356
354,9,35,1,1,0.138411,-0.513324
168,5,9,1,1,0.059519,-0.779595
135,4,16,1,1,0.090792,-0.674043
937,24,18,1,1,0.803107,1.730102


In [8]:
ft.animate_foragers(
    random_foragers_derived,
    plot_rewards=True,
    width=600,
    height=600,
    point_size=10,
    plot_visibility=2,
    plot_traces=False,
    plot_velocity=0,
)

In [None]:
grid = generate_grid(sim.grid_size)

grid = grid.sample(frac=1, random_state=42)

sim.grid = grid

vis = construct_visibility(
    sim.foragers,
    sim.grid_size,
    visibility_range=10,
    time_shift=0,
    grid=grid,
)
sim.visibility_range = 10
sim.visibility = vis["visibility"]
sim.visibilityDF = vis["visibilityDF"]

add_velocities_to_data_object(sim)

In [3]:
# for b in range(1, sim.num_foragers + 1):

time_shift = 0
finders_tolerance = 2
info_time_decay = 3

filter_by_visibility(
    sim,
    subject=1,
    time_shift=0,
    visibility_restriction="visible",
    info_time_decay=1,
    finders_tolerance=1,
    filter_by_on_reward=False,
).head()

Unnamed: 0,x,y,time,forager,type,velocity_x,velocity_y,distance,out_of_range
0,21.0,19.0,1,2,random,0.0,0.0,4.472136,False
0,18.0,22.0,1,3,random,0.0,0.0,1.414214,False
1,20.0,22.0,2,2,random,-1.0,3.0,3.162278,False
1,20.0,20.0,2,3,random,2.0,-2.0,5.09902,False
2,23.0,24.0,3,2,random,3.0,2.0,5.0,False


In [4]:
vs = generate_velocity_scores(sim)
sim.velocity_scores = vs["velocity_scores"]
sim.velocity_scoresDF = vs["velocity_scoresDF"]