In [4]:
import numpy as np
import pandas as pd
import glob
from tqdm import tqdm
import pickle
import src.procrustes as procrustes
import src.utils as utils

In [5]:
home_dir = '/data/LiftFly3D/prism/'
data_dir = '/mnt/NAS/SG/prism_data/'
scorer_bottom = 'DLC_resnet50_jointTrackingDec13shuffle1_200000'
scorer_side = 'DLC_resnet50_sideJointTrackingDec17shuffle1_200000'

#joints 
leg_tips = ['tarsus tip front L', 'tarsus tip mid L', 'tarsus tip back L',
          'tarsus tip front R', 'tarsus tip mid R', 'tarsus tip back R']

legs = [['body-coxa front L', 'coxa-femur front L', 'femur-tibia front L', 'tibia-tarsus front L', 'tarsus tip front L'],
       ['body-coxa mid L', 'coxa-femur mid L', 'femur-tibia mid L', 'tibia-tarsus mid L', 'tarsus tip mid L'],
       ['body-coxa back L', 'coxa-femur back L', 'femur-tibia back L', 'tibia-tarsus back L', 'tarsus tip back L'],
       ['body-coxa front R', 'coxa-femur front R', 'femur-tibia front R', 'tibia-tarsus front R', 'tarsus tip front R'],
       ['body-coxa mid R', 'coxa-femur mid R', 'femur-tibia mid R', 'tibia-tarsus mid R', 'tarsus tip mid R'],
       ['body-coxa back R', 'coxa-femur back R', 'femur-tibia back R', 'tibia-tarsus back R', 'tarsus tip back R']]

#lateral images of enclosure
images_side = ['191125_PR/Fly1/001_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/002_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/003_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/004_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly2/001_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/002_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/003_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/004_prism/behData/images/side_view_prism_data_191125_PR_Fly2/']

#ventral images of enclosure
images_bottom =  ['191125_PR/Fly1/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly2/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/']

#position of crop around moving fly
crop_positions = ['/bottom_view/videos/crop_loc_191125_PR_Fly1_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_004_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_004_prism.txt']

#lateral cropped video of moving fly
videos_side = ['side_view/videos/video_191125_PR_Fly1_001_prism',
               'side_view/videos/video_191125_PR_Fly1_002_prism',
               'side_view/videos/video_191125_PR_Fly1_003_prism',
               'side_view/videos/video_191125_PR_Fly1_004_prism',
               'side_view/videos/video_191125_PR_Fly2_001_prism',
               'side_view/videos/video_191125_PR_Fly2_002_prism',
               'side_view/videos/video_191125_PR_Fly2_003_prism',
               'side_view/videos/video_191125_PR_Fly2_004_prism']

#ventral cropped video of moving fly
videos_bottom =  ['bottom_view/videos/video_191125_PR_Fly1_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_004_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_004_prism']

assert len(videos_side)==len(videos_bottom), 'Number of video files must be the same from side and bottom!'

# Select mode

In [62]:
mode='train' #0: train, 1: prediction, 2: DLC_video, 3: train_low_res

if mode=='train':
    th1 = 0.99 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=1
    align=1
    fine_align=0
    nice_frames=1
    register_floor=1
if mode=='prediction':
    th1 = 0.99 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=1
    align=1
    fine_align=1
    nice_frames=0
    register_floor=1
if mode=='DLC_video':
    th1 = 0.99 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    orient=0
    align=0
    fine_align=0
    nice_frames=0
    register_floor=0

Process data

In [63]:
#frames mislabelled by DLC
bad_frames = [[],
              [663, 668, 676, 1012, 1013, 1014, 1015, 1016, 1017, 1019, 1024, 1294, 2099, 2114, 2149, 2152, 2860, 3506],
              [],
              [5, 306, 871, 945],
              [595],
              [],
              [],
              []]

