In [1]:
import torch
from tqdm import tqdm
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import SimpleITK as sitk
import os
from lazy_imports import itkwidgets
from lazy_imports import itkview
from lazy_imports import interactive
from lazy_imports import ipywidgets
from lazy_imports import pv

In [2]:
from mtch.RegistrationFunc3DCuda import *
from mtch.SplitEbinMetric3DCuda import *
from mtch.GeoPlot import *

In [3]:
# from Packages.disp.vis import show_2d, show_2d_tensors
from disp.vis import vis_tensors, vis_path, disp_scalar_to_file
from disp.vis import disp_vector_to_file, disp_tensor_to_file
from disp.vis import disp_gradG_to_file, disp_gradA_to_file
from disp.vis import view_3d_tensors, tensors_to_mesh

In [4]:
import algo.metricModSolver2d as mms
import algo.geodesic as geo
import algo.euler as euler
import algo.dijkstra as dijkstra

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# after switch device, you need restart the script
torch.cuda.set_device(1)
torch.set_default_tensor_type('torch.cuda.DoubleTensor')

## Data I/O convention

### Read
Shape of input_tensor.nhdr is `[d, w, h, 6]`, and Shape of input_mask.nhdr is `[d, w, h]`
```
input_tensor = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(path)),(3,2,1,0))
input_mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(path)),(2,1,0))
```
input_tensor.shape is `[3, h, w, d]`, and input_mask.shape is `[h, w, d]`
### Write
output_tensor.shape is `[3, h, w, d]`, and output_mask.shape is `[h, w, d]`
```
output_tensor = sitk.WriteImage(sitk.GetImageFromArray(np.transpose(output_tensor,(3,2,1,0)), path)
output_mask = sitk.WriteImage(sitk.GetImageFromArray(np.transpose(output_tensor,(2,1,0)), path)
```
Shape of output_tensor.nhdr is `[d, w, h, 3]`, and Shape of output_mask.nhdr is `[d, w, h]`

### Note
`sitk.WriteImage(sitk.GetImageFromArray())` and `sitk.GetArrayFromImage(sitk.ReadImage(path))` is a pair of inverse operation, and you can see there is no inconsistence with regards to the dimension issue.
```
output_tensor = np.zeros((12,34,56,78))
sitk.WriteImage(sitk.GetImageFromArray(output_tensor), path)
input_tensor = sitk.GetArrayFromImage(sitk.ReadImage(path))
print(input_tensor)
'(12,34,56,78)'
```

## Data dim convention

Make sure you follow the conventions below to make the algorithm consistent.
- Tensor fields: All the tensor fields variables by default are of size `[h, w, d, 3, 3]`, making the last two dimensions index metric matrix, to comply pytorch. In my code, arguments and the outputs of all functions meet this requirement. 
- Compressed tensor fields: `atlas_lin` is always of size `[6, h, w, d]`, when it comes to the argument of `view_3d_tensors()`, using `np.transpose(atlas_lin,(1,2,3,0))` to satisfy the requirement temporarily.
- Diffeomorphisms: All the diffeo variables by default are of size `[3, h, w, d]`.
- Masks: All the mask variables by default are of size `[h, w, d]`, when it comes to `torch.einsum()`, you can use `.unsqueeze(0)` temporarily.

## Data plotting convention

To avoid the `x` and `y` ambiguity in indexing and ploting, naming the first two dimension in `[h, w, 2, 2]` in the order of `x`, `y` is the best choice! 
- When indexing the array, `x` indexes row and `y` indexes column, the way I typically do and the way how matplotlib plot the 2d image. 
- When plotting the tensors, matplotlib would rotate the array counterclockwise by 90 degrees. So the vertical axis is `y` and horizontal axis is `x`, which is also consistent with our knowledge in drawing the Cartesian coordinate system. Fortunately, Kirs' code has already done in this way, like the ellipse(x, y). 


