In [14]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import collections
import functools
from functools import partial
import itertools
import math

# visualization libraries
%matplotlib inline
from pprint import pprint
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import skimage
import torch
import torch.nn.functional as F
import functorch
import einops
import monai
import dipy
import dipy.reconst
import dipy.reconst.csdeconv, dipy.reconst.shm, dipy.viz
import dipy.denoise
import nibabel as nib

import jax
import jax.config

# Disable jit for debugging.
# jax.config.update("jax_disable_jit", True)
# Enable 64-bit precision.
# jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_default_matmul_precision", 32)
import jax.numpy as jnp
from jax import lax
import jax.dlpack

import pitn


plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})
plt.rcParams.update({"image.cmap": "gray"})
plt.rcParams.update({"image.interpolation": "antialiased"})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    # Pick only one device for the default, may use multiple GPUs for training later.
    dev_idx = 0
    device = torch.device(f"cuda:{dev_idx}")
    print("CUDA Device IDX ", dev_idx)
    torch.cuda.set_device(device)
    print("CUDA Current Device ", torch.cuda.current_device())
    print("CUDA Device properties: ", torch.cuda.get_device_properties(device))
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
    # in PyTorch 1.12 and later.
    torch.backends.cuda.matmul.allow_tf32 = True
    # See
    # <https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices>
    # for details.

    # Activate cudnn benchmarking to optimize convolution algorithm speed.
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True
        print("CuDNN convolution optimization enabled.")
        # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
        torch.backends.cudnn.allow_tf32 = True

else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

CUDA Device IDX  0
CUDA Current Device  0
CUDA Device properties:  _CudaDeviceProperties(name='NVIDIA RTX A5000', major=8, minor=6, total_memory=24256MB, multi_processor_count=64)
CuDNN convolution optimization enabled.
cuda:0


In [16]:
hcp_full_res_data_dir = Path("/data/srv/data/pitn/hcp")
hcp_full_res_fodf_dir = Path("/data/srv/outputs/pitn/hcp/full-res/fodf")
hcp_low_res_data_dir = Path("/data/srv/outputs/pitn/hcp/downsample/scale-2.00mm/vol")
hcp_low_res_fodf_dir = Path("/data/srv/outputs/pitn/hcp/downsample/scale-2.00mm/fodf")

assert hcp_full_res_data_dir.exists()
assert hcp_full_res_fodf_dir.exists()
assert hcp_low_res_data_dir.exists()
assert hcp_low_res_fodf_dir.exists()

## fODF Peak Finding

In [17]:
sample_fod_f = (
    hcp_full_res_fodf_dir / "162329" / "T1w" / "postproc_wm_msmt_csd_fod.nii.gz"
)
fod_coeff_im = nib.load(sample_fod_f)
fod_coeff_im = nib.as_closest_canonical(fod_coeff_im)
mask_f = sample_fod_f.parent / "postproc_nodif_brain_mask.nii.gz"
mask_im = nib.load(mask_f)
mask_im = nib.as_closest_canonical(mask_im)

# Pre-select voxels of interest for this specific subject.
# CC forceps minor, strong L-R uni-modal lobe
cc_lr_lobe_idx = (55, 98, 53)
# Dual-polar approx. equal volume fiber crossing
lr_and_ap_bipolar_lobe_idx = (70, 106, 54)
# Vox. adjacent to CST, tri-polar
tri_polar_lobe_idx = (60, 68, 43)

In [18]:
coeffs = fod_coeff_im.get_fdata()
coeffs = torch.from_numpy(coeffs)
# Move to channels-first layout.
coeffs = coeffs.movedim(-1, 0)
mask = mask_im.get_fdata().astype(bool)
mask = torch.from_numpy(mask)[None]

print(coeffs.shape)
print(mask.shape)

torch.Size([45, 110, 134, 108])
torch.Size([1, 110, 134, 108])


In [45]:
aff = fod_coeff_im.affine
aff = torch.from_numpy(aff)
print(aff)
print(aff.shape)

