In [1]:
%load_ext autoreload
%autoreload 2
import load_prism
import load_tether
import torch
import yaml
import logging
import matplotlib.pyplot as plt
from liftpose.vision_3d import *
from liftpose.preprocess import concat_dict, total_frames, center_poses, anchor_to_root, flatten_dict, unflatten_dict
from liftpose.plot import plot_pose_3d, plot_pose_2d
from tqdm import tqdm
tqdm.get_lock().locks = []

# Load data in the source domain (tethered fly) and rotate it to ventral camera angle

In [2]:
# declare data parameters
par_train = {  'data_dir'       : '/data/LiftPose3D/fly_tether/data_DF3D', # change the path 
              #'data_dir'       : r'\Users\NeLy\Desktop\fly_tether', #windows path format
               'out_dir'        : '/data/LiftPose3D/domain_adaptation',
               'train_subjects' : [1,2,3,4,5],
               'test_subjects'  : [6,7],
               'actions'        : ['all'],
               'cam_id'         : [0]} #dummy camera

# merge with training parameters
par_data = yaml.full_load(open('param.yaml', "rb"))
par_source = {**par_data["data"], **par_train}

# Load 3D data
pts3d_source, _, _ = load_tether.load_3D(
    par_source["data_dir"],
    par_source,
    cam_id=par_source["cam_id"],
    subjects=par_source["train_subjects"],
    actions=par_source["actions"],
)

#concatenate dicts of experiments
pts3d_source = concat_dict(pts3d_source)

#roots are bit wobbly across frames so stabilize them by anchoring and then adding the root for frame #1
pts3d_source = flatten_dict(pts3d_source)
pts3d_source, offset = anchor_to_root(pts3d_source, par_source['roots'], par_source['target_sets'], 3)
k0 = None
for k in pts3d_source.keys():
    if k0 is None:
        k0=k
    pts3d_source[k] += offset[k0][0,:]

pts3d_source = unflatten_dict(pts3d_source,3)    
pts3d_source = pts3d_source['']

#project to ventral view
pts2d_source = project_to_eangle(pts3d_source, [-90,0,0], axsorder='xzy', project=True)
pts3d_source = project_to_eangle(pts3d_source, [-90,0,0], axsorder='xzy')


# Load data in the target domain (fly in prism-mirror setup)

In [3]:
# declare data parameters
par_train = {'data_dir'       : "/data/LiftPose3D/fly_prism/data_oriented/test_data", # change the path
             'out_dir'        : '/data/LiftPose3D/domain_adaptation',
             "train_subjects" : ["001", "002", "003","004"],
             'test_subjects'  : ["004"],
             'actions'        : ['PR']}

# merge with training parameters
par_data = yaml.full_load(open('param.yaml', "rb"))
par = {**par_data["data"], **par_train}

# load data
pts3d_target, keypts_target, _ = load_prism.load_3D(par["data_dir"], 
                                            subjects=par['train_subjects'], 
                                            actions=par['actions'])

#concatenate dicts of experiments
pts3d_target = concat_dict(pts3d_target)
keypts_target = concat_dict(keypts_target)

#select only high confidence points with roots
ind = (keypts_target.max(2).sum(1)>12) & (keypts_target[:,par['roots'],:].max(2).sum(1)==3)
pts3d_target = pts3d_target[ind]
keypts_target = keypts_target[ind]

#roots are bit wobbly across frames so stabilize them by anchoring and then adding the root for frame #1
pts3d_target = flatten_dict(pts3d_target)
pts3d_target, offset = anchor_to_root(pts3d_target, par['roots'], par['target_sets'], 3)
k0 = None
for k in pts3d_target.keys():
    if k0 is None:
        k0=k
    pts3d_target[k] += offset[k0][0,:]
    
pts3d_target = unflatten_dict(pts3d_target,3)

#project data to ventral view
pts2d_target = XY_coord_dict(pts3d_target)

pts2d_target = pts2d_target['']
pts3d_target = pts3d_target['']

