In [None]:
import numpy as np
import os, skimage, sys
import SimpleITK as sitk
import matplotlib.patches as patches

sys.path.append('../Packages')
import algo.geodesic as geo
import algo.euler as euler
import disp.vis as vis
import util.riemann as riemann
import util.tensors as tensors
import data.convert as convert
import disp.vis as vis

from skimage import filters
from matplotlib import pyplot as plt

%matplotlib widget

def subtract(lis, num):
    return [number - num for number in lis]

def crop_path(path_x, path_y, x_range, y_range):
    index_kept = []
    for i in range(len(path_x)):
        if x_range[0]<=path_x[i] and path_x[i]<=x_range[1] and y_range[0]<=path_y[i] and path_y[i]<=y_range[1]:
            index_kept.append(i)
    return path_x[np.array(index_kept).astype(int)], path_y[np.array(index_kept).astype(int)]

## Never use the super slow `torch` in geodesic shooting algorithm!! Use `numpy`

In [None]:
brain_id = 100610
input_dir = '../Brains'
output_dir = f'../Checkpoints/{brain_id}'
mask = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_filt_mask.nhdr').double().permute(1,0).numpy()
tensor_pred_lin = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_learned_tensors_final.nhdr').permute(2,1,0).numpy()
tensor_scaled_lin = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_scaled_tensors.nhdr').permute(2,1,0).numpy()
tensor_orig_lin = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_orig_tensors.nhdr').permute(2,1,0).numpy()
vector_lin = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_vector_field.nhdr').permute(2,0,1).numpy()

In [None]:
eroded_mask = skimage.morphology.erosion(mask,skimage.morphology.square(3))
tensor_pred_mat = tensors.lin2mat(tensor_pred_lin)
metric_pred_mat = np.linalg.inv(tensor_pred_mat)
metric_pred_lin = tensors.mat2lin(metric_pred_mat)
nabla_vv_pred = riemann.covariant_derivative_2d(vector_lin, metric_pred_mat, mask, differential_accuracy=2)
sigma = ((vector_lin[0]*nabla_vv_pred[0]+vector_lin[1]*nabla_vv_pred[1])/(vector_lin[0]**2+vector_lin[1]**2))

x = np.linspace(0, mask.shape[0]-1, mask.shape[0])
y = np.linspace(0, mask.shape[1]-1, mask.shape[1])
xx, yy = np.meshgrid(x,y,indexing='ij')
plt.figure(figsize=(8,10))
plt.title('epsilon=nabla VV-sigma*V')
plt.quiver(xx,yy,\
           (nabla_vv_pred[0]-sigma*vector_lin[0])*eroded_mask, \
           (nabla_vv_pred[1]-sigma*vector_lin[1])*eroded_mask,scale=1e-2)
plt.axis('off')

## Geodesic plotting

In [None]:
start_coords = np.array([73,124]) # 100610

init_velocities = vector_lin[:, int(start_coords[0]), int(start_coords[1])]
geo_delta_t, euler_delta_t = 5e-3, 5e-3
geo_iters, euler_iters = 60000, 60000

'''Integral curve'''
eulxb, eulyb = euler.eulerpath_vectbase_2d_w_dv(vector_lin, mask, start_coords, euler_delta_t, iter_num=euler_iters, both_directions=True)

'''Inverted'''
tensor_orig_mat = tensors.lin2mat(tensor_orig_lin)
tensor_orig_mat[np.linalg.det(tensor_orig_mat)==0] = np.eye(2)
tensor_orig_lin = tensors.mat2lin(tensor_orig_mat)
geox_orig, geoy_orig = geo.geodesicpath('f', tensor_orig_lin, vector_lin, mask,\
                                                      start_coords, init_velocities, \
                                                      geo_delta_t, iter_num=geo_iters, both_directions=True)

'''Adjugate'''
tensor_orig_mat = tensors.lin2mat(tensor_orig_lin)
tensor_orig_mat[np.linalg.det(tensor_orig_mat)==0] = np.eye(2)
metric_adj_mat = np.linalg.inv(tensor_orig_mat)
metric_adj_mat[...,0,0] = metric_adj_mat[...,0,0]*np.linalg.det(tensor_orig_mat)
metric_adj_mat[...,0,1] = metric_adj_mat[...,0,1]*np.linalg.det(tensor_orig_mat)
metric_adj_mat[...,1,0] = metric_adj_mat[...,1,0]*np.linalg.det(tensor_orig_mat)
metric_adj_mat[...,1,1] = metric_adj_mat[...,1,1]*np.linalg.det(tensor_orig_mat)
tensor_adj_mat = np.linalg.inv(metric_adj_mat)
tensor_adj_lin = tensors.mat2lin(tensor_adj_mat)
geox_adj, geoy_adj = geo.geodesicpath('f', tensor_adj_lin, vector_lin, mask,\
                                                      start_coords, init_velocities, \
                                                      geo_delta_t, iter_num=geo_iters, both_directions=True)