## Algorithm caveat
- In energy calculation, only use the binary mask provided by Kris, rather than a weighted map, which will change the alpha field applied to the tensor field previously and result in geodesic misgoing.
- Both metric matching and mean calculating should be implemented on the inverse of the original DTI tensor field, since the geodesics are running on the inverse of the tensor field.
- When accumulating the diffeomorphisms, always remember the order of accumulation of phi and its inverse is different.
```
phi_acc = compose_function(phi_acc, phi)
psi_inv_acc = compose_function(phi_inv, psi_inv_acc)
```
- When an error like below is raised, it's probably caused by a large epsilon, so the composed tensor field is no longer positive definite everywhere.
```
cholesky_cpu: For batch 0: U(1,1) is zero, singular U.
```
- `a` in `Squared_distance_Ebin(g0, g1, a, mask)`, `get_karcher_mean(G, a)`, `get_geo(g0, g1, a, Tpts)`, `inv_RieExp_extended(g0, g1, a)`, `Rie_Exp_extended(g0, u, a)`, `Rie_Exp(g0, u, a)`, `inv_RieExp(g0, g1, a)` equals to the reciprocal of dimension, `1/dim`, namely the last entry of tensor field's shape.

- When an out of range error is raised, check if all the tensors get the right dimension order.
- Pay attention to the indexes when assigning the `atlas` to `atlas_lin`, there hasn't been any bugs in `SplitEbinMetric.py` found, to the best of my knowledge.

In [6]:
def phi_pullback(phi, g):
#     input: phi.shape = [3, h, w, d]; g.shape = [h, w, d, 3, 3]
#     output: shape = [h, w, 2, 2]
#     torch.set_default_tensor_type('torch.cuda.DoubleTensor')
    g = g.permute(3, 4, 0, 1, 2)
    idty = get_idty(*g.shape[-3:])
    #     four layers of scalar field, of all 1, all 0, all 1, all 0, where the shape of each layer is g.shape[-2:]?
    eye = torch.eye(3)
    ones = torch.ones(*g.shape[-3:])
    d_phi = get_jacobian_matrix(phi - idty) + torch.einsum("ij,mno->ijmno", eye, ones)
    g_phi = compose_function(g, phi)
    return torch.einsum("ij...,ik...,kl...->...jl", d_phi, g_phi, d_phi)


def energy_ebin(phi, g0, g1, f0, f1, sigma, dim, mask): 
#     input: phi.shape = [3, h, w, d]; g0/g1/f0/f1.shape = [h, w, d, 3, 3]; sigma/dim = scalar; mask.shape = [1, h, w, d]
#     output: scalar
# the phi here is identity
    phi_star_g1 = phi_pullback(phi, g1)
    phi_star_f1 = phi_pullback(phi, f1)# the compose operation in this step uses a couple of thousands MB of memory
    E1 = sigma * Squared_distance_Ebin(f0, phi_star_f1, 1./dim, mask)
    E2 = Squared_distance_Ebin(g0, phi_star_g1, 1./dim, mask)
    return E1 + E2


def energy_L2(phi, g0, g1, f0, f1, sigma, mask): 
#     input: phi.shape = [3, h, w, d]; g0/g1/f0/f1.shape = [h, w, d, 3, 3]; sigma = scalar; mask.shape = [1, h, w, d]
#     output: scalar
    phi_star_g1 = phi_pullback(phi, g1)
    phi_star_f1 = phi_pullback(phi, f1)
    E1 = sigma * torch.einsum("ijk...,lijk->", (f0 - phi_star_f1) ** 2, mask.unsqueeze(0))
    E2 = torch.einsum("ijk...,lijk->", (g0 - phi_star_g1) ** 2, mask.unsqueeze(0))
    # E = E1 + E2
#     del phi_star_g1, phi_star_f1
#     torch.cuda.empty_cache()
    return E1 + E2


