In [None]:
import sys

In [None]:
sys.path.insert(0,'../../optimization/rtms_bayesopt/lib/python2.7/site-packages/')
sys.path.insert(0,'/KIMEL/tigrlab/projects/jjeyachandra/gmsh-sdk/lib/')

In [None]:
import os
from nilearn import image as img
from nilearn import plotting as plot
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np

This part here will need to be generalized and taken to external scripts to compute the "centroid" voxel/vertex. For now we use a Naive volumetric centroid based on voxels (roughly equivalent to computing spatial centroid in FEM). 

Alternative we could find a surface-based centroid which is based on shortest paired-path on a mesh-derived graph. Both procedures can be abstracted from the parameter space forming routine

In [None]:
#Get parcels
ribbon_dir = '../../data/sub-CMH090/ribbon/'
t1 = '../../data/simnibs_output/m2m_sub-CMH090/T1fs_nu_conform.nii.gz'
L_ribbons = sorted([os.path.join(ribbon_dir,f) for f in os.listdir(ribbon_dir) if '.L.network' in f])
R_ribbons = sorted([os.path.join(ribbon_dir,f) for f in os.listdir(ribbon_dir) if '.R.network' in f])

In [None]:
ind = 0
r1 = img.math_img('a+b', a=L_ribbons[ind], b=R_ribbons[ind])

In [None]:
plot.view_img(r1,bg_img=t1,symmetric_cmap=False)

In [None]:
#Parcel 4 of Control Network A on Left Hemisphere as an example
L_img = img.load_img(L_ribbons[ind])
r_data = L_img.get_data()

In [None]:
#Get image centroid (left side)
x,y,z = np.where(r_data == 4)
mu_x,mu_y,mu_z = x.mean(), y.mean(), z.mean()

#Round values into integer for visualization
vx,vy,vz = int(np.round(mu_x)), int(np.round(mu_y)), int(np.round(mu_z))

In [None]:
#Write into r_data
r_data[vx-1:vx+1,vy-1:vy+1,vz-1:vz+1] = 9000

#Make into NIFTI1
roi_img = nib.Nifti1Image(r_data,r1.affine,r1.header)

In [None]:
plot.view_img(roi_img,bg_img=t1,symmetric_cmap=False,cmap='ocean_hot')

In [None]:
vx,vy,vz

In [None]:
nib.save(roi_img,'./testing.nii.gz')

With the centroid coordinate, apply the affine transformation. Then project to closest head voxel (using gmsh), then display in voxel space using reverse affine transformation.

In [None]:
import gmsh
gmsh.initialize()

In [None]:
msh_file = '../../data/simnibs_output/sub-CMH090.msh'
gmsh.open(msh_file)

