# Example: PCA with Brain Shapes

## Downloading the data

In [None]:
! wget https://www.dropbox.com/s/4xraqjtplz5e8ku/brainshape-data.zip
! unzip brainshape-data.zip

In [None]:
!pip install vtk

In [None]:
data_dir = 'data/brainshapes/'

In [None]:
import numpy as np
import vtk
import matplotlib.pyplot as plt
%matplotlib inline

## Shape registration

In [None]:
def umeyama_rigid(X, Y):
    
    # Get dimension and number of points
    m, n = X.shape
    
    # Demean the point sets X and Y
    X_mean = X.mean(1)
    Y_mean = Y.mean(1)

    X_demean =  X - np.tile(X_mean, (n, 1)).T
    Y_demean =  Y - np.tile(Y_mean, (n, 1)).T
    
    # Computing matrix XY' using demeaned point sets
    XY = np.dot(X_demean, Y_demean.T)

    # Singular value decomposition
    U,D,V = np.linalg.svd(XY,full_matrices=True,compute_uv=True)
    V=V.T.copy()
    
    # Determine rotation
    R = np.dot( V, U.T)

    # Determine translation
    t = Y_mean - np.dot(R, X_mean)
    
    return R,t

In [None]:
def umeyama_similarity(X, Y):
    
    # Get dimension and number of points
    m, n = X.shape

    # Demean the point sets X and Y
    X_mean = X.mean(1) #MODEL ANSWER
    Y_mean = Y.mean(1) #MODEL ANSWER
    
    X_demean =  X - np.tile(X_mean, (n, 1)).T #MODEL ANSWER
    Y_demean =  Y - np.tile(Y_mean, (n, 1)).T #MODEL ANSWER

    # Computing matrix XY' using demeaned and NORMALISED point sets (divide by the number of points n)
    # See Equation (38) in the paper
    XY = np.dot(X_demean, Y_demean.T) / n  #MODEL ANSWER

    # Determine variances of points X and Y, see Equation (36),(37) in the paper
    X_var = np.mean(np.sum(X_demean*X_demean, 0))
    Y_var = np.mean(np.sum(Y_demean*Y_demean, 0))

    # Singular value decomposition
    U,D,V = np.linalg.svd(XY,full_matrices=True,compute_uv=True)
    V=V.T.copy()
    
    # Determine rotation
    R = np.dot( V, U.T) #MODEL ANSWER
    
    # Determine the scaling, see Equation (42) in the paper (assume S to be the identity matrix, so ignore)
    c = np.trace(np.diag(D)) / X_var #MODEL ANSWER

    # Determine translation, see Equation (41) in the paper
    t = Y_mean - c * np.dot(R, X_mean) #MODEL ANSWER

    return R,t,c

In [None]:
def read_vtk(filename):
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(filename)
    reader.Update()
    polydata = reader.GetOutput()
    vertices = np.array([polydata.GetPoint(i) for i in range(polydata.GetNumberOfPoints())])
    return vertices

In [None]:
from pathlib import Path

pts = []
for path in Path(data_dir).rglob('*BrStem*.vtk'):
    v = read_vtk(str(path))
    v = np.hstack( (v[:,0], v[:,1], v[:,2]) )
    pts.append(v)
pts = np.array(pts).transpose()

In [None]:
m, n = pts.shape

num_nodes = m//3;
x_ind = range(num_nodes)
y_ind = range(num_nodes,num_nodes*2)
z_ind = range(num_nodes*2,num_nodes*3)

cx = pts[x_ind,:];
cy = pts[y_ind,:];
cz = pts[z_ind,:];

print('Dimension:\t' + str(m))
print('Samples:\t' + str(n))

In [None]:
from mpl_toolkits.mplot3d import Axes3D

def plot_pts(x,y,z,max_range=None,marker_size=10,figure_size=5):

    fig = plt.figure(figsize=(figure_size, figure_size), dpi=100)
    ax = fig.add_subplot(projection='3d')
    for s in range(x.shape[1]):
        ax.scatter(x[:,s], y[:,s], z[:,s], s=marker_size, marker='.')

    if max_range == None:
        max_range = np.array([x.max()-x.min(), y.max()-y.min(), z.max()-z.min()]).max() / 2.0

    mid_x = (x.max()+x.min()) * 0.5
    mid_y = (y.max()+y.min()) * 0.5
    mid_z = (z.max()+z.min()) * 0.5
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)
    ax.view_init(10,45)
    ax.grid()

In [None]:
# spatial normalisation
cx_norm = cx - np.tile(np.mean(cx,axis=0),(num_nodes,1))
cy_norm = cy - np.tile(np.mean(cy,axis=0),(num_nodes,1))
cz_norm = cz - np.tile(np.mean(cz,axis=0),(num_nodes,1))

plot_pts(cx_norm, cy_norm, cz_norm, marker_size=10)