ind = [i for i in range(2000,6000)]
pts2d_target = pts2d_target[ind]
pts3d_target = pts3d_target[ind]
keypts_target = keypts_target[ind]

In [None]:
fig = plt.figure(figsize=plt.figaspect(1), dpi=50)
ax = fig.add_subplot(111)#, projection='3d')

ind=poses[0]

plot_pose_2d(
        ax, 
        pts2d_target[ind,:,:2], 
        bones=par_data["vis"]["bones"], 
        limb_id=par_data["vis"]["limb_id"],  
    #colors=par_data["vis"]["colors"], 
    )

In [None]:
%matplotlib inline

ind = [i for i in range(2000,6000)]
tar = pts3d_target[ind,:]
#keypts = keypts_source[ind,:]

from liftpose.plot import plot_video_3d

fig = plt.figure(figsize=plt.figaspect(1), dpi=50)
ax = fig.add_subplot(111)#, projection='3d')
#ax.view_init(elev=30, azim=140)

def f(ax, idx):
    ax.cla()

    #plot_pose_3d(ax=ax, tar=tar[idx],
    #    #pred=pred,
    #    bones=par_data["vis"]["bones"], 
    #    limb_id=par_data["vis"]["limb_id"], 
    #    colors=par_data["vis"]["colors"],
    #    #good_keypts = keypts[idx],
    #    normalize=False,
    #    legend=True,
    #    axes=True)
    
    plot_pose_2d(
        ax, 
        tar[idx,:,:2], 
        bones=par_data["vis"]["bones"], 
        limb_id=par_data["vis"]["limb_id"],  
    #colors=par_data["vis"]["colors"], 
    )
    ax.set_xlim([-200,200])
    ax.set_ylim([-200,200])
    #ax.set_zlim([-40,40])
    
plot_video_3d(fig, ax, n=2000, fps=20, draw_function=f, name='LiftPose3D_prediction.mp4')

# Find nearest neighbours

In [None]:
from tqdm.contrib.concurrent import process_map
from functools import partial
import pickle

nn = 20
frac_samples = 0.25

try:
    nns_2d, nns_3d = pickle.load(open(par_train['out_dir']+'/neighbors.pkl','rb'))

except:   
    total = pts3d_target.shape[0]
    #poses = [i for i in range(int(frac_samples*total))]
    poses = np.random.choice([i for i in range(total)], size=int(frac_samples*total), replace=False)

    #3D poses
    nns_3d = process_map(partial(find_neighbours, 
                         pts=pts3d_source, 
                         target_pts=pts3d_target, 
                         nn=nn,
                         good_keypts=keypts_target), 
                  poses, max_workers=16)
                         
    #2D poses
    nns_2d = process_map(partial(find_neighbours, 
                         pts=pts2d_source, 
                         target_pts=pts2d_target, 
                         nn=nn), 
                  poses, max_workers=16)

    pickle.dump([poses, nns_2d,nns_3d],open(par_train['out_dir']+'/neighbors.pkl','wb'))

  0%|          | 0/1000 [00:00<?, ?it/s]

Check how much data is needed to generalize

In [None]:
import random 

poses = np.array(poses)
nns_2d = np.array(nns_2d)
frac = np.linspace(0.1,1,10)
error = []
ind = range(len(poses))
fold = 10
for s in range(fold): #10-fold cross validation
    ind_test = random.sample(ind,int(1/fold*len(ind)))
    diff = set(ind) - set(ind_test)
    
    err = []
    for f in frac:
        ind_train = random.sample(diff,int(f*len(diff)))
        A_est_2D = best_linear_map(pts2d_source,pts2d_target[poses[ind_train]],nns_2d[ind_train],nn=5)
        pts2d_prism = apply_linear_map(A_est_2D, pts2d_target[poses[ind_test]])
        err.append(np.abs(pts2d_source[nns_2d[ind_test][:,:1]] - pts2d_prism).mean())
        
    error.append(np.array(err))
    
