In [None]:
%load_ext autoreload
%autoreload 2

from typing import Optional, Tuple
from pathlib import Path
import collections
import functools
from functools import partial
import itertools
import math
import os

# Change default behavior of jax GPU memory allocation.
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".20"

# visualization libraries
%matplotlib inline
from pprint import pprint
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import scipy
import skimage
import torch
import torch.nn.functional as F
import functorch
import einops
import monai
import dipy
import dipy.reconst
import dipy.reconst.csdeconv, dipy.reconst.shm, dipy.viz
import dipy.denoise
import dipy.io
import dipy.io.streamline
import nibabel as nib

import jax
import jax.config

# Disable jit for debugging.
# jax.config.update("jax_disable_jit", True)
# Enable 64-bit precision.
# jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_default_matmul_precision", 32)
import jax.numpy as jnp
from jax import lax
import jax.dlpack

import pitn


plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})
plt.rcParams.update({"image.cmap": "gray"})
plt.rcParams.update({"image.interpolation": "antialiased"})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    # Pick only one device for the default, may use multiple GPUs for training later.
    dev_idx = 0
    device = torch.device(f"cuda:{dev_idx}")
    print("CUDA Device IDX ", dev_idx)
    torch.cuda.set_device(device)
    print("CUDA Current Device ", torch.cuda.current_device())
    print("CUDA Device properties: ", torch.cuda.get_device_properties(device))
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
    # in PyTorch 1.12 and later.
    torch.backends.cuda.matmul.allow_tf32 = True
    # See
    # <https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices>
    # for details.

    # Activate cudnn benchmarking to optimize convolution algorithm speed.
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True
        print("CuDNN convolution optimization enabled.")
        # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
        torch.backends.cudnn.allow_tf32 = True

else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

In [None]:
hcp_full_res_data_dir = Path("/data/srv/data/pitn/hcp")
hcp_full_res_fodf_dir = Path("/data/srv/outputs/pitn/hcp/full-res/fodf")
hcp_low_res_data_dir = Path("/data/srv/outputs/pitn/hcp/downsample/scale-2.00mm/vol")
hcp_low_res_fodf_dir = Path("/data/srv/outputs/pitn/hcp/downsample/scale-2.00mm/fodf")

assert hcp_full_res_data_dir.exists()
assert hcp_full_res_fodf_dir.exists()
assert hcp_low_res_data_dir.exists()
assert hcp_low_res_fodf_dir.exists()

## Seed-Based Tractography Test

### Data & Parameter Selection

In [None]:
sample_fod_f = (
    hcp_full_res_fodf_dir / "162329" / "T1w" / "postproc_wm_msmt_csd_fod.nii.gz"
)
fod_coeff_im = nib.load(sample_fod_f)
fod_coeff_im = nib.as_closest_canonical(fod_coeff_im)
print("Original shape", fod_coeff_im.shape)
print("Original affine", fod_coeff_im.affine)
mask_f = sample_fod_f.parent / "postproc_nodif_brain_mask.nii.gz"
mask_im = nib.load(mask_f)
mask_im = nib.as_closest_canonical(mask_im)
white_matter_mask_f = sample_fod_f.parent / "postproc_5tt_parcellation.nii.gz"
wm_mask_im = nib.load(white_matter_mask_f)
wm_mask_im = nib.as_closest_canonical(wm_mask_im)
wm_mask_im = wm_mask_im.slicer[..., 2]

# Pre-select voxels of interest in RAS+ space for this specific subject.
# CC forceps minor, strong L-R uni-modal lobe
cc_lr_lobe_idx = (55, 98, 53)
# Dual-polar approx. equal volume fiber crossing
lr_and_ap_bipolar_lobe_idx = (70, 106, 54)
# Vox. adjacent to CST, tri-polar
tri_polar_lobe_idx = (60, 68, 43)


# Re-orient volumes from RAS to SAR (xyz -> zyx)
nib_affine_vox2ras_mm = fod_coeff_im.affine
affine_ras_vox2ras_mm = torch.from_numpy(nib_affine_vox2ras_mm).to(device)
ornt_ras = nib.orientations.io_orientation(nib_affine_vox2ras_mm)
ornt_sar = nib.orientations.axcodes2ornt(("S", "A", "R"))
ornt_ras2sar = nib.orientations.ornt_transform(ornt_ras, ornt_sar)
# We also need an affine that maps from SAR -> RAS
affine_sar2ras = nib.orientations.inv_ornt_aff(
    ornt_ras2sar, tuple(fod_coeff_im.shape[:-1])
)
affine_sar2ras = torch.from_numpy(affine_sar2ras).to(affine_ras_vox2ras_mm)
affine_ras2sar = torch.linalg.inv(affine_sar2ras)

# This essentially just flips the translation vector in the affine matrix. It may be
# "RAS" relative to the object/volume itself, but it is "SAR" relative to the original
# ordering of the dimensions in the data.
affine_sar_vox2sar_mm = affine_ras2sar @ (affine_ras_vox2ras_mm @ affine_sar2ras)

# Swap spatial dimensions, assign a new vox->world affine space.
sar_fod = einops.rearrange(fod_coeff_im.get_fdata(), "x y z coeffs -> z y x coeffs")
fod_coeff_im = nib.Nifti1Image(
    sar_fod,
    affine=(affine_sar_vox2sar_mm).cpu().numpy(),
    header=fod_coeff_im.header,
)
sar_mask = einops.rearrange(mask_im.get_fdata().astype(bool), "x y z -> z y x")
mask_im = nib.Nifti1Image(
    sar_mask,
    affine=(affine_sar_vox2sar_mm).cpu().numpy(),
    header=mask_im.header,
)
sar_wm_mask = einops.rearrange(wm_mask_im.get_fdata().astype(bool), "x y z -> z y x")
wm_mask_im = nib.Nifti1Image(
    sar_wm_mask,
    affine=(affine_sar_vox2sar_mm).cpu().numpy(),
    header=wm_mask_im.header,
)

