## SUMMARY: Will correct the h5 prediction files for all of the training cooperation videos for a specified color pair

# get color vids

In [1]:
import os
import random
import h5py
import numpy as np
import time

from utils import find_node_velocity, get_stats, fill_missing, graph_vels, nan_vals

In [2]:
defaultdir = '/gpfs/radev/pi/saxena/aj764'
rootdir = f'{defaultdir}/Training_COOPERATION/'

Compiles a list of all of the sessions in PairedTestingSessions that have a Videos folder in them

In [7]:
vid_subdirs = []
for subdir, dirs, files in os.walk(rootdir):
    vid_subdirs.append(subdir)
vid_subdirs = sorted(vid_subdirs)

Seperates all of the videos into single instance videos and multi instance videos (and takes out videos from before April).

In [8]:
multi_vids = {}
for vids in vid_subdirs:
    files = os.listdir(vids)
    cut_vids = vids[28:]
    multi_vids[cut_vids] = []
    for file in files:
        if file.endswith('.mp4'): # and int(file[:2]) >= 4:
            multi_vids[cut_vids].append(file)
            

In [9]:
multi_len_tot = 0
for key, value in multi_vids.items():
    multi_len_tot += len(value)
print(f'There are {multi_len_tot} multi instance videos')

There are 212 multi instance videos


In [10]:
# split the multi-instance videos into their respective color pairs...
color_vids = {}
for key, value in multi_vids.items():
    for vid in value:
        parsed = vid.split('-')
        
        trial_color = [parsed[0][-1], parsed[1][5]]
        trial_key = ''
        if 'R' in trial_color:
            trial_key += 'R'
        if 'G' in trial_color:
            trial_key += 'G'
        if 'Y' in trial_color:
            trial_key += 'Y'
        if 'B' in trial_color:
            trial_key += 'B'
        if trial_key not in color_vids.keys():
            color_vids[trial_key] = []
        color_vids[trial_key].append(vid)

In [11]:
len_tot = 0
for key, value in color_vids.items():
    print(f'There are {len(value)} videos from {key} color pair')
    len_tot += len(value)
print('\n')
print(f'There are {len_tot} multi instance videos')

There are 69 videos from YB color pair
There are 18 videos from GB color pair
There are 28 videos from RG color pair
There are 19 videos from GY color pair
There are 32 videos from RB color pair
There are 46 videos from RY color pair


There are 212 multi instance videos


# fills in missing vals

In [13]:
CHECK = False
ACTUALLY_FILL = False
color_pair = 'YB'

In [15]:
start_time = time.time()

total_intial_nan = 0
total_after_out_nan = 0
total_final_nan = 0
bad_vids = []

for i, session in enumerate(multi_vids.keys()): 
    video_list = multi_vids[session]
    analysis_path = defaultdir + '/' + session + '/Tracking/h5/'
    
    for video in video_list:
        if video in color_vids[color_pair]:
            # open analysis file
            analysis_file = analysis_path + video[:-3] + 'predictions.h5'
            with h5py.File(analysis_file,'r+') as f:
                locations = f["tracks"][:].T 
    
                # find nan values
                intial = nan_vals(locations)

                if intial > 20:
                    bad_vids.append(video)
    
                # just to check you haven't already done this vid or it isn't empty
                if intial != 0:
                    # take out positional outliers
                    for rat in range(locations.shape[-1]): # for each rat (not actually necessary, the dims work out without this loop but I don't feel like thinking abt that)
                        all_vels = {}
                        for node in range(locations.shape[1]): # for each node
                            # find the velocities
                            all_vels[node] = find_node_velocity(locations[:, node, :,  rat:rat+1])
                        
                            # get values need to find outliers
                            mean, std, low, high = get_stats(all_vels[node])
                        
                            # if you want to check that these values looks good
                            graph_vels(all_vels[node], CHECK)
                        
                            # replace outliers in locations with nan
                            nan_index = [i for i in range(len(all_vels[node])) if (all_vels[node][i] > high or all_vels[node][i] < low)]
                            for index in nan_index:
                                locations[index + 1, node, 0, rat], locations[index + 1, node, 0, rat] = np.nan, np.nan
                        
                            # if you want to check that new locations look good
                            test_vels = find_node_velocity(locations[:, node, :])
                            graph_vels(test_vels, check=CHECK, old_low=low, old_high=high)
        
                    # find nan values again
                    after_out = nan_vals(locations)
        
                    # fill in missing locations
                    print(f'video name: {video}')
                    new_locations = fill_missing(locations)
                    if ACTUALLY_FILL:
                        f["tracks"][:] = new_locations.T
    
                    # finds nan values for final time
                    after_fill = nan_vals(new_locations)
                    
                    total_intial_nan += intial
                    total_after_out_nan += after_out
                    total_final_nan += after_fill
        
                    # if you want to check the nan/fill values for a each video
                    if True:
                        # print(f'video name: {video}')
                        print(f'intial nan: {round(intial, 2)} %, after out nan: {round(after_out, 2)} %, final nan: {round(after_fill, 2)} %')
        