error = np.array(error)
plt.plot(frac*len(poses)*(1-1/fold), error.mean(0)/error.mean(0)[0])
plt.ylabel('Normalized distance to nearest neighbor (a.u.)')
plt.xlabel('Number of poses')

# Find best linear transformation for 2D, d2

In [None]:
ind = 10
nn=10

A_est_2D = best_linear_map(pts2d_source, pts2d_target[poses], nns_2d, nn=nn)
pts2d_prism = apply_linear_map(A_est_2D, pts2d_target[poses])

fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(121)
plot_pose_2d(
    ax, 
    pts2d_target[poses[ind]], 
    bones=par_data["vis"]["bones"], 
    limb_id=par_data["vis"]["limb_id"],  
    #colors=par_data["vis"]["colors"], 
)
ax.set_title('Target domain pose')

ax = fig.add_subplot(122)
plot_pose_2d(
    ax, 
    pts2d_prism[ind], 
    bones=par_data["vis"]["bones"], 
    limb_id=par_data["vis"]["limb_id"],  
    #colors=par_data["vis"]["colors"], 
)
ax.set_title('Source domain poses')

for i in nns_2d[ind][:nn]:
    plot_pose_2d(
        ax, 
        pts2d_source[i], 
        bones=par_data["vis"]["bones"], 
        limb_id=par_data["vis"]["limb_id"],  
        colors=par_data["vis"]["colors"], 
    )
    
#plt.savefig('2D_mapping.svg')

# Find best linear transformation for 3D, d3

In [None]:
ind = 10
nn = 10

A_est_3D = best_linear_map(pts3d_source,pts3d_target[poses],nns_3d,nn=nn)
pts3d_prism = apply_linear_map(A_est_3D, pts3d_target[poses])
    
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111, projection="3d")
plot_pose_3d(
    ax, 
    pts3d_prism[ind], 
    bones=par_data["vis"]["bones"], 
    limb_id=par_data["vis"]["limb_id"],  
    #colors=par_data["vis"]["colors"], 
    good_keypts=None
)
for i in nns_3d[ind][:nn]:
    plot_pose_3d(
        ax, 
        pts3d_source[i], 
        bones=par_data["vis"]["bones"], 
        limb_id=par_data["vis"]["limb_id"],  
        colors=par_data["vis"]["colors"], 
        good_keypts=None
    )    
    
#plt.savefig('3D_mapping.svg')

# Predict poses with trained network

In [None]:
from liftpose.postprocess import load_test_results
from liftpose.main import set_test_data
from liftpose.main import test as lp3d_test
from liftpose.lifter.utils import filter_data

test_3d_gt, test_3d_pred, good_keypts = [], [], []
        
#normalize test data
test_2d, test_3d, stat_2d, stat_3d = set_test_data(par['out_dir'], {'a':pts2d_prism.copy()}, {'a':pts3d_prism.copy()}, {'a':keypts_target.copy()})
    
#test data on network
lp3d_test(par['out_dir'],test_2d, test_3d, keypts_target.copy())
    
#load statistics and test results
gt, pred, _ = load_test_results(par['out_dir'], stat_2d, stat_3d)

#filter noise
#test_3d_gt = filter_data(test_3d_gt)
#test_3d_pred = filter_data(test_3d_pred)


In [None]:
%matplotlib inline
from liftpose.plot import plot_pose_3d

t = 40

fig = plt.figure(figsize=plt.figaspect(1), dpi=100)
ax = fig.add_subplot(111, projection='3d')

plot_pose_3d(ax=ax, tar=pred[t], 
            pred=gt[t],
            bones=par_data["vis"]["bones"], 
            limb_id=par_data["vis"]["limb_id"], 
            colors=par_data["vis"]["colors"],
            good_keypts=keypts_target[t],
            show_pred_always=True,
            legend=True)

In [None]:
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(121)
plot_pose_2d(
    ax, 
    pts2d_target[poses[ind]], 
    bones=par_data["vis"]["bones"], 
    limb_id=par_data["vis"]["limb_id"],  
    #colors=par_data["vis"]["colors"], 
)