### Volume to Tetra mapping
This is a Python-based prototype of the volume2tet routine which will map grey matter ribbon properties into the tetrahedral mesh generated by simNIBS. 

This will be accomplished using a Monte Carlo intersectional volume estimation technique. The main idea is to estimate the volume of the tetrahedra belonging to a particular volume to compute partial parcel memberships for each tetrahedron. 

Monte Carlo volumetric sampling relies on uniform sampling from a tetrahedron, this will be done using the parallelepiped folding algorithm. A variance metric will be used to test convergence of the sampling algorithm for each tetrahedron.

#### To Check
1. Version differences between gmsh 3 (which was never released URGHHH) and gmsh 4

In [None]:
import sys
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Run if on SCC JupyterHub

sys.path.insert(0,'/imaging/home/kimel/jjeyachandra/.conda/envs/3.6_tetra/lib/python3.6/site-packages/')
sys.path.insert(0,'/imaging/home/kimel/jjeyachandra/projects/jjeyachandra/gmsh-sdk/lib/')

In [None]:
import os
import gmsh
import numpy as np
import nibabel as nib
import nilearn as nil
from nilearn import plotting, image
from random import uniform
import matplotlib
from timeit import default_timer as timer
%matplotlib inline

In [None]:
def map_vox2coord(coord, affine):
    '''
    Performs a voxel coordinate --> RAS coordinate transformation 
    Arguments:
        coord                     1D - numpy array (x,y,z)
        affine                    Affine transformation matrix (qform)
        
    Output:
        vert                      Voxel midpoint in scanner space
    '''
    
    coord = coord.reshape(3,1)
    one = np.ones((1,1))
    #Transform into projection coordinate system
    coord = np.append(coord,one)
    
    #Apply affine matrix
    mid = np.matmul(affine,coord)
    
    #Remove last dummy coordinate (value=1)
    mid = mid[:-1]
    return mid
    

In [None]:
# Function to get bounds when given a midpoint
def point_in_vox(point,midpoint,voxdim=1):
    '''
    Arguments:
        point                         Iterable of length 3
        midpoint                      Voxel midpoint
        voxdim                        Voxel dimensions, assuming isotropic
        
    Output:
        Boolean: True if point in voxel bounds
    '''
    
    halfvox = voxdim/2
    
    #Checks
    if (point[0] <= midpoint[0] - halfvox) or (point[0] >= midpoint[0] + halfvox):
        return False
    elif (point[1] <= midpoint[1] - halfvox) or (point[1] >= midpoint[1] + halfvox):
        return False
    elif (point[2] <= midpoint[2] - halfvox) or (point[2] >= midpoint[2] + halfvox):
        return False
    else:
        return True

In [None]:
def uniform_tet(coords):
    '''
    Argument:
        coords                A (4,3) matrix with rows representing nodes
    Output:
        point                 A random point inside the tetrahedral volume
    '''
    
    s = uniform(0,1)
    t = uniform(0,1)
    u = uniform(0,1)

    #First cut
    if (s+t > 1):
        s = 1.0 - s
        t = 1.0 - t
        
    #Second set of cuts  
    if (t+u > 1):
        tmp = u
        u = 1.0 - s - t
        t = 1.0 - tmp
    elif (s + t + u > 1):
        tmp = u 
        u = s + t + u - 1 
        s = 1 - t - tmp
        
    a = 1 - s - t - u

    return a*coords[0] + s*coords[1] + t*coords[2] + u*coords[3]

