In [None]:
# add root folder of the project to path
import sys
sys.path.insert(0, '../..')

In [None]:
# parameter settings
is_plot = False
is_export = False

landmarks_path = '../../data/landmarks/refine_6kmh_braless_18markers_12fps.pkl'
meshes_path = '../../data/meshes/6kmh_braless_26markers/'
test_landmarks_path = '../../data/test/braless_random_landmarks.pkl'

start=0
stride = 12
end=120

export_folder = 'output/rbf/ablation/'

# Data Loading

In [None]:
from mesh4d import obj3d

mesh_ls, texture_ls = obj3d.load_mesh_series(
    folder=meshes_path,
    start=start,
    stride=stride,
    end=end,
)

In [None]:
from mesh4d import utils

landmarks = utils.load_pkl_object(landmarks_path)
landmarks.interp_field()

In [None]:
from mesh4d.analyse.crave import clip_with_contour

contour = landmarks.extract(('marker 0', 'marker 2', 'marker 3', 'marker 14', 'marker 15', 'marker 17'))
mesh_clip_ls = clip_with_contour(mesh_ls, start_time=0, fps=120/stride, contour=contour, clip_bound='xy', margin=30)

In [None]:
body_ls = obj3d.init_obj_series(
    mesh_ls, 
    obj_type=obj3d.Obj3d_Deform
    )

In [None]:
breast_ls = obj3d.init_obj_series(
    mesh_clip_ls, 
    obj_type=obj3d.Obj3d_Deform
    )

# Ablation testing class

In [None]:
import mesh4d
from mesh4d import field, obj4d
from mesh4d.analyse import measure
from scipy.interpolate import RBFInterpolator

POST_ALIGN = True
KERNEL = 'thin_plate_spline'
FIELD_NBR = 100

class Trans_Nonrigid_RBF_ablation(field.Trans_Nonrigid):
    def regist(self, landmark_name: str, field_nbr: int = 100, **kwargs):
        landmarks_source = self.source.kps_group[landmark_name].get_points_coord()
        landmarks_target = self.target.kps_group[landmark_name].get_points_coord()
        landmarks_field = RBFInterpolator(landmarks_source, landmarks_target, kernel=KERNEL, **kwargs)

        self.post_align(landmarks_field, field_nbr)

    def post_align(self, landmarks_field, field_nbr: int = 100):
        self.source_points = self.source.get_vertices()

        if POST_ALIGN:
            shift_points = landmarks_field(self.source_points)
            self.deform_points = measure.nearest_points_from_plane(self.target.mesh, shift_points)

        else:
            self.deform_points = landmarks_field(self.source_points)

        self.field = RBFInterpolator(self.source_points, self.deform_points, neighbors=FIELD_NBR, kernel=KERNEL)

class Obj4d_RBF_ablation(obj4d.Obj4d_Deform):
    def regist(self, landmark_name: str, **kwargs):
        reg_num = len(self.obj_ls)
        
        for idx in range(reg_num):
            if idx == 0:
                self.process_first_obj()
                continue

            if self.enable_rigid:
                self.process_rigid_dynamic(idx - 1, idx, **kwargs)  # aligned to the previous one

            if self.enable_nonrigid:
                self.process_nonrigid_dynamic(idx - 1, idx, landmark_name, **kwargs)  # aligned to the later one
            
            if mesh4d.output_msg:
                percent = (idx + 1) / reg_num
                utils.progress_bar(percent, back_str=" registered the {}-th frame".format(idx))
            

    def process_nonrigid_dynamic(self, idx_source: int, idx_target: int, landmark_name: str, **kwargs):
        trans = Trans_Nonrigid_RBF_ablation(
            source_obj=self.obj_ls[idx_source],
            target_obj=self.obj_ls[idx_target],
        )
        trans.regist(landmark_name, **kwargs)
        self.obj_ls[idx_source].set_trans_nonrigid(trans)

# Kernel selection

In [None]:
kernel_results = {
    'control landmarks': {},
    'non-control landmarks': {},
}

## Control landmarks

In [None]:
import time
import mesh4d
from mesh4d import kps
mesh4d.output_msg = False

POST_ALIGN = True
FIELD_NBR = 100

