In [1]:
import torch
from tqdm import tqdm
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import SimpleITK as sitk
from lazy_imports import itkwidgets
from lazy_imports import itkview
from lazy_imports import interactive
from lazy_imports import ipywidgets
from lazy_imports import pv

In [2]:
from mtch.RegistrationFunc3D import *
from mtch.SplitEbinMetric3D import *
from mtch.GeoPlot import *

In [3]:
# from Packages.disp.vis import show_2d, show_2d_tensors
from disp.vis import vis_tensors, vis_path, disp_scalar_to_file
from disp.vis import disp_vector_to_file, disp_tensor_to_file
from disp.vis import disp_gradG_to_file, disp_gradA_to_file
from disp.vis import view_3d_tensors, tensors_to_mesh

In [4]:
import algo.metricModSolver2d as mms
import algo.geodesic as geo
import algo.euler as euler
import algo.dijkstra as dijkstra

## I/O convention
Due to the certain reason of Kris' simulated data generation, please follow the i/o convention below which is symmetric, to make sure the files are read and written correctly. Following example is for 2D situation, 3D case is analogous.
### Read
Shape of input_tensor.nhdr is [w, h, 3], and Shape of input_mask.nhdr is [w, h]
```
input_tensor = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(path)),(2,1,0))
input_mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(path)),(1,0))
```
input_tensor.shape is [3, h, w], and input_mask.shape is [h, w]
### Write
output_tensor.shape is [3, h, w], and output_mask.shape is [h, w]
```
output_tensor = sitk.WriteImage(sitk.GetImageFromArray(np.transpose(output_tensor,(2,1,0)), path)
output_mask = sitk.WriteImage(sitk.GetImageFromArray(np.transpose(output_tensor,(2,1,0)), path)
```
Shape of output_tensor.nhdr is [w, h, 3], and Shape of output_mask.nhdr is [w, h]

### Note
`sitk.WriteImage(sitk.GetImageFromArray())` and `sitk.GetArrayFromImage(sitk.ReadImage(path))` is a pair of inverse operation, and you can see there is no inconsistence with regards to the dimension issue.
```
output_tensor = np.zeros((12,34,56,78))
sitk.WriteImage(sitk.GetImageFromArray(output_tensor), path)
input_tensor = sitk.GetArrayFromImage(sitk.ReadImage(path))
print(input_tensor)
'(12,34,56,78)'
```

## Import data

In [9]:
index0, index1 = 4, 6
input_dir = '/usr/sci/projects/HCP/Kris/NSFCRCNS/TestResults/working_3d_python'
output_dir = 'output/Cubic12Geo'
g0_lin = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/cubic{index0}_scaled_tensors.nhdr'))).double().permute(3,2,1,0)
g1_lin = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/cubic{index1}_scaled_tensors.nhdr'))).double().permute(3,2,1,0)

height, width, depth = 100,100,41
g0_met, g1_met = torch.zeros(height, width, depth, 3, 3, dtype=torch.double), torch.zeros(height, width, depth, 3, 3, dtype=torch.double)
g0_met[:,:,:,0,0] = g0_lin[0]
g0_met[:,:,:,0,1] = g0_lin[1]
g0_met[:,:,:,0,2] = g0_lin[2]
g0_met[:,:,:,1,0] = g0_lin[1]
g0_met[:,:,:,1,1] = g0_lin[3]
g0_met[:,:,:,1,2] = g0_lin[4]
g0_met[:,:,:,2,0] = g0_lin[2]
g0_met[:,:,:,2,1] = g0_lin[4]
g0_met[:,:,:,2,2] = g0_lin[5]

g1_met[:,:,:,0,0] = g1_lin[0]
g1_met[:,:,:,0,1] = g1_lin[1]
g1_met[:,:,:,0,2] = g1_lin[2]
g1_met[:,:,:,1,0] = g1_lin[1]
g1_met[:,:,:,1,1] = g1_lin[3]
g1_met[:,:,:,1,2] = g1_lin[4]
g1_met[:,:,:,2,0] = g1_lin[2]
g1_met[:,:,:,2,1] = g1_lin[4]
g1_met[:,:,:,2,2] = g1_lin[5]

