In [1]:
import numpy as np
import pandas as pd
import glob
from tqdm import tqdm
import pickle
import cv2
import sys
from itertools import groupby, repeat
from operator import itemgetter
sys.path.append('./src')
import procrustes as proc

In [2]:
home_dir = '/data/LiftFly3D/prism/'
data_dir = '/mnt/NAS/SG/prism_data/'
scorer_bottom = 'DLC_resnet50_jointTrackingDec13shuffle1_200000'
scorer_side = 'DLC_resnet50_sideJointTrackingDec17shuffle1_200000'

#joints whose confidence to consider
leg_tips = ['tarsus tip front L', 'tarsus tip mid L', 'tarsus tip back L',
          'tarsus tip front R', 'tarsus tip mid R', 'tarsus tip back R']

crop_positions = ['/bottom_view/videos/crop_loc_191125_PR_Fly1_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_004_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_004_prism.txt']

videos_side = ['side_view/videos/video_191125_PR_Fly1_001_prism',
               'side_view/videos/video_191125_PR_Fly1_002_prism',
               'side_view/videos/video_191125_PR_Fly1_003_prism',
               'side_view/videos/video_191125_PR_Fly1_004_prism',
               'side_view/videos/video_191125_PR_Fly2_001_prism',
               'side_view/videos/video_191125_PR_Fly2_002_prism',
               'side_view/videos/video_191125_PR_Fly2_003_prism',
               'side_view/videos/video_191125_PR_Fly2_004_prism']

videos_bottom =  ['bottom_view/videos/video_191125_PR_Fly1_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_004_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_004_prism']

images_side = ['191125_PR/Fly1/001_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/002_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/003_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/004_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly2/001_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/002_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/003_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/004_prism/behData/images/side_view_prism_data_191125_PR_Fly2/']

images_bottom =  ['191125_PR/Fly1/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly2/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/']

assert len(videos_side)==len(videos_bottom), 'Number of video files must be the same from side and bottom!'

In [3]:
def read_crop_pos(file):
    f=open(file, "r")
    contents =f.readlines()
    im_file = []
    x_pos = []
    for i in range(4,len(contents)):
        line = contents[i][:-1].split(' ')
        im_file.append(line[0])
        x_pos.append(line[1])
        
    return im_file, x_pos


def flip_LR(data):
    cols = list(data.columns)
    half = int(len(cols)/2)
    tmp = data.loc[:,cols[:half]].values
    data.loc[:,cols[:half]] = data.loc[:,cols[half:]].values
    data.loc[:,cols[half:]] = tmp
    
    return data


def rotate_to_horizontal(pts2d, exp):
    
    #access corresponding image file
    idx = pts2d.name 
    im_file, _ = read_crop_pos(home_dir + crop_positions[exp])
    im_crop_bottom = cv2.imread(data_dir + images_bottom[exp] + im_file[idx],0)
    
    pts2d = proc.center_and_align(pts2d, im_crop_bottom)
    
    return pts2d


def orient_left(bottom, side):
    #rotate flies pointing right
    side_R_lk = side.loc[:,(leg_tips[3:],'likelihood')] #high confidence on R joints means fly points right
    flip_idx = (side_R_lk>th1).sum(1)==3
        
    theta = np.radians(180)
    cos, sin = np.cos(theta), np.sin(theta)
    R = np.array(((cos, -sin), (sin, cos)))     
    
    if np.sum(flip_idx) != 0:
        tmp = bottom.loc[flip_idx,(slice(None),['x','y'])].to_numpy()
        tmp = np.reshape(tmp, [-1, 2])
        mu = tmp.mean(axis=0)
        tmp = np.matmul(tmp-mu,R)# + mu
        tmp = np.reshape( tmp, [-1, 60] )
        bottom.loc[flip_idx,(slice(None),['x','y'])] = tmp

    return bottom, side, flip_idx
 

def procrustes_on_epochs(data, epochs):
    xy = data.loc[:,(slice(None),['x','y'])]
    for e in epochs:
        for step in range(1,len(e)):
            X = xy.loc[e[step]-1,:]
            Xtransf = xy.loc[e[step],:]
                
            X = X.to_numpy().reshape([-1, 2])
            Xtransf = Xtransf.to_numpy().reshape([-1, 2])
                
            _, _, T, _, c = proc.compute_similarity_transform(Xtransf, X)
            Xtransf = (T@Xtransf.T).T

            xy.loc[e[step],:] = Xtransf.reshape(-1, 60).flatten()
                
    data.loc[:,(slice(None),['x','y'])] = bottom_xy
    return data
    
        