In [None]:
for i in range(len(videos_side)):
    print(home_dir + videos_side[i])
    
    #load data of side and bottom view
    _side = pd.read_hdf(home_dir + videos_side[i] + scorer_side + '.h5')
    _bottom = pd.read_hdf(home_dir + videos_bottom[i] + scorer_bottom + '.h5')
    _side = _side.droplevel('scorer',axis=1) 
    _bottom = _bottom.droplevel('scorer',axis=1) 
    
    #flip left and right side due to prism reflection
    _side = utils.flip_LR(_side)
    
    #select for high confidence datapoints
    _bottom, _side, flip_idx, good_keypts = utils.select_best_data(_bottom, _side, th1, th2, leg_tips)
    
    #take only those frames where all keypoints on at least one side are correct
    if nice_frames: #1 for training, 0 for prediction
        print('nice frames')
        
        mask = (good_keypts.loc[:,leg_tips[:3]].sum(1)==3) |\
               (good_keypts.loc[:,leg_tips[3:]].sum(1)==3)
        _side = _side[mask].dropna()
        _bottom = _bottom[mask].dropna()
        flip_idx = flip_idx[mask]
        good_keypts = good_keypts.loc[mask,:]
    
    #center and orient all flies left
    if orient: 
        print('orient')
        _bottom = utils.orient_left(_bottom, th1, flip_idx)
        
    #frame indices
    index = _bottom.index.values
    _bottom = _bottom.reset_index()
    _side = _side.reset_index()
    
    if register_floor:
        print('align with x-y plane')
        floor = 0
        for ind in _side.index:
            good_tips = _side.loc[:,(slice(None),'y')].iloc[:,good_keypts.iloc[ind,:].values].loc[ind,(leg_tips,'y')]
            floor_new = np.max(good_tips.to_numpy())
            if ~np.isnan(floor_new):
                floor = floor_new
            _side.loc[ind,(slice(None),'y')] = _side.loc[ind,(slice(None),'y')] - floor
        
    #align horizontally
    if align: #1 for training and prediction, 0 for making of DLC video
        print('align')
        path_crop_pos = home_dir + crop_positions[i]
        path_img = data_dir + images_bottom[i]
        angle, c = procrustes.get_orientation(path_crop_pos, path_img, index, flip_idx)
        _bottom.loc[:,(slice(None),['x','y'])] = \
        _bottom.loc[:,(slice(None),['x','y'])].apply(lambda x: procrustes.center_and_align(x, angle, c), axis=1)
        
    #fine-align flies 
    if fine_align: #0 for training, 1 for prediction
        print('fine align')
        epochs = utils.get_epochs(_bottom) #list of lists of consecutive timesteps
        _bottom = procrustes.procrustes_on_epochs(_bottom, epochs)
    
    #convert & save to DF3D format
    side_np = _side.loc[:,(slice(None),['x','y'])].to_numpy()
    z = _side.loc[:,(slice(None),'y')].to_numpy()
    side_np = np.stack((side_np[:,::2], side_np[:,1::2]), axis=2)

    bottom_np = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()
    bottom_np = np.stack((bottom_np[:,::2], bottom_np[:,1::2]), axis=2)
    points2d = np.stack((bottom_np, side_np), axis=0)
    points3d = np.concatenate((bottom_np, -z[:,:,None]), axis=2)
    good_keypts = np.array(good_keypts)
    
    #remove some bad frames manually
    for b_frame in bad_frames[i]:
        points2d = np.delete(points2d, b_frame, 1)
        points3d = np.delete(points3d, b_frame, 0)
        index = np.delete(index, b_frame, 0)
        flip_idx = np.delete(flip_idx, b_frame, 0)
        good_keypts = np.delete(good_keypts, b_frame, 0)
        
    if np.isnan(z).any():
        print('NaNs found, something went wrong...')
    
    poses = {'points2d': points2d,
             'points3d': points3d,
             'index': index,
             'flip_idx': flip_idx,
             'good_keypts': good_keypts
            }

    pickle.dump(poses,open(home_dir + videos_side[i].split('/')[-1][6:] + '.pkl', 'wb'))

/data/LiftFly3D/prism/side_view/videos/video_191125_PR_Fly1_001_prism
nice frames
orient
align with x-y plane


