In [1]:
import numpy as np
import pandas as pd
import glob
from tqdm import tqdm
import pickle
import cv2
from itertools import repeat
import src.procrustes as procrustes
import src.utils as utils

In [2]:
home_dir = '/data/LiftFly3D/prism/'
data_dir = '/mnt/NAS/SG/prism_data/'
scorer_bottom = 'DLC_resnet50_jointTrackingDec13shuffle1_200000'
scorer_side = 'DLC_resnet50_sideJointTrackingDec17shuffle1_200000'

#joints whose confidence to consider
leg_tips = ['tarsus tip front L', 'tarsus tip mid L', 'tarsus tip back L',
          'tarsus tip front R', 'tarsus tip mid R', 'tarsus tip back R']

crop_positions = ['/bottom_view/videos/crop_loc_191125_PR_Fly1_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_004_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_004_prism.txt']

videos_side = ['side_view/videos/video_191125_PR_Fly1_001_prism',
               'side_view/videos/video_191125_PR_Fly1_002_prism',
               'side_view/videos/video_191125_PR_Fly1_003_prism',
               'side_view/videos/video_191125_PR_Fly1_004_prism',
               'side_view/videos/video_191125_PR_Fly2_001_prism',
               'side_view/videos/video_191125_PR_Fly2_002_prism',
               'side_view/videos/video_191125_PR_Fly2_003_prism',
               'side_view/videos/video_191125_PR_Fly2_004_prism']

videos_bottom =  ['bottom_view/videos/video_191125_PR_Fly1_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_004_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_004_prism']

images_side = ['191125_PR/Fly1/001_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/002_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/003_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/004_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly2/001_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/002_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/003_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/004_prism/behData/images/side_view_prism_data_191125_PR_Fly2/']

images_bottom =  ['191125_PR/Fly1/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly2/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/']

assert len(videos_side)==len(videos_bottom), 'Number of video files must be the same from side and bottom!'

# Select mode

In [3]:
mode= 'prediction' #0: train, 1: prediction, 2: DLC_video, 3: train_low_res

if mode=='train':
    th1 = 0.95 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=1
    align=1
    fine_align=0
    nice_epochs=0
if mode=='prediction':
    th1 = 0.9 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=1
    align=1
    fine_align=1
    nice_epochs=1
if mode=='DLC_video':
    th1 = 0.9 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=0
    align=0
    fine_align=0
    nice_epochs=1
if mode=='train_low_res':
    th1 = 0.95 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=1
    align=1
    fine_align=0
    nice_epochs=0

Process data

In [4]:
index = []
side = pd.DataFrame()
bottom = pd.DataFrame()
flip_idx = None
for i in tqdm(range(len(videos_side))):
    #load data of side and bottom view
    _side = pd.read_hdf(home_dir + videos_side[i] + scorer_side + '.h5')
    _bottom = pd.read_hdf(home_dir + videos_bottom[i] + scorer_bottom + '.h5')
    _side = _side.droplevel('scorer',axis=1) 
    _bottom = _bottom.droplevel('scorer',axis=1) 
    
    #select for high confidence datapoints
    _bottom, _side = utils.select_best_data(_bottom, _side, th1, th2, leg_tips)
    
    #check which way the flies are pointing
    side_R_lk = _side.loc[:,(leg_tips[3:],'likelihood')] #high confidence on R joints means fly points right
    flip_idx = (side_R_lk>th1).sum(1)==3
    
    #flip left and right side due to prism reflection
    _side = utils.flip_LR(_side)
    
    #get epochs (list of lists of consecutive timesteps)
    epochs = utils.get_epochs(_bottom)
    
    #orient all flies left
    if orient: 
        print('orient')
        _bottom, _side = utils.orient_left(_bottom, _side, th1, flip_idx)
        
    #align horizontally
    if align: #1 for training and prediction, 0 for making of DLC video
        print('align')
        path_crop_pos = home_dir + crop_positions[i]
        path_img = data_dir + images_bottom[i]
        _bottom.loc[:,(slice(None),['x','y'])] = \
        _bottom.loc[:,(slice(None),['x','y'])].apply(lambda x: procrustes.rotate_to_horizontal(x, path_crop_pos, path_img), axis=1)
        
    #fine-align flies 
    if fine_align: #0 for training, 1 for prediction
        print('fine align')
        _bottom = procrustes.procrustes_on_epochs(_bottom, epochs)
            
    #take only those frames with at least 10 consecutive frames
    if nice_epochs: #0 for training, 1 for prediction
        print('nice epochs')
        long_epochs = []
        for e in range(len(epochs)-1):
            if (len(epochs[e]) >= 5) | ((len(epochs[e])<5) & ((epochs[e+1][0] - epochs[e][-1]<5))):
                long_epochs += epochs[e]
                long_epochs += epochs[e+1]
            
        long_epochs=np.unique(long_epochs)  
    
        _bottom=_bottom.loc[long_epochs,:]
        _side=_side.loc[long_epochs,:]
        flip_idx = flip_idx[long_epochs]
    
    #save indices for later
    index.append(_bottom.index.values)
    
    #convert & save to DF3D format
    side_np = _side.loc[:,(slice(None),['x','y'])].to_numpy()
    z = _side.loc[:,(slice(None),'y')].to_numpy()
    side_np = np.stack((side_np[:,::2], side_np[:,1::2]), axis=2)

    bottom_np = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()
    bottom_np = np.stack((bottom_np[:,::2], bottom_np[:,1::2]), axis=2)
    
    poses = {'points2d': np.stack((bottom_np, side_np), axis=0),
             'points3d': np.concatenate((bottom_np, z[:,:,None]), axis=2),
             'index':_side.index.values,
             'flip_idx': np.array(flip_idx)
            }
    
    #save
    pickle.dump(poses,open(home_dir + videos_side[i].split('/')[-1][6:] + '.pkl', 'wb'))

  0%|          | 0/8 [00:00<?, ?it/s]

orient
align
fine align


 12%|█▎        | 1/8 [00:43<05:04, 43.53s/it]

nice epochs
orient
align
fine align


 25%|██▌       | 2/8 [02:35<06:24, 64.08s/it]

nice epochs
orient
align
fine align


 38%|███▊      | 3/8 [04:11<06:07, 73.49s/it]

nice epochs
orient
align
fine align


 50%|█████     | 4/8 [04:47<04:09, 62.43s/it]

nice epochs
orient
align
fine align


 62%|██████▎   | 5/8 [05:34<02:53, 57.71s/it]

nice epochs
orient
align


 75%|███████▌  | 6/8 [05:49<01:29, 44.80s/it]

fine align
nice epochs
orient
align
fine align


 88%|████████▊ | 7/8 [06:09<00:37, 37.64s/it]

nice epochs
orient
align


100%|██████████| 8/8 [06:25<00:00, 48.21s/it]

fine align
nice epochs