def laplace_inverse(u):
#     input: u.shape = [3, h, w, d]
#     output: shape = [3, h, w, d]
    '''
    this function computes the laplacian inverse of a vector field u of size 3 x size_h x size_w x size_d
    '''
    size_h, size_w, size_d = u.shape[-3:]
    idty = get_idty(size_h, size_w, size_d).cpu().numpy()
    lap = 6. - 2. * (np.cos(2. * np.pi * idty[0] / size_h) +
                     np.cos(2. * np.pi * idty[1] / size_w) +
                     np.cos(2. * np.pi * idty[2] / size_d))
    lap[0, 0] = 1.
    lapinv = 1. / lap
    lap[0, 0] = 0.
    lapinv[0, 0] = 1.

    u = u.cpu().detach().numpy()
    fx = np.fft.fftn(u[0])
    fy = np.fft.fftn(u[1])
    fz = np.fft.fftn(u[2])
    fx *= lapinv
    fy *= lapinv
    fz *= lapinv
    vx = torch.from_numpy(np.real(np.fft.ifftn(fx)))
    vy = torch.from_numpy(np.real(np.fft.ifftn(fy)))
    vz = torch.from_numpy(np.real(np.fft.ifftn(fz)))

    return torch.stack((vx, vy, vz)).to(device=torch.device('cuda'))

        
def metric_matching(gi, gm, height, width, depth, mask, iter_num, epsilon, sigma, dim):
    phi_inv = get_idty(height, width, depth)
    phi = get_idty(height, width, depth)
    idty = get_idty(height, width, depth)
    idty.requires_grad_()
    f0 = torch.eye(int(dim)).repeat(height, width, depth, 1, 1)
    f1 = torch.eye(int(dim)).repeat(height, width, depth, 1, 1)
    
    for j in range(iter_num):
        phi_actsg0 = phi_pullback(phi_inv, gi)
        phi_actsf0 = phi_pullback(phi_inv, f0)
        E = energy_ebin(idty, phi_actsg0, gm, phi_actsf0, f1, sigma, dim, mask) 
        E.backward()
        v = - laplace_inverse(idty.grad)
        with torch.no_grad():
            psi =  idty + epsilon*v  
            psi[0][psi[0] > height - 1] = height - 1
            psi[1][psi[1] > width - 1] = width - 1
            psi[2][psi[2] > depth - 1] = depth - 1
            psi[psi < 0] = 0
            psi_inv =  idty - epsilon*v
            psi_inv[0][psi_inv[0] > height - 1] = height - 1
            psi_inv[1][psi_inv[1] > width - 1] = width - 1
            psi_inv[2][psi_inv[2] > depth - 1] = depth - 1
            psi_inv[psi_inv < 0] = 0
            phi = compose_function(psi, phi)
            phi_inv = compose_function(phi_inv, psi_inv)
            idty.grad.data.zero_()
            
    gi = phi_pullback(phi_inv, gi)
    return gi, phi, phi_inv


def tensor_cleaning(g, scale_factor):
#     det_zero_map = torch.where(torch.det(g)<=0,1.,0.)
#     background = torch.einsum("mno,ij->mnoij", torch.ones(*tensor_met_zeros.shape[:3]), torch.eye(3))*scale_factor
#     g = g + torch.einsum('ijk...,lijk->ijk...', background, det_zero_map.unsqueeze(0))
#     e,_ = torch.symeig(g)
#     lambd1_neg_map = torch.where(e[:,:,:,0]<=0,1.,0.)
#     lambd2_neg_map = torch.where(e[:,:,:,1]<=0,1.,0.)
#     lambd3_neg_map = torch.where(e[:,:,:,2]<=0,1.,0.)
#     abnormal_map = torch.where(lambd1_neg_map+lambd2_neg_map+lambd3_neg_map>0,1.,0.)
    abnormal_map = torch.where(torch.det(g)>10,1.,0.)
    background = torch.einsum("mno,ij->mnoij", torch.ones(*tensor_met_zeros.shape[:3]), torch.eye(3))*scale_factor
#     return torch.einsum('ijk...,lijk->ijk...', g, 1.-abnormal_map.unsqueeze(0))+\
#             torch.einsum('ijk...,lijk->ijk...', background, abnormal_map.unsqueeze(0))
    return torch.einsum('ijk...,lijk->ijk...', g, 1.-abnormal_map.unsqueeze(0))+\
            torch.einsum('ijk...,lijk->ijk...', g, (abnormal_map/torch.det(g)).unsqueeze(0))

    