In [None]:
id_source = 0
id_target = 3

source = np.vstack( (cx[:,id_source], cy[:,id_source], cz[:,id_source]) )
target = np.vstack( (cx[:,id_target], cy[:,id_target], cz[:,id_target]) )

In [None]:
shapes_x = np.vstack((source[0,:], target[0,:])).transpose()
shapes_y = np.vstack((source[1,:], target[1,:])).transpose()
shapes_z = np.vstack((source[2,:], target[2,:])).transpose()
plot_pts( shapes_x, shapes_y, shapes_z, marker_size=10 )

In [None]:
R, t = umeyama_rigid(source, target)
warped = np.dot(R,source) + np.tile(t, (num_nodes, 1)).transpose()

shapes_x = np.vstack((warped[0,:], target[0,:])).transpose()
shapes_y = np.vstack((warped[1,:], target[1,:])).transpose()
shapes_z = np.vstack((warped[2,:], target[2,:])).transpose()
plot_pts( shapes_x, shapes_y, shapes_z, marker_size=10 )

In [None]:
# Switch here between the two methods.
use_rigid = False

id_target = 0

target = np.vstack( (cx[:,id_target], cy[:,id_target], cz[:,id_target]) )

cx_norm[:,id_target] = target[0,:]
cy_norm[:,id_target] = target[1,:]
cz_norm[:,id_target] = target[2,:]

for i in range(1,n):
    source = np.vstack( (cx[:,i], cy[:,i], cz[:,i]) )
    
    if use_rigid:
        R, t = umeyama_rigid(source, target)        
        warped = np.dot(R,source) + np.tile(t, (num_nodes, 1)).transpose()
    else:
        R, t, c = umeyama_similarity(source, target)
        warped = c * np.dot(R,source) + np.tile(t, (num_nodes, 1)).transpose()
    
    cx_norm[:,i] = warped[0,:]
    cy_norm[:,i] = warped[1,:]
    cz_norm[:,i] = warped[2,:]
    
plot_pts(cx_norm, cy_norm, cz_norm, marker_size=10)

In [None]:
cx_mean = np.mean(cx_norm,axis=1)
cy_mean = np.mean(cy_norm,axis=1)
cz_mean = np.mean(cz_norm,axis=1)

plot_pts(cx_mean.reshape(-1,1), cy_mean.reshape(-1,1), cz_mean.reshape(-1,1), marker_size=10)

## Principal Component Analysis

In [None]:
import sklearn.decomposition as decomp

X = np.vstack((cx_norm, cy_norm, cz_norm))
m, n = X.shape
print('Dimension:\t' + str(m))
print('Samples:\t' + str(n))

In [None]:
# Create PCA instance
pca = decomp.PCA()

# Fit the data
pca.fit(X.T)

# Get the mean from PCA
mu_X = pca.mean_

# Get principal modes (a.k.a. components) from PCA
U = pca.components_.T

# Get the eigenvalues from PCA's singular values
D = pca.singular_values_**2 / (n - 1)

In [None]:
# Plot retained variance
fig, ax = plt.subplots()
ax.plot(np.cumsum(pca.explained_variance_ratio_))
ax.set_xlabel('Mode')
ax.set_ylabel('Retained Variance')
plt.show()

In [None]:
num_modes = 3
for i in range(num_modes):

    # add and subtract 2 times the standard deviation from the mean
    sp = mu_X + U[:,i] * np.sqrt(D[i]) * 3
    sn = mu_X - U[:,i] * np.sqrt(D[i]) * 3
    
    cxx = np.vstack((mu_X[x_ind], sp[x_ind], sn[x_ind])).T
    cyy = np.vstack((mu_X[y_ind], sp[y_ind], sn[y_ind])).T
    czz = np.vstack((mu_X[z_ind], sp[z_ind], sn[z_ind])).T
        
    plot_pts(cxx, cyy, czz, marker_size=10)

In [None]:
from ipywidgets import interact, fixed

def plot_points(mean_shape,modes,s1,s2,s3,s4,s5,s6):
    spine = mu_X + U[:,0] * s1 + U[:,1] * s2 + U[:,2] * s3 + U[:,3] * s4 + U[:,4] * s5 + U[:,5] * s6
    sx = spine[x_ind].reshape(-1,1)
    sy = spine[y_ind].reshape(-1,1)
    sz = spine[z_ind].reshape(-1,1)
    plot_pts(sx, sy, sz, max_range=30, marker_size=10)

def interactive_pca(mu_X,U,D):
    interact(plot_points,mean_shape=fixed(mu_X),modes=fixed(U),
             **{'s%d' % (i+1): (-np.sqrt(D[i]) * 10, np.sqrt(D[i]) * 10, np.sqrt(D[i])) for i in range(6)});

interactive_pca(mu_X,U,D)