In [None]:
import numpy as np

import os
import os.path as op

import matplotlib.pyplot as plt

import nibabel as nib

import importlib

import interpolation as interp

The first step is to generate the high-definition grid (defined from the *MNI template* resampled here at `resolution = 0.6` mm), load the images to interpolate (located in `path_to_dir`) as well as the corresponding transforms (located in the `transform_dir` subfolder of `path_to_dir`).

In [None]:
importlib.reload(interp)

#path_to_der = "/home/acionca/Documents/data/hcph-template/multivar-v00/derivatives/"
#der_name = "histomatch"

path_to_data="/Users/acionca/data"
path_to_der = op.join(path_to_data, "hcph-template/multivar-v00/derivatives/")
der_name = "allInRef"

path_to_dir = op.join(path_to_der, der_name)

resolution = 0.5
#mni_grid = interp.generate_MNI_grid(resolution)

path_to_initial_template = op.join(path_to_der, der_name, "A_tpl_template0.nii.gz")
mni_grid = interp.generate_grid_from_img(path_to_initial_template, resolution)
print(f"Reference grid at resolution {resolution}mm has shape: {mni_grid.shape}")

transform_dir = "ANTs_iteration_2"
path_to_transforms = op.join(path_to_dir, transform_dir)

transforms_files = [op.join(path_to_transforms, file) for file in os.listdir(path_to_transforms) if "Affine" in file and "template" not in file]

exclude = ["ses-017", "ses-pilot019"]
for excl in exclude:
    transforms_files = [file for file in transforms_files if excl not in file]

transforms_files = sorted(transforms_files)

anat_files = interp.get_anat_filenames(path_to_dir, pattern=".nii.gz")
for excl in exclude:
    anat_files = [file for file in anat_files if excl not in file]
affine_transforms = interp.get_transforms(transforms_files, [mni_grid]*len(transforms_files))

anat_files = anat_files[:5]
affine_transforms = affine_transforms[:5]

print(f"{len(affine_transforms)} transforms found")
print(f"{len(anat_files)} anat_files found")

Now we can compute the resampled images as well as the distance maps.

In [None]:
importlib.reload(interp)

# To reduce memmory load and to parallelize the computation, the voxel indices
# of the high-resolution grid are separated into `n_batches` batches.
n_batches = 1000
# This decides if we want to `weight` the interpolation using the projected distances.
weight = True
# This kernel is applied to the distances to give more weight to smaller values (see example)
dist_kernel_order = 1
# This is the order of the BSpline interpolation of the target images (usually 3 for cubic BSpline)
spline_order = 3
# The number of jobs to use for parallel execution (using *Joblib*)
n_jobs = 25

d_map, r_map = interp.get_individual_map(
    mni_grid,
    anat_files,
    affine_transforms,
    n_batches=n_batches,
    spline_order=spline_order,
    n_jobs=n_jobs
)

In [None]:
d_map_backup = d_map.copy()

We can have a look a the distance maps as well as the influence of normalization.

In [None]:
importlib.reload(interp)

MAX_DIST = np.sqrt(3 * (0.5 ** 2))

slice_id = 100
n_maps_to_show = 4

distance_map_red = d_map[:n_maps_to_show]
fig, axes = plt.subplots(nrows=len(distance_map_red), ncols=3, figsize=(15, 5*len(distance_map_red)))

for ax, dist in zip(axes[:, 0], distance_map_red):
    ax.set_title("Distance map")
    ax.imshow(dist[..., slice_id])

for i, kernel_order in enumerate([1, 2, 5, 10]):
    norm_distances = interp.normalize_distances(distance_map_red, dist_kernel_order=kernel_order, offset=0.1)
    for ax, dist in zip(axes[:, i+1], norm_distances):
        ax.set_title(f"Normalized distance map\n(order = {kernel_order})")
        ax.imshow(dist[..., slice_id])

Here we show the influence of the BSpline kernel order to the distance weights.

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 10), sharex=True)
fig.subplots_adjust(hspace=0)


k_order = [1, 2, 3]
subsample = np.random.choice(np.arange(len(d_map.flatten())), 100)
for i, offset in enumerate([0, 0.1]):
    for k in k_order:
        dmap = interp.normalize_distances(d_map, k, offset=offset)
        axes[0, i].scatter(d_map.flatten()[subsample], dmap.flatten()[subsample], alpha=0.3, label=f"Kernel order = {k}")

    axes[0, i].set_ylabel("Weight")
    axes[0, i].legend()

    axes[1, i].hist(d_map.flatten(), bins=20, label="Distances")
    axes[1, i].set_xticks(np.arange(0, 1, 0.1))
    axes[1, i].set_ylabel("Count")
    axes[1, i].set_xlabel("Distance")

And finally we can inspect the output of the interpolation procedure.

In [None]:
weighted_r_maps = r_map.__mul__(interp.normalize_distances(d_map, dist_kernel_order=dist_kernel_order))

interpolated_map = weighted_r_maps.sum(axis=0)

n_slices = 2
slices = np.linspace(0, interpolated_map.shape[-1], n_slices+3, dtype=int)

if resolution >= 1:
    ZOOM1 = slice(50,100)
    ZOOM2 = slice(100,150)
elif resolution >= 0.8:
    ZOOM1 = slice(120,180)
    ZOOM2 = slice(120,180)