for kernel in ['thin_plate_spline', 'linear', 'cubic', 'quintic']:
    KERNEL = kernel
    print('='*70)
    print('kernel {}'.format(KERNEL))

    kernel_results['control landmarks'][KERNEL] = {}
    results = kernel_results['control landmarks'][KERNEL]
    
    # registration
    start_time = time.time()

    o4 = Obj4d_RBF_ablation(
        fps=120 / stride,
        enable_rigid=False,
        enable_nonrigid=True,
    )

    o4.add_obj(*breast_ls)
    o4.load_markerset('landmarks', landmarks)
    o4.regist('landmarks')

    duration = time.time() - start_time
    
    # computation time
    print("computation time: {}".format(duration))
    results['duration'] = duration
    
    # control landmarks
    kps_source = landmarks.get_time_coord(0)
    o4.vkps_track(kps_source, start_id=0)
    vkps = o4.assemble_markerset(name='vkps')
    diff = kps.MarkerSet.diff(vkps, landmarks)

    print(diff['diff_str'])
    results['dist_mean'] = diff['dist_mean']
    results['dist_std'] = diff['dist_std']

## Non-control landmarks

In [None]:
import mesh4d
import numpy as np
mesh4d.output_msg = False

POST_ALIGN = True
FIELD_NBR = 100

for kernel in ['thin_plate_spline', 'linear', 'cubic', 'quintic']:
    KERNEL = kernel
    print('='*70)
    print('kernel {}'.format(KERNEL))

    kernel_results['non-control landmarks'][KERNEL] = {}
    results = kernel_results['non-control landmarks'][KERNEL]
    dist_ls = []
    duration_ls = []

    # k-fold cross-verification
    for name in landmarks.markers.keys():
        # split dataset
        landmarks_test, landmarks_train = landmarks.split((name, ))

        # registration
        start_time = time.time()
        
        o4 = Obj4d_RBF_ablation(
            fps=120 / stride,
            enable_rigid=False,
            enable_nonrigid=True,
        )

        o4.add_obj(*breast_ls)
        o4.load_markerset('landmarks_train', landmarks_train)
        o4.load_markerset('landmarks_test', landmarks_test)
        o4.regist('landmarks_train')

        duration = time.time() - start_time
        duration_ls.append(duration)

        # virtual key points tracking
        kps_source = landmarks_test.get_time_coord(0)
        o4.vkps_track(kps_source, start_id=0, name='vkps')
        vkps = o4.assemble_markerset(name='vkps')
        diff = kps.MarkerSet.diff(vkps, landmarks_test)

        for marker_diff in diff['diff_dict'].values():
            dist_ls.append(marker_diff['dist'])
    
    dist_mean = np.mean(np.array(dist_ls))
    dist_std = np.std(np.array(dist_ls))

    print("computation time: {}".format(np.mean(np.array(duration))))
    print('overall error: {:.2f} ± {:.2f} (mm)'.format(dist_mean, dist_std))

    results['duration'] = duration
    results['dist_mean'] = diff['dist_mean']
    results['dist_std'] = diff['dist_std']

In [None]:
utils.save_pkl_object(kernel_results, export_folder, 'kernel_results')

# Post-alignment

In [None]:
align_results = {
    'merge': {},
    'control landmarks': {},
    'non-control landmarks': {},
}

## Shape mergence improvement

In [None]:
import mesh4d
import numpy as np
from scipy.spatial import KDTree
mesh4d.output_msg = False

KERNEL = 'thin_plate_spline'
FIELD_NBR = 100

