# Tracking of Xenopus Embryos using Kalman Filter
This notebook uses segmentation masks generated with watershed to identify the centres of gravity of the embryos. These centres of mass are tracked across frames using a Kalman filter. The assignment of new measurements to existing tracks is based on the maximum likelihood of a given measurement belonging to a track, which is calculated from the covariance matrices extracted from the Kalman filter.

### Next steps:
- label paths on video for trouble shooting (optional)
- show segmentation regions for trouble shooting (optional)
- Apply kalman smoothing (optional)

The user should set the paramers in the parameters section, where the id of the video to be considered can also be specified. Locations of the input and output data are specified in the cell data locations.

Note: Running this notebook requires that initial markers are defined for the video to be processed. The notebook [4a_place_and_evaluate_markers](4a_place_and_evaluate_markers.ipynb) in this folder can be used to define those markers.
Also note that running this notebook takes a long time (up to 4 hours on a 32-core node), so it is good to make sure enough time is available in the hpc ondemand session.

# Imports

In [None]:
%load_ext autoreload

In [None]:
import numpy as np
import scipy as sp
from scipy import signal
import cv2
import os
import copy
import yaml

from IPython import display
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from tqdm import tqdm

from fam13a import consts, utils, track, image, kalman
from fam13a.paths import path, get_new_unique_label, get_path_by_label, check_paths_alive
from collections import defaultdict

In [None]:
%autoreload 2

# Parameters

In [None]:
VIDEO_ID = '5_L2_1' # Specify video to consider
n_frames = -1 # Use -1 for considering all frames
save_frequency = 500

# Minimum length of path to be included in analysis
MIN_LEN_PATH = 100

# Number of frames we will propagate forwards without measurements until a path is considered dead.
PROP_THD = 10

# Data locations

In [None]:
PROJ_ROOT = utils.here(True)
VIDEO_ROOT = os.path.join(PROJ_ROOT, 'data', 'interim', 'xenopus')
MARKERS_SEGMENTED_ROOT = os.path.join(PROJ_ROOT, 'data', 'processed', 'xenopus', 'segmented', 'markers')
ANALYSIS_OUTPUT_DIR = os.path.join(PROJ_ROOT, 'data', 'processed', 'xenopus', 'statistics')
VIDEO_OUTPUT_DIR = os.path.join(PROJ_ROOT, 'data', 'processed', 'xenopus', 'videos')

## Set which video to track

In [None]:
with open(os.path.join(MARKERS_SEGMENTED_ROOT, VIDEO_ID + ".yml"), "r") as f:
    points = yaml.load(f, Loader=yaml.SafeLoader)

In [None]:
# timestep between measurements/updates
DT = 1

# Assuming a random walk we can represent the state as x = [X, Y, X_dot, Y_dot]
# Then the A matrix is the state model x_dot = A.x + G.w
# where w is the state noise
A = np.array([
    [0,0,1,0],
    [0,0,0,1],
    [0,0,0,0],
    [0,0,0,0]
])
PHI = sp.linalg.expm(A)

# we set the SD 
SIG_WK = 0.03
NOISE_ARR = np.array([SIG_WK, SIG_WK])

SIG_VK = 100

GAM = np.array([
    [0.5*DT**2,0],
    [0,0.5*DT**2],
    [DT,0],
    [0,DT]
])

H = np.array([[1,0,0,0],[0,1,0,0]])

R = np.identity(2) * (SIG_VK**2)
Q = np.array([
    [SIG_WK**2, 0],
    [0, SIG_WK**2],
])

### Set up intitial conditions

## Kalman filter in the loop

In [None]:
raw_frames = utils.frames_from_video(os.path.join(VIDEO_ROOT, f'{VIDEO_ID}.mp4'))
if n_frames != -1:
    raw_frames = raw_frames[:n_frames,...]

*Improvement suggestion for the future*: Instead of storing the videos, we should store the data that the video can be generated from. In the loop below, significant time is spent overwriting the video. In addition, it is not trivial to concatenate videos but it is relatively straight forward to concatenate data frames and lists, to then generate the videos afterwards.