'''Conformal'''
tensor_scaled_mat = tensors.lin2mat(tensor_scaled_lin)
tensor_scaled_mat[np.linalg.det(tensor_scaled_mat)==0] = np.eye(2)
tensor_scaled_lin = tensors.mat2lin(tensor_scaled_mat)
geox_scaled, geoy_scaled = geo.geodesicpath('f', tensor_scaled_lin, vector_lin, mask,\
                                                      start_coords, init_velocities, \
                                                      geo_delta_t, iter_num=geo_iters, both_directions=True)

'''Proposed'''
geox_predb, geoy_predb = geo.geodesicpath('b', tensor_pred_lin, vector_lin, mask,\
                                                      start_coords, init_velocities, \
                                                      geo_delta_t, iter_num=geo_iters, both_directions=True)

In [None]:
interp_colors = ['#777777','#253494', '#2c7fb8', '#41b6c4', '#a1dab4', '#fed98e', '#fe9929', '#d95f0e', '#993404']
x = np.linspace(0, mask.shape[0]-1, mask.shape[0])
y = np.linspace(0, mask.shape[1]-1, mask.shape[1])
xx, yy = np.meshgrid(x,y,indexing='ij')
plt.figure(figsize=(8,11))
scale, slic = 1e0, 12
plt.axis('off')
vect_fig = plt.quiver(xx[mask>0],yy[mask>0],vector_lin[0][mask>0],vector_lin[1][mask>0],scale=1e2,color='#666666',headaxislength=0,headlength=0)

vis.vis_path(geox_orig, geoy_orig, vect_fig, f"Inverted", '#ffb901', 10, 1, False, show_legend=False)
vis.vis_path(geox_adj, geoy_adj, vect_fig, f"Adjugate", '#f25022', 10, 1, False, show_legend=False)
vis.vis_path(geox_scaled, geoy_scaled, vect_fig, f"Conformal", '#7fba00', 10, 1, False, show_legend=False)
'''b:g_ddot=-g_dot*G*g_dot+\sigma*V'''
vis.vis_path(geox_predb, geoy_predb, vect_fig, f"Proposed", '#41b6c4', 10, 1, False, show_legend=False)
vis.vis_path(eulxb, eulyb, vect_fig, "Integral curve", 'black', 2, 1, False, show_legend=False)

plt.savefig(f'{output_dir}/{name}_{start_coords[0]}_{start_coords[1]}_final_vect.png', bbox_inches='tight', dpi=300)

In [None]:
x_range = (51,92)
y_range = (113,152)
vector_field_part = vector_lin[:,x_range[0]:x_range[1],y_range[0]:y_range[1]]
mask_part = mask[x_range[0]:x_range[1],y_range[0]:y_range[1]]
x = np.linspace(0,x_range[1]-x_range[0]-1,x_range[1]-x_range[0])
y = np.linspace(0,y_range[1]-y_range[0]-1,y_range[1]-y_range[0])
xx, yy = np.meshgrid(x,y,indexing='ij')
plt.figure(figsize=(7,7))
scale, slic = 1e0, 12
plt.axis('off')
vect_fig = plt.quiver(xx[mask_part>0],yy[mask_part>0],vector_field_part[0][mask_part>0],vector_field_part[1][mask_part>0],scale=5e1,color='#666666',headaxislength=0,headlength=0)

'''crop paths in range'''
geox_orig, geoy_orig = crop_path(geox_orig, geoy_orig, x_range, y_range)
geox_adj, geoy_adj = crop_path(geox_adj, geoy_adj, x_range, y_range)
geox_scaled, geoy_scaled = crop_path(geox_scaled, geoy_scaled, x_range, y_range)
geox_predb, geoy_predb = crop_path(geox_predb, geoy_predb, x_range, y_range)
eulxb, eulyb = crop_path(eulxb, eulyb, x_range, y_range)

vis.vis_path(subtract(geox_orig, x_range[0]), subtract(geoy_orig, y_range[0]), vect_fig, f"Inverted", '#ffb901', 15, 1, False)
vis.vis_path(subtract(geox_adj, x_range[0]), subtract(geoy_adj, y_range[0]), vect_fig, f"Adjugate", '#f25022', 15, 1, False)
vis.vis_path(subtract(geox_scaled, x_range[0]), subtract(geoy_scaled, y_range[0]), vect_fig, f"Conformal", '#7fba00', 15, 1, False)
vis.vis_path(subtract(geox_predb, x_range[0]), subtract(geoy_predb, y_range[0]), vect_fig, f"Proposed", '#41b6c4', 15, 1, False)
vis.vis_path(subtract(eulxb, x_range[0]), subtract(eulyb, y_range[0]), vect_fig, "Integral curve", 'black', 3, 1, False)

plt.plot([start_coords[0]-x_range[0]],[start_coords[1]-y_range[0]], linestyle='', marker='*', color='black', markersize=20)
plt.savefig(f'{output_dir}/{name}_{start_coords[0]}_{start_coords[1]}_final_zoomin_vect.png', bbox_inches='tight', dpi=300)

## Tensor visualization

In [None]:
vis.vis_tensors(metric_pred_lin[:,x_range[0]:x_range[1],y_range[0]:y_range[1]], 'title', save_file=False, filename='', mask=mask[x_range[0]:x_range[1],y_range[0]:y_range[1]],scale=1, opacity=0.5, show_axis_labels=True, ax=None,zorder=1,stride=None)