print(fod_coeff_im.affine)
print(fod_coeff_im.shape)
print(mask_im.affine)
print(mask_im.shape)

# Flip the pre-selected voxels.
sar_vox_idx = pitn.affine.coord_transform_3d(
    affine_ras2sar.new_tensor(
        [cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx]
    ),
    affine_ras2sar,
)
cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx = tuple(
    sar_vox_idx.int().cpu().tolist()
)
cc_lr_lobe_idx = tuple(cc_lr_lobe_idx)
lr_and_ap_bipolar_lobe_idx = tuple(lr_and_ap_bipolar_lobe_idx)
tri_polar_lobe_idx = tuple(tri_polar_lobe_idx)
print(cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx)

In [None]:
coeffs = fod_coeff_im.get_fdata()
coeffs = torch.from_numpy(coeffs).to(device)
fod_coeff_im.uncache()
# Move to channels-first layout.
coeffs = einops.rearrange(coeffs, "z y x coeffs -> coeffs z y x")
brain_mask = mask_im.get_fdata().astype(bool)
brain_mask = torch.from_numpy(brain_mask).to(device)
mask_im.uncache()
brain_mask = einops.rearrange(brain_mask, "z y x -> 1 z y x")
wm_mask = torch.from_numpy(wm_mask_im.get_fdata().astype(bool)).to(device)
wm_mask = einops.rearrange(wm_mask, "z y x -> 1 z y x")
wm_mask_im.uncache()
seed_mask = torch.zeros_like(brain_mask).bool()

select_vox_idx = cc_lr_lobe_idx
# select_vox_idx = lr_and_ap_bipolar_lobe_idx
# select_vox_idx = tri_polar_lobe_idx
seed_mask[0, select_vox_idx[0], select_vox_idx[1], select_vox_idx[2]] = True

print(coeffs.shape)
print(brain_mask.shape)
print(seed_mask.shape)

In [None]:
# sphere = dipy.data.HemiSphere.from_sphere(dipy.data.get_sphere("repulsion200"))
sphere = dipy.data.HemiSphere.from_sphere(dipy.data.get_sphere("repulsion724"))

theta, phi = pitn.odf.get_torch_sample_sphere_coords(
    sphere, coeffs.device, coeffs.dtype
)

nearest_sphere_samples = pitn.odf.adjacent_sphere_points_idx(theta=theta, phi=phi)
nearest_sphere_samples_idx = nearest_sphere_samples[0]
nearest_sphere_samples_valid_mask = nearest_sphere_samples[1]

In [None]:
max_sh_order = 8

# Element-wise filtering of sphere samples.
min_sample_pdf_threshold = 0.0001

# Threshold parameter for FMLS segmentation.
lobe_merge_ratio = 0.8
# Post-segmentation label filtering.
min_lobe_pdf_peak_threshold = 1e-5
min_lobe_pdf_integral_threshold = 0.05

# Seed creation.
peaks_per_seed_vox = 3
seed_batch_size = 2
# Total seeds per voxel will be `seeds_per_vox_axis`^3
seeds_per_vox_axis = 1

# RK4 estimation
step_size = 0.4
alpha_exponential_moving_avg = 0.15

# Stopping & invalidation criteria.
min_streamline_len = 10
max_streamline_len = 100
gfa_min_threshold = 0.25

### Tractography Reconstruction Loop - Trilinear Interpolation

In [None]:
# temp is x,y,z tuple of scipy.sparse.lil_arrays
# full streamline list is x,y,z tuple of scipy.sparse.csr_arrays
# After every seed batch, the remaining temp tracts are row-wise stacked onto the full
# streamline list with scipy.sparse.vstack()