In [None]:
def get_overlap_voxels(coords, img, affine):
    '''
    Take advantage of voxel space ordering to find voxels in axis-aligned bounding box intersection
    Arguments:
        coords                      (4x3) array containing 4 nodes of tetrahedron in 3 dimensional space
        img                         Image consisting of parcels
        affine                      The affine transformation matrix for vox --> RAS
    Output:
        midpoints                   List of midpoints of the voxels that may intersect the tetrahedron
        parcel                      The label values of the associated voxels in midpoints
    '''
    
    #Compute inverse affine
    inv_aff = np.linalg.inv(affine)
    
    #Convert into homogenous coordinate system of column vectors (4 nodes x 4 dim)
    coord_set = np.append(coords,np.ones((4,1)),axis=1).transpose()
    
    #Transform into voxel space
    coord_vox = np.matmul(inv_aff,coord_set)[:-1,:]
    
    #Get axis-aligned bounding box
    min_vox = np.floor(np.min(coord_vox,axis=1)).astype(np.int)
    max_vox = np.ceil(np.max(coord_vox,axis=1)).astype(np.int)
    
    #Get set of voxels
    x_range = np.arange(min_vox[0],max_vox[0]+1)
    y_range = np.arange(min_vox[1],max_vox[1]+1)
    z_range = np.arange(min_vox[2],max_vox[2]+1)
    vox_grid = np.meshgrid(x_range,y_range,z_range)

    #Format into column vectors
    x_arr = vox_grid[0].reshape((1,vox_grid[0].size))
    y_arr = vox_grid[1].reshape((1,vox_grid[1].size))
    z_arr = vox_grid[2].reshape((1,vox_grid[2].size))
    vox_arr = np.concatenate((x_arr,y_arr,z_arr),axis=0)
    
    #Extract parcels
    parcel_list = [img[tuple(col)] for col in vox_arr.T]
    
    #Convert into homogenous coordinate system, apply affine to get midpoints
    vox_arr = np.append(vox_arr,np.ones((1,vox_arr.shape[1])),axis=0)
    mid_arr = np.matmul(affine,vox_arr)[:-1,:]
    
    #return mid_arr.astype(np.int).T,parcel_list
    return vox_arr

In [None]:
gmsh.initialize()

In [None]:
f_tetra = '../data/simnibs_output/sub-CMH090.msh'
f_ribbon_r = '../data/sub-CMH090/ribbon/sub-CMH090_R_ribbon.nii.gz'
f_ribbon_l = '../data/sub-CMH090/ribbon/sub-CMH090_L_ribbon.nii.gz'
f_t1 = '../data/simnibs_output/m2m_sub-CMH090/T1fs_nu_conform.nii.gz'

In [None]:
t1_img = image.load_img(f_t1)
affine = t1_img.affine

#Load in ribbon files and merge hemispheres
r_ribbon_img = image.load_img(f_ribbon_r)
l_ribbon_img = image.load_img(f_ribbon_l)
ribbon_img = image.math_img('a+b',a=r_ribbon_img,b=l_ribbon_img)
#plotting.plot_img(ribbon_img,bg_img=t1_img)

In [None]:
#Now pull tetrahedral coordinate data
gmsh.open(f_tetra)

In [None]:
#Get tetrahedral volume 
tet_gm = (3,2)
tet_node_tag, tet_node_coord, tet_node_param = gmsh.model.mesh.getNodes(tet_gm[0],tet_gm[1])
tet_elem_tag, tet_elem_coord, tet_elem_param = gmsh.model.mesh.getElements(tet_gm[0],tet_gm[1])

In [None]:
# Get grey matter boundary surface
surf_gm = (2,2)
gm_node_tag, gm_node_coord, gm_node_param = gmsh.model.mesh.getNodes(surf_gm[0],surf_gm[1])
_, gm_elem_coord, gm_elem_param = gmsh.model.mesh.getElements(surf_gm[0],surf_gm[1])

In [None]:
# Get white matter boundary surface
surf_wm = (2,1)
wm_node_tag, wm_node_coord, wm_node_param = gmsh.model.mesh.getNodes(surf_wm[0], surf_wm[1])
_, wm_elem_coord, wm_elem_param = gmsh.model.mesh.getElements(surf_wm[0], surf_wm[1])

Mapping volumetric tetrahedrons into T1fs_nu_conform volume space

Here we explore the correspondence between the tetrahedrons derived from the volumetric mesh (which should retain coordinates in surfaceRAS) to the ribbon T1fs_nu_conform volume space. 

In [None]:
# Map tetrahedral vertices to spatial coordinates
tet_coord_map = { n : tet_node_coord[3*i:(3*i)+3] for i,n in enumerate(tet_node_tag) }
gm_coord_map = { n : gm_node_coord[3*i:(3*i)+3] for i,n in enumerate(gm_node_tag) }
wm_coord_map = { n : wm_node_coord[3*i:(3*i)+3] for i,n in enumerate(wm_node_tag) }

# Bring together dictionaries
node_2_coord = {}
node_2_coord.update(tet_coord_map)
node_2_coord.update(gm_coord_map)
node_2_coord.update(wm_coord_map)