def fractional_anisotropy(g):
    e, _ = torch.symeig(g)
    lambd1 = e[:,:,:,0]
    lambd2 = e[:,:,:,1]
    lambd3 = e[:,:,:,2]
    mean = torch.mean(e,dim=len(e.shape)-1)
    return torch.sqrt(3.*(torch.pow((lambd1-mean),2)+torch.pow((lambd2-mean),2)+torch.pow((lambd3-mean),2)))/\
    torch.sqrt(2.*(torch.pow(lambd1,2)+torch.pow(lambd2,2)+torch.pow(lambd3,2)))

## Data organization

In [7]:
# %matplotlib widget
# plt.imshow(torch.det(tensor_met_list[0])[:,:,20])
start_iter = 0
iter_num = 800
print(f'Starting from iteration {start_iter} to iteration {iter_num+start_iter}')

Starting from iteration 0 to iteration 800


In [8]:
torch.set_default_tensor_type('torch.DoubleTensor')
file_name = [108222, 102715, 105923]
input_dir = '/usr/sci/projects/HCP/Kris/NSFCRCNS/TestResults/UKF_experiments'
output_dir = 'BrainAtlasUkf1Cuda'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
height, width, depth = 145,174,145
sample_num = len(file_name)
tensor_lin_list, tensor_met_list, mask_list, mask_thresh_list, fa_list = [], [], [], [], []
mask_union = torch.zeros(height, width, depth).double().to(device)
phi_inv_acc_list, phi_acc_list, energy_list = [], [], []
resume = False

for s in range(len(file_name)):
# for s in range(1):
#     read tensor and mask files
    tensor_np = sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{file_name[s]}_scaled_tensors.nhdr'))
    mask_np = sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{file_name[s]}_filt_mask.nhdr'))
    tensor_lin_list.append(torch.from_numpy(tensor_np).double().permute(3,2,1,0).to(device))
#     create union of masks
#     print(torch.from_numpy(mask_np).double().permute(2,1,0).to(device).is_cuda)
    mask_union += torch.from_numpy(mask_np).double().permute(2,1,0).to(device)
    mask_list.append(torch.from_numpy(mask_np).double().permute(2,1,0))
#     rearrange tensor_lin to tensor_met
    tensor_met_zeros = torch.zeros(height,width,depth,3,3,dtype=torch.float64)
    tensor_met_zeros[:,:,:,0,0] = tensor_lin_list[s][0]
    tensor_met_zeros[:,:,:,0,1] = tensor_lin_list[s][1]
    tensor_met_zeros[:,:,:,0,2] = tensor_lin_list[s][2]
    tensor_met_zeros[:,:,:,1,0] = tensor_lin_list[s][1]
    tensor_met_zeros[:,:,:,1,1] = tensor_lin_list[s][3]
    tensor_met_zeros[:,:,:,1,2] = tensor_lin_list[s][4]
    tensor_met_zeros[:,:,:,2,0] = tensor_lin_list[s][2]
    tensor_met_zeros[:,:,:,2,1] = tensor_lin_list[s][4]
    tensor_met_zeros[:,:,:,2,2] = tensor_lin_list[s][5]
#     balance the background and subject by rescaling
#     tensor_met_zeros = tensor_cleaning(tensor_met_zeros, scale_factor=torch.tensor(1,dtype=torch.float64))
    fa_list.append(fractional_anisotropy(tensor_met_zeros))
    tensor_met_list.append(torch.inverse(tensor_met_zeros))
    fore_back_adaptor = torch.where(torch.det(tensor_met_list[s])>1e2, 1e-3, 1.)
    mask_thresh_list.append(fore_back_adaptor)
    tensor_met_list[s] = torch.einsum('ijk...,lijk->ijk...', tensor_met_list[s], mask_thresh_list[s].unsqueeze(0))