In [None]:
def _fn_linear_interp_zyx_tangent_t2theta_phi(
    target_coords_mm_zyx: torch.Tensor,
    init_direction_theta_phi: Optional[torch.Tensor],
    fodf_coeffs_brain_vol: torch.Tensor,
    affine_vox2mm: torch.Tensor,
    sphere_samples_theta: torch.Tensor,
    sphere_samples_phi: torch.Tensor,
    sh_order: int,
    fodf_pdf_thresh_min: float,
    fmls_lobe_merge_ratio: float,
    lobe_fodf_pdf_filter_kwargs: dict,
    duplicate_peaks_whole_sphere: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Initial interpolation of fodf coefficients at the target points.
    pred_sample_fodf_coeffs = pitn.odf.sample_odf_coeffs_lin_interp(
        target_coords_mm_zyx,
        fodf_coeff_vol=fodf_coeffs_brain_vol,
        affine_vox2mm=affine_vox2mm,
    )

    # Transform to fodf spherical samples.
    target_sphere_samples = pitn.odf.sample_sphere_coords(
        pred_sample_fodf_coeffs,
        theta=sphere_samples_theta,
        phi=sphere_samples_phi,
        sh_order=sh_order,
    )

    # Threshold spherical function values.
    target_sphere_samples = pitn.odf.thresh_fodf_samples_by_pdf(
        target_sphere_samples, fodf_pdf_thresh_min
    )

    # Segment lobes on the fodf samples in each voxel.
    lobe_labels = pitn.tract.peak.fmls_fodf_seg(
        target_sphere_samples,
        lobe_merge_ratio=fmls_lobe_merge_ratio,
        theta=sphere_samples_theta,
        phi=sphere_samples_phi,
    )

    # Refine the segmentation.
    lobe_labels = pitn.tract.peak.remove_fodf_labels_by_pdf(
        lobe_labels, target_sphere_samples, **lobe_fodf_pdf_filter_kwargs
    )

    # Find the peaks from the lobe segmentation.
    peaks = pitn.tract.peak.peaks_from_segment(
        lobe_labels,
        target_sphere_samples,
        theta_coord=sphere_samples_theta,
        phi_coord=sphere_samples_phi,
    )

    # If no initial direction is given, or the initial direction vector is 0, then
    # just find the largest peak.
    if (init_direction_theta_phi is None) or (
        torch.as_tensor(init_direction_theta_phi) == 0
    ).all():
        largest_peak = pitn.tract.peak.topk_peaks(
            k=1,
            fodf_peaks=peaks.peaks,
            theta_peak=peaks.theta,
            phi_peaks=peaks.phi,
            valid_peak_mask=peaks.valid_peak_mask,
        )
        result_direction_theta_phi = (largest_peak.theta, largest_peak.phi)
    # Otherwise if an initial direction vector is given, find the peak closest to that
    # incoming direction.
    else:
        if duplicate_peaks_whole_sphere:
            # Duplicate for coverage over the whole sphere.
            full_sphere_fodf_peaks = (
                pitn.tract.direction.fodf_duplicate_hemisphere2sphere(
                    peaks.theta, peaks.phi, (peaks.peaks, peaks.valid_peak_mask), (1, 1)
                )
            )
            fodf_peaks = full_sphere_fodf_peaks.vals[0]
            peak_coords_theta_phi = torch.stack(
                [full_sphere_fodf_peaks.theta, full_sphere_fodf_peaks.phi], -1
            )
            valid_mask = full_sphere_fodf_peaks.vals[1]
        else:
            fodf_peaks = peaks.peaks
            peak_coords_theta_phi = torch.stack([peaks.theta, peaks.phi], -1)
            valid_mask = peaks.valid_peak_mask

        opposing_peak = pitn.tract.direction.closest_opposing_direction(
            init_direction_theta_phi,
            fodf_peaks=fodf_peaks,
            peak_coords_theta_phi=peak_coords_theta_phi,
            peaks_valid_mask=valid_mask,
        )
        # Unpack the spherical coordinates, then split each coordinate into a 2-tuple.
        result_direction_theta_phi = opposing_peak[0]
        result_direction_theta_phi = (
            result_direction_theta_phi[..., 0],
            result_direction_theta_phi[..., 1],
        )

    return result_direction_theta_phi


fn_linear_interp_zyx_tangent_t2theta_phi = partial(
    _fn_linear_interp_zyx_tangent_t2theta_phi,
    fodf_coeffs_brain_vol=coeffs,
    affine_vox2mm=affine_sar_vox2sar_mm,
    sphere_samples_theta=theta,
    sphere_samples_phi=phi,
    sh_order=max_sh_order,
    fodf_pdf_thresh_min=min_sample_pdf_threshold,
    fmls_lobe_merge_ratio=lobe_merge_ratio,
    lobe_fodf_pdf_filter_kwargs={
        "pdf_peak_min": min_lobe_pdf_peak_threshold,
        "pdf_integral_min": min_lobe_pdf_integral_threshold,
    },
)

In [None]:
# Reduced version of the full interpolation function, to be called only when expanding
# the seed points at the start of streamline estimation.
def _peaks_only_fn_linear_interp_zyx(
    target_coords_mm_zyx: torch.Tensor,
    fodf_coeffs_brain_vol: torch.Tensor,
    affine_vox2mm: torch.Tensor,
    sphere_samples_theta: torch.Tensor,
    sphere_samples_phi: torch.Tensor,
    sh_order: int,
    fodf_pdf_thresh_min: float,
    fmls_lobe_merge_ratio: float,
    lobe_fodf_pdf_filter_kwargs: dict,
) -> pitn.tract.peak.PeaksContainer:
    # Initial interpolation of fodf coefficients at the target points.
    pred_sample_fodf_coeffs = pitn.odf.sample_odf_coeffs_lin_interp(
        target_coords_mm_zyx,
        fodf_coeff_vol=fodf_coeffs_brain_vol,
        affine_vox2mm=affine_vox2mm,
    )

    # Transform to fodf spherical samples.
    target_sphere_samples = pitn.odf.sample_sphere_coords(
        pred_sample_fodf_coeffs,
        theta=sphere_samples_theta,
        phi=sphere_samples_phi,
        sh_order=sh_order,
    )

    # Threshold spherical function values.
    target_sphere_samples = pitn.odf.thresh_fodf_samples_by_pdf(
        target_sphere_samples, fodf_pdf_thresh_min
    )

    # Segment lobes on the fodf samples in each voxel.
    lobe_labels = pitn.tract.peak.fmls_fodf_seg(
        target_sphere_samples,
        lobe_merge_ratio=fmls_lobe_merge_ratio,
        theta=sphere_samples_theta,
        phi=sphere_samples_phi,
    )

    # Refine the segmentation.
    lobe_labels = pitn.tract.peak.remove_fodf_labels_by_pdf(
        lobe_labels, target_sphere_samples, **lobe_fodf_pdf_filter_kwargs
    )

    # Find the peaks from the lobe segmentation.
    peaks = pitn.tract.peak.peaks_from_segment(
        lobe_labels,
        target_sphere_samples,
        theta_coord=sphere_samples_theta,
        phi_coord=sphere_samples_phi,
    )

    return peaks


# Copy the static parameters from the full interplation function.
peaks_only_fn_linear_interp_zyx = partial(
    _peaks_only_fn_linear_interp_zyx,
    **fn_linear_interp_zyx_tangent_t2theta_phi.keywords,
)

In [None]:
# Create initial seeds and tangent/direction vectors.

seeds_t_neg1 = pitn.tract.seed.seeds_from_mask(
    seed_mask,
    seeds_per_vox_axis=seeds_per_vox_axis,
    affine_vox2mm=affine_sar_vox2sar_mm,
)
seed_peaks = peaks_only_fn_linear_interp_zyx(seeds_t_neg1)

(seeds_t_neg1_to_0, tangent_t0_zyx,) = pitn.tract.seed.expand_seeds_from_topk_peaks_rk4(
    seeds_t_neg1,
    max_peaks_per_voxel=peaks_per_seed_vox,
    seed_peak_vals=seed_peaks.peaks,
    theta_peak=seed_peaks.theta,
    phi_peak=seed_peaks.phi,
    valid_peak_mask=seed_peaks.valid_peak_mask,
    step_size=step_size,
    fn_zyx_direction_t2theta_phi=partial(
        fn_linear_interp_zyx_tangent_t2theta_phi, duplicate_peaks_whole_sphere=False
    ),
)

In [None]:
# Handle stopping conditions.
with torch.no_grad():
    gfa_sampling_sphere = dipy.data.get_sphere("repulsion724").subdivide(1)

    gfa_theta, gfa_phi = pitn.odf.get_torch_sample_sphere_coords(
        gfa_sampling_sphere, coeffs.device, coeffs.dtype
    )
    # Function applies non-negativity constraint.
    gfa_sphere_samples = pitn.odf.sample_sphere_coords(
        coeffs.cpu(),
        theta=gfa_theta.cpu(),
        phi=gfa_phi.cpu(),
        sh_order=8,
        sh_order_dim=0,
        mask=brain_mask.cpu(),
    )

    gfa = pitn.odf.gfa(gfa_sphere_samples, sphere_samples_idx=0).to(device)
    # Also, mask out only the white matter in the gfa! Otherwise, gfa can be high in
    # most places...
    gfa = gfa * wm_mask
    del gfa_sphere_samples, gfa_theta, gfa_phi, gfa_sampling_sphere

In [None]:
#!DEBUG


def fn_only_right_zyx2theta_phi(
    target_coords_mm_zyx: torch.Tensor, init_direction_theta_phi: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
    results_shape = tuple(target_coords_mm_zyx.shape[:-1])
    theta = target_coords_mm_zyx.new_ones(results_shape) * (torch.pi / 2)
    phi = torch.zeros_like(theta)

    return (theta, phi)

In [None]:
# Primary tracrography loop.

streamline_status = (
    torch.ones(
        seeds_t_neg1_to_0.shape[1], dtype=torch.int8, device=seeds_t_neg1_to_0.device
    )
    * pitn.tract.stopping.CONTINUE
)
# At least one step has been made.
streamline_len = torch.zeros_like(streamline_status).float() + step_size

streamlines = list()
streamlines.append(seeds_t_neg1_to_0[0])
streamlines.append(seeds_t_neg1_to_0[1])

points_t = seeds_t_neg1_to_0[1]
tangent_t_theta_phi = torch.stack(
    pitn.tract.local.zyx2unit_sphere_theta_phi(tangent_t0_zyx), -1
)
tangent_t_zyx = tangent_t0_zyx

# t_max = 1e8
t_max = 300
t = 1
while pitn.tract.stopping.to_continue_mask(streamline_status).any():

    points_tp1 = torch.zeros_like(points_t) * torch.nan
    tangent_tp1_theta_phi = torch.zeros_like(tangent_t_theta_phi) * torch.nan

    to_process_mask = pitn.tract.stopping.to_continue_mask(streamline_status)
    valid_tangent_tp1_zyx = pitn.tract.local.gen_tract_step_rk4(
        points_t[to_process_mask],
        init_direction_theta_phi=tangent_t_theta_phi[to_process_mask],
        fn_zyx_direction_t2theta_phi=fn_linear_interp_zyx_tangent_t2theta_phi,
        # fn_zyx_direction_t2theta_phi=fn_only_right_zyx2theta_phi, #!DEBUG
        step_size=step_size,
    )
    ema_tangent_tp1_zyx = (
        alpha_exponential_moving_avg * valid_tangent_tp1_zyx
        + (1 - alpha_exponential_moving_avg) * tangent_t_zyx[to_process_mask]
    )
    ema_tangent_tp1_zyx = (
        step_size
        * ema_tangent_tp1_zyx
        / torch.linalg.vector_norm(ema_tangent_tp1_zyx, ord=2, dim=-1, keepdim=True)
    )

    points_tp1.masked_scatter_(
        to_process_mask[..., None], points_t[to_process_mask] + ema_tangent_tp1_zyx
    )
    tangent_tp1_zyx = ema_tangent_tp1_zyx
    valid_tangent_tp1_theta_phi = torch.stack(
        pitn.tract.local.zyx2unit_sphere_theta_phi(tangent_tp1_zyx), -1
    )
    tangent_tp1_theta_phi.masked_scatter_(
        to_process_mask[..., None], valid_tangent_tp1_theta_phi
    )

    # Update state variables based upon new streamline statuses.
    status_tp1 = torch.clone(streamline_status)
    tmp_len = torch.where(to_process_mask, streamline_len + step_size, streamline_len)
    status_tp1 = pitn.tract.stopping.streamline_len_mm(
        streamline_status,
        tmp_len,
        min_len=min_streamline_len,
        max_len=max_streamline_len,
    )
    status_tp1 = pitn.tract.stopping.gfa_threshold(
        status_tp1,
        sample_coords_mm_zyx=points_tp1,
        gfa_min_threshold=gfa_min_threshold,
        gfa_vol=gfa,
        affine_vox2mm=affine_sar_vox2sar_mm,
    )

    tp1_continue_mask = pitn.tract.stopping.to_continue_mask(status_tp1)

    points_tp1.masked_fill_(~tp1_continue_mask[..., None], torch.nan)

    # t <- t + 1
    print(t, end=" ")
    t += 1
    if t > t_max:
        break

    streamlines.append(points_tp1)
    points_t = points_tp1
    tangent_t_theta_phi = torch.where(
        tp1_continue_mask[..., None], tangent_tp1_theta_phi, 0
    )
    tangent_t_zyx = (0 * tangent_t_zyx).masked_scatter(
        tp1_continue_mask[..., None], tangent_tp1_zyx
    )
    streamline_len = torch.where(
        tp1_continue_mask,
        streamline_len + step_size,
        streamline_len,
    )
    streamline_status = status_tp1
    # if (streamline_status != pitn.tract.stopping.CONTINUE).any():
    #     print("Stopped a tract!")
    #     break
# Shape `tract_seed x n_steps x 3`
streamlines = torch.stack(streamlines, 1)
print("", end="", flush=True)

In [None]:
tracts = np.split(streamlines.detach().cpu().numpy(), streamlines.shape[0], axis=0)
tracts = [t.squeeze()[(~np.isnan(t.squeeze())).any(-1)] for t in tracts]
sar_tracts = dipy.io.dpy.Streamlines(tracts)
sar_tracto = dipy.io.streamline.Tractogram(
    sar_tracts, affine_to_rasmm=affine_sar2ras.cpu().numpy()
)
tracto = sar_tracto.to_world()
# Get the header from an "un-re-oriented" fod volume and give to the tractogram.

ref_header = nib.as_closest_canonical(nib.load(sample_fod_f)).header
tracto = dipy.io.streamline.StatefulTractogram(
    tracto.streamlines,
    space=dipy.io.stateful_tractogram.Space.RASMM,
    reference=ref_header,
)

In [None]:
dipy.io.streamline.save_tck(tracto, "/tmp/dipolar_single_vox_test_trax.tck")

In [None]:
plt.plot(streamlines[3, :, 2].cpu().numpy(), label="x")
plt.plot(streamlines[3, :, 1].cpu().numpy(), label="y")
plt.plot(streamlines[3, :, 0].cpu().numpy(), label="z")

plt.legend();

In [None]:
im = nib.Nifti1Image(
    gfa[0].cpu().swapdims(0, 2).numpy(), affine_ras_vox2ras_mm.cpu().numpy(), ref_header
)

nib.save(im, str(sample_fod_f.parent / "gfa.nii.gz"))

In [None]:
gfa_sampling_sphere = dipy.data.get_sphere("repulsion724").subdivide(1)

gfa_theta, gfa_phi = pitn.odf.get_torch_sample_sphere_coords(
    gfa_sampling_sphere, coeffs.device, coeffs.dtype
)
# Function applies non-negativity constraint.
gfa_sphere_samples = pitn.odf.sample_sphere_coords(
    coeffs.cpu(),
    theta=gfa_theta.cpu(),
    phi=gfa_phi.cpu(),
    sh_order=8,
    sh_order_dim=0,
    mask=brain_mask.cpu(),
)

In [None]:
dipy_gfa = dipy.direction.gfa(
    gfa_sphere_samples.cpu().movedim(0, -1).swapdims(0, 2).numpy()
)

dipy_gfa = np.nan_to_num(dipy_gfa, nan=0)

In [None]:
dipy_im = nib.Nifti1Image(dipy_gfa, affine_ras_vox2ras_mm.cpu().numpy(), ref_header)
nib.save(im, str(sample_fod_f.parent / "dipy_gfa.nii.gz"))

### Tractogram & Sampling Testing

In [None]:
sample_fod_f.parent / "gfa.nii.gz"

In [None]:
sar_tracto

In [None]:
# print(seeds_t_neg1[:, 0].unique().mean())

start_point = pitn.affine.coord_transform_3d(
    affine_vox2mm.new_tensor(select_vox_idx), affine_vox2mm
)
fake_tracts = start_point.new_empty(100, 3)
# Just step along the same direction in small step sizes.
for i in range(100):
    fake_tracts[i] = start_point + i * 0.2 * (
        start_point / torch.linalg.norm(start_point, 2)
    )
fake_tracts = fake_tracts[None]
fake_tracts = np.split(fake_tracts.detach().cpu().numpy(), fake_tracts.shape[0], axis=0)
fake_tracts = [t.squeeze()[(~np.isnan(t.squeeze())).any(-1)] for t in fake_tracts]
fake_tracto = dipy.io.dpy.Streamlines(fake_tracts)
# Coordinates are already in RAS+ (mm) world coordinates.
fake_tracto = dipy.io.streamline.StatefulTractogram(
    fake_tracto,
    space=dipy.io.stateful_tractogram.Space.RASMM,
    reference=fod_coeff_im.header,
)

dipy.io.streamline.save_tck(fake_tracto, "/tmp/fake_trax.tck")

In [None]:
print(affine_vox2mm)
aff = affine_vox2mm.cpu().numpy()

In [None]:
print(aff)
ornt_ras = nib.orientations.io_orientation(aff)
ornt_sar = nib.orientations.axcodes2ornt(("S", "A", "R"))
print(ornt_ras)
print(ornt_sar)
ornt_sar2ras = nib.orientations.ornt_transform(ornt_sar, ornt_ras)
aff_ras2sar = nib.orientations.inv_ornt_aff(ornt_sar2ras, tuple(coeffs.shape[1:]))

print(aff_ras2sar)

print(aff_ras2sar.dot(aff))
print(aff_ras2sar @ aff)

## fODF Peak Finding

In [None]:
sample_fod_f = (
    hcp_full_res_fodf_dir / "162329" / "T1w" / "postproc_wm_msmt_csd_fod.nii.gz"
)
fod_coeff_im = nib.load(sample_fod_f)
fod_coeff_im = nib.as_closest_canonical(fod_coeff_im)
mask_f = sample_fod_f.parent / "postproc_nodif_brain_mask.nii.gz"
mask_im = nib.load(mask_f)
mask_im = nib.as_closest_canonical(mask_im)

# Pre-select voxels of interest for this specific subject.
# CC forceps minor, strong L-R uni-modal lobe
cc_lr_lobe_idx = (55, 98, 53)
# Dual-polar approx. equal volume fiber crossing
lr_and_ap_bipolar_lobe_idx = (70, 106, 54)
# Vox. adjacent to CST, tri-polar
tri_polar_lobe_idx = (60, 68, 43)

In [None]:
coeffs = fod_coeff_im.get_fdata()
coeffs = torch.from_numpy(coeffs)
# Move to channels-first layout.
coeffs = coeffs.movedim(-1, 0)
mask = mask_im.get_fdata().astype(bool)
mask = torch.from_numpy(mask)[None]

print(coeffs.shape)
print(mask.shape)

In [None]:
aff = fod_coeff_im.affine
aff = torch.from_numpy(aff)
print(aff)
print(aff.shape)

p1 = torch.as_tensor(cc_lr_lobe_idx).float()
p2 = torch.as_tensor(lr_and_ap_bipolar_lobe_idx).float()
p3 = torch.as_tensor(tri_polar_lobe_idx).float()
p = torch.stack([p1, p2, p3], 0)
print(p)
print(p.shape)
p_mm = pitn.affine.coord_transform_3d(p, aff)
print(p_mm)

In [None]:
vol_shape = coeffs.shape
aff_mm2vox = torch.linalg.inv(aff)
aff_vox2grid = torch.eye(4).to(aff_mm2vox)
aff_diag = 2 / (torch.as_tensor(vol_shape[-3:]) - 1)
aff_diag = torch.cat([aff_diag, aff_diag.new_ones(1)], 0)
aff_vox2grid = aff_vox2grid.diagonal_scatter(aff_diag)
# aff_vox2grid = aff_vox2grid.diag_embed()
aff_vox2grid[:3, 3:4] = -1
print(aff_vox2grid)

pitn.affine.coord_transform_3d(p_mm, aff_vox2grid @ aff_mm2vox)
print(
    pitn.affine.coord_transform_3d(
        torch.as_tensor([68.75, 66.5, 73]), aff_vox2grid @ aff_mm2vox
    )
)
print(
    pitn.affine.coord_transform_3d(
        torch.as_tensor([-67.5, -99.75, -60.75]), aff_vox2grid @ aff_mm2vox
    )
)

In [None]:
# Test mask sampling.
print(mask.shape)
print(mask_im.affine == fod_coeff_im.affine)
aff = torch.from_numpy(fod_coeff_im.affine).to(torch.float32)
p1 = torch.tensor([-51.25, 22.75, -2])  # Should be True, may have off-by-one error
p2 = torch.tensor([-67.5, -99.75, -60.75])  # False
p3 = torch.tensor([-67.5, -99.75, -48.25])  # False
p4 = torch.tensor([-1.25, -2.25, -33])  # False
p5 = torch.tensor([3.75, -18.5, 0.5])  # True
p6 = torch.tensor([-46.25, -53.5, -58.25])  # True, inserted manually into mask.
p = torch.stack([p1, p2, p3, p4, p5, p6], 0)
m = torch.clone(mask)
m[:, 17, 37, 2] = 1  # Corresponds to p6
mask_samples = pitn.affine.sample_3d(m, p, aff, mode="nearest", align_corners=True)
print(mask_samples)
print(mask_samples.shape)

In [None]:
# Test sampling.
vol = torch.arange(0, 4**3).reshape(1, 1, 4, 4, 4).float()
print(vol.shape)
aff = torch.eye(4)
p = torch.tensor(
    [
        [0, 0, 0],
        [3, 3, 3],
        [0, 0, 3],
        [2, 1, 1],
        [4, 4, 4],
        [2.7095, 1.75, 1.5],
    ]
)
print(p.shape)
samples = pitn.affine.sample_3d(
    vol, torch.stack([p[..., 2], p[..., 1], p[..., 0]], -1), aff
)
print(samples)
print(samples.shape)
print(vol.squeeze()[tuple(p[:-2].T.long())])

In [None]:
# # Change orientation for visualization.
# new_ornt = nib.orientations.axcodes2ornt(tuple("IPR"))
# ornt_tf = nib.orientations.ornt_transform(
#     nib.orientations.axcodes2ornt(nib.orientations.aff2axcodes(fod_coeff_im.affine)), new_ornt
# )
# coeffs = fod_coeff_im.as_reoriented(ornt_tf).get_fdata()
# coeffs = torch.from_numpy(coeffs)
# # Move to channels-first layout.
# coeffs = coeffs.movedim(-1, 0)
# mask = mask_im.as_reoriented(ornt_tf).get_fdata().astype(bool)
# mask = torch.from_numpy(mask)[None]

# print(coeffs.shape)
# print(mask.shape)

# # Transform the points of interest to the new coord layout.
# print("\nTransforming voxel coordinates of interest.")
# affine_vox2ras_phys = fod_coeff_im.affine
# affine_vox2ipr_phys = fod_coeff_im.as_reoriented(ornt_tf).affine
# affine_ipr_phys2vox = np.linalg.inv(affine_vox2ipr_phys)
# p_vox_ipr = list()
# for p in (cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx):
#     p = np.asarray(p)[:, None]
#     p_phys = (affine_vox2ras_phys[:3, :3] @ p) + affine_vox2ras_phys[:3, 3:4]
#     p_orient = (affine_ipr_phys2vox[:3, :3] @ p_phys) + affine_ipr_phys2vox[:3, 3:4]
#     print(p_orient.flatten().astype(int))
#     p_vox_ipr.append(tuple(p_orient.flatten().astype(int)))
# cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx = tuple(p_vox_ipr)
# print(cc_lr_lobe_idx, lr_and_ap_bipolar_lobe_idx, tri_polar_lobe_idx)

In [None]:
# sphere = dipy.data.HemiSphere.from_sphere(dipy.data.get_sphere("repulsion200"))
sphere = dipy.data.HemiSphere.from_sphere(dipy.data.get_sphere("repulsion724"))

theta, phi = pitn.odf.get_torch_sample_sphere_coords(
    sphere, coeffs.device, coeffs.dtype
)
with torch.no_grad():
    # Function applies non-negativity constraint.
    sphere_samples = pitn.odf.sample_sphere_coords(
        coeffs, theta=theta, phi=phi, sh_order=8, sh_order_dim=0, mask=mask
    )

nearest_sphere_samples = pitn.odf.adjacent_sphere_points_idx(theta=theta, phi=phi)
nearest_sphere_samples_idx = nearest_sphere_samples[0]
nearest_sphere_samples_valid_mask = nearest_sphere_samples[1]

### Fast-Marching Level Set (FMLS) Segmentation

In [None]:
# Threshold parameter from Algorithm 1 in Appendix A of SIFT paper.
min_sample_pdf_threshold = 0.0001

peak_diff_threshold = 0.8

min_lobe_pdf_peak_threshold = 1e-5
min_lobe_pdf_integral_threshold = 0.05

# Single voxel lobe segmentation
# vox_idx = cc_lr_lobe_idx
# vox_idx = lr_and_ap_bipolar_lobe_idx
# vox_idx = tri_polar_lobe_idx

In [None]:
# lr_and_ap_bipolar_lobe_idx = (70, 106, 54)
fodf_idx_range = (slice(60, 81), slice(96, 117), slice(44, 65))

b_fodf = sphere_samples[:, fodf_idx_range[0], fodf_idx_range[1], fodf_idx_range[2]]
b_fodf = einops.rearrange(b_fodf, "s ... -> (...) s")
# Remove low fodf values (count them as "noise").
b_fodf = pitn.odf.thresh_fodf_samples_by_pdf(b_fodf, min_sample_pdf_threshold)

b_fodf = b_fodf.to(device).to(torch.float32)
b_theta = theta.to(device).to(torch.float32)
b_phi = phi.to(device).to(torch.float32)

In [None]:
# Perform FMLS segmentation.
lobe_labels = pitn.tract.peak.fmls_fodf_seg(
    b_fodf, peak_diff_threshold, theta=b_theta, phi=b_phi
)

In [None]:
# Refine lobe labels.
lobe_labels = pitn.tract.peak.remove_fodf_labels_by_pdf(
    lobe_labels,
    b_fodf,
    pdf_peak_min=min_lobe_pdf_peak_threshold,
    pdf_integral_min=min_lobe_pdf_integral_threshold,
)

In [None]:
unique_labels = lobe_labels.unique()
unique_labels = unique_labels[unique_labels > 0]

peak_vals = torch.zeros(lobe_labels.shape[0], len(unique_labels)).to(b_fodf)
peak_idx = -torch.ones_like(peak_vals).to(torch.long)
for i, l in enumerate(unique_labels):
    select_vals = torch.where(lobe_labels == l, b_fodf, -1)
    l_peak_idx = torch.argmax(select_vals, dim=1)[:, None]
    peak_idx[:, i] = l_peak_idx.flatten()
    peak_idx[:, i] = torch.where(
        select_vals.take_along_dim(l_peak_idx, dim=1) > 0, peak_idx[:, i, None], -1
    ).flatten()

valid_peak_mask = peak_idx >= 0
peak_vals = torch.where(
    peak_idx >= 0, b_fodf.take_along_dim(peak_idx.clamp_min(0), dim=1), -1
)
# The invalid indices are set to 0 to avoid subtle indexing errors later on; cuda in
# particular hates indexing out-of-bounds of a Tensor. Even though it is possible that
# an index value of 0 is valid, this is the only way to avoid those errors. The valid
# peak mask must be used to distinguish between real peak indices and those that are
# actually valued at 0.
peak_idx.clamp_min_(0)
print(peak_vals.shape)
print(peak_idx.shape)
print(valid_peak_mask.shape)

In [None]:
peak_theta = torch.take(b_theta, index=peak_idx) * valid_peak_mask
peak_phi = torch.take(b_phi, index=peak_idx) * valid_peak_mask

entry_dirs = torch.stack(
    [
        torch.ones_like(peak_theta[:, 0]) * torch.pi / 4,
        torch.ones_like(peak_phi[:, 0]) * -torch.pi / 2,
    ],
    dim=-1,
)
peak_dirs = torch.stack([peak_theta, peak_phi], dim=-1)
near_directs, near_peaks = pitn.tract.direct.closest_opposing_direction(
    entry_dirs, peak_vals, peak_dirs, valid_peak_mask
)

In [None]:
torch.finfo(torch.float32)

In [None]:
# # Visualize lobe segmentation result
# # Plot 3D surface of odf.
# %matplotlib widget

# # "surface" or "points"
# to_plot = "points"
# post_seg_filter = True

# viz_sphere = sphere
# viz_theta, viz_phi = pitn.odf.get_torch_sample_sphere_coords(
#     viz_sphere, coeffs.device, coeffs.dtype
# )
# polar_tri = mpl.tri.Triangulation(viz_phi, viz_theta)
# polar_tri_idx = torch.from_numpy(polar_tri.triangles).long()

# # Take labels from pre-segmentation filtering.
# if not post_seg_filter:
#     viz_tri_labels = lobe_labels.flatten()[polar_tri_idx]
# else:
#     # Take labels from post-segmentation filtering.
#     viz_tri_labels = ll.flatten()[polar_tri_idx]

# label_cmap = sns.cubehelix_palette(
#     n_colors=len(np.unique(viz_tri_labels.flatten())), reverse=True, rot=2, as_cmap=True
# )
# # viz_tri_labels = torch.mean(viz_tri_labels.float(), dim=1)
# viz_tri_labels = torch.median(viz_tri_labels, dim=1).values

# with torch.no_grad():
#     viz_coeffs = coeffs[(slice(None),) + vox_idx][:, None, None, None]
#     viz_mask = mask[(slice(None),) + vox_idx][:, None, None, None]
#     # Function applies non-negativity constraint.
#     viz_sphere_samples = pitn.odf.sample_sphere_coords(
#         viz_coeffs,
#         theta=viz_theta,
#         phi=viz_phi,
#         sh_order=8,
#         sh_order_dim=0,
#         mask=viz_mask,
#     )

# viz_fodf = np.copy(viz_sphere_samples.detach().cpu().numpy().flatten())
# viz_tri_labels = viz_tri_labels.detach().cpu().numpy().flatten()

# viz_theta = viz_theta.detach().cpu().numpy().flatten()
# viz_phi = viz_phi.detach().cpu().numpy().flatten()
# directions, values, indices = dipy.direction.peak_directions(
#     viz_fodf, viz_sphere, relative_peak_threshold=0.5, min_separation_angle=25
# )
# # viz_fodf[viz_fodf < values.min() * 0.3] = 1e-8
# with mpl.rc_context({"figure.autolayout": False}):
#     fig = plt.figure(dpi=120)

#     ax = fig.add_subplot(projection="3d")

#     vals = viz_fodf

#     r = (vals - vals.min()) / (vals - vals.min()).max()
#     r = vals / vals.sum()

#     x = r * np.sin(viz_theta) * np.cos(viz_phi)
#     y = r * np.sin(viz_theta) * np.sin(viz_phi)
#     z = r * np.cos(viz_theta)
#     mapper = mpl.cm.ScalarMappable(cmap=label_cmap)

#     # center_colors = mapper.to_rgba(viz_tri_labels)
#     # center_colors = np.where((viz_tri_labels == 0)[:, None], np.zeros_like(center_colors), center_colors)
#     #     vertex_colors = mapper.to_rgba(ll.numpy().flatten()[polar_tri.edges])
#     # polar_tri_idx = torch.from_numpy(polar_tri.triangles).long()
#     # viz_tri_labels = ll.flatten()[polar_tri_idx]
#     euclid_tri = mpl.tri.Triangulation(x, y, triangles=polar_tri.triangles)
#     surf = ax.plot_trisurf(euclid_tri, z, linewidth=0.3, antialiased=True, zorder=4)

#     if to_plot == "surface":
#         face_colors = mapper.to_rgba(viz_tri_labels)
#         face_colors = np.where(
#             (viz_tri_labels == 0)[:, None], np.zeros_like(face_colors), face_colors
#         )
#         surf.set_fc(face_colors)
#     elif to_plot == "points":
#         surf.set_fc("white")
#         surf.set_edgecolors([0.2, 0.2, 0.2, 0.5])
#         if post_seg_filter:
#             point_colors = mapper.to_rgba(ll.numpy().flatten())
#             point_colors = np.where(
#                 (ll.numpy().flatten() == 0)[:, None],
#                 np.zeros_like(point_colors),
#                 point_colors,
#             )
#         else:
#             point_colors = mapper.to_rgba(lobe_labels.numpy().flatten())
#             point_colors = np.where(
#                 (lobe_labels.numpy().flatten() == 0)[:, None],
#                 np.zeros_like(point_colors),
#                 point_colors,
#             )
#         ax.scatter3D(x, y, z, c=point_colors, s=20, zorder=0.1)
#     plt.colorbar(mapper, shrink=0.67)
#     plt.show()