def get_epochs(data):
    data_idx = list(data.index)
    epochs = []
    for k, g in groupby(enumerate(data_idx), lambda ix : ix[0] - ix[1]):
        epochs.append(list(map(itemgetter(1), g)))
        
    return epochs


def select_best_data(bottom, side, th1, th2):
    
    #split L and R (remove if we include flipping)
    side_L_lk = side.loc[:,(leg_tips[:3],'likelihood')]
    side_R_lk = side.loc[:,(leg_tips[3:],'likelihood')]
    bottom_lk = bottom.loc[:,(leg_tips,'likelihood')]
    
    #select for high likelihood frames
    #mask = ((side_L_lk>th1).sum(1)==3) & ((bottom_lk>th1).sum(1)==6) #only flies pointing left
    #mask = ((side_R_lk>th1).sum(1)==3) & ((bottom_lk>th1).sum(1)==6) #only flies pointing right
    mask = ( (((side_L_lk>th1).sum(1)==3) & ((side_R_lk>th1).sum(1)==0)) | \
             (((side_R_lk>th1).sum(1)==3) & ((side_L_lk>th1).sum(1)==0)) ) & \
             ((bottom_lk>th1).sum(1)==6)
    side = side[mask].dropna()
    bottom = bottom[mask].dropna()
    
    #sometimes DLC mixes up limbs so take only those frames there the x coordinate matches on side and bottom views
    diff_L = np.abs(bottom.loc[:,(leg_tips[:3],'x')].values - side.loc[:,(leg_tips[3:],'x')].values)
    diff_R = np.abs(bottom.loc[:,(leg_tips[3:],'x')].values - side.loc[:,(leg_tips[:3],'x')].values)
    mask = ((diff_L<th2).sum(1)==3) | ((diff_R<th2).sum(1)==3)
    side = side[mask].dropna()
    bottom = bottom[mask].dropna()
    
    assert side.shape[0]==bottom.shape[0], 'Number of rows must match in filtered data!'
    
    return bottom, side
    

def augment(bottom, typ, rng):
    tmp = bottom.loc[:,(slice(None),['x','y'])].to_numpy()
    bottom_old = bottom.copy()
    
    if typ == 'rot':
        _rng = rng
    if typ == 'noise':
        _rng = range(len(rng))
    
    for angle in _rng: 
        bottom_rot = bottom_old.copy()
        
        if typ=='noise':
            bottom_rot.loc[:,(slice(None),['x','y'])] = tmp + np.random.normal(0,6,size=tmp.shape)
        
        if typ=='rot':
            theta = np.radians(angle)
            cos, sin = np.cos(theta), np.sin(theta)
            R1 = np.array(((cos, -sin), (sin, cos)))
            tmp1 = np.reshape(tmp, [-1, 2])
            mu = tmp1.mean(axis=0)
            tmp1 = np.matmul(tmp1-mu,R1)# + mu
            tmp1 = np.reshape( tmp1, [-1, 60] )
            bottom_rot.loc[:,(slice(None),['x','y'])] = tmp1
            
        bottom = bottom.append(bottom_rot) #append
    
    return bottom

# Select mode

In [6]:
mode= 'train' #0: train, 1: prediction, 2: DLC_video, 3: train_low_res

if mode=='train':
    th1 = 0.95 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    align=1
    fine_align=0
    nice_epochs=0
    aug_angles=1
    aug_noise=0
if mode=='prediction':
    th1 = 0.5 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    align=1
    fine_align=1
    nice_epochs=1
    aug_angles=0
    aug_noide=0
if mode=='DLC_video':
    th1 = 0.5 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    align=0
    fine_align=0
    nice_epochs=1
    aug_angles=0
    aug_noide=0
if mode=='train_low_res':
    th1 = 0.95 #confidence threshold
    th2 = 10 #max L-R discrepancy in x coordinate
    align=1
    fine_align=0
    nice_epochs=0
    aug_angles=1
    aug_noise=1

