# Comparison: Unprocessed HCP DWI Against Clinically-Acquired DWI

Tyler Spears

Dr. Tom Fletcher

## Setup

In [None]:
import itertools
import functools
import pathlib
from pathlib import Path
import os
import io
import subprocess

import numpy as np
import scipy
import skimage
import skimage.filters
import dipy
import dipy.align, dipy.align.imaffine, dipy.viz, dipy.viz.regtools, dipy.segment, dipy.segment.mask
import ants
import nibabel as nib
import dotenv
import box
from box import Box

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import seaborn as sns

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash

### Locate and Load Data Files

In [None]:
# Set up directories
# HCP directory
hcp_data_dir = pathlib.Path(os.environ["DATA_DIR"]) / "hcp"
assert hcp_data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"]) / "hcp"
assert write_data_dir.exists()

# Clinical directory
clinical_data_dir = Path("/mnt/storage/data/pitn/uva")
assert clinical_data_dir.exists()

In [None]:
# Import files
# HCP
# Subjects of interest.
hcp_subj_ids = [
    "140117",
]
hcp = Box(default_box=True)

# Loop over subjects
for subj_id in hcp_subj_ids:
    subj_dir = hcp_data_dir / str(subj_id)
    unproc_dir = subj_dir / "unprocessed/3T/Diffusion"
    # Loop over each set of scans in the subject.
    for bval_scan_name in unproc_dir.glob("*.bval"):
        scan_name = bval_scan_name.stem
        bval_file = unproc_dir / (scan_name + ".bval")
        bvec_file = unproc_dir / (scan_name + ".bvec")
        nifti_file = unproc_dir / (scan_name + ".nii.gz")
        hcp[subj_id][scan_name].bval = np.loadtxt(bval_file)
        hcp[subj_id][scan_name].bvec = np.loadtxt(bvec_file)
        hcp[subj_id][scan_name].dwi = nib.load(nifti_file)

# Clinical
clinic_subj_ids = [
    "001",
]
clinic = Box(default_box=True)

# Loop over subjects
for subj_id in clinic_subj_ids:
    subj_dir = clinical_data_dir / str(subj_id)

    nifti_file = subj_dir / "sub-001_ses-01_run-2_dwi.nii.gz"
    bval_file = subj_dir / "sub-001_ses-01_run-2_dwi.bval"
    bvec_file = subj_dir / "sub-001_ses-01_run-2_dwi.bvec"
    clinic[subj_id].bval = np.loadtxt(bval_file)
    clinic[subj_id].bvec = np.loadtxt(bvec_file)
    clinic[subj_id].dwi = nib.load(nifti_file)

### Calculate Masks

In [None]:
def rough_b0_mask(dwi):
    #     thresh = skimage.filters.threshold_triangle(dwi)
    #     mask = img >= thresh
    #     mask = np.logical_xor(
    #         mask, skimage.morphology.white_tophat(mask, selem=skimage.morphology.ball(3))
    #     )
    #     # mask = skimage.morphology.binary_opening(mask, selem=skimage.morphology.ball(3))
    #     mask = skimage.morphology.remove_small_holes(mask, 20 ** 3)
    mask = dipy.segment.mask.median_otsu(dwi, dilate=5)[1]

    return mask

In [None]:
# Calculate HCP masks using b0 mask averages, for each scan and over the entire subject.
for subj_id in hcp_subj_ids:
    subject_mask = np.zeros(list(hcp[subj_id].values())[0].dwi.shape[:-1], dtype=bool)
    for i_scan, scan in hcp[subj_id].items():
        scan_mask = np.zeros(scan.dwi.shape[:-1], dtype=bool)
        for j_b0 in np.where(scan.bval <= 100)[0]:
            img = scan.dwi.get_fdata()[..., j_b0]
            img_mask = rough_b0_mask(img)
            scan_mask = scan_mask | img_mask
        hcp[subj_id][i_scan].mask = scan_mask
        subject_mask = subject_mask | scan_mask

    hcp[subj_id].mask = subject_mask