print('totals:')
print(f'intial nan: {round(total_intial_nan / len(color_vids[color_pair]), 2)} %, after out nan: {round(total_after_out_nan / len(color_vids[color_pair]), 2)} %, final nan: {round(total_final_nan / len(color_vids[color_pair]), 2)} %')
print(f'time elapse: {time.time() - start_time}')

video name: 032824_COOPTRAIN_LARGEARENA_KL002B-KL002Y_Camera2.mp4
intial nan: 60.7 %, after out nan: 60.91 %, final nan: 0.0 %
video name: 032824_COOPTRAIN_LARGEARENA_KL001B-KL001Y_Camera1.mp4
intial nan: 87.74 %, after out nan: 88.28 %, final nan: 0.0 %
video name: 032924_COOPTRAIN_LARGEARENA_KL001B-KL001Y_Camera1.mp4
intial nan: 10.45 %, after out nan: 11.16 %, final nan: 0.0 %
video name: 032924_COOPTRAIN_LARGEARENA_KL005B-KL005Y_Camera1.mp4
intial nan: 8.37 %, after out nan: 8.72 %, final nan: 0.0 %
video name: 032924_COOPTRAIN_LARGEARENA_KL002B-KL002Y_Camera2.mp4
intial nan: 10.93 %, after out nan: 11.54 %, final nan: 0.0 %
video name: 033024_COOPTRAIN_LARGEARENA_KL005B-KL005Y_Camera1.mp4
intial nan: 6.83 %, after out nan: 7.24 %, final nan: 0.0 %
video name: 033024_COOPTRAIN_LARGEARENA_KL001B-KL001Y_Camera1.mp4
intial nan: 12.67 %, after out nan: 13.3 %, final nan: 0.0 %
video name: 033024_COOPTRAIN_LARGEARENA_KL002B-KL002Y_Camera2.mp4
intial nan: 9.25 %, after out nan: 9.81 %, f

# check our work... (abt 28 vids that I WOULDN'T trust!!)

In [16]:
print(f'intial nan: {round(total_intial_nan / len(color_vids[color_pair]), 2)} %, after out nan: {round(total_after_out_nan / len(color_vids[color_pair]), 2)} %, final nan: {round(total_final_nan / len(color_vids[color_pair]), 2)} %')


intial nan: 23.99 %, after out nan: 24.46 %, final nan: 0.0 %


In [17]:
print(f'percent of videos intitially have over 1/5 of values nan: {round(100 * len(bad_vids) / len(color_vids[color_pair]) ,2)}% ')

percent of videos intitially have over 1/5 of values nan: 40.58% 


In [18]:
# the videos in question :(
bad_vids

['032824_COOPTRAIN_LARGEARENA_KL002B-KL002Y_Camera2.mp4',
 '032824_COOPTRAIN_LARGEARENA_KL001B-KL001Y_Camera1.mp4',
 '040124_COOPTRAIN_LARGEARENA_KL001B-KL001Y_Camera1.mp4',
 '040324_COOPTRAIN_LARGEARENA_KL001B-KL001Y_Camera1.mp4',
 '040824_COOPTRAIN_LARGEARENA_KL005B-KL005Y_Camera1.mp4',
 '041024_COOPTRAIN_LARGEARENA_EB031B-EB033Y_Camera3.mp4',
 '041224_COOPTRAIN_LARGEARENA_EB009B-EB019Y_Camera4.mp4',
 '041324_COOPTRAIN_LARGEARENA_EB031B-EB033Y_Camera4.mp4',
 '041324_COOPTRAIN_LARGEARENA_EB009B-EB019Y_Camera3.mp4',
 '041424_COOPTRAIN_LARGEARENA_EB009B-EB019Y_Camera2.mp4',
 '041524_COOPTRAIN_LARGEARENA_EB009B-EB019Y_Camera2.mp4',
 '041624_COOPTRAIN_LARGEARENA_EB009B-EB019Y_Camera2.mp4',
 '042424_COOPTRAIN_LARGEARENA_EB003B-EB019Y_Camera2.mp4',
 '042524_COOPTRAIN_LARGEARENA_EB003B-EB019Y_Camera2.mp4',
 '061224_COOPTRAIN_LARGEARENA_HF003B-HF004Y_Camera2.mp4',
 '061824_COOPTRAIN_LARGEARENA_HF003B-HF004Y_Camera2.mp4',
 '062024_COOPTRAIN_LARGEARENA_HF003B-HF004Y_Camera2.mp4',
 '062424_COOPT