In [None]:
# Track points across frames
for frame_idx in tqdm(range(raw_frames.shape[0])):
    
    if frame_idx == 0:
        new_markers = utils.points_to_markers(points, raw_frames[frame_idx, ...].shape)
        new_markers = image.segment.process_markers(frame=cv2.cvtColor(raw_frames[frame_idx, ...],
                                                                       cv2.COLOR_BGR2HSV),
                                                    min_size=5000,
                                                    max_size=18000,
                                                    markers=new_markers)
        paths = []
        measure_uncert = [25,25]
        position_estimates = []
        for pt in utils.markers_to_pts(new_markers).values():
            new_label = get_new_unique_label(paths)
            new_path = path(new_label)
            x_est, x_pred, P_pred, P_est = kalman.initialise_point(np.stack(pt, axis=0), *measure_uncert)
            frame_idx = 0
            new_path.add_predictions_to_track(x_pred, P_pred, frame_idx)
            new_path.add_estimates_to_track(x_est, P_est, frame_idx)
            new_path.z_meas.append((np.stack(pt, axis=0), frame_idx))
            paths.append(new_path)
            position_estimates.append(x_est[0:2])
        prev_pts = np.squeeze(position_estimates)
    
    curr_pts = utils.markers_to_pts(new_markers)
    new_markers = image.segment.process_markers(frame=cv2.cvtColor(raw_frames[frame_idx, ...],
                                                                   cv2.COLOR_BGR2HSV),
                                                min_size=5000,
                                                max_size=18000,
                                                markers=new_markers)
    
    frame_idx +=1 # Increment as frame_idx starts counting from 0 but we'v already processed the 0th frame above.
    tracked_pts = []
    
    # Get measurements from the current frame
    curr_pts = list(curr_pts.values())
    if curr_pts:
        curr_pts = np.stack(curr_pts, axis=0)
    else:
        # Why are we setting all points to zero if there are no points?
        curr_pts = np.empty((0, *prev_pts.shape[1:]))
        
    # Propagate paths from the previous estimate
    paths = kalman.propagate_paths(paths, frame_idx, PHI, GAM, Q)
    
    # Mark paths that have not been propagated for more than PROP_THD as dead
    # TODO: consider errors on the state estimate as a threshold.
    paths = check_paths_alive(paths, PROP_THD)

    curr_pts_idxs, path_labels = kalman.align_measurements_to_paths(paths, curr_pts, H, R)
    paths, new_pts = kalman.assign_measurements_to_paths(paths, curr_pts, curr_pts_idxs, path_labels, frame_idx)

    # Add new paths for new points:
    # Get any new measurements (unassigned in the munkres algorithm)
    paths = kalman.append_new_paths(new_pts, measure_uncert, frame_idx, paths)
    
    # Update estimates for paths with measurements
    paths = kalman.update_estimates_on_paths_with_measurements(paths, H, R, PHI)
    
    if frame_idx % save_frequency == 0:
        utils.save_video(copy.deepcopy(raw_frames[:frame_idx, ...]), copy.deepcopy(paths), os.path.join(VIDEO_OUTPUT_DIR, VIDEO_ID))
        print("Saved video")
utils.save_video(copy.deepcopy(raw_frames[:frame_idx, ...]), copy.deepcopy(paths), os.path.join(VIDEO_OUTPUT_DIR, VIDEO_ID))

In [None]:
# Pick a path and plot the estimation errors
idx_path = 5
fig = plt.figure(figsize=(15, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), aspect=False, share_all=True)

for idx, ax in enumerate(grid):
    frames = [frame_id for P_est, frame_id in paths[idx_path].P_est]
    x_est = [x_est[idx] for x_est, frame_id in paths[idx_path].x_est]
    z_meas = [z_meas[idx] for z_meas, frame_id in paths[idx_path].z_meas]
    z_diff = [x - z for x, z in zip(x_est, z_meas)]
    ax.plot(frames, z_diff)
    ax.plot(frames, np.sqrt([P_est[idx,idx] for P_est, frame_id in paths[idx_path].P_est]), 'r:')
    ax.plot(frames, -np.sqrt([P_est[idx,idx] for P_est, frame_id in paths[idx_path].P_est]), 'r:')

In [None]:
# Filter out all paths that have a length which is shorter than the defined threshold MIN_LEN_PATH
filt_paths = [p for p in paths if len(p.x_est) > MIN_LEN_PATH]

# Save processed data

In [None]:
def get_velocity(path):
    velocity = [float(np.sqrt(pow(x[0][2], 2) + pow(x[0][3], 2))[0]) for x in path.x_est]
    return velocity

vel_per_path = [get_velocity(x) for x in filt_paths]
avg_vel_per_path = [float(np.mean(x)) for x in vel_per_path]
all_vel = [x for y in vel_per_path for x in y]
avg_vel_video = float(np.mean(all_vel))
std_vel_per_path = [float(np.std(x)) for x in vel_per_path]
std_vel_video = float(np.std(all_vel))

result = dict({
    "vel_per_path": vel_per_path,
    "avgvel_per_path": avg_vel_per_path,
    "all_vel": all_vel,
    "avg_vel_video": avg_vel_video,
    "std_vel_per_path": std_vel_per_path,
    "std_vel_video": std_vel_video
})

with open(os.path.join(ANALYSIS_OUTPUT_DIR, VIDEO_ID + ".yml"), "w") as f:
    yaml.dump([result], f)