p1 = torch.as_tensor(cc_lr_lobe_idx).float()
p2 = torch.as_tensor(lr_and_ap_bipolar_lobe_idx).float()
p3 = torch.as_tensor(tri_polar_lobe_idx).float()
p = torch.stack([p1, p2, p3], 0)
print(p)
print(p.shape)
p_mm = pitn.affine.coord_transform_3d(p, aff)
print(p_mm)

tensor([[  1.2500,   0.0000,   0.0000, -67.5000],
        [  0.0000,   1.2500,   0.0000, -99.7500],
        [  0.0000,   0.0000,   1.2500, -60.7500],
        [  0.0000,   0.0000,   0.0000,   1.0000]], dtype=torch.float64)
torch.Size([4, 4])
tensor([[ 55.,  98.,  53.],
        [ 70., 106.,  54.],
        [ 60.,  68.,  43.]])
torch.Size([3, 3])
tensor([[  1.2500,  22.7500,   5.5000],
        [ 20.0000,  32.7500,   6.7500],
        [  7.5000, -14.7500,  -7.0000]], dtype=torch.float64)


torch.Size([45, 110, 134, 108])

In [66]:
vol_shape = coeffs.shape
aff_mm2vox = torch.linalg.inv(aff)
aff_vox2grid = torch.eye(4).to(aff_mm2vox)
aff_diag = 2 / (torch.as_tensor(vol_shape[-3:]) - 1)
aff_diag = torch.cat([aff_diag, aff_diag.new_ones(1)], 0)
aff_vox2grid = aff_vox2grid.diagonal_scatter(aff_diag)
# aff_vox2grid = aff_vox2grid.diag_embed()
aff_vox2grid[:3, 3:4] = -1
print(aff_vox2grid)

pitn.affine.coord_transform_3d(p_mm, aff_vox2grid @ aff_mm2vox)
print(
    pitn.affine.coord_transform_3d(
        torch.as_tensor([68.75, 66.5, 73]), aff_vox2grid @ aff_mm2vox
    )
)
print(
    pitn.affine.coord_transform_3d(
        torch.as_tensor([-67.5, -99.75, -60.75]), aff_vox2grid @ aff_mm2vox
    )
)

tensor([[ 0.0183,  0.0000,  0.0000, -1.0000],
        [ 0.0000,  0.0150,  0.0000, -1.0000],
        [ 0.0000,  0.0000,  0.0187, -1.0000],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], dtype=torch.float64)
tensor([1.0000, 1.0000, 1.0000], dtype=torch.float64)
tensor([-1., -1., -1.], dtype=torch.float64)


In [141]:
# Test mask sampling.
print(mask.shape)
print(mask_im.affine == fod_coeff_im.affine)
aff = torch.from_numpy(fod_coeff_im.affine).to(torch.float32)
p1 = torch.tensor([-51.25, 22.75, -2])  # Should be True, may have off-by-one error
p2 = torch.tensor([-67.5, -99.75, -60.75])  # False
p3 = torch.tensor([-67.5, -99.75, -48.25])  # False
p4 = torch.tensor([-1.25, -2.25, -33])  # False
p5 = torch.tensor([3.75, -18.5, 0.5])  # True
p6 = torch.tensor([-46.25, -53.5, -58.25])  # True, inserted manually into mask.
p = torch.stack([p1, p2, p3, p4, p5, p6], 0)
m = torch.clone(mask)
m[:, 17, 37, 2] = 1  # Corresponds to p6
mask_samples = pitn.affine.sample_3d(m, p, aff, mode="nearest", align_corners=True)
print(mask_samples)
print(mask_samples.shape)

torch.Size([1, 110, 134, 108])
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]
tensor([[ True, False, False, False,  True,  True]])
torch.Size([1, 6])


In [132]:
# Test sampling.
vol = torch.arange(0, 4**3).reshape(1, 1, 4, 4, 4).float()
print(vol.shape)
aff = torch.eye(4)
p = torch.tensor(
    [
        [0, 0, 0],
        [3, 3, 3],
        [0, 0, 3],
        [2, 1, 1],
        [4, 4, 4],
        [2.7095, 1.75, 1.5],
    ]
)
print(p.shape)
samples = pitn.affine.sample_3d(
    vol, torch.stack([p[..., 2], p[..., 1], p[..., 0]], -1), aff
)
print(samples)
print(samples.shape)
print(vol.squeeze()[tuple(p[:-2].T.long())])