Process data

In [None]:
index = []
side = pd.DataFrame()
bottom = pd.DataFrame()
flip_idx = None
for i in tqdm(range(len(videos_side))):
    #load data of side and bottom view
    _side = pd.read_hdf(home_dir + videos_side[i] + scorer_side + '.h5')
    _bottom = pd.read_hdf(home_dir + videos_bottom[i] + scorer_bottom + '.h5')
    _side = _side.droplevel('scorer',axis=1) 
    _bottom = _bottom.droplevel('scorer',axis=1) 
    
    #select for high confidence datapoints
    _bottom, _side = select_best_data(_bottom, _side, th1, th2)
    
    #flip left and right side of bottom due to prism reflection
    _bottom = flip_LR(_bottom)
    
    #get epochs (list of lists of consecutive timesteps)
    epochs = get_epochs(_bottom)
      
    #align horizontally
    if align: #1 for training and prediction, 0 for making of DLC video
        _bottom.loc[:,(slice(None),['x','y'])] = _bottom.loc[:,(slice(None),['x','y'])].apply(lambda x: rotate_to_horizontal(x, i), axis=1)
          
        #orient all flies left
        _bottom, _side, flip_idx = orient_left(_bottom, _side)
        
    #procrustes to fine-align flies 
    if fine_align: #0 for training, 1 for prediction
        _bottom = procrustes_on_epochs(_bottom, epochs)
            
    #take only those frames with at least 10 consecutive frames
    if nice_epochs: #0 for training, 1 for prediction
        long_epochs = []
        for e in range(len(epochs)-1):
            if (len(epochs[e]) >= 5) | ((len(epochs[e])<5) & ((epochs[e+1][0] - epochs[e][-1]<5))):
                long_epochs += epochs[e]
                long_epochs += epochs[e+1]
            
        long_epochs=np.unique(long_epochs)  
    
        _bottom=_bottom.loc[long_epochs,:]
        _side=_side.loc[long_epochs,:]
        print('long epochs')
    
    #save indices for later
    index.append(_bottom.index.values)

    #augment with rotated flies +/- 10 deg
    if aug_angles: #1 for training, 0 for prediction
        angles = [-10, 10]
        _bottom = augment(_bottom, typ='rot', rng=angles)
        _side = pd.concat([_side]*(len(angles)+1))
        flip_idx = pd.concat([flip_idx]*(len(angles)+1))
       
    #augment with noise to scale down
    if aug_noise: #1 only if applied to lower resolution setups, 0 otherwise
        samples = 6
        _bottom = augment(bottom, typ='noise', rng=samples)
        _side = pd.concat([_side]*(samples+1))
        flip_idx = pd.concat([flip_idx]*(samples+1))
    
    #convert & save to DF3D format
    side_np = _side.loc[:,(slice(None),['x','y'])].to_numpy()
    z = _side.loc[:,(slice(None),'y')].to_numpy()
    side_np = np.stack((side_np[:,::2], side_np[:,1::2]), axis=2)

    bottom_np = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()
    bottom_np = np.stack((bottom_np[:,::2], bottom_np[:,1::2]), axis=2)
    
    poses = {'points2d': np.stack((bottom_np, side_np), axis=0),
             'points3d': np.concatenate((bottom_np, z[:,:,None]), axis=2),
             'index':_side.index.values,
             'flip_idx': np.array(flip_idx)
            }
    
    #save
    pickle.dump(poses,open(home_dir + videos_side[i].split('/')[-1][6:] + '.pkl', 'wb'))


  0%|          | 0/8 [00:00<?, ?it/s][A
 12%|█▎        | 1/8 [00:32<03:47, 32.46s/it][A
 25%|██▌       | 2/8 [02:46<06:18, 63.02s/it][A

In [None]:
import matplotlib.pyplot as plt

coord1 = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()[2,:].reshape(-1,2)

x1 = coord1[:,0]
y1 = coord1[:,1]

plt.scatter(x1,y1)

coord2 = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()[3,:].reshape(-1,2)


_, _, T, _, c = proc.compute_similarity_transform(coord2, coord1 )
coord2 = (T@coord2.T).T

x2 = coord2[:,0]
y2 = coord2[:,1]
plt.scatter(x2,y2)