In [1]:
import sys
sys.path.append('../')

import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from stepselector.viewshed import generate_viewshed, generate_downsample_viewshed
from stepselector.data_loader import ZebraDataset, ZebraBatchSampler, custom_collate
import os
import glob
import pandas as pd
from osgeo import gdal
import numpy as np
gdal.UseExceptions()

In [2]:
# Change server_mount for your system
server_mount = '/home/blair/server/herd_hover'

# Specify radius of viewshed (in meters)
viewshed_radius = 100
# Specify height/width of downsampled viewshed (e.g. 512 will return an array of 512x512 pixels
viewshed_hw = 512
# Specify radius (in meters) to define social density (number of conspecifics within radius)
social_radius = 10

# How many reference steps do you want per target step? (max = 20)
n_ref_steps = 5


# Make list of columns to keep - these will be fetched by the dataloader
columns_to_keep = ['angle_to_observers', 'dist_to_observer', 'delta_observer_dist', 'road', 'ground_slope', 'visibility', 'social_dens', 'social_vis']

In [3]:
# Define data directories
data_folder = os.path.join(server_mount, 'zebra_movement_data')
observed_steps_directory = os.path.join(data_folder, 'five_meter_steps', 'observed')
simulated_steps_directory = os.path.join(data_folder, 'five_meter_steps', 'simulated')
rasters_directory = os.path.join(data_folder, 'rasters')
ob_metadata_file = os.path.join(data_folder, 'observation_metadata.csv')
track_metadata_file = os.path.join(data_folder, 'track_metadata.csv')

In [4]:
dataset = ZebraDataset(target_dir = observed_steps_directory,
                       reference_dir = simulated_steps_directory,
                       rasters_dir = rasters_directory,
                       ob_metadata_file = ob_metadata_file,
                       viewshed_radius = viewshed_radius,
                       viewshed_hw = viewshed_hw,
                       social_radius = social_radius,
                       threads = 4,
                       n_ref_steps = n_ref_steps,
                       columns_to_keep = columns_to_keep)

In [5]:
batch_sampler = ZebraBatchSampler(dataset)

In [6]:
dataloader = DataLoader(dataset, batch_sampler = batch_sampler, collate_fn = custom_collate)

In [None]:
for batch in dataloader:
    target, references = batch
    print(f"Target: {target}")
    print(f"References: {references}")
    