In [None]:
#Load head vertices
surf_head = (2,5)
head_node_tag, head_node_coord, head_node_param = gmsh.model.mesh.getNodes(surf_head[0],surf_head[1])
head_node_tag = np.array(head_node_tag)
head_node_coord = np.array(head_node_coord).reshape((len(head_node_coord)//3,3))

#Load head elements
head_tag, head_el, head_tri = gmsh.model.mesh.getElements(dim=2,tag=5)
head_tri = np.array(head_tri[0]).reshape((len(head_tri[0])//3,3))

In [None]:
#Affine transform the candidate coordinate
centroid_vox = np.array([mu_x,mu_y,mu_z,1],dtype=np.float32)
aff = r1.affine
centroid_coord = np.dot(aff,centroid_vox)[:-1]

In [None]:
#Euclidean distance from centroid
eudist = np.linalg.norm(head_node_coord - centroid_coord,axis=1)
min_ind = np.argmin(eudist)

In [None]:
head_node_coord[min_ind]

Now using closest head vertex coordinate <code> head_node_coord[min_ind] </code>, define the parameteric surface by a simple Euclidean distance metric (rather overestimate parametric surface than under using geodesic). 

Using vertex subset, compute average normalized normal of surrouding faces, then push outward. This defines the spatial positioning parameteric mesh

In [None]:
#Get all head vertices within Euclidean sphere of head coordinate
#in mm
head_eudist = np.linalg.norm(head_node_coord - head_node_coord[min_ind],axis=1)
search_rad= 25
search_inds = np.where(head_eudist < search_rad)

In [None]:
#Set up algorithm
#Step 1: Sort the vertex list (doesn't actually matter unless using binary tree)
vert_list = head_node_tag[search_inds]
vert_coords = head_node_coord[search_inds]
vert_list.sort()

In [None]:
from timeit import default_timer as timer

In [None]:
import numba

In [None]:
@numba.njit(parallel=True)
def get_relevant_triangles(verts, triangles):
    '''
    From an array of vertices and triangles. Get triangles that contain at least one vertex
    Arguments:
        verts                                 1-D array of vertexIDs
        triangles                             (Nx3) array of triangles, where each column is a vertex
    Output:
        t_arr                                 Nx1 Boolean array where indices correspond to triangles
                                              True if triangle contains at least one vertex from list
    '''
    
    t_arr = np.zeros((triangles.shape[0]),dtype=np.int64)
    
    for t in numba.prange(0,triangles.shape[0]):
        for c in np.arange(0,3):
            for v in verts:
                
                if triangles[t][c] == v:
                    t_arr[t] = 1
                    break
            if t_arr[t] == 1:
                break

    return t_arr

In [None]:
#Get an array with relevant triangles
start = timer()
t_arr = get_relevant_triangles(vert_list,head_tri)
stop = timer()
print(stop-start)

In [None]:
#Subset the original triangle array
t_rel = np.where(t_arr > 0)
rel_trigs = head_tri[t_rel[0],:]

#Get val --> index array mapping (no dicts in numba lel)
u_val = np.unique(rel_trigs)
u_ind = np.arange(0,u_val.shape[0])

#Create mapping 
sort_map = {v:i for v,i in zip(u_val,u_ind)}

#Map triangle nodes to normalized values based on sort index
map_func = lambda x: sort_map[x]
vmap_func = np.vectorize(map_func)

In [None]:
#Apply map to triangles then get associated coordinates (already sorted)
mapped_trigs = vmap_func(rel_trigs)
rtrig_verts = np.where(np.isin(head_node_tag,u_val))
rvert_coords = head_node_coord[rtrig_verts,:][0]

In [None]:
@numba.njit
def unitize_arr(arr):
    '''
    Normalize array row-wise
    '''
    
    narr = np.zeros((arr.shape[0],3),dtype=np.float64)
    for i in np.arange(0,arr.shape[0]):
        narr[i] = arr[i,:]/np.linalg.norm(arr[i,:])
        
    return narr

In [None]:
@numba.njit
def cross(a,b):
    '''
    Compute cross product between two vectors (latest numpy method)
    Arguments:
        a,b                    A single vector 
        
    Output
        Cross product
    '''
    #Output array
    out = np.zeros(3,dtype=np.float64)
    
    out[0] = a[1]*b[2]
    tmp = a[2]*b[1]
    out[0] -= tmp
    
    out[1] = a[2]*b[0]
    tmp = a[0]*b[2]
    out[1] -= tmp
    
    out[2] = a[0]*b[1]
    tmp = a[1]*b[0]
    out[2] -= tmp
    
    return out 

In [None]:
@numba.njit
def get_vert_norms(trigs, coords):
    '''
    Compute vertex normals using cumulative normalization trick
    Arguments:
        trigs                                Array of triangles with normalized values (1 --> size(unique(trigs)))
        coords                               Array of coordinates (vals in trigs corresponds to ind in coords)
    Output:
        norm_arr                             Array of norm vectors
    '''
    
    cnorm_arr = np.zeros((coords.shape[0],3),dtype=np.float64)
    for i in np.arange(0,trigs.shape[0]):
        
        iv1 = trigs[i,0]
        iv2 = trigs[i,1]
        iv3 = trigs[i,2]
        
        v1 = coords[iv1,:]
        v2 = coords[iv2,:]
        v3 = coords[iv3,:]
        
        c = cross(v2-v1,v3-v1)
        
        cnorm_arr[iv1,:] += c
        cnorm_arr[iv2,:] += c
        cnorm_arr[iv3,:] += c
        
    
    #Run normalization routine
    norm_arr = unitize_arr(cnorm_arr)
    return norm_arr

In [None]:
#Compute vertex normals
start = timer()
norm_arr = get_vert_norms(mapped_trigs,rvert_coords)
stop = timer()
print(stop-start)

In [None]:
#Get indices of norms to use (from original vertex list)
norm_vinds = np.where(np.isin(vert_list,u_val))[0]
norm_varr = norm_arr[norm_vinds]
print(vert_list.shape,norm_varr.shape, vert_coords.shape)

In [None]:
#Apply vertex-wise dilation (1 unit = 1mm), use (c)mm 
c = 5
dil_coords = vert_coords + c*np.mean(norm_varr,axis=0)

In [None]:
#Get subset of triangles for vertex subset
@numba.njit(parallel=True)
def get_subset_triangles(verts, triangles):
    '''
    From an array of vertices and triangles. Get triangles that contain all vertices
    Arguments:
        verts                                 1-D array of vertexIDs
        triangles                             (Nx3) array of triangles, where each column is a vertex
    Output:
        t_arr                                 Nx1 Boolean array where indices correspond to triangles
                                              True if all 3 vertices of triangle found in verts
    '''
    
    t_arr = np.zeros((triangles.shape[0]),dtype=np.int64)
    
    for t in numba.prange(0,triangles.shape[0]):
        for c in np.arange(0,3):
            for v in verts:
                
                if triangles[t][c] == v:
                    t_arr[t] += 1
                    break
                    
        if t_arr[t] == 3:
            t_arr[t] = 1
        else:
            t_arr[t] = 0

    return t_arr
    

In [None]:
#Get face information for parametric surface (for visualization)
start = timer()
dil_faces_ind = get_subset_triangles(vert_list,rel_trigs)
stop = timer()
print(stop-start)

In [None]:
#Vars for parameteric surface mesh, shift vert_list/trigs by max to prevent nodal overlap
dil_faces = rel_trigs[np.where(dil_faces_ind)].flatten(order='C') + vert_list.max()
dil_faces = list(dil_faces)
dil_verts = vert_list + vert_list.max() 
dil_coords = dil_coords.flatten()
print(len(dil_faces),dil_verts.shape,dil_coords.shape)

In [None]:
# Generate parameteric surface mesh and save
gmsh.initialize()
gmsh.model.add('param_surf')
tag = gmsh.model.addDiscreteEntity(2,2001)
gmsh.model.mesh.setNodes(2,tag,nodeTags=dil_verts,coord=dil_coords)
gmsh.model.mesh.setElements(2,tag,[2],
                            elementTags=[range(1,len(dil_faces)//3 + 1)],
                            nodeTags=[dil_faces])
gmsh.write('../../output/param_surf.msh')
gmsh.finalize() 

In [None]:
#Write sampling coordinates into numpy binary
dil_coords.tofile('../../output/param_surf')

In [None]:
#Write normal vertex to surface
v_norm = np.mean(norm_varr,axis=0)
v_norm.tofile('../../output/norm_varr')