In [1]:
import sys
sys.path.append('../')

from stepselector.data_loader import ZebraDataset, custom_collate
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
import numpy as np
import pandas as pd
import os
from tqdm.notebook import tqdm

In [2]:
# Do you want to load five meter steps or ten meter steps?
step_length = 'five'

# Specify radius (in meters) to define social density (number of conspecifics within radius)
social_radius = 10

# How many fake steps to use per real step? (max 20)
n_ref_steps = 5

# How many context steps should be included?
num_context_steps = 5

# Make list of columns to keep - these will be fetched by the dataloader
columns_to_keep = ['target_id',
                   'observation',
                   'step_speed_mps',
                   'angle_to_observers', 
                   'dist_to_observer', 
                   'delta_observer_dist', 
                   'road', 
                   'ground_class',
                   'ground_slope', 
                   'viewshed_vis', 
                   'social_dens', 
                   'social_vis',
                   'age_class',
                   'species',
                   'individual_ID']

server_mount = '/home/blair/server/herd_hover'

In [3]:
# Define data directories
data_folder = os.path.join(server_mount, 'zebra_movement_data')
observed_steps_directory = os.path.join(data_folder, '%s_meter_steps' %step_length, 'observed')
simulated_steps_directory = os.path.join(data_folder, '%s_meter_steps' %step_length, 'simulated')

In [4]:
# Initialize dataset
dataset = ZebraDataset(target_dir=observed_steps_directory,
                       reference_dir=simulated_steps_directory,
                       social_radius=social_radius,
                       columns_to_keep=columns_to_keep,
                       num_ref_steps=n_ref_steps,
                       num_context_steps = num_context_steps)

In [5]:
# Create DataLoader
dataloader = DataLoader(dataset, batch_size = 10, drop_last = False, collate_fn = custom_collate)

In [6]:
target_test = []
refs_test = []
context_test= []

# Iterate through the DataLoader
for batch in dataloader:
    target, references, context = batch
    target_test.extend(target)
    refs_test.extend(references)
    context_test.extend(context)