In [56]:
good_tips = _side.loc[:,(slice(None),'y')].iloc[:,good_keypts.iloc[ind,:].values].loc[ind,(leg_tips,'y')]
floor_new = np.max(good_tips.to_numpy())
if ~np.isnan(floor_new):
    floor = floor_new
_side.loc[ind,(slice(None),'y')] = _side.loc[ind,(slice(None),'y')] - floor

In [61]:
_side.loc[:,(slice(None),'y')].iloc[:,good_keypts.iloc[ind,:].values]

bodyparts,body-coxa front R,coxa-femur front R,femur-tibia front R,tibia-tarsus front R,tarsus tip front R,body-coxa mid R,coxa-femur mid R,femur-tibia mid R,tibia-tarsus mid R,tarsus tip mid R,body-coxa back R,coxa-femur back R,femur-tibia back R,tibia-tarsus back R,tarsus tip back R
coords,y,y,y,y,y,y,y,y,y,y,y,y,y,y,y
0,-56.086594,-30.237946,-59.994995,-44.986404,-46.066422,-47.069611,-30.513123,-74.492386,-19.385223,0.000000,-41.694656,-27.898758,-58.511688,-21.911163,-0.447556
1,-72.250015,-47.263107,-70.373138,-36.740479,-3.489868,-63.600235,-45.678802,-79.088104,-19.233627,-0.014984,-62.720001,-44.673706,-74.174438,-20.994064,0.000000
2,-71.118484,-42.963211,-63.790359,-20.205048,-0.128403,-59.402893,-45.957077,-89.741058,-31.507568,-0.249832,-59.644470,-46.242249,-94.903214,-31.544479,0.000000
3,164.300415,190.998276,170.750275,211.475555,235.183517,175.496674,189.658249,144.028717,204.393143,236.267441,174.909882,189.372192,141.326065,206.468430,236.421265
4,163.993591,190.353317,167.766266,208.490524,235.324753,175.109360,187.188416,142.859360,202.696716,234.713684,175.923798,190.055328,142.966751,205.061234,236.895767
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1571,183.418716,207.984634,184.310440,221.097107,238.362244,197.229218,213.106705,166.755920,217.435410,237.711197,200.546997,211.990997,166.061050,210.678543,239.823715
1572,183.584930,207.933029,184.277130,221.202560,237.988785,197.279984,212.847214,166.581329,216.987808,237.784836,200.456238,211.905960,166.158722,210.933105,239.796432
1573,183.107330,207.641159,184.283539,221.296173,238.276260,196.900955,213.052383,165.196106,217.642700,238.040482,200.099701,212.252441,166.195206,210.846695,239.357941
1574,183.224426,207.694061,183.814392,220.714600,238.283035,196.971115,213.236435,165.116562,217.218277,237.942795,200.246323,212.403030,165.869339,210.461411,239.670502


In [60]:
_side.loc[:10,(leg_tips,'y')]

bodyparts,tarsus tip front L,tarsus tip mid L,tarsus tip back L,tarsus tip front R,tarsus tip mid R,tarsus tip back R
coords,y,y,y,y,y,y
0,-45.11084,-0.008377,-50.605652,-46.066422,0.0,-0.447556
1,-0.099045,-0.402298,-10.801376,-3.489868,-0.014984,0.0
2,3.622955,0.639175,-0.773438,-0.128403,-0.249832,0.0
3,236.647614,236.14566,236.820465,235.183517,236.267441,236.421265
4,237.269058,238.57869,211.24382,235.324753,234.713684,236.895767
5,237.538284,236.465607,210.111542,236.078384,234.483246,237.455688
6,237.881165,234.864548,209.679337,236.013062,235.127686,235.722626
7,237.383377,236.250214,210.220169,236.233688,234.456955,237.093719
8,238.423462,239.654572,154.187683,235.436295,235.099716,236.932541
9,235.61496,236.560425,218.417221,234.724564,235.083801,236.752731


In [38]:
_side.index

RangeIndex(start=0, stop=1576, step=1)