torch.Size([1, 1, 4, 4, 4])
torch.Size([6, 3])
tensor([[ 0.0000, 63.0000,  3.0000, 37.0000,  0.0000, 51.8520]])
torch.Size([1, 6])
tensor([ 0., 63.,  3., 37.])


In [6]:
# # Change orientation for visualization.
# new_ornt = nib.orientations.axcodes2ornt(tuple("IPR"))
# ornt_tf = nib.orientations.ornt_transform(
#     nib.orientations.axcodes2ornt(nib.orientations.aff2axcodes(fod_coeff_im.affine)), new_ornt
# )
# coeffs = fod_coeff_im.as_reoriented(ornt_tf).get_fdata()
# coeffs = torch.from_numpy(coeffs)
# # Move to channels-first layout.
# coeffs = coeffs.movedim(-1, 0)
# mask = mask_im.as_reoriented(ornt_tf).get_fdata().astype(bool)
# mask = torch.from_numpy(mask)[None]

# print(coeffs.shape)
# print(mask.shape)

# # Transform the points of interest to the new coord layout.
# print("\nTransforming voxel coordinates of interest.")
# affine_vox2ras_phys = fod_coeff_im.affine
# affine_vox2ipr_phys = fod_coeff_im.as_reoriented(ornt_tf).affine
# affine_ipr_phys2vox = np.linalg.inv(affine_vox2ipr_phys)
# p_vox_ipr = list()
# for p in (cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx):
#     p = np.asarray(p)[:, None]
#     p_phys = (affine_vox2ras_phys[:3, :3] @ p) + affine_vox2ras_phys[:3, 3:4]
#     p_orient = (affine_ipr_phys2vox[:3, :3] @ p_phys) + affine_ipr_phys2vox[:3, 3:4]
#     print(p_orient.flatten().astype(int))
#     p_vox_ipr.append(tuple(p_orient.flatten().astype(int)))
# cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx = tuple(p_vox_ipr)
# print(cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx)

In [7]:
# sphere = dipy.data.HemiSphere.from_sphere(dipy.data.get_sphere("repulsion200"))
sphere = dipy.data.HemiSphere.from_sphere(dipy.data.get_sphere("repulsion724"))

theta, phi = pitn.odf.get_torch_sample_sphere_coords(
    sphere, coeffs.device, coeffs.dtype
)
with torch.no_grad():
    # Function applies non-negativity constraint.
    sphere_samples = pitn.odf.sample_sphere_coords(
        coeffs, theta=theta, phi=phi, sh_order=8, sh_order_dim=0, mask=mask
    )

nearest_sphere_samples = pitn.odf.adjacent_sphere_points_idx(theta=theta, phi=phi)
nearest_sphere_samples_idx = nearest_sphere_samples[0]
nearest_sphere_samples_valid_mask = nearest_sphere_samples[1]

### Fast-Marching Level Set (FMLS) Segmentation

In [8]:
# Threshold parameter from Algorithm 1 in Appendix A of SIFT paper.
min_sample_pdf_threshold = 0.0001

peak_diff_threshold = 0.8

min_lobe_pdf_peak_threshold = 1e-5
min_lobe_pdf_integral_threshold = 0.05

# Single voxel lobe segmentation
# vox_idx = cc_lr_lobe_idx
# vox_idx = lr_and_ap_bipolar_lobe_idx
# vox_idx = tri_polar_lobe_idx

In [9]:
# lr_and_ap_bipolar_lobe_idx = (70, 106, 54)
fodf_idx_range = (slice(60, 81), slice(96, 117), slice(44, 65))

