In [1]:
import os
import pickle

import cv2
from matplotlib import pyplot as plt
import numpy as np

from falatra.utils import find_bbox_centre, distance_from_line
from falatra.keypoints import Frame, FrameMatcher
from falatra.model.head3d import deserialize_headmodel, HeadModel3D
from falatra.model.stereo import StereoCalibration
from falatra.markers import MarkerDetection
import procrustes.utils
from procrustes import generic
from mpl_toolkits.mplot3d import Axes3D

%matplotlib qt

In [2]:
def listFullPaths(folder):
    
    paths = []
    for filename in sorted(os.listdir(folder)):
        path = os.path.join(folder, filename)
        paths.append(path)
        
    return paths

training = {}
training['side'] = listFullPaths('./data/training/sonny_ba/left')
training['front'] = listFullPaths('./data/training/sonny_ba/centre')

labels = {}
labels['side'] = listFullPaths('./data/training/sonny_ba/left_labels')
labels['front'] = listFullPaths('./data/training/sonny_ba/centre_labels')

with open('./data/framefront.ser', 'rb') as fp:
    modelframe_front = pickle.load(fp)
    
with open('./data/frameside.ser', 'rb') as fp:
    modelframe_side = pickle.load(fp)

headmodel_front = deserialize_headmodel('./data/headmodelfront.ser')
headmodel_side  = deserialize_headmodel('./data/headmodelleft.ser')
calibration = StereoCalibration()
calibration.load('./data/calibration/calibration1')
matcher = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)

In [3]:
modelframe_front.display()
modelframe_side.display()

In [13]:
def lowe_ratio_filter(knnmatches, loweratio=0.7):
    
    good = []
    for knnmatch in knnmatches:
        m, n = knnmatch[:2]  # get the 1st and 2nd closest
        if m.distance < loweratio * n.distance:
            good.append([m])
            
    return good