## Calculate Geodesic

In [6]:
Tpts = 7
geo_met_list = get_geo(g0_met, g1_met, 1./3., Tpts)

This overload of nonzero is deprecated:
	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/utils/python_arg_parser.cpp:882.)


## Save Result

In [None]:
geo_lin_list = []
geo_mask_list = []

for i in range(Tpts):
    geo_lin = torch.zeros((6,height,width,depth))
    geo_lin[0]=geo_met_list[i][:,:,:,0,0]
    geo_lin[1]=geo_met_list[i][:,:,:,0,1]
    geo_lin[2]=geo_met_list[i][:,:,:,0,2]
    geo_lin[3]=geo_met_list[i][:,:,:,1,1]
    geo_lin[4]=geo_met_list[i][:,:,:,1,2]
    geo_lin[5]=geo_met_list[i][:,:,:,2,2]
    geo_mask = torch.where(torch.det(geo_met_list[i])>1,1,0)
    geo_lin_list.append(geo_lin)
    geo_mask_list.append(geo_mask)
#     sitk.WriteImage(sitk.GetImageFromArray(geo_lin.permute(3,2,1,0).numpy()), f'{output_dir}/cubic_46_geodesic_{i}_tens.nhdr')
#     sitk.WriteImage(sitk.GetImageFromArray(geo_mask.permute(2,1,0).numpy()), f'{output_dir}/cubic_46_geodesic_{i}_mask.nhdr')

In [16]:
# vwr = itkview(torch.where(torch.det(geo_met_list[i])>1,1,0))
# vwr
start_coords = [[13, 14, 21]]
init_velocities = [None]
geo_delta_t = 0.08
geo_iters = 1300 

In [18]:
tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{0}_tens.nhdr')),(3,2,1,0))
mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{0}_mask.nhdr')),(2,1,0))
geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                        start_coords[0], init_velocities[0], \
                                        geo_delta_t, iter_num=geo_iters, both_directions=False)
vwr = view_3d_tensors(np.transpose(tensor_lin,(1,2,3,0)),
                      mask,
                      mask,
                      paths=[(geox[:-2], geoy[:-2], geoz[:-2])],stride=6,scale=6)
vwr

Finding geodesic path from [13, 14, 21] with initial velocity [0.32780118 0.94474673 0.        ]
Found 0 voxels where unable to take 1st derivative.
Found 0 reduced accuracy 2nd derivative voxels.
numpts 97
smallest,largest max eigenvalue 3.5280353735091987 12.981618818751645