In [None]:
# Display an example.
s = hcp[hcp_subj_ids[0]]
scan = s[list(s.keys())[0]]
img = scan.dwi.get_fdata()[..., 0]
mask = scan.mask
for i in range(3):
    dipy.viz.regtools.overlay_slices(img, mask, None, i, "HCP b0", "Mask").set_dpi(170)

In [None]:
# Calculate clinical data masks over b0 images.
for subj_id in clinic_subj_ids:
    subject_mask = np.zeros(clinic[subj_id].dwi.shape[:-1], dtype=bool)

    for i_b0 in np.where(clinic[subj_id].bval <= 100)[0]:
        img = clinic[subj_id].dwi.get_fdata()[..., j_b0]
        img_mask = rough_b0_mask(img)
        subject_mask = subject_mask | img_mask

    clinic[subj_id].mask = subject_mask

In [None]:
# Display an example.
s = clinic[clinic_subj_ids[0]]
img = s.dwi.get_fdata()[..., 0]
mask = s.mask
for i in range(3):
    dipy.viz.regtools.overlay_slices(img, mask, None, i, "Clinical b0", "Mask").set_dpi(
        170
    )

## b-values and b Vectors

### HCP Unprocessed DWI Attributes

In [None]:
# There's a stackoverflow post specifically for visualizing b vectors!
# <https://stackoverflow.com/a/63708529>

fig = plt.figure(dpi=150)
ax = fig.add_subplot(projection="3d")
bvec = list()
for scan in hcp["140117"].values():
    if isinstance(scan, dict):
        bvec.append(scan.bvec)
bvec = np.concatenate(bvec, axis=1)

# Draw a unit sphere.
u, v = np.mgrid[0 : 2 * np.pi : 50j, 0 : np.pi : 50j]
x = np.cos(u) * np.sin(v)
y = np.sin(u) * np.sin(v)
z = np.cos(v)
# alpha controls opacity
# ax.plot_surface(x, y, z, color="g", alpha=0.5)

# x, y, z = np.meshgrid(np.zeros(len(bvec[0])), np.zeros(len(bvec[0])), np.zeros(len(bvec[0])))
# u, v, w = np.meshgrid(bvec[0], bvec[1], bvec[2])
x = y = z = np.zeros(len(bvec[0]))
u, v, w = bvec[0], bvec[1], bvec[2]
# ax.quiver(x, y, z, u, v, w, length=1.0, arrow_length_ratio=0.1, alpha=0.8, lw=0.3, color='black')
ax.scatter(bvec[0], bvec[1], bvec[2], lw=0, color="black")

# Project each axis in 2D
ax.scatter(bvec[0], bvec[1], zs=-1, zdir="z", marker=".", lw=0, color="blue")
ax.scatter(bvec[0], bvec[2], zs=1, zdir="y", marker=".", lw=0, color="green")
ax.scatter(bvec[1], bvec[2], zs=-1, zdir="x", marker=".", lw=0, color="red")

ax.set_xlim(-1.0, 1.0)
ax.set_ylim(-1.0, 1.0)
ax.set_zlim(-1.0, 1.0)
ax.grid(True);

In [None]:
# Plot bvec directions as related to gradient strength (bvals)
fig = plt.figure(dpi=150)
ax = fig.add_subplot(projection="3d")
bvec = list()
bval = list()
for scan in hcp["140117"].values():
    if isinstance(scan, dict):
        bvec.append(scan.bvec)
        bval.append(scan.bval)
bvec = np.concatenate(bvec, axis=1)
bval = np.concatenate(bval)

x = y = z = np.zeros(len(bvec[0]))
u, v, w = bvec[0], bvec[1], bvec[2]
cmap = mpl.cm.inferno
colors = cmap(bval / bval.max())

# ax.quiver(x, y, z, u, v, w, length=1.0, arrow_length_ratio=0.1, lw=0.5, alpha=0.8, color=colors)
ax.scatter(bvec[0], bvec[1], bvec[2], lw=0, color=colors, depthshade=False)