elif resolution >= 0.4:
    ZOOM1 = slice(150,250)
    ZOOM2 = slice(100,200)
else: # resolution < 0.4:
    ZOOM1 = slice(200,300)
    ZOOM2 = slice(200,300)

vmax = 600

fig, axes = plt.subplots(nrows=1, ncols=n_slices+1, figsize=(18, 18))

for ax, i in zip(axes[:-1], slices[1:-2]):
    ax.set_title(f"Interpolated map: z = {i}")
    ax.imshow(interpolated_map[..., i], cmap="binary_r", vmin=0, vmax=vmax)

axes[-1].set_title(f"Zoom: z = {i}")
axes[-1].imshow(interpolated_map[..., i][ZOOM1][:, ZOOM2], cmap="binary_r", vmin=0, vmax=vmax)

And finally save the data as a compressed `.nii` file.

In [None]:
dw_interp = nib.Nifti1Image(interpolated_map.astype(np.int16), affine=mni_grid.affine)
#dw_interp.header.set_data_dtype("float32")
dw_interp.header.set_data_dtype("int16")

suffix = f"N{len(anat_files)}"
suffix += weight*f"DisWei{dist_kernel_order}"
suffix += "AllInRes_"

fname = f"distance_weighted_template_res-{resolution}_desc-{suffix}T1w.nii.gz"
print(fname)
saveloc = op.join(path_to_data, "hcph-template/multivar-v00/derivatives/diswe_interpolation")
dw_interp.to_filename(op.join(saveloc, fname))

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(12, 6))

norm_dist = interp.normalize_distances(d_map[:-1], dist_kernel_order=2)
weighted_img = r_map[:-1].__mul__(norm_dist)

print(len(weighted_img))

axes[0].imshow(weighted_img.sum(axis=0)[..., slice_id])
axes[1].imshow(weighted_img.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(18, 18))
slice_id = 100

#norm_dist = interp.normalize_distances(d_map[-1], dist_kernel_order=2)
#weighted_img = r_map[-1].__mul__(norm_dist)

axes[0, 0].imshow(r_map[-1][..., slice_id][50:200][:, 50:250])
axes[0, 1].imshow(r_map[-1][..., slice_id][ZOOM1][:, ZOOM2])
axes[1, 0].imshow(d_map[-1][..., slice_id], vmin=0, vmax=1)
axes[1, 1].imshow(d_map[-1][..., slice_id][ZOOM1][:, ZOOM2], vmin=0, vmax=1)
#axes[2].set_title(f"Using {len(weighted_img)} imgs")
#axes[2].imshow(weighted_img.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

In [None]:
fig, axes = plt.subplots(nrows=d_map.shape[0], ncols=3, figsize=(18, 6*d_map.shape[0]))
slice_id = 100

for i, axes_row in enumerate(axes[:-1]):
    if i > 30:
        axes_row[0].imshow(r_map[:i+1].mean(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
        norm_dist = interp.normalize_distances(d_map[:i+1], dist_kernel_order=2)
        axes_row[1].imshow(norm_dist[i][..., slice_id][ZOOM1][:, ZOOM2], vmin=0, vmax=1)
        weighted_img = r_map[:i+1].__mul__(norm_dist)
        axes_row[2].set_title(f"Using {len(weighted_img)} imgs")
        axes_row[2].imshow(weighted_img.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

diff_img = weighted_img.sum(axis=0) - r_map.mean(axis=0)
axes[-1, 0].imshow(weighted_img.sum(axis=0)[..., slice_id])
axes[-1, 1].imshow(diff_img[..., slice_id], vmin=-10, vmax=10, cmap="coolwarm")
axes[-1, 2].imshow(diff_img[..., slice_id][ZOOM1][:, ZOOM2], vmin=-10, vmax=10, cmap="coolwarm")

del weighted_img

#axes[-1, 0].imshow(r_map.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
#axes[-1, 1].imshow(norm_dist.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
#axes[-1, 2].imshow(np.array(weighted).sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

In [None]:
norm_dist = interp.normalize_distances(d_map, dist_kernel_order=2)

In [None]:
fig, axes = plt.subplots(nrows=d_map.shape[0]+1, ncols=3, figsize=(12, 5*d_map.shape[0]))
slice_id = 100

weighted = []
for i, axes_row in enumerate(axes[:-1]):
    axes_row[0].imshow(r_map[i][..., slice_id][ZOOM1][:, ZOOM2])
    axes_row[1].imshow(d_map[i][..., slice_id][ZOOM1][:, ZOOM2], vmin=0, vmax=1)
    weighted_img = r_map[i].__mul__(norm_dist[i])
    weighted.append(weighted_img)
    axes_row[2].imshow(weighted_img[..., slice_id][ZOOM1][:, ZOOM2])

axes[-1, 0].imshow(r_map.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
d_imshow = axes[-1, 1].imshow(norm_dist.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
axes[-1, 2].imshow(np.array(weighted).sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

plt.colorbar(d_imshow)

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(17, 6))

axes[0].imshow(r_map.mean(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
axes[1].imshow(np.array(weighted).sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
diff_map = r_map.mean(axis=0) - np.array(weighted).sum(axis=0)
cbar = axes[2].imshow(diff_map[..., slice_id][ZOOM1][:, ZOOM2], vmin=-10, vmax=10, cmap="coolwarm")
plt.colorbar(cbar)