#     initialize the accumulative diffeomorphism    
    if resume==False:
        print('start from identity')
        phi_inv_acc_list.append(get_idty(height, width, depth))
        phi_acc_list.append(get_idty(height, width, depth))
    else:
        print('start from checkpoint')
        phi_inv_acc_list.append(torch.from_numpy(sio.loadmat(f'{output_dir}/brain{file_name[s]}_{start_iter-1}_phi_inv.mat')['diffeo']))
        phi_acc_list.append(torch.from_numpy(sio.loadmat(f'{output_dir}/brain{file_name[s]}_{start_iter-1}_phi.mat')['diffeo']))
        tensor_met_list[s] = phi_pullback(phi_inv_acc_list[s], tensor_met_list[s])
    energy_list.append([])    
    
mask_union[mask_union>0] = 1

RuntimeError: Exception thrown in SimpleITK ImageFileReader_Execute: /tmp/SimpleITK/Code/IO/src/sitkImageReaderBase.cxx:97:
sitk::ERROR: The file "/usr/sci/projects/HCP/Kris/NSFCRCNS/TestResults/UKF_experiments/108222_scaled_tensors.nhdr" does not exist.

## Building process

In [11]:
for i in tqdm(range(start_iter, start_iter+iter_num)):
    G = torch.stack(tuple(tensor_met_list))
    dim, sigma, epsilon, iter_num_matching = 3., 0, 5e-3, 1 # epsilon = 3e-3 for orig tensor
    atlas = get_karcher_mean(G, 1./dim)

    phi_inv_list, phi_list = [], []
    for s in range(sample_num):
        energy_list[s].append(torch.einsum("ijk...,lijk->",[(tensor_met_list[s] - atlas)**2, mask_union.unsqueeze(0)]).item())
        old = tensor_met_list[s]
        tensor_met_list[s], phi, phi_inv = metric_matching(tensor_met_list[s], atlas, height, width, depth, mask_union, iter_num_matching, epsilon, sigma,dim)
        phi_inv_list.append(phi_inv)
        phi_list.append(phi)
        phi_inv_acc_list[s] = compose_function(phi_inv_acc_list[s], phi_inv_list[s])
        phi_acc_list[s] = compose_function(phi_list[s], phi_acc_list[s])
        mask_list[s] = compose_function(mask_list[s], phi_inv_list[s])
#         if i%1==0:
#             plot_diffeo(phi_acc_list[s][1:, 50, :, :], step_size=2, show_axis=True)
#             plot_diffeo(phi_acc_list[s][:2, :, :, 20], step_size=2, show_axis=True)
#             plot_diffeo(torch.stack((phi_acc_list[s][0, :, 50, :],phi_acc_list[s][2, :, 50, :]),0), step_size=2, show_axis=True)
            
    '''check point'''
    if i%25==0:
        atlas_lin = np.zeros((6,height,width,depth))
        mask_acc = np.zeros((height,width,depth))
        atlas_inv = torch.inverse(atlas)
        atlas_lin[0] = atlas_inv[:,:,:,0,0].cpu()
        atlas_lin[1] = atlas_inv[:,:,:,0,1].cpu()
        atlas_lin[2] = atlas_inv[:,:,:,0,2].cpu()
        atlas_lin[3] = atlas_inv[:,:,:,1,1].cpu()
        atlas_lin[4] = atlas_inv[:,:,:,1,2].cpu()
        atlas_lin[5] = atlas_inv[:,:,:,2,2].cpu()
        for s in range(sample_num):
            sio.savemat(f'{output_dir}/{file_name[s]}_{i}_phi_inv.mat', {'diffeo': phi_inv_acc_list[s].cpu().detach().numpy()})
            sio.savemat(f'{output_dir}/{file_name[s]}_{i}_phi.mat', {'diffeo': phi_acc_list[s].cpu().detach().numpy()})
            sio.savemat(f'{output_dir}/{file_name[s]}_{i}_energy.mat', {'energy': energy_list[s]})
#             plt.plot(energy_list[s])
            mask_acc += mask_list[s].cpu().numpy()
        mask_acc[mask_acc>0]=1
        sitk.WriteImage(sitk.GetImageFromArray(np.transpose(atlas_lin,(3,2,1,0))), f'{output_dir}/atlas_{i}_tens.nhdr')
        sitk.WriteImage(sitk.GetImageFromArray(np.transpose(mask_union,(2,1,0))), f'{output_dir}/atlas_{i}_mask.nhdr')

  0%|          | 0/800 [00:02<?, ?it/s]