ax.set_xlim(-1.0, 1.0)
ax.set_ylim(-1.0, 1.0)
ax.set_zlim(-1.0, 1.0)

ax.grid(True);

In [None]:
fig = plt.figure()
ax = fig.add_subplot()
bval = list()
for scan in hcp["140117"].values():
    if isinstance(scan, dict):
        bval.append(scan.bval)
bval = np.concatenate(bval)

ax.hist(bval, bins=len(np.unique(bval)));

In [None]:
dir_96 = hcp["140117"]["140117_3T_DWI_dir95_LR"].dwi.get_fdata()[..., 0]
dir_97 = hcp["140117"]["140117_3T_DWI_dir97_LR"].dwi.get_fdata()[..., 0]

for i in range(3):
    dipy.viz.regtools.overlay_slices(
        dir_96, dir_97, None, i, "95 Scan", "97 Scan"
    ).set_dpi(170)

### Clinical Unprocessed DWI Attributes

In [None]:
fig = plt.figure(dpi=150)
ax = fig.add_subplot(projection="3d")
bvec = clinic["001"].bvec

# Draw a unit sphere.
u, v = np.mgrid[0 : 2 * np.pi : 50j, 0 : np.pi : 50j]
x = np.cos(u) * np.sin(v)
y = np.sin(u) * np.sin(v)
z = np.cos(v)
# alpha controls opacity
# ax.plot_surface(x, y, z, color="g", alpha=0.5)

# x, y, z = np.meshgrid(np.zeros(len(bvec[0])), np.zeros(len(bvec[0])), np.zeros(len(bvec[0])))
# u, v, w = np.meshgrid(bvec[0], bvec[1], bvec[2])
x = y = z = np.zeros(len(bvec[0]))
u, v, w = bvec[0], bvec[1], bvec[2]
# ax.quiver(x, y, z, u, v, w, length=1.0, arrow_length_ratio=0.1, alpha=0.8, lw=0.3, color='black')
ax.scatter(bvec[0], bvec[1], bvec[2], color="black", lw=0)

# Project each axis in 2D
ax.scatter(bvec[0], bvec[1], zs=-1, zdir="z", marker=".", lw=0, color="blue")
ax.scatter(bvec[0], bvec[2], zs=1, zdir="y", marker=".", lw=0, color="green")
ax.scatter(bvec[1], bvec[2], zs=-1, zdir="x", marker=".", lw=0, color="red")

ax.set_xlim(-1.0, 1.0)
ax.set_ylim(-1.0, 1.0)
ax.set_zlim(-1.0, 1.0)
ax.grid(True);

In [None]:
# Plot bvec directions as related to gradient strength (bvals)
fig = plt.figure(dpi=150)
ax = fig.add_subplot(projection="3d")
bvec = clinic["001"].bvec
bval = clinic["001"].bval

x = y = z = np.zeros(len(bvec[0]))
u, v, w = bvec[0], bvec[1], bvec[2]
cmap = mpl.cm.inferno
colors = cmap(bval / bval.max())

# ax.quiver(x, y, z, u, v, w, length=1.0, arrow_length_ratio=0.1, lw=0.5, alpha=0.8, color=colors)
ax.scatter(bvec[0], bvec[1], bvec[2], lw=0, color=colors, depthshade=False)

ax.set_xlim(-1.0, 1.0)
ax.set_ylim(-1.0, 1.0)
ax.set_zlim(-1.0, 1.0)

ax.grid(True);

## Patch-Level Statistics

In [None]:
def random_sample_windows(vol, window_shape, n_patches, step=1, mask=None):
    patches = skimage.util.view_as_windows(vol, window_shape, step)
    patches = patches.reshape(-1, *window_shape)
    if mask is not None:
        mask_patches = skimage.util.view_as_windows(mask, window_shape, step)
        mask_patches = mask_patches.reshape(-1, *window_shape)
        mask_patches = np.any(
            mask_patches, axis=tuple(range(1, len(mask_patches.shape)))
        )
        patches = patches[mask_patches]

    sample_idx = np.random.choice(len(patches), n_patches, replace=False)
    samples = patches[sample_idx, ...]
    return samples