for post_align in [True, False]:
    POST_ALIGN = post_align
    print('='*70)
    print('post-align {}'.format(POST_ALIGN))

    align_results['merge'][POST_ALIGN] = {}
    results = align_results['merge'][POST_ALIGN]

    # registration
    start_time = time.time()

    o4 = Obj4d_RBF_ablation(
        fps=120 / stride,
        enable_rigid=False,
        enable_nonrigid=True,
    )

    o4.add_obj(*breast_ls)
    o4.load_markerset('landmarks', landmarks)
    o4.regist('landmarks')

    duration = time.time() - start_time
    
    # computation time
    print("computation time: {}".format(duration))
    results['duration'] = duration
    
    # shape mergence
    dist_ls = []

    for id in range(len(o4.obj_ls) - 1):
        # point-to-plane distance
        trans = o4.obj_ls[id].trans_nonrigid
        deform_points = trans.deform_points
        target_points = trans.target.mesh.points
        
        closest_points = measure.nearest_points_from_plane(trans.target.mesh, deform_points)
        dist = np.linalg.norm(deform_points - closest_points, axis=1)
        dist_ls.append(dist)

    dist_array = np.concatenate(dist_ls)
    results['dist_mean'] = np.mean(dist_array)
    results['dist_std'] = np.std(dist_array)

    print("nearest point distance: {:.2f} ± {:.2f} (mm)".format(results['dist_mean'], results['dist_std']))

In [None]:
import pyvista as pv
import numpy as np

scene = pv.Plotter()

source_points = breast_ls[2].get_sample_points(150)
deform_points = o4.obj_ls[2].trans_nonrigid.shift_points(source_points)
disp = deform_points - source_points

pdata = pv.vector_poly_data(source_points, disp)
glyph = pdata.glyph(factor=1)
scene.add_mesh(glyph, scalar_bar_args={'title': "displacement (mm)"})

scene.camera_position = 'xy'
scene.show()

In [None]:
import pyvista as pv
import numpy as np

scene = pv.Plotter()

mesh = breast_ls[2].mesh
trans = o4.obj_ls[2].trans_nonrigid
mesh_deform = trans.shift_mesh(mesh)
scene.add_mesh(mesh_deform, show_edges=True)

scene.camera_position = 'xy'
scene.show()

## Control landmarks

In [None]:
import time
import mesh4d
from mesh4d import kps
mesh4d.output_msg = False

KERNEL = 'thin_plate_spline'
FIELD_NBR = 100

for post_align in [True, False]:
    POST_ALIGN = post_align
    print('='*70)
    print('post-align {}'.format(POST_ALIGN))

    align_results['control landmarks'][POST_ALIGN] = {}
    results = align_results['control landmarks'][POST_ALIGN]
    
    # registration
    start_time = time.time()

    o4 = Obj4d_RBF_ablation(
        fps=120 / stride,
        enable_rigid=False,
        enable_nonrigid=True,
    )

    o4.add_obj(*breast_ls)
    o4.load_markerset('landmarks', landmarks)
    o4.regist('landmarks')

    duration = time.time() - start_time
    
    # computation time
    print("computation time: {}".format(duration))
    results['duration'] = duration
    
    # control landmarks
    kps_source = landmarks.get_time_coord(0)
    o4.vkps_track(kps_source, start_id=0)
    vkps = o4.assemble_markerset(name='vkps')
    diff = kps.MarkerSet.diff(vkps, landmarks)

    print(diff['diff_str'])
    results['dist_mean'] = diff['dist_mean']
    results['dist_std'] = diff['dist_std']

## Non-control landmarks

In [None]:
import mesh4d
import numpy as np
mesh4d.output_msg = False

KERNEL = 'thin_plate_spline'
FIELD_NBR = 100

for post_align in [True, False]:
    POST_ALIGN = post_align
    print('='*70)
    print('post-align {}'.format(POST_ALIGN))

    align_results['non-control landmarks'][POST_ALIGN] = {}
    results = align_results['non-control landmarks'][POST_ALIGN]
    dist_ls = []
    duration_ls = []

    # k-fold cross-verification
    for name in landmarks.markers.keys():
        # split dataset
        landmarks_test, landmarks_train = landmarks.split((name, ))

        # registration
        start_time = time.time()
        
        o4 = Obj4d_RBF_ablation(
            fps=120 / stride,
            enable_rigid=False,
            enable_nonrigid=True,
        )

        o4.add_obj(*breast_ls)
        o4.load_markerset('landmarks_train', landmarks_train)
        o4.load_markerset('landmarks_test', landmarks_test)
        o4.regist('landmarks_train')

        duration = time.time() - start_time
        duration_ls.append(duration)

        # virtual key points tracking
        kps_source = landmarks_test.get_time_coord(0)
        o4.vkps_track(kps_source, start_id=0, name='vkps')
        vkps = o4.assemble_markerset(name='vkps')
        diff = kps.MarkerSet.diff(vkps, landmarks_test)

        for marker_diff in diff['diff_dict'].values():
            dist_ls.append(marker_diff['dist'])
    
    dist_mean = np.mean(np.array(dist_ls))
    dist_std = np.std(np.array(dist_ls))

    print("computation time: {}".format(np.mean(np.array(duration))))
    print('overall error: {:.2f} ± {:.2f} (mm)'.format(dist_mean, dist_std))

    results['duration'] = duration
    results['dist_mean'] = diff['dist_mean']
    results['dist_std'] = diff['dist_std']