RuntimeError: inverse_cpu: For batch 0: U(1,1) is zero, singular U.

## Save Result

In [None]:
%matplotlib widget
atlas_lin = np.zeros((6,height,width,depth))
mask_acc = np.zeros((height,width,depth))

for s in range(sample_num):
    sio.savemat(f'{output_dir}/brain{file_name[s]}_phi_inv.mat', {'diffeo': phi_inv_acc_list[s].cpu().detach().numpy()})
    sio.savemat(f'{output_dir}/brain{file_name[s]}_phi.mat', {'diffeo': phi_acc_list[s].cpu().detach().numpy()})
    sio.savemat(f'{output_dir}/brain{file_name[s]}_energy.mat', {'energy': energy_list[s]})
    
    plt.plot(energy_list[s])
    mask_acc += mask_list[s].cpu().numpy()

atlas = torch.inverse(atlas)
atlas_lin[0] = atlas[:,:,:,0,0].cpu()
atlas_lin[1] = atlas[:,:,:,0,1].cpu()
atlas_lin[2] = atlas[:,:,:,0,2].cpu()
atlas_lin[3] = atlas[:,:,:,1,1].cpu()
atlas_lin[4] = atlas[:,:,:,1,2].cpu()
atlas_lin[5] = atlas[:,:,:,2,2].cpu()
mask_acc[mask_acc>0]=1
sitk.WriteImage(sitk.GetImageFromArray(np.transpose(atlas_lin,(3,2,1,0))), f'{output_dir}/atlas_tens.nhdr')
sitk.WriteImage(sitk.GetImageFromArray(np.transpose(mask_union,(2,1,0))), f'{output_dir}/atlas_mask.nhdr')

## Visualization

### tensor field

In [None]:
mask_acc = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/atlas_mask.nhdr')),(2,1,0))
atlas_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/atlas_tens.nhdr')),(3,2,1,0))

### diffeomorphism

In [10]:
def view_3d_diffeos(diffeo, stride, interp):
    height, width, depth = diffeo.shape[1:]
    spline = []
    for i in range(1,height,stride):
        for j in range(1,width,stride):
            spline.append(pv.Spline(np.transpose(diffeo[:,i,j,:]), interp))
            
    for i in range(1,height,stride):
        for k in range(1,depth,stride):
            spline.append(pv.Spline(np.transpose(diffeo[:,i,:,k]), interp))
            
    for j in range(1,width,stride):
        for k in range(1,depth,stride):
            spline.append(pv.Spline(np.transpose(diffeo[:,:,j,k]), interp))
            
    return itkview(geometries=spline)

In [11]:
# diffeo = sio.loadmat(f'{output_dir}/105923_799_phi_inv.mat')['diffeo']
diffeo = sio.loadmat(f'/home/sci/hdai/Projects/Atlas3D/output/BrainAtlasUkfBallMetSept11/105923_800_phi_inv.mat')['diffeo']
vwr = view_3d_diffeos(diffeo, 5,1000)
vwr

INFO:numexpr.utils:Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