In [None]:
img = hcp["140117"]["140117_3T_DWI_dir95_LR"].dwi.get_fdata()[..., 0]
mask = hcp["140117"]["140117_3T_DWI_dir95_LR"].mask

sample = random_sample_windows(
    img,
    (9, 9, 9),
    100,
    mask=mask,
)
# sample_mean = np.mean(sample, axis=(1, 2, 3))
# sns.histplot(sample_mean, kde=True);
sample_var = np.var(sample, axis=(1, 2, 3))
sns.histplot(sample_var, kde=True);

---

Local mean downscaling is *not* the same as N-dimension interpolation:

In [None]:
# x = np.random.randint(0, 10, (10, 10)).astype(float)
# lm = skimage.transform.downscale_local_mean(x, (2, 2))
# interp = skimage.transform.rescale(x, (0.5, 0.5), anti_aliasing=True)
# lm - interp

In [None]:
# img = np.random.random((162, 190, 162))

# xs = np.arange(0 + ((2 - 1.25) / 2), (162 * 1.25), 2) / 1.25
# ys = np.arange(0 + ((2 - 1.25) / 2), (190 * 1.25), 2) / 1.25
# zs = np.arange(0 + ((2 - 1.25) / 2), (162 * 1.25), 2) / 1.25

# downsample_factor = 2 / 1.25
# target_window_size = downsample_factor

# window_size = max(
#     np.concatenate(
#         [
#             np.ceil(xs[1:]) - np.floor(xs[:-1]),
#             np.ceil(ys[1:]) - np.floor(ys[:-1]),
#             np.ceil(zs[1:]) - np.floor(zs[:-1]),
#         ]
#     ).astype(int)
# )
# print(window_size)

# g = np.stack(np.meshgrid(xs, ys, zs))

In [None]:
# %%prun

# c = 0
# g_iter = g.transpose(1, 2, 3, 0)

# full_indices = np.zeros(g.shape[1:] + (3,) + (window_size,) * 3, dtype=int)
# full_weights = np.zeros(g.shape[1:] + (window_size,) * 3, dtype=float)

# downsampled = np.zeros(g.shape[1:])

# for i in np.ndindex(*g_iter.shape[:-1]):

#     target_start_index = g_iter[i]
#     target_end_index = target_start_index + target_window_size

#     source_start_index = np.floor(target_start_index).astype(int)
#     source_end_index = source_start_index + window_size

#     ranges = list()
#     for i_start, i_end in zip(source_start_index, source_end_index):
#         ranges.append(np.arange(i_start, i_end))

#     window_idx = np.stack(np.meshgrid(*ranges))

#     vol_intersections = np.ones(window_idx.shape[1:])
#     for dim_i in range(window_idx[0].ndim):

#         axis_intersect = np.clip(
#             np.min(
#                 [
#                     np.broadcast_to(target_end_index[dim_i], window_idx[dim_i].shape),
#                     window_idx[dim_i] + 1,
#                 ],
#                 axis=0,
#             )
#             - np.max(
#                 [
#                     np.broadcast_to(target_start_index[dim_i], window_idx[dim_i].shape),
#                     window_idx[dim_i],
#                 ],
#                 axis=0,
#             ),
#             0,
#             np.inf,
#         )

#         vol_intersections = vol_intersections * axis_intersect

#     weights = vol_intersections / vol_intersections.sum()
#     values = img[tuple(window_idx)]
#     target_mean = (values * weights).sum()

#     downsampled[i] = target_mean
#     full_weights[i] = weights
#     full_indices[i] = window_idx

#     if i[1] == 100 and i[2] == 100:
#         print(i)

#     c += 1
#     if c > np.inf:
#         break