b_fodf = sphere_samples[:, fodf_idx_range[0], fodf_idx_range[1], fodf_idx_range[2]]
b_fodf = einops.rearrange(b_fodf, "s ... -> (...) s")
# Remove low fodf values (count them as "noise").
b_fodf = pitn.odf.thresh_fodf_samples_by_pdf(b_fodf, min_sample_pdf_threshold)

b_fodf = b_fodf.to(device).to(torch.float32)
b_theta = theta.to(device).to(torch.float32)
b_phi = phi.to(device).to(torch.float32)

In [10]:
# Perform FMLS segmentation.
lobe_labels = pitn.tract.peak.fmls_fodf_seg(
    b_fodf, peak_diff_threshold, theta=b_theta, phi=b_phi
)

2023-01-12 11:33:33,658 - Remote TPU is not linked into jax; skipping remote TPU.
2023-01-12 11:33:33,659 - Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
2023-01-12 11:33:33,739 - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter Host CUDA
2023-01-12 11:33:33,741 - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
2023-01-12 11:33:33,742 - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.


In [11]:
# Refine lobe labels.
lobe_labels = pitn.tract.peak.remove_fodf_labels_by_pdf(
    lobe_labels,
    b_fodf,
    pdf_peak_min=min_lobe_pdf_peak_threshold,
    pdf_integral_min=min_lobe_pdf_integral_threshold,
)

In [12]:
unique_labels = lobe_labels.unique()
unique_labels = unique_labels[unique_labels > 0]

peak_vals = torch.zeros(lobe_labels.shape[0], len(unique_labels)).to(b_fodf)
peak_idx = -torch.ones_like(peak_vals).to(torch.long)
for i, l in enumerate(unique_labels):
    select_vals = torch.where(lobe_labels == l, b_fodf, -1)
    l_peak_idx = torch.argmax(select_vals, dim=1)[:, None]
    peak_idx[:, i] = l_peak_idx.flatten()
    peak_idx[:, i] = torch.where(
        select_vals.take_along_dim(l_peak_idx, dim=1) > 0, peak_idx[:, i, None], -1
    ).flatten()

valid_peak_mask = peak_idx >= 0
peak_vals = torch.where(
    peak_idx >= 0, b_fodf.take_along_dim(peak_idx.clamp_min(0), dim=1), -1
)
# The invalid indices are set to 0 to avoid subtle indexing errors later on; cuda in
# particular hates indexing out-of-bounds of a Tensor. Even though it is possible that
# an index value of 0 is valid, this is the only way to avoid those errors. The valid
# peak mask must be used to distinguish between real peak indices and those that are
# actually valued at 0.
peak_idx.clamp_min_(0)
print(peak_vals.shape)
print(peak_idx.shape)
print(valid_peak_mask.shape)

torch.Size([9261, 8])
torch.Size([9261, 8])
torch.Size([9261, 8])


In [32]:
peak_theta = torch.take(b_theta, index=peak_idx) * valid_peak_mask
peak_phi = torch.take(b_phi, index=peak_idx) * valid_peak_mask

entry_dirs = torch.stack(
    [
        torch.ones_like(peak_theta[:, 0]) * torch.pi / 4,
        torch.ones_like(peak_phi[:, 0]) * -torch.pi / 2,
    ],
    dim=-1,
)
peak_dirs = torch.stack([peak_theta, peak_phi], dim=-1)
near_directs, near_peaks = pitn.tract.direct.closest_opposing_direction(
    entry_dirs, peak_vals, peak_dirs, valid_peak_mask
)

In [64]:
torch.finfo(torch.float32)

finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=float32)

In [None]:
# # Visualize lobe segmentation result
# # Plot 3D surface of odf.
# %matplotlib widget

# # "surface" or "points"
# to_plot = "points"
# post_seg_filter = True

# viz_sphere = sphere
# viz_theta, viz_phi = pitn.odf.get_torch_sample_sphere_coords(
#     viz_sphere, coeffs.device, coeffs.dtype
# )
# polar_tri = mpl.tri.Triangulation(viz_phi, viz_theta)
# polar_tri_idx = torch.from_numpy(polar_tri.triangles).long()

