In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from time import time
from lib import utils, app
from lib.misc import get_markers

plt.style.use(os.path.join('..', 'configs', 'mplstyle.yaml'))

%load_ext autoreload
%autoreload 2

ROOT_DATA_DIR = os.path.join("..", "data")

# Reconstruction Params
Define the params in the cell below. Thereafter, run all cells

In [None]:
# DATA_DIR = os.path.join(ROOT_DATA_DIR, "2019_03_09", "lily", "run")
DATA_DIR = os.path.join(ROOT_DATA_DIR, "2017_08_29", "top", "jules", "run1_1")

start_frame = 50
end_frame = 115

# DLC p_cutoff - any points with likelihood < dlc_thresh are not trusted in optimisation
dlc_thresh = 0.8  # change this only if optimisation result is unsatisfactory

# Optimization

In [None]:
t0 = time()

assert os.path.exists(DATA_DIR)
OUT_DIR = os.path.join(DATA_DIR, 'sba')
DLC_DIR = os.path.join(DATA_DIR, 'dlc')
assert os.path.exists(DLC_DIR)
os.makedirs(OUT_DIR, exist_ok=True)

app.start_logging(os.path.join(OUT_DIR, 'sba.log'))

# load video info
res, fps, tot_frames, _ = app.get_vid_info(DATA_DIR) # path to original videos
assert end_frame <= tot_frames

start_frame -= 1 # 0 based indexing
assert start_frame >= 0
N = end_frame-start_frame

*_, n_cams, scene_fpath = utils.find_scene_file(DATA_DIR)

dlc_points_fpaths = glob(os.path.join(DLC_DIR, '*.h5'))
assert n_cams == len(dlc_points_fpaths)
    
# Load Measurement Data (pixels, likelihood)
points_2d_df = utils.load_dlc_points_as_df(dlc_points_fpaths, verbose=False)
points_2d_df = points_2d_df[points_2d_df["frame"].between(start_frame, end_frame-1)]
points_2d_df = points_2d_df[points_2d_df['likelihood']>dlc_thresh] # ignore points with low likelihood

t1 = time()
print("Initialization took {0:.2f} seconds\n".format(t1 - t0))

points_3d_df, residuals = app.sba_points_fisheye(scene_fpath, points_2d_df)

app.stop_logging()

plt.plot(residuals['before'], label="Cost before")
plt.plot(residuals['after'], label="Cost after")
plt.legend()
fig_fpath = os.path.join(OUT_DIR, 'sba.svg')
plt.savefig(fig_fpath, transparent=True)
print(f'Saved to {fig_fpath}\n')
plt.show()

# Save SBA results

In [None]:
markers = get_markers()

positions = np.full((N, len(markers), 3), np.nan)
for i, marker in enumerate(markers):
    marker_pts = points_3d_df[points_3d_df["marker"]==marker][["frame", "x", "y", "z"]].values
    for frame, *pt_3d in marker_pts:
        positions[int(frame)-start_frame, i] = pt_3d

app.save_sba(positions, OUT_DIR, scene_fpath, start_frame, dlc_thresh)

# Plot the cheetah!

In [None]:
data_fpath = os.path.join(OUT_DIR, 'sba.pickle')
app.plot_cheetah_reconstruction(data_fpath, hide_lure=True, reprojections=False, centered=True, dark_mode=True)

In [None]:
# data_fpaths = [os.path.join(DATA_DIR, 'sba', 'sba.pickle'),
#                os.path.join(DATA_DIR, 'ekf', 'ekf.pickle'),
#                os.path.join(DATA_DIR, 'fte', 'fte.pickle')]
# app.plot_multiple_cheetah_reconstructions(data_fpaths, hide_lure=True, reprojections=False, centered=True, dark_mode=True)