# full_indices = full_indices.reshape(
#     *full_indices.shape[0:3],
#     3,
#     -1,
# )
# full_indices = full_indices.transpose(3, 4, 0, 1, 2)
# full_weights = full_weights.reshape(*full_weights.shape[0:3], -1)
# full_weights = full_weights.transpose(3, 0, 1, 2)

---

In [None]:
# img = np.random.random((162, 190, 162))
# target_dim_size = 2
# source_dim_size = 1.25

# xs = (
#     np.arange(
#         0,
#         (162 * source_dim_size),
#         target_dim_size,
#     )
#     / source_dim_size
# )
# ys = (
#     np.arange(
#         0,
#         (190 * source_dim_size),
#         target_dim_size,
#     )
#     / source_dim_size
# )
# zs = (
#     np.arange(
#         0,
#         (162 * source_dim_size),
#         target_dim_size,
#     )
#     / source_dim_size
# )

# downsample_factor = target_dim_size / source_dim_size
# target_window_size = downsample_factor

# window_size = max(
#     np.concatenate(
#         [
#             np.ceil(xs[1:]) - np.floor(xs[:-1]),
#             np.ceil(ys[1:]) - np.floor(ys[:-1]),
#             np.ceil(zs[1:]) - np.floor(zs[:-1]),
#         ]
#     ).astype(int)
# )
# print(window_size)

# g = np.stack(np.meshgrid(xs, ys, zs, indexing='ij'))

In [None]:
# c = 0
# g_iter = g.transpose(1, 2, 3, 0)

# full_indices = np.zeros(g.shape[1:] + (3,) + (window_size,) * 3, dtype=int)
# full_weights = np.zeros(g.shape[1:] + (window_size,) * 3, dtype=float)

# downsampled = np.zeros(g.shape[1:])

# padded_img = np.pad(
#     img,
#     ((0, window_size), (0, window_size), (0, window_size)),
#     "constant",
#     constant_values=0.0,
# )

# for i in np.ndindex(*g_iter.shape[:-1]):

#     target_start_index = g_iter[i]
#     target_end_index = target_start_index + target_window_size

#     source_start_index = np.floor(target_start_index).astype(int)
#     source_end_index = source_start_index + window_size

#     ranges = list()
#     for i_start, i_end in zip(source_start_index, source_end_index):
#         ranges.append(np.arange(i_start, i_end))

#     window_idx = np.stack(np.meshgrid(*ranges))

#     vol_intersections = np.ones(window_idx.shape[1:])
#     for dim_i in range(window_idx[0].ndim):

#         axis_intersect = np.clip(
#             np.min(
#                 [
#                     np.broadcast_to(target_end_index[dim_i], window_idx[dim_i].shape),
#                     window_idx[dim_i] + 1,
#                 ],
#                 axis=0,
#             )
#             - np.max(
#                 [
#                     np.broadcast_to(target_start_index[dim_i], window_idx[dim_i].shape),
#                     window_idx[dim_i],
#                 ],
#                 axis=0,
#             ),
#             0,
#             np.inf,
#         )

#         vol_intersections = vol_intersections * axis_intersect

#     weights = vol_intersections / vol_intersections.sum()
#     values = padded_img[tuple(window_idx)]
#     target_mean = (values * weights).sum()

#     downsampled[i] = target_mean
#     full_weights[i] = weights
#     full_indices[i] = window_idx

#     if i[1] == 100 and i[2] == 100:
#         print(i)

#     c += 1
#     if c > np.inf:
#         break

# full_indices = full_indices.reshape(
#     *full_indices.shape[0:3],
#     3,
#     -1,
# )
# full_indices = full_indices.transpose(3, 4, 0, 1, 2)
# full_weights = full_weights.reshape(*full_weights.shape[0:3], -1)
# full_weights = full_weights.transpose(3, 0, 1, 2)

In [None]:
# skimage_downsampled = skimage.transform.downscale_local_mean(
#     img, (int(downsample_factor),) * 3, cval=0
# )
# print(np.isclose(downsampled, skimage_downsampled).all())
# print(np.isclose(downsampled, skimage_downsampled))