In [3]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn 
import scipy as sp
from torch.utils.data.dataset import TensorDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.autograd import Variable
from utils import *

In [4]:
# load the locations and velocities
locations = np.load('thorax_locations.npy')
locations = np.reshape(locations, (1,63033,2,2))
print(locations.shape)

# load and reshape velocities
velocities = np.load('velocities.npy')
velocities = np.reshape(velocities, (1,63033,2,2))
print(velocities.shape)

(1, 63033, 2, 2)
(1, 63033, 2, 2)


In [5]:
# split into train, valid, test
locations_train = locations[:,:20000,:,:]
locations_valid = locations[:,39000:45000,:,:]
locations_test = locations[:,45000:51000,:,:]

velocities_train = velocities[:,:20000,:,:]
velocities_valid = velocities[:,39000:45000,:,:]
velocities_test = velocities[:,45000:51000,:,:]

In [6]:
locations_train.shape

(1, 20000, 2, 2)

In [7]:
locations_train[0,:,0,0].shape

(20000,)

In [8]:
# prededfine the threshold for interaction -- minimum distance
x = np.sqrt((locations[0,:,0,0]-locations[0,:,0,1])**2 + (locations[0,:,1,0] - locations[0,:,1,1])**2)
print(x.shape)
print(x.mean())
print(x.std())
print(x.min())
print(x.max())


(63033,)
170.30845452464433
78.25640579578305
4.422176746248456
422.4069072773607


In [None]:
def construct_edges(locs, threshold=50, batch_size = 50):
    
    # define edges based on presence of interaction in each batch
    edges_present = np.array([[0,1],[1,0]], dtype=float)
    edges_absent = np.array([[0,0],[0,0]], dtype=float)
    edge_matrices = []

    n_present = 0
    
    for i in range(0,locs.shape[1],batch_size):
        # get the euclidean distance between the 2 mouses locations
        x2 = (locs[0,i:i+batch_size,0,0] - locs[0,i:i+batch_size,0,1])**2
        y2 = (locs[0,i:i+batch_size,1,0] - locs[0,i:i+batch_size,1,1])**2
        dist =  np.sqrt(x2 + y2)
        
        # check against the threshold if there's at least 1 frame where they're 'close'
        count = np.sum(dist<threshold).item()
        
        if count>=1:
            define_edges = edges_present
            n_present +=1  
        else:
            define_edges = edges_absent
        
        edge_matrices.append(define_edges)
        
    edge_matrices = np.stack(edge_matrices)
    print("shape of edge matrices:", edge_matrices.shape)
    print("number of batches with interaction:", n_present)
    
    return edge_matrices

In [None]:
def dynamic_load():
    
    # load the locations and velocities
    locations = np.load('thorax_locations.npy')
    locations = np.reshape(locations, (1,63033,2,2))
    print("locations.shape at start",locations.shape)
    
    velocities = np.load('velocities.npy')
    velocities = np.reshape(velocities, (1,63033,2,2))
    print("velocities.shape at start", velocities.shape)
    
    # split into train, valid, test
    locations_train = locations[:,:20000,:,:]
    locations_valid = locations[:,39000:45000,:,:]
    locations_test = locations[:,45000:51000,:,:]
    
    velocities_train = velocities[:,:20000,:,:]
    velocities_valid = velocities[:,39000:45000,:,:]
    velocities_test = velocities[:,45000:51000,:,:]
    
    # construct the edge matrices
    formatted_edges_train = construct_edges(locations_train)
    formatted_edges_valid = construct_edges(locations_valid)
    formatted_edges_test = construct_edges(locations_test)

    # reformatting to allow for batch sizes bigger than 1
    formatted_locs_train = batch_ready(locations_train)
    formatted_locs_valid = batch_ready(locations_valid)
    formatted_locs_test = batch_ready(locations_test)

    formatted_vel_train = batch_ready(velocities_train)
    formatted_vel_valid = batch_ready(velocities_valid)
    formatted_vel_test = batch_ready(velocities_test)
    
    print("formatted edges_train.shape",formatted_edges_train.shape)
    print("formatted edges_valid.shape",formatted_edges_valid.shape)
    print("formatted edges_test.shape",formatted_edges_test.shape)
    
    # combined into list for easy output
    locs_list = [formatted_locs_train,formatted_locs_valid,formatted_locs_test]
    vel_list = [formatted_vel_train,formatted_vel_valid,formatted_vel_test]
    edges_list = [formatted_edges_train,formatted_edges_valid,formatted_edges_test]
    
    return locs_list, vel_list, edges_list


In [15]:
edge_train_ = construct_edges(locations_valid)

shape of edge matrices: (120, 2, 2)
number of batches with interaction: 11


In [11]:
locations_train.shape[1]/50

400.0

In [13]:
quote = 5

if quote ==4:
    print('yes')