tet_node_list = tet_elem_param[0].reshape((tet_elem_param[0].shape[0]//4, 4))

In [None]:
ribbon = ribbon_img.get_data()

#### Method:
For each tetrahedron:
1. Get subset of voxels which contain at least 1 vertex of the tetrahedron
2. Sample a random point using parallelepiped folding method
3. Check which voxel 'owns' the sampled point
4. Loop (2-3) until volume convergence using Monte-Carlo variance criterion (implement later)

In [None]:
def estimate_partial_parcel(coord, vox, parcels, n_iter=300):
    '''
    Arguments:
        coord               (4,3) indexable iterable of tetrahedral coordinates
        vox                 (n,3) indexable iterable of voxel coordinates
        parcels             (n,1) indexable iterable of parcel labels associated with jth voxel coordinate
        iter                 Number of Monte-Carlo sampling interations
    Output:
        partial_parcel      (n,1) array containing voxel ownership of tetrahedron 
        
    For each tetrahedron we want to assign the value of the voxel 
    '''
    
    if len(set(parcels)) == 1:
        return {int(parcels[0]) : 1}

    #Set up dictionary to store parcel membership score
    parcel_dict = {int(p) : 0 for p in set(parcels)}

    #Shift tetrahedron to origin
    trans = coord[0]
    shift_coord = coord - trans

    for i in np.arange(0,n_iter):
        #Sample and check membership
        p = uniform_tet(shift_coord)
        for i,v in enumerate(vox):
            if point_in_vox(p+trans,v):
                parcel_dict[int(parcels[i])] += 1
                continue
                
    #Normalize membership
    parcel_dict = {p: v/n_iter for p,v in parcel_dict.items()}
    return parcel_dict

In [None]:
x,y,z = np.where(ribbon != 0)

In [None]:
tet_2_parcel = {}
start = timer()
for i,tet in enumerate(tet_node_list):
    
   
    
    #Get tetrahedral coordinates 
    coords = np.array([node_2_coord[n] for n in tet])
    
    #Get voxels and parcel list
    vox2_coords = get_overlap_voxels(coords,ribbon,affine)
    break
    #Estimate voxel ownership of tetrahedron
    partial_dict = estimate_partial_parcel(coords,vox2_coords,parcels)
    
    #Need to store this data
    tet_2_parcel[i] = partial_dict
    
    end = timer()
    
    #Print progress
    sys.stdout.write('\r {}% done, time elapsed {}'.format(i/len(tet_node_list)*100, end-start))
    break

In [None]:
(len(tet_node_list) * 0.12)/(60*60*48)

## Components to Optimize

#### METHOD 
1. Perform affine transformation of tetrahedral coordinates into voxel space
2. Use axis-aligned bounding box
3. Use ribbon image for selecting voxels
4. Will bring you to O(1) complexity for finding associated voxels

EXPLOIT THE IMAGE GEOMETRY

In [None]:
# Old implementation that actually SUCKED HARD O(mn)
# tet_2_parcel = {}
# for i,tet in enumerate(tet_node_list):
    
#     start = timer()
    
#     #Get tetrahedral coordinates 
#     coords = np.array([node_2_coord[n] for n in tet])
    
#     #Get voxels intersecting tetrahedron O(4m) (can average a 2x speedup amortized if skipped on first find)
#     test = [ 
#         i for i,v in enumerate(zip(x,y,z))
#         if (any([point_in_vox(vert,map_vox2coord(np.array(v),affine))for vert in coords])) 
#     ]
      
#     #Make voxel list
#     vox_coords = np.array([map_vox2coord(np.array([x[v],y[v],z[v]]),affine) for v in test],dtype=np.int)
    
#     #Get parcels
#     parcels = [int(ribbon[tuple(v)]) for v in vox_coords]
    
#     #Estimate voxel ownership of tetrahedron
#     partial_dict = estimate_partial_parcel(coords,vox_coords,parcels)
    
#     #Need to store this data
#     tet_2_parcel[i] = partial_dict
    
#     end = timer()
    
#     #Print progress
#     sys.stdout.write('\r {}% done, time elapsed {}'.format(i/len(tet_node_list)*100, end-start))
    
#     break