In [1]:
import sys

sys.path.insert(0, "..")

import random
import numpy as np
from sklearn.preprocessing import LabelEncoder

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import random
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import torch
import pyro
import foraging_toolkit as ft
import torch.nn.functional as F
import pyro.distributions as dist
import pyro.optim as optim
from pyro.nn import PyroModule
from pyro.infer.autoguide import (
    AutoNormal,
    AutoDiagonalNormal,
    AutoMultivariateNormal,
    init_to_mean,
    init_to_value,
)
from pyro.contrib.autoguide import AutoLaplaceApproximation
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer import Predictive
from pyro.infer import MCMC, NUTS

import os
import logging

logging.basicConfig(format="%(message)s", level=logging.INFO)
smoke_test = "CI" in os.environ


# import foraging_toolkit as ft


In [2]:
with open("communicatorsDFs.pkl", "rb") as f:
    communicatorsDFs = pickle.load(f)
birdsDF = communicatorsDFs[0]
rewardsDF = communicatorsDFs[1]
birdsDF = birdsDF[birdsDF["bird"] >= 3]
indices_to_rename = birdsDF[birdsDF["bird"].isin([3, 4])].index
new_values = birdsDF.loc[indices_to_rename, "bird"].replace({3: 1, 4: 2})
birdsDF.loc[indices_to_rename, "bird"] = new_values
communicators = ft.object_from_data(birdsDF, rewardsDF)

rewardsDF.head()


Unnamed: 0,x,y,time
0,0,24,1
1,1,24,1
2,2,24,1
3,3,24,1
4,0,25,1


In [3]:
locust = pd.read_csv("locust.csv")

locust.drop("cnt", axis=1, inplace=True)
locust.rename(
    columns={"pos_x": "x", "pos_y": "y", "id": "bird", "frame": "time"}, inplace=True
)

locust = locust[["x", "y", "time", "bird"]]


encoder = LabelEncoder()

locust['bird'] = encoder.fit_transform(locust['bird'])

locust['bird'] = locust['bird'] + 1

locust.tail()
print(locust.shape)


(1350000, 4)


In [4]:
#don't run twice!
locust['x'] = np.round(locust['x']/ (2000/200)).astype(int)
locust['y'] = np.round(locust['y']/ (2000/200)).astype(int)
locust['time'] = np.round(locust['time']/ (90000/1800)).astype(int)

locust = locust.drop_duplicates(subset=['time', 'bird'], keep='first')


In [7]:

print(locust.tail())
locust_initial = locust[locust['time'] <= 180]
print(locust_initial.tail())

           x    y  time  bird
1349620   32   42  1800    11
1349621  186  143  1800    12
1349622  197   86  1800    13
1349623   15   85  1800    14
1349624  153   42  1800    15
          x    y  time  bird
134620  149  147   180    11
134621   51   59   180    12
134622  154  135   180    13
134623   30  138   180    14
134624   66  167   180    15


In [8]:
def object_from_data(birdsDF, rewardsDF, trim = False):
    grid_max = max(max(birdsDF["x"]), max(birdsDF["y"]))
    maxes = [max(birdsDF["time"]), max(rewardsDF["time"])]
    if trim:
        limit = min(maxes)
    else:
        limit = max(maxes)
    birdsDF = birdsDF[birdsDF["time"] <= limit]
    rewardsDF = rewardsDF[rewardsDF["time"] <= limit]

    class EmptyObject:
        pass

    sim = EmptyObject()

    sim.grid_size = int(grid_max)
    sim.num_frames = int(limit)
    sim.birdsDF = birdsDF
    sim.rewardsDF = rewardsDF
    sim.birds = [group for _, group in birdsDF.groupby("bird")]
    sim.rewards = [group for _, group in rewardsDF.groupby("time")]
    sim.num_birds = len(sim.birds)

    step_maxes = []


    for b in range(len(sim.birds)):
        step_maxes.append(
            max(
                max(
                    [
                        abs(sim.birds[b]["x"].iloc[t + 1] - sim.birds[b]["x"].iloc[t])
                        for t in range(sim.num_frames - 1)
                    ]
                ),
                max(
                    [
                        abs(sim.birds[b]["y"].iloc[t + 1] - sim.birds[b]["y"].iloc[t])
                        for t in range(sim.num_frames - 1)
                    ]
                ),
            )
        )

    sim.step_size_max = max(step_maxes)

    return sim

loc = object_from_data(locust_initial, rewardsDF)


In [9]:
loc.birds[0].tail()

Unnamed: 0,x,y,time,bird
131610,92,154,176,1
132375,93,154,177,1
133110,92,157,178,1
133875,91,159,179,1
134610,90,160,180,1


In [10]:
ft.animate_birds(loc, plot_rewards=False, width=600, height=600, point_size=10)