Viewer(geometries=[{'vtkClass': 'vtkPolyData', 'points': {'vtkClass': 'vtkPoints', 'name': '_points', 'numberO…

In [19]:
tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{1}_tens.nhdr')),(3,2,1,0))
mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{1}_mask.nhdr')),(2,1,0))
geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                        start_coords[0], init_velocities[0], \
                                        geo_delta_t, iter_num=geo_iters, both_directions=False)
vwr = view_3d_tensors(np.transpose(tensor_lin,(1,2,3,0)),
                      mask,
                      mask,
                      paths=[(geox[:-2], geoy[:-2], geoz[:-2])],stride=6,scale=6)
vwr

Finding geodesic path from [13, 14, 21] with initial velocity [0.31583873 0.94881289 0.        ]
Found 0 voxels where unable to take 1st derivative.
Found 0 reduced accuracy 2nd derivative voxels.
numpts 135
smallest,largest max eigenvalue 1.2490172296091875 10.3267820041831


Viewer(geometries=[{'vtkClass': 'vtkPolyData', 'points': {'vtkClass': 'vtkPoints', 'name': '_points', 'numberO…

In [20]:
tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{2}_tens.nhdr')),(3,2,1,0))
mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{2}_mask.nhdr')),(2,1,0))
geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                        start_coords[0], init_velocities[0], \
                                        geo_delta_t, iter_num=geo_iters, both_directions=False)
vwr = view_3d_tensors(np.transpose(tensor_lin,(1,2,3,0)),
                      mask,
                      mask,
                      paths=[(geox[:-2], geoy[:-2], geoz[:-2])],stride=6,scale=6)
vwr

Finding geodesic path from [13, 14, 21] with initial velocity [0.30258274 0.95312312 0.        ]
Found 0 voxels where unable to take 1st derivative.
Found 0 reduced accuracy 2nd derivative voxels.
numpts 135
smallest,largest max eigenvalue 1.5801098277369667 9.559687055469483


Viewer(geometries=[{'vtkClass': 'vtkPolyData', 'points': {'vtkClass': 'vtkPoints', 'name': '_points', 'numberO…

In [21]:
tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{3}_tens.nhdr')),(3,2,1,0))
mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{3}_mask.nhdr')),(2,1,0))
geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                        start_coords[0], init_velocities[0], \
                                        geo_delta_t, iter_num=geo_iters, both_directions=False)
vwr = view_3d_tensors(np.transpose(tensor_lin,(1,2,3,0)),
                      mask,
                      mask,
                      paths=[(geox[:-2], geoy[:-2], geoz[:-2])],stride=6,scale=6)
vwr

Finding geodesic path from [13, 14, 21] with initial velocity [0.28787325 0.95766852 0.        ]
Found 0 voxels where unable to take 1st derivative.
Found 0 reduced accuracy 2nd derivative voxels.
numpts 135
smallest,largest max eigenvalue 1.8475599159179434 8.83122635409812


Viewer(geometries=[{'vtkClass': 'vtkPolyData', 'points': {'vtkClass': 'vtkPoints', 'name': '_points', 'numberO…

In [None]:
tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{4}_tens.nhdr')),(3,2,1,0))
mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{4}_mask.nhdr')),(2,1,0))
geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                        start_coords[0], init_velocities[0], \
                                        geo_delta_t, iter_num=geo_iters, both_directions=False)
vwr = view_3d_tensors(np.transpose(tensor_lin,(1,2,3,0)),
                      mask,
                      mask,
                      paths=[(geox[:-2], geoy[:-2], geoz[:-2])],stride=6,scale=6)
vwr

In [None]:
tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{5}_tens.nhdr')),(3,2,1,0))
mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{5}_mask.nhdr')),(2,1,0))
geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                        start_coords[0], init_velocities[0], \
                                        geo_delta_t, iter_num=geo_iters, both_directions=False)
vwr = view_3d_tensors(np.transpose(tensor_lin,(1,2,3,0)),
                      mask,
                      mask,
                      paths=[(geox[:-2], geoy[:-2], geoz[:-2])],stride=6,scale=6)
vwr

In [None]:
tensor_lin = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{6}_tens.nhdr')),(3,2,1,0))
mask = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(f'{output_dir}/cubic12_{6}_mask.nhdr')),(2,1,0))
geox, geoy, geoz = geo.geodesicpath_3d(tensor_lin, mask,\
                                        start_coords[0], init_velocities[0], \
                                        geo_delta_t, iter_num=geo_iters, both_directions=False)
vwr = view_3d_tensors(np.transpose(tensor_lin,(1,2,3,0)),
                      mask,
                      mask,
                      paths=[(geox[:-2], geoy[:-2], geoz[:-2])],stride=6,scale=6)
vwr

In [None]:
mask00 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'Data/Cubic/cubic{index0}_filt_mask.nhdr'))).double().permute(1,0)
mask11 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'Data/Cubic/cubic{index1}_filt_mask.nhdr'))).double().permute(1,0)
mask = mask00+mask11
mask_bin = np.zeros_like(mask)
mask_bin[mask>0] = 1
plt.imshow(mask_bin)
plt.show()
# print(mask_bin.shape)
sitk.WriteImage(sitk.GetImageFromArray(mask_bin.transpose()), f'Output/cubic_46_geodesic_{7}/cubic_46_filt_mask.nhdr')