In [2]:
import numpy as np
import pandas as pd
import glob
from tqdm import tqdm
import pickle
import cv2
from math import atan2

In [3]:
home_dir = '/data/LiftFly3D/prism/'
data_dir = '/mnt/NAS/SG/prism_data/'
scorer_bottom = 'DLC_resnet50_jointTrackingDec13shuffle1_200000'
scorer_side = 'DLC_resnet50_sideJointTrackingDec17shuffle1_200000'

#joints whose confidence to consider
leg_tips = ['tarsus tip front L', 'tarsus tip mid L', 'tarsus tip back L',
          'tarsus tip front R', 'tarsus tip mid R', 'tarsus tip back R']

crop_positions = ['/bottom_view/videos/crop_loc_191125_PR_Fly1_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly1_004_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_001_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_002_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_003_prism.txt',
                  '/bottom_view/videos/crop_loc_191125_PR_Fly2_004_prism.txt']

videos_side = ['side_view/videos/video_191125_PR_Fly1_001_prism',
               'side_view/videos/video_191125_PR_Fly1_002_prism',
               'side_view/videos/video_191125_PR_Fly1_003_prism',
               'side_view/videos/video_191125_PR_Fly1_004_prism',
               'side_view/videos/video_191125_PR_Fly2_001_prism',
               'side_view/videos/video_191125_PR_Fly2_002_prism',
               'side_view/videos/video_191125_PR_Fly2_003_prism',
               'side_view/videos/video_191125_PR_Fly2_004_prism']

videos_bottom =  ['bottom_view/videos/video_191125_PR_Fly1_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly1_004_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_001_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_002_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_003_prism',
                  'bottom_view/videos/video_191125_PR_Fly2_004_prism']

images_side = ['191125_PR/Fly1/001_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/002_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/003_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly1/004_prism/behData/images/side_view_prism_data_191125_PR_Fly1/',
               '191125_PR/Fly2/001_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/002_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/003_prism/behData/images/side_view_prism_data_191125_PR_Fly2/',
               '191125_PR/Fly2/004_prism/behData/images/side_view_prism_data_191125_PR_Fly2/']

images_bottom =  ['191125_PR/Fly1/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly1/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly1/',
                  '191125_PR/Fly2/001_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/002_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/003_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/',
                  '191125_PR/Fly2/004_prism/behData/images/bottom_view_prism_data_191125_PR_Fly2/']

assert len(videos_side)==len(videos_bottom), 'Number of video files must be the same from side and bottom!'

In [4]:
def read_crop_pos(file):
    f=open(file, "r")
    contents =f.readlines()
    im_file = []
    x_pos = []
    for i in range(4,len(contents)):
        line = contents[i][:-1].split(' ')
        im_file.append(line[0])
        x_pos.append(line[1])
        
    return im_file, x_pos