In [None]:
utils.save_pkl_object(align_results, export_folder, 'align_results')

# $M$ neighbors interpolation

In [None]:
nbr_results = {
    'control landmarks': {},
    'non-control landmarks': {},
}

## Control landmarks

In [None]:
import time
import mesh4d
from mesh4d import kps
mesh4d.output_msg = False

POST_ALIGN = True
KERNEL = 'thin_plate_spline'

for field_nbr in [20, 50, 100, 200]:
    FIELD_NBR = field_nbr
    print('='*70)
    print('field_nbr {}'.format(FIELD_NBR))

    nbr_results['control landmarks'][FIELD_NBR] = {}
    results = nbr_results['control landmarks'][FIELD_NBR]
    
    # registration
    start_time = time.time()

    o4 = Obj4d_RBF_ablation(
        fps=120 / stride,
        enable_rigid=False,
        enable_nonrigid=True,
    )

    o4.add_obj(*breast_ls)
    o4.load_markerset('landmarks', landmarks)
    o4.regist('landmarks')

    duration = time.time() - start_time
    
    # computation time
    print("computation time: {}".format(duration))
    results['duration'] = duration
    
    # control landmarks
    kps_source = landmarks.get_time_coord(0)
    o4.vkps_track(kps_source, start_id=0)
    vkps = o4.assemble_markerset(name='vkps')
    diff = kps.MarkerSet.diff(vkps, landmarks)

    print(diff['diff_str'])
    results['dist_mean'] = diff['dist_mean']
    results['dist_std'] = diff['dist_std']

## Non-control landmarks

In [None]:
import mesh4d
import numpy as np
mesh4d.output_msg = False

POST_ALIGN = True
KERNEL = 'thin_plate_spline'

for field_nbr in [20, 50, 100, 200]:
    FIELD_NBR = field_nbr
    print('='*70)
    print('field_nbr {}'.format(FIELD_NBR))

    nbr_results['non-control landmarks'][FIELD_NBR] = {}
    results = nbr_results['non-control landmarks'][FIELD_NBR]
    dist_ls = []
    duration_ls = []

    # k-fold cross-verification
    for name in landmarks.markers.keys():
        # split dataset
        landmarks_test, landmarks_train = landmarks.split((name, ))

        # registration
        start_time = time.time()
        
        o4 = Obj4d_RBF_ablation(
            fps=120 / stride,
            enable_rigid=False,
            enable_nonrigid=True,
        )

        o4.add_obj(*breast_ls)
        o4.load_markerset('landmarks_train', landmarks_train)
        o4.load_markerset('landmarks_test', landmarks_test)
        o4.regist('landmarks_train')

        duration = time.time() - start_time
        duration_ls.append(duration)

        # virtual key points tracking
        kps_source = landmarks_test.get_time_coord(0)
        o4.vkps_track(kps_source, start_id=0, name='vkps')
        vkps = o4.assemble_markerset(name='vkps')
        diff = kps.MarkerSet.diff(vkps, landmarks_test)

        for marker_diff in diff['diff_dict'].values():
            dist_ls.append(marker_diff['dist'])
    
    dist_mean = np.mean(np.array(dist_ls))
    dist_std = np.std(np.array(dist_ls))

    print("computation time: {}".format(np.mean(np.array(duration))))
    print('overall error: {:.2f} ± {:.2f} (mm)'.format(dist_mean, dist_std))

    results['duration'] = duration
    results['dist_mean'] = diff['dist_mean']
    results['dist_std'] = diff['dist_std']

In [None]:
utils.save_pkl_object(nbr_results, export_folder, 'nbr_results')