# # Take labels from pre-segmentation filtering.
# if not post_seg_filter:
#     viz_tri_labels = lobe_labels.flatten()[polar_tri_idx]
# else:
#     # Take labels from post-segmentation filtering.
#     viz_tri_labels = ll.flatten()[polar_tri_idx]

# label_cmap = sns.cubehelix_palette(
#     n_colors=len(np.unique(viz_tri_labels.flatten())), reverse=True, rot=2, as_cmap=True
# )
# # viz_tri_labels = torch.mean(viz_tri_labels.float(), dim=1)
# viz_tri_labels = torch.median(viz_tri_labels, dim=1).values

# with torch.no_grad():
#     viz_coeffs = coeffs[(slice(None),) + vox_idx][:, None, None, None]
#     viz_mask = mask[(slice(None),) + vox_idx][:, None, None, None]
#     # Function applies non-negativity constraint.
#     viz_sphere_samples = pitn.odf.sample_sphere_coords(
#         viz_coeffs,
#         theta=viz_theta,
#         phi=viz_phi,
#         sh_order=8,
#         sh_order_dim=0,
#         mask=viz_mask,
#     )

# viz_fodf = np.copy(viz_sphere_samples.detach().cpu().numpy().flatten())
# viz_tri_labels = viz_tri_labels.detach().cpu().numpy().flatten()

# viz_theta = viz_theta.detach().cpu().numpy().flatten()
# viz_phi = viz_phi.detach().cpu().numpy().flatten()
# directions, values, indices = dipy.direction.peak_directions(
#     viz_fodf, viz_sphere, relative_peak_threshold=0.5, min_separation_angle=25
# )
# # viz_fodf[viz_fodf < values.min() * 0.3] = 1e-8
# with mpl.rc_context({"figure.autolayout": False}):
#     fig = plt.figure(dpi=120)

#     ax = fig.add_subplot(projection="3d")

#     vals = viz_fodf

#     r = (vals - vals.min()) / (vals - vals.min()).max()
#     r = vals / vals.sum()

#     x = r * np.sin(viz_theta) * np.cos(viz_phi)
#     y = r * np.sin(viz_theta) * np.sin(viz_phi)
#     z = r * np.cos(viz_theta)
#     mapper = mpl.cm.ScalarMappable(cmap=label_cmap)

#     # center_colors = mapper.to_rgba(viz_tri_labels)
#     # center_colors = np.where((viz_tri_labels == 0)[:, None], np.zeros_like(center_colors), center_colors)
#     #     vertex_colors = mapper.to_rgba(ll.numpy().flatten()[polar_tri.edges])
#     # polar_tri_idx = torch.from_numpy(polar_tri.triangles).long()
#     # viz_tri_labels = ll.flatten()[polar_tri_idx]
#     euclid_tri = mpl.tri.Triangulation(x, y, triangles=polar_tri.triangles)
#     surf = ax.plot_trisurf(euclid_tri, z, linewidth=0.3, antialiased=True, zorder=4)

#     if to_plot == "surface":
#         face_colors = mapper.to_rgba(viz_tri_labels)
#         face_colors = np.where(
#             (viz_tri_labels == 0)[:, None], np.zeros_like(face_colors), face_colors
#         )
#         surf.set_fc(face_colors)
#     elif to_plot == "points":
#         surf.set_fc("white")
#         surf.set_edgecolors([0.2, 0.2, 0.2, 0.5])
#         if post_seg_filter:
#             point_colors = mapper.to_rgba(ll.numpy().flatten())
#             point_colors = np.where(
#                 (ll.numpy().flatten() == 0)[:, None],
#                 np.zeros_like(point_colors),
#                 point_colors,
#             )
#         else:
#             point_colors = mapper.to_rgba(lobe_labels.numpy().flatten())
#             point_colors = np.where(
#                 (lobe_labels.numpy().flatten() == 0)[:, None],
#                 np.zeros_like(point_colors),
#                 point_colors,
#             )
#         ax.scatter3D(x, y, z, c=point_colors, s=20, zorder=0.1)
#     plt.colorbar(mapper, shrink=0.67)
#     plt.show()