def orientation(img):
    #_, img_th = cv2.threshold(img, 250, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    img_th = img.copy()
    img_th[img_th < 130] = 0 # was 140
    # Find all the contours in the thresholded image
    contours, _ = cv2.findContours(img_th, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
    if len(contours) < 1 : return None
    for i, contour in enumerate(contours):
        # Calculate the area of each contour
        area = cv2.contourArea(contour)
        # Ignore contours that are too small or too large
        if area > 10000:
            break

    # Find the orientation of each shape
    img_pts = np.empty((len(contour), 2), dtype=np.float64)
    img_pts[:,0], img_pts[:,1] = contour[:,0,0], contour[:,0,1]

    # PCA analysis
    mean = np.empty((0))
    _, eigenvectors, _ = cv2.PCACompute2(img_pts, mean)

    angle = atan2(eigenvectors[0,1], eigenvectors[0,0])
    return angle


def center_and_align(pts2d, exp):
    '''rotate align data'''
    
    #access corresponding image file
    idx = pts2d.name 
    im_file, _ = read_crop_pos(home_dir + crop_positions[exp])
    im_crop_bottom = cv2.imread(data_dir + images_bottom[exp] + im_file[idx],0)
    
    #get orientation and centre
    angle = orientation(im_crop_bottom)
    c = np.array(im_crop_bottom.shape)/2
    
    #rotate points
    cos, sin = np.cos(angle), np.sin(angle)
    R = np.array(((cos, -sin), (sin, cos)))    
    tmp = pts2d.to_numpy().reshape(-1, 2)
    tmp = np.matmul(tmp-c,R) + c   
    pts2d.iloc[:] = tmp.reshape(-1,tmp.shape[0]*2).flatten()
        
    return pts2d

Filter for high quality frames

In [16]:
th1 = 0.95 #confidence threshold
th2 = 10 #max L-R discrepancy in x coordinate

index = []
side = pd.DataFrame()
bottom = pd.DataFrame()
for i in tqdm(range(len(videos_side))):
    _side = pd.read_hdf(home_dir + videos_side[i] + scorer_side + '.h5')
    _bottom = pd.read_hdf(home_dir + videos_bottom[i] + scorer_bottom + '.h5')
    
    #drop scorer column label
    _side = _side.droplevel('scorer',axis=1) 
    _bottom = _bottom.droplevel('scorer',axis=1) 
    
    #split L and R (remove if we include flipping)
    side_L_lk = _side.loc[:,(leg_tips[:3],'likelihood')]
    side_R_lk = _side.loc[:,(leg_tips[3:],'likelihood')]
    bottom_lk = _bottom.loc[:,(leg_tips,'likelihood')]
    
    #select for high likelihood frames
    #mask = ((side_L_lk>th1).sum(1)==3) & ((bottom_lk>th1).sum(1)==6) #only flies pointing left
    #mask = ((side_R_lk>th1).sum(1)==3) & ((bottom_lk>th1).sum(1)==6) #only flies pointing right
    #mask = ( ((side_L_lk>th1).sum(1)==3) | ((side_R_lk>th1).sum(1)==3) ) & ((bottom_lk>th1).sum(1)==6)
    mask = ( (((side_L_lk>th1).sum(1)==3) & ((side_R_lk>th1).sum(1)==0)) | \
             (((side_R_lk>th1).sum(1)==3) & ((side_L_lk>th1).sum(1)==0)) ) & \
             ((bottom_lk>th1).sum(1)==6)
    _side = _side[mask].dropna()
    _bottom = _bottom[mask].dropna()
    
    #sometimes DLC mixes up limbs so take only those frames there the x coordinate matches on side and bottom views
    diff_L = np.abs(_bottom.loc[:,(leg_tips[:3],'x')].values - _side.loc[:,(leg_tips[3:],'x')].values)
    diff_R = np.abs(_bottom.loc[:,(leg_tips[3:],'x')].values - _side.loc[:,(leg_tips[:3],'x')].values)
    mask = ((diff_L<th2).sum(1)==3) | ((diff_R<th2).sum(1)==3)
    _side = _side[mask].dropna()
    _bottom = _bottom[mask].dropna()
    
    #take only those frames with at least 10 consecutive frames
    if 0:
        from itertools import groupby
        from operator import itemgetter
        data = list(_bottom.index)
        epochs = []
        for k, g in groupby(enumerate(data), lambda ix : ix[0] - ix[1]):
            epochs.append(list(map(itemgetter(1), g)))
        long_epochs = []
        for e in range(len(epochs)-1):
            if (len(epochs[e]) >= 5) | ((len(epochs[e])<5) & ((epochs[e+1][0] - epochs[e][-1]<5))):
                long_epochs += epochs[e]
                long_epochs += epochs[e+1]
            
        long_epochs=np.unique(long_epochs)  
    
        _bottom=_bottom.loc[long_epochs,:]
        _side=_side.loc[long_epochs,:]
        print('long epochs')
    
    #save indices for later
    index.append(_bottom.index.values)
    
    assert _side.shape[0]==_bottom.shape[0], 'Number of rows must match in filtered data!'
    
    #flip left and right side due to prism reflection
    cols = list(_side.columns)
    half = int(len(cols)/2)
    tmp = _bottom.loc[:,cols[:half]].values
    _bottom.loc[:,cols[:half]] = _bottom.loc[:,cols[half:]].values
    _bottom.loc[:,cols[half:]] = tmp
    
    #align horizontally
    side_R_lk = _side.loc[:,(leg_tips[3:],'likelihood')] #high confidence on R joints means fly points right
    flip_idx = (side_R_lk>th1).sum(1)==3
    
    if 1:
        _bottom.loc[:,(slice(None),['x','y'])] = _bottom.loc[:,(slice(None),['x','y'])].apply(lambda x: center_and_align(x, i), axis=1)
    
        #rotate flies pointing right
        theta = np.radians(180)
        cos, sin = np.cos(theta), np.sin(theta)
        R = np.array(((cos, -sin), (sin, cos)))     
    
        if np.sum(flip_idx) != 0:
            tmp = _bottom.loc[flip_idx,(slice(None),['x','y'])].to_numpy()
            tmp = np.reshape(tmp, [-1, 2])
            mu = tmp.mean(axis=0)
            tmp = np.matmul(tmp-mu,R) + mu
            tmp = np.reshape( tmp, [-1, 60] )
            _bottom.loc[flip_idx,(slice(None),['x','y'])] = tmp

    if 0:
        #augment with rotated flies +/- 10 deg
        tmp = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()
        _bottom_old = _bottom.copy()
        _side_old = _side.copy()
        flip_idx_old = flip_idx.copy()
        for angle in [-10, 10]: 
            _bottom_rot = _bottom_old.copy()
            theta = np.radians(angle)
            cos, sin = np.cos(theta), np.sin(theta)
            R1 = np.array(((cos, -sin), (sin, cos)))
    
            tmp1 = np.reshape(tmp, [-1, 2])
            mu = tmp1.mean(axis=0)
            tmp1 = np.matmul(tmp1-mu,R1) + mu
            tmp1 = np.reshape( tmp1, [-1, 60] )
            _bottom_rot.loc[:,(slice(None),['x','y'])] = tmp1
            _bottom=_bottom.append(_bottom_rot) #append
            _side=_side.append(_side_old)
            flip_idx = flip_idx.append(flip_idx_old)
       
    if 0:
        #augment with noise perturbation
        tmp = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()
        _bottom_old = _bottom.copy()
        _side_old = _side.copy()
        flip_idx_old = flip_idx.copy()
        for angle in range(5): 
            _bottom_noise = _bottom_old.copy()
            _bottom_noise.loc[:,(slice(None),['x','y'])] = tmp + np.random.normal(0,6,size=tmp.shape)
            _bottom=_bottom.append(_bottom_noise) #append
            _side=_side.append(_side_old)
            flip_idx = flip_idx.append(flip_idx_old)
    
    #convert & save to DF3D format
    side_np = _side.loc[:,(slice(None),['x','y'])].to_numpy()
    z = _side.loc[:,(slice(None),'y')].to_numpy()
    side_np = np.stack((side_np[:,::2], side_np[:,1::2]), axis=2)

    bottom_np = _bottom.loc[:,(slice(None),['x','y'])].to_numpy()
    bottom_np = np.stack((bottom_np[:,::2], bottom_np[:,1::2]), axis=2)
    
    poses = {'points2d': np.stack((bottom_np, side_np), axis=0),
             'points3d': np.concatenate((bottom_np, z[:,:,None]), axis=2),
             'index':_side.index.values,
             'flip_idx': np.array(flip_idx)
            }
    
    #save
    pickle.dump(poses,open(home_dir + videos_side[i].split('/')[-1][6:] + '.pkl', 'wb'))





  0%|          | 0/8 [00:00<?, ?it/s][A[A[A[A



 12%|█▎        | 1/8 [00:20<02:25, 20.72s/it][A[A[A[A



 25%|██▌       | 2/8 [01:31<03:33, 35.61s/it][A[A[A[A



 38%|███▊      | 3/8 [02:24<03:24, 40.83s/it][A[A[A[A



 50%|█████     | 4/8 [02:38<02:11, 32.94s/it][A[A[A[A



 62%|██████▎   | 5/8 [02:56<01:25, 28.38s/it][A[A[A[A



 75%|███████▌  | 6/8 [03:00<00:42, 21.18s/it][A[A[A[A



 88%|████████▊ | 7/8 [03:07<00:16, 16.74s/it][A[A[A[A



100%|██████████| 8/8 [03:11<00:00, 23.97s/it][A[A[A[A
