In [1]:
import numpy as np
import pandas as pd
import glob
from tqdm import tqdm
import pickle
import src.procrustes as procrustes
import src.utils as utils

In [2]:
home_dir = '/data/LiftFly3D/prism/'
data_dir = '/mnt/NAS/SG/prism_data/'
scorer_bottom = 'DLC_resnet50_jointTrackingDec13shuffle1_200000'
scorer_side = 'DLC_resnet50_sideJointTrackingDec17shuffle1_200000'

#joints whose confidence to consider
leg_tips = ['tarsus tip front L', 'tarsus tip mid L', 'tarsus tip back L',
          'tarsus tip front R', 'tarsus tip mid R', 'tarsus tip back R']

#lateral images of enclosure
images_side = ['191125_PR/Fly1/001_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/002_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/003_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/004_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly2/001_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/002_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/003_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/004_prism/behData/images/side_view_prism_data_191125_PR_Fly2/']

#ventral images of enclosure
images_bottom =  ['191125_PR/Fly1/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly2/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/']

#position of crop around moving fly
crop_positions = ['/bottom_view/videos/crop_loc_191125_PR_Fly1_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_004_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_004_prism.txt']

#lateral cropped video of moving fly
videos_side = ['side_view/videos/video_191125_PR_Fly1_001_prism',
               'side_view/videos/video_191125_PR_Fly1_002_prism',
               'side_view/videos/video_191125_PR_Fly1_003_prism',
               'side_view/videos/video_191125_PR_Fly1_004_prism',
               'side_view/videos/video_191125_PR_Fly2_001_prism',
               'side_view/videos/video_191125_PR_Fly2_002_prism',
               'side_view/videos/video_191125_PR_Fly2_003_prism',
               'side_view/videos/video_191125_PR_Fly2_004_prism']

#ventral cropped video of moving fly
videos_bottom =  ['bottom_view/videos/video_191125_PR_Fly1_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_004_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_004_prism']

assert len(videos_side)==len(videos_bottom), 'Number of video files must be the same from side and bottom!'

# Select mode

In [5]:
mode='prediction' #0: train, 1: prediction, 2: DLC_video, 3: train_low_res

if mode=='train':
    th1 = 0.95 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=1
    align=1
    fine_align=0
    nice_frames=1
if mode=='prediction':
    th1 = 0.99 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=1
    align=1
    fine_align=1
    nice_frames=0
if mode=='DLC_video':
    th1 = 0.99 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=0
    align=0
    fine_align=0
    nice_frames=0
if mode=='train_low_res':
    th1 = 0.95 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=1
    align=1
    fine_align=0
    nice_frames=1

Process data

In [6]:
for i in range(len(videos_side)):
    print(home_dir + videos_side[i])
    
    #load data of side and bottom view
    _side = pd.read_hdf(home_dir + videos_side[i] + scorer_side + '.h5')
    _bottom = pd.read_hdf(home_dir + videos_bottom[i] + scorer_bottom + '.h5')
    _side = _side.droplevel('scorer',axis=1) 
    _bottom = _bottom.droplevel('scorer',axis=1) 
    
    #flip left and right side due to prism reflection
    _side = utils.flip_LR(_side)
    
    #select for high confidence datapoints
    _bottom, _side, flip_idx, good_keypts = utils.select_best_data(_bottom, _side, th1, th2, leg_tips)
    
    #take only those frames where all keypoints on at least one side are correct
    if nice_frames: #1 for training, 0 for prediction
        print('nice frames')
        
        mask = (good_keypts.loc[:,leg_tips[:3]].sum(1)==3) |\
               (good_keypts.loc[:,leg_tips[3:]].sum(1)==3)
        _side = _side[mask].dropna()
        _bottom = _bottom[mask].dropna()
        flip_idx = flip_idx[mask]
        good_keypts = good_keypts.loc[mask,:]
    
    #orient all flies left
    if orient: 
        print('orient')
        _bottom = utils.orient_left(_bottom, th1, flip_idx)
        
    #frame indices
    index = _bottom.index.values
    _bottom = _bottom.reset_index()
    _side = _side.reset_index()
        
    #align horizontally
    if align: #1 for training and prediction, 0 for making of DLC video
        print('align')
        path_crop_pos = home_dir + crop_positions[i]
        path_img = data_dir + images_bottom[i]
        angle, c = procrustes.get_orientation(path_crop_pos, path_img, index, flip_idx)
        _bottom.loc[:,(slice(None),['x','y'])] = \
        _bottom.loc[:,(slice(None),['x','y'])].apply(lambda x: procrustes.center_and_align(x, angle, c), axis=1)
        
    #fine-align flies 
    if fine_align: #0 for training, 1 for prediction
        print('fine align')
        epochs = utils.get_epochs(_bottom) #list of lists of consecutive timesteps
        _bottom = procrustes.procrustes_on_epochs(_bottom, epochs)
    
    #convert & save to DF3D format
    side_np = _side.loc[:,(slice(None),['x','y'])].to_numpy()
    z = _side.loc[:,(slice(None),'y')].to_numpy()
    side_np = np.stack((side_np[:,::2], side_np[:,1::2]), axis=2)

    bottom_np = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()
    bottom_np = np.stack((bottom_np[:,::2], bottom_np[:,1::2]), axis=2)
    
    poses = {'points2d': np.stack((bottom_np, side_np), axis=0),
             'points3d': np.concatenate((bottom_np, z[:,:,None]), axis=2),
             'index': index,
             'flip_idx': flip_idx,
             'good_keypts': good_keypts
            }

    pickle.dump(poses,open(home_dir + videos_side[i].split('/')[-1][6:] + '.pkl', 'wb'))

/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly1_001_prism
orient
align
fine align
/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly1_002_prism
orient
align
fine align
/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly1_003_prism
orient
align
fine align
/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly1_004_prism
orient
align
fine align
/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly2_001_prism
orient
align
fine align
/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly2_002_prism
orient
align
fine align
/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly2_003_prism
orient
align
fine align
/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly2_004_prism
orient
align
fine align
