# ACM Summer School on Shape Modeling - 2022
# IIIT Delhi
## Rigid and Non-Rigid Shape Matching - July 26, 2022
## Aditya Tatu, DAIICT Gandhinagar

In [4]:
# Import required packages
%matplotlib notebook
import igl
import scipy as sp
import numpy as np
from meshplot import plot, subplot, interact
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import axes3d
import os
root_folder = os.getcwd()

# 1. Rigid Shape Matching

## Iterative Closest point(ICP) Matching Algorithm
#### You are given functions that will help you write the code for ICP algorithm
#### 1: Find Correspondences based on Closest Euclidean distance point: findCorr
#### 2: Given two sets of points X and Y, Map X using the optimal rigid transformation that minimizes least square
####       error between X and Y: myrigidalignment

In [37]:
# Returns indices of points in Y that are closest to each point in X. If X is px3, and Y is qx3 the returned 
# array size is p
def findCorr(X,Y):
    ns = X.shape[0]
    Corr = np.zeros((ns))
    for j in np.arange(ns):
        d = np.linalg.norm(Y-X[j],axis=1)
        Corr[j] = np.argmin(d)
    Corr = Corr.astype(int)
    return Corr

In [38]:
# Shape of X and Y must be N x 3
# Function that returns the aligned set of points X_new obtained by rigidly aligning points X to Y
# Function does not assume that X and Y are origin-centered.
def myrigidalign(X,Y):
    y_mn = np.mean(Y,axis=0)
    X = X - np.mean(X,axis=0)
    Xcorr = np.dot(X,(Y-y_mn).T)
    uu,ss,vv = np.linalg.svd(Xcorr)
    dg = np.diag([1,1,np.linalg.det(np.dot(vv.T,uu.T))]) #Peculiar compared to MATLAB
    R = np.dot(np.dot(vv.T,dg),uu.T)
    Xaligned = np.dot(R,X.T).T + y_mn
    return Xaligned
    
    

## Q.1 Complete the code below for ICP algorithm using findCorr and myrigidalign

In [None]:
# Use the above functions to complete the following function
def myICP(vsrc,vtgt):
    
    # Step 1: Compute mean centered vsrc
    # Step 2: Use a new variable that stores vsrc after every iteration
    # Step 2: Initialize a loop based on least square error
    # Step 3: Compute correspondence between current vsrc and vtgt
    # Step 4: Using the correspondence, select the appropriate subset of vtgt in the right order
    # Step 5: Rigidly align "original centered vsrc" with the above subset of vtgt, to obtain a new vsrc
    # Step 6: Compute the new error, and check whether the error is still decreasing. If not, STOP.    

In [None]:
# Run Examples for ICP
# Use cat0.off as target and cat0h2.off as source
vh, fh = igl.read_triangle_mesh(os.path.join(root_folder, "data", "cat0h2.off")) # Head of the cat0 mesh
v4, f4 = igl.read_triangle_mesh(os.path.join(root_folder, "data", "cat0.off")) # Full cat mesh

vhnew = myICP(vh,v4)
fig3 = plot(v4,f4,return_plot=True) # Plot cat0 mesh
fig3.add_points(vh,shading={"point_color": "red", "point_size": 5}) # Plot points of original head of cat
fig3.add_points(vnew,shading={"point_color": "blue", "point_size": 5}) # Plot points of aligned head of cat

# SVD Based Multidimensional Scaling(MDS)
## Q.2. Complete the code for SVD based MDS given below and observe the results

In [None]:
# Define a function that takes in pointwise distance between points of a mesh, and outputs a 3 dimensional 
# canonical embedding
# Distance matrices for two meshes, cat0 and cat4 are given in the data folder as discat0.npy and discat4.npy
def mySVDbasedMDS(d):
    #Step 1: Depending on the number of points of the mesh, define the centering matrix (J)
    #Step 2: From J, compute the matrix B = -0.5JDJ
    #Step 3: Compute its SVD B = U*S*V^T
    #Step 4: Compute the embedding Y as U[0:3]*sqrt{diag([S(0),S(1),S(2)])}, and return Y

In [5]:
# Run the SVD based MDS on cat0 and cat4 here and display the obtained embedding
v0, f0 = igl.read_triangle_mesh(os.path.join(root_folder, "data", "cat0.off"))
v4, f4 = igl.read_triangle_mesh(os.path.join(root_folder, "data", "cat4.off"))

d0 = np.load('data/discat0.npy')
d4 = np.load('data/discat4.npy')

# v0emb = mySVDbasedMDS(d0)
# v4emb = mySVDbasedMDS(d4)


In [None]:
plot(vo,f0)
plot(v4,f4)

In [None]:
# Display the Results
fig = plot(v0emb,f0,return_plot=True)
fig.add_points(v4emb,shading={"point_color": "red", "point_size": 2})