def display_matches(matches, queryimg, querykps, trainimg, trainkps):
    vis = cv2.drawMatchesKnn(queryimg, querykps, 
                             trainimg, trainkps, 
                             matches,
                             None,
                             flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
    plt.figure()
    plt.imshow(vis[...,[2,1,0]])
    plt.show()
    
def plot_atom_coordinates(coords1, coords2,
                          figsize=(12, 10),
                          fontsize_label=14,
                          fontsize_title=16,
                          fontsize_legend=16,
                          label1=None,
                          label2=None,
                          title=None,
                          figfile=None):
    """Plot Cartesian coordinates of given atoms.

    Parameters
    ----------
    coords1: np.ndarray
        Cartesian coordinates of given atom set 1.
    coords2: np.ndarray
        Cartesian coordinates of given atom set 2.
    figsize : (float, float), optional
        Figure size with width and height in inchies.
    fontsize_label: int, optional
        The font size for labels. Default=14.
    fontsize_label: int, optional
        The font size for title. Default=16.
    label1 : str, optional
        Label for coords1. Default=None.
    label2 : str, optional
        Label for coords2. Default=None.
    title : str, optional
        Figure title. Default=None.
    figfile : str, optional
        Figure file name to save it. Default=None.

    """
    fig = plt.figure(figsize=figsize)
    ax = Axes3D(fig)

    ax.scatter(xs=coords1[:, 0], ys=coords1[:, 1], zs=coords1[:, 2],
               marker="o", color="blue", s=40, label=label1)
    ax.scatter(xs=coords2[:, 0], ys=coords2[:, 1], zs=coords2[:, 2],
               marker="o", color="red", s=40, label=label2)

    ax.set_xlabel("X", fontsize=fontsize_label)
    ax.set_ylabel("Y", fontsize=fontsize_label)
    ax.set_zlabel("Z", fontsize=fontsize_label)
    ax.legend(fontsize=fontsize_legend, loc="best")

    plt.title(title,
              fontsize=fontsize_title)
    # save figure to a file
    if figfile:
        plt.savefig(figfile)

    plt.show()
    
error_dict = {}
errors = []
end = len(training['side']) - 1
for image_index in range(1):

    image_front = cv2.imread(training['front'][image_index])
    label_front = MarkerDetection()
    label_front.load(labels['front'][image_index])
    frame_front = Frame(image_front)
    frame_front.detect(detectFace=True)

    image_side = cv2.imread(training['side'][image_index])
    label_side = MarkerDetection()
    label_side.load(labels['side'][image_index])
    frame_side = Frame(image_side)
    frame_side.detect(detectFace=True)
    
    # ---------- 1st step: FIND CORRESPONDENCES BETWEEN STEREO IMAGES ---------- 

    stereomatches = matcher.knnMatch(frame_side.des, frame_front.des, 2)

    # Lowe ratio filtering
    loweratio = 0.9
    stereomatches = lowe_ratio_filter(stereomatches, loweratio)

    # Epipolar constraint
    epipolar_threshold = 20.0
    good = []
    for match in stereomatches:

        pt_src = frame_side.kps[match[0].queryIdx].pt
        pt_dst = frame_front.kps[match[0].trainIdx].pt

        # undistort these points
        pt_src = calibration.undistort_points([pt_src], view='right').squeeze()
        pt_dst = calibration.undistort_points([pt_dst], view='left').squeeze()

        pt_dst = pt_dst[np.newaxis, ...]  # junk to make this work if my code wasn't so bad
        line = calibration.compute_correspond_epilines(pt_dst, view='right').squeeze()

        d = distance_from_line(pt_src, *line)
        if d < epipolar_threshold:
            good.append(match)
    stereomatches = good

    # display result
    #display_matches(stereomatches, frame_side.image, frame_side.kps, frame_front.image, frame_front.kps)

    # ---------- 2nd step: FIND CORRESPONDENCES BETWEEN IMAGE AND HEAD MODEL ---------- 
    modelmatches_centre = matcher.knnMatch(frame_front.des, headmodel_front.descriptors, 2)
    modelmatches_side = matcher.knnMatch(frame_side.des, headmodel_side.descriptors, 2)

    # Lowe ratio filtering
    loweratio = 0.95
    modelmatches_centre = lowe_ratio_filter(modelmatches_centre, loweratio)
    #display_matches(modelmatches_centre, frame_front.image, frame_front.kps, modelframe_front.image, modelframe_front.kps)


    # ---------- 3rd step: 3-WAY MATCHING BETWEEN STEREO MODEL & VECTRA 3D MODEL ----------

    highfi_model = HeadModel3D()
    lowfi_model = HeadModel3D()


    #O(N^2), can be improved to O(Nlog) if we use a more efficient biparte graph matching or
    # using a hash map to store stereo matches...
    new_modelmatches_centre = []
    for match in modelmatches_centre:

        stereoimg_indx = match[0].queryIdx  
        for match2 in stereomatches:
            if stereoimg_indx == match2[0].trainIdx:
                new_modelmatches_centre.append(match)

                # triangulate keypoint point
                pt_src = frame_side.kps[match2[0].queryIdx].pt
                pt_dst = frame_front.kps[match2[0].trainIdx].pt

                # undistort these points
                pt_src = calibration.undistort_points([pt_src], view='right').squeeze()
                pt_dst = calibration.undistort_points([pt_dst], view='left').squeeze()

                X, X1, X2 = calibration.triangulate(pt_dst, pt_src) # src and dst is swap because calibration inverse camera definitions
                lowfi_model.addFeaturePoint(X1.flatten(), frame_side.des[match2[0].queryIdx])

                # create a new model from highfi model that only contain correspondence keypoints
                highfi_model.addFeaturePoint(headmodel_front.keypoints[match[0].trainIdx],
                                             headmodel_front.descriptors[match[0].trainIdx])

    # Triangulate landmarks
    for name in label_front.bboxes.keys():
        if name in label_side.bboxes.keys():

            # Triangulate landmark
            pt_src = find_bbox_centre(label_side.bboxes[name])
            pt_dst = find_bbox_centre(label_front.bboxes[name])

            # undistort these points
            pt_src = calibration.undistort_points([pt_src], view='right').squeeze()
            pt_dst = calibration.undistort_points([pt_dst], view='left').squeeze()

            X, X1, X2 = calibration.triangulate(pt_dst, pt_src) # src and dst is swap because calibration inverse camera
            lowfi_model.setLandmark(name, X1)


    # Transfer landmarks to highfi mode
    for landmark in lowfi_model._landmarks.keys():
        if landmark in headmodel_front._landmarks:
            highfi_model.setLandmark(landmark, headmodel_front._landmarks[landmark])


    #display_matches(new_modelmatches_centre, frame_front.image, frame_front.kps, modelframe_front.image, modelframe_front.kps)       

    # ---------- 4th step: Procrustes analysis ----------
        
    result = generic(highfi_model.keypoints, lowfi_model.keypoints, translate=True, scale=True)
    kp_new = np.dot(result.new_a, result.t)
    """
    plot_atom_coordinates(coords1=kp_new,
                          coords2=result.new_b,
                          figsize=(8, 6),
                          label1="Highfi model points",
                          label2="Lowfi model points",
                          title="Keypoint registration")
    """
    lndmk_ref = []  # reference landmarks from highfi model
    lndmk_gt = []   # ground truth 
    for landmark in lowfi_model._landmarks.keys():
        lndmk_gt.append(np.array(lowfi_model._landmarks[landmark]).flatten())
        lndmk_ref.append(np.array(highfi_model._landmarks[landmark]).flatten())

    lndmk_gt = np.array(lndmk_gt)
    lndmk_ref = np.array(lndmk_ref)

    # normalisation
    lndmk_gt, lndmk_ref = procrustes.utils.setup_input_arrays(lndmk_gt, lndmk_ref, remove_zero_col=False,
                                                         remove_zero_row=False,
                                                         pad=False, translate=True, scale=True,
                                                         check_finite=True, weight=None)

    # registration
    lndmk_inferred = np.dot(lndmk_ref, result.t)
    """
    plot_atom_coordinates(coords1=lndmk_inferred,
                          coords2=lndmk_gt,
                          figsize=(8, 6),
                          label1="Inferred Landmarks",
                          label2="Ground-truth landmarks",
                          title="Landmarks registration")
    """
    rmsd= np.sqrt(np.mean(np.sum((lndmk_ref - lndmk_inferred)**2, axis=1)))
    print(rmsd)
    errors.append(rmsd)
    
    for i, landmark in enumerate(lowfi_model._landmarks.keys()):
        pt3d_1 = lndmk_inferred[i]
        pt3d_2 = lndmk_gt[i]
        
        if landmark not in error_dict:
            error_dict[landmark] = 0
            
        error_dict[landmark] += np.sqrt(np.sum((pt3d_1 - pt3d_2)**2))
        
for key, value in error_dict.items():
    error_dict[key] = value / len(training['side'])
    
print(error_dict)

0.5653836966262983
{'RightSide_ch': 0.0008523237473579557, 'LeftSide_ch': 0.0017814899434235832, 'Medial_m': 0.0013754545761547213, 'Medial_g': 0.0013302169167161704, 'Medial_prn': 0.0010822501093186567, 'Medial_pg': 0.0014583810811557428, 'Medial_ls': 0.0003550208974285491, 'RightSide_mvi': 0.0021438987122036866, 'RightSide_cph': 0.0006022809726044809, 'LeftSide_mvi': 0.002570711989227479, 'LeftSide_cph': 0.00031040050698610957}


In [16]:


fig = plt.figure()
ax = Axes3D(fig)

ax.scatter(xs=lndmk_inferred[:, 0], ys=lndmk_inferred[:, 1], zs=lndmk_inferred[:, 2],
               marker="o", color="blue", s=20, label="Landmarks inferred")
ax.scatter(xs=lndmk_gt[:, 0], ys=lndmk_gt[:, 1], zs=lndmk_gt[:, 2],
               marker="o", color="red", s=20, label="Landmarks ground-truth")

for i, landmark in enumerate(lowfi_model._landmarks.keys()):
    pt3d = lndmk_inferred[i]
    ax.text(*pt3d, f"{landmark}", color='blue', fontsize='xx-small')
    pt3d = lndmk_gt[i]
    ax.text(*pt3d, f"{landmark}", color='red', fontsize='xx-small')
    
    
ax.set_xlabel("X", fontsize=8)
ax.set_ylabel("Y", fontsize=8)
ax.set_zlabel("Z", fontsize=8)
ax.legend(fontsize=8, loc="best")

plt.title("Landmark Registration",
          fontsize=16)

Text(0.5, 0.92, 'Landmark Registration')

In [15]:
plt.figure()
plt.bar(range(len(errors)),errors)
plt.xlabel('Image pair')
plt.xticks(range(len(errors)))
plt.ylabel('Root-Mean Squared Error')
plt.title('Registration Error normalised')
plt.show()

In [9]:
fig = plt.figure()
ax = fig.subplots()
ax.barh(list(error_dict.keys()),error_dict.values(), align='center')
ax.set_title('Average L2 error for each landmark of one participant')
ax.set_xlabel('Normalised L2 Error')

Text(0.5, 0, 'Normalised L2 Error')