Viewer(geometries=[{'vtkClass': 'vtkPolyData', 'points': {'vtkClass': 'vtkPoints', 'name': '_points', 'numberO…

### path

- `geo.geodesicpath`'s input tensor should be original DTI tensor field, instead of the inverted one. So make sure save the DTI like result, rather than inverted atlas.
- If you are want to visualize the path by itkview, it's recommended that **set the `both_directions` argument as `False`**. As the itkview will connect all the points in order, not like plotting the points densely in 2D case. The `geo.geodesicpath_3d` returns a concatenated list of both directions but without reversing one of them, therefore the list returned by the function is actually not in order.
    - Likewise, setting the `geo_iters` large is better than small, as the logic in `geo.geodesicpath_3d` is like: if the expected path doesn't go beyond the mask region, the last element in the returned list would be `[0,0,0]`, which leads to a twist path after calling the `view_3d_tensors`, due to the property of `pyvista.spline`; if the expected path goes beyond the mask region, the algorithm would be forced to suspend, and the last element in the returned list wouldn't be `[0,0,0]`.
- At `[13, 14, 21]`, the four cubics and atlas are approximately overlapped, as for the other position, this is not guaranteed.
- When calling the `geo.geodesicpath_3d`, **make sure the mask you put in `mask_image`(second) argument aligns with the object as accurate as possible**. The `mask_union`, which covers unnecessary area, will result in the extremely slow running in `util.diff.gradient_mask_3d` and `util.maskops.determine_boundary_3d`. But in visualization, using `mask_union` is acceptable.
- To distinguish each path, you can use the drop-down menu locates at the down-left corner of the interactive window, switch to the geometry you would like to change and choose a color.

#### geodesic on atlas

In [None]:
path_set = []
start_coords = [[13, 14, 21]]
init_velocities = [None]

In [None]:
vwr=view_3d_tensors(np.transpose(atlas_lin,(1,2,3,0)),mask_acc,atlas_lin[3,:,:,:],paths=[],stride=6,scale=6)
vwr

In [None]:
geo_delta_t = -0.1
geo_iters = 1200 
euler_delta_t = 0.1
euler_iters = 8000 

# geodesicpath_3d([6,h,w,d],[h,w,d],...)
geox, geoy, geoz = geo.geodesicpath_3d(atlas_lin, mask_union,\
                              start_coords[0], init_velocities[0], \
                              geo_delta_t, iter_num=geo_iters, both_directions=False)

# eulerpath_3d([6,h,w,d],[h,w,d],...)
# eulx, euly, eulz = euler.eulerpath_3d(atlas_lin, mask_union,\
#                               start_coords[0], init_velocities[0], euler_delta_t, iter_num=euler_iters, both_directions=False)
path_set = [(geox[:-1], geoy[:-1], geoz[:-1])]

# view_3d_tensors([h,w,d,6],[h,w,d],...)
vwr=view_3d_tensors(np.transpose(atlas_lin,(1,2,3,0)),mask_acc,atlas_lin[3,:,:,:],paths=path_set,stride=6,scale=6)
vwr

#### geodesics on atlas and cubics

In [None]:
geo_delta_t = 0.1
geo_iters = 1200 

for s in file_name:
    tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{file_name[s]}/scaled_tensors.nhdr')),(3,2,1,0))
    mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{file_name[s]}/filt_mask.nhdr')),(2,1,0))
    geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                            start_coords[0], init_velocities[0], \
                                            geo_delta_t, iter_num=geo_iters, both_directions=False)
    path_set.append((geox[:-1],geoy[:-1],geoz[:-1]))

In [None]:
vwr=view_3d_tensors(np.transpose(atlas_lin,(1,2,3,0)),mask_union,atlas_lin[3,:,:,:],paths=path_set,stride=6,scale=6)
vwr

#### geodesics on cubics pushforwarded to the atlas space

In [None]:
import math
def coord_register(point_x, point_y, point_z, diffeo):
  # TODO work out which is y and which is x, maintain consistency.
  # For now, pass in y for point_x, x for point_y
    height, width, depth=diffeo.shape[-3:]
    new_point_x, new_point_y, new_point_z = [], [], []
    for i in range(len(point_x)):
        C = point_x[i] - math.floor(point_x[i])
        D = point_y[i] - math.floor(point_y[i])
        E = point_z[i] - math.floor(point_z[i])
        new_point_x.append(\
          (1.-C)*(1.-D)*(1.-E)*diffeo[0, math.floor(point_x[i])%height, math.floor(point_y[i])%width, math.floor(point_z[i])%depth]\
        + (1.-C)*D*(1.-E)*diffeo[0, math.floor(point_x[i])%height, math.ceil(point_y[i])%width, math.floor(point_z[i])%depth]\
        + C*(1.-D)*(1.-E)*diffeo[0, math.ceil(point_x[i])%height, math.floor(point_y[i])%width, math.floor(point_z[i])%depth]\
        + C*D*(1.-E)*diffeo[0, math.ceil(point_x[i])%height, math.ceil(point_y[i])%width, math.floor(point_z[i])%depth]\
        + (1.-C)*(1.-D)*E*diffeo[0, math.floor(point_x[i])%height, math.floor(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + (1.-C)*D*E*diffeo[0, math.floor(point_x[i])%height, math.ceil(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + C*(1.-D)*E*diffeo[0, math.ceil(point_x[i])%height, math.floor(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + C*D*E*diffeo[0, math.ceil(point_x[i])%height, math.ceil(point_y[i])%width, math.ceil(point_z[i])%depth])

        new_point_y.append(\
          (1.-C)*(1.-D)*(1.-E)*diffeo[1, math.floor(point_x[i])%height, math.floor(point_y[i])%width, math.floor(point_z[i])%depth]\
        + (1.-C)*D*(1.-E)*diffeo[1, math.floor(point_x[i])%height, math.ceil(point_y[i])%width, math.floor(point_z[i])%depth]\
        + C*(1.-D)*(1.-E)*diffeo[1, math.ceil(point_x[i])%height, math.floor(point_y[i])%width, math.floor(point_z[i])%depth]\
        + C*D*(1.-E)*diffeo[1, math.ceil(point_x[i])%height, math.ceil(point_y[i])%width, math.floor(point_z[i])%depth]\
        + (1.-C)*(1.-D)*E*diffeo[1, math.floor(point_x[i])%height, math.floor(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + (1.-C)*D*E*diffeo[1, math.floor(point_x[i])%height, math.ceil(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + C*(1.-D)*E*diffeo[1, math.ceil(point_x[i])%height, math.floor(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + C*D*E*diffeo[1, math.ceil(point_x[i])%height, math.ceil(point_y[i])%width, math.ceil(point_z[i])%depth])

        new_point_z.append(\
          (1.-C)*(1.-D)*(1.-E)*diffeo[2, math.floor(point_x[i])%height, math.floor(point_y[i])%width, math.floor(point_z[i])%depth]\
        + (1.-C)*D*(1.-E)*diffeo[2, math.floor(point_x[i])%height, math.ceil(point_y[i])%width, math.floor(point_z[i])%depth]\
        + C*(1.-D)*(1.-E)*diffeo[2, math.ceil(point_x[i])%height, math.floor(point_y[i])%width, math.floor(point_z[i])%depth]\
        + C*D*(1.-E)*diffeo[2, math.ceil(point_x[i])%height, math.ceil(point_y[i])%width, math.floor(point_z[i])%depth]\
        + (1.-C)*(1.-D)*E*diffeo[2, math.floor(point_x[i])%height, math.floor(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + (1.-C)*D*E*diffeo[2, math.floor(point_x[i])%height, math.ceil(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + C*(1.-D)*E*diffeo[2, math.ceil(point_x[i])%height, math.floor(point_y[i])%width, math.ceil(point_z[i])%depth]\
        + C*D*E*diffeo[2, math.ceil(point_x[i])%height, math.ceil(point_y[i])%width, math.ceil(point_z[i])%depth])
 
    return (new_point_x, new_point_y, new_point_z)

In [None]:
path_set = []
start_coords = [[13, 14, 21]] # golden test start point
init_velocities = [None]
geo_delta_t = 0.1
geo_iters = 1300 

for s in file_name:
    tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{file_name[s]}/scaled_tensors.nhdr')),(3,2,1,0))
    mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{file_name[s]}/filt_mask.nhdr')),(2,1,0))
    diffeo = sio.loadmat(f'{output_dir}/{file_name[s]}_799_phi.mat')['diffeo']
    geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                            start_coords[0], init_velocities[0], \
                                            geo_delta_t, iter_num=geo_iters, both_directions=False)
    path_set.append(coord_register(geox[:-1], geoy[:-1], geoz[:-1], diffeo))

In [None]:
vwr=view_3d_tensors(np.transpose(atlas_lin,(1,2,3,0)),mask_acc,atlas_lin[3,:,:,:],paths=path_set,stride=6,scale=6)
vwr