# Load DeepFly3D Data

In [5]:
%load_ext autoreload
%autoreload 2
from load import *
import torch
import yaml
import logging
from imp import reload
import matplotlib.pyplot as plt
from liftpose.vision_3d import world_to_camera_dict, reprojection_error
reload(logging)
logger = logging.getLogger(__name__).setLevel(logging.INFO)
from tqdm import tqdm
tqdm.get_lock().locks = []

# decleare data parameters
par_train = {  'data_dir'       : '/data/LiftPose3D/fly_tether/data_DF3D', # change the path 
               'out_dir'        : '/data/LiftPose3D/fly_tether/cams_test',
               'train_subjects' : [1],
               'test_subjects'  : [6,7],
               'actions'        : ['all'],
               'cam_id'         : [0,6]}

# merge with training parameters
par_data = yaml.full_load(open('param.yaml', "rb"))
par = {**par_data["data"], **par_train}

# Load 2D data
train_2d = load_2D(
    par["data_dir"],
    par,
    cam_id=par["cam_id"],
    subjects=par["train_subjects"],
    actions=par["actions"],
)
test_2d = load_2D(
    par["data_dir"],
    par,
    cam_id=par["cam_id"],
    subjects=par["test_subjects"],
    actions=par["actions"],
)

# Load 3D data
train_3d, train_keypts, rcams_train = load_3D(
    par["data_dir"],
    par,
    cam_id=par["cam_id"],
    subjects=par["train_subjects"],
    actions=par["actions"],
)
test_3d, test_keypts, rcams_test = load_3D(
    par["data_dir"],
    par,
    cam_id=par["cam_id"],
    subjects=par["test_subjects"],
    actions=par["actions"],
)

train_3d = world_to_camera_dict(train_3d, rcams_train)
test_3d = world_to_camera_dict(test_3d, rcams_test)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Train LiftPose3D Network on DeepFly3D Data

In [6]:
from liftpose.main import train as lp3d_train
from liftpose.lifter.augmentation import add_noise

lp3d_train(train_2d=train_2d, test_2d=test_2d,
           train_3d=train_3d, test_3d=test_3d, 
           train_keypts=train_keypts,
           test_keypts=test_keypts,
           roots=par['roots'],
           target_sets=par['target_sets'],
           out_dir=par['out_dir'],
           training_kwargs={"epochs":10},
          )#augmentation=[add_noise(noise_amplitude=.1)])

(1, 'all', 'pose_result_fix__data_paper_180919_MDN_CsCh_Fly5_004_SG1_behData_images.cam_0')
[ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -2.49090967e-01
  5.12187802e-01 -4.42739790e-02  6.90228533e-01  4.17887631e-01
 -2.45959561e-01  6.79966847e-01  1.05711406e+00 -6.01945760e-01
  9.72010672e-01  1.65738145e+00 -1.02723053e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  1.38485967e-03]
(1, 'all', 'pose_result_fix__data_paper_180919_MDN_CsCh_Fly5_004_SG1_behData_images.cam_6')
[ 0.          0.          0.          0.1194247   0.51891605  0.20689784
 -0.58091966  0.41496775 -0.44930582 -0.84975981  1.07126955 -0.26920912
 -1.34786421  1.6857394  -0.2974312   0.          0.          0.
 -0.10519913]
(1, 'all', 'pose_result_fix__data_paper_180919_MDN_CsCh_Fly6_002_SG1_behData_images.cam_0')
[ 0.          0.          0.         -0.12422136  0.49320193 -0.03421081
  0.85522261  0.50754417 -0.06307381  1.05488659  1.18493846 -0.28585134
  1.58548864  1.80715428 -0.46559002  0.   

SystemExit: 

In [None]:
%matplotlib inline
from liftpose.plot import read_log_train, plot_log_train
epoch, lr, loss_train, loss_test, err_test = read_log_train(par['out_dir'])
plot_log_train(plt.gca(), loss_train, loss_test, epoch)

# Run Trained LiftPose3D Network on the Test Data

In [None]:
from liftpose.main import test as lp3d_test
lp3d_test(par['out_dir'])

# Visualize Predictions

In [None]:
from liftpose.postprocess import load_test_results
data = torch.load(os.path.join(par['out_dir'], "test_results.pth.tar"))
stat_2d, stat_3d = (
    torch.load(os.path.join(par['out_dir'], "stat_2d.pth.tar")),
    torch.load(os.path.join(par['out_dir'], "stat_3d.pth.tar")),
)
test_3d_gt, test_3d_pred, good_keypts = load_test_results(data, stat_2d, stat_3d)

In [None]:
# https://stackoverflow.com/a/38865534/7554774
# conda install ipympl
#%matplotlib widget
%matplotlib inline
from liftpose.plot import plot_pose_3d

fig = plt.figure(figsize=plt.figaspect(1), dpi=100)
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=90, azim=0)

t = 0
plot_pose_3d(ax=ax, tar=test_3d_gt[t], 
            pred=test_3d_pred[t], 
            bones=par_data["vis"]["bones"], 
            limb_id=par_data["vis"]["limb_id"], 
            colors=par_data["vis"]["colors"],
            good_keypts=good_keypts[t])