In [None]:
import numpy as np

import os
import os.path as op

import matplotlib.pyplot as plt

import nibabel as nib

import importlib

import interpolation as interp

In [None]:
importlib.reload(interp)

#path_to_der = "/home/acionca/Documents/data/hcph-template/multivar-v00/derivatives/"
#der_name = "histomatch"

path_to_data="/Users/acionca/data"
path_to_der = op.join(path_to_data, "hcph-template/multivar-v00/derivatives/")
der_name = "allInRef"

path_to_dir = op.join(path_to_der, der_name)

resolution = 0.2
#mni_grid = interp.generate_MNI_grid(resolution)

path_to_initial_template = op.join(path_to_der, der_name, "A_tpl_template0.nii.gz")
mni_grid = interp.generate_grid_from_img(path_to_initial_template, resolution)
print(f"Reference grid at resolution {resolution}mm has shape: {mni_grid.shape}")

transform_dir = "ANTs_iteration_2"
path_to_transforms = op.join(path_to_dir, transform_dir)

transforms_files = [op.join(path_to_transforms, file) for file in os.listdir(path_to_transforms) if "Affine" in file and "template" not in file]

exclude = ["ses-017", "ses-pilot019", "ses-pilot021"]
for excl in exclude:
    transforms_files = [file for file in transforms_files if excl not in file]

transforms_files = sorted(transforms_files)

anat_files = interp.get_anat_filenames(path_to_dir, pattern=".nii.gz")
for excl in exclude:
    anat_files = [file for file in anat_files if excl not in file]
affine_transforms = interp.get_transforms(transforms_files, [mni_grid]*len(transforms_files))

anat_files = anat_files
affine_transforms = affine_transforms

print(f"{len(affine_transforms)} transforms found")
print(f"{len(anat_files)} anat_files found")

In [None]:
importlib.reload(interp)
from tqdm import tqdm

# To reduce memory load and to parallelize the computation, the voxel indices
# of the high-resolution grid are separated into `n_batches` batches.
n_batches = 1000
# This decides if we want to `weight` the interpolation using the projected distances.
weight = True
# This kernel is applied to the distances to give more weight to smaller values (see example)
dist_kernel_order = 1
# This is the order of the BSpline interpolation of the target images (usually 3 for cubic BSpline)
spline_order = 3
# The number of jobs to use for parallel execution (using *Joblib*)
n_jobs = 25

fixed_img = mni_grid
moving_list = anat_files
transforms = affine_transforms

# Resolution 0.5mm
#x_ex = [260, 290]
#y_ex = [250, 275]
#z_ex = [295, 305]

# Resolution 0.2mm
x_ex = [650, 725]
y_ex = [690, 730]
z_ex = [715, 750]

x_coords = np.arange(x_ex[0], x_ex[1])
y_coords = np.arange(y_ex[0], y_ex[1])
z_coords = np.arange(z_ex[0], z_ex[1])

indices = np.array([[x, y, z] for x in x_coords for y in y_coords for z in z_coords])
batches = interp.batch_handler(indices, n_batches=1)

interpolated_array = np.zeros(fixed_img.shape)
for batch_indices in tqdm(batches):
    interpolated_values = interp.interpolate_from_indices(
            batch_indices,
            fixed_img,
            moving_list,
            transforms,
            weight=True,
            normalize=True,
            interpolate=True,
            dist_kernel_order=2)
    
    interpolated_array[tuple(batch_indices.T)] = interpolated_values

#distances_array = np.zeros((len(transforms), *fixed_img.shape))
#images_array = np.zeros_like(distances_array, dtype=int)
#for batch_indices in tqdm(batches):
#    val_and_dist = interp.get_sample_val_and_dist(batch_indices, fixed_img, moving_list, transforms, weight=True, interpolate=True)
#    
#    images_array[:, batch_indices.T[0], batch_indices.T[1], batch_indices.T[2]] = val_and_dist[0]
#    
#    distances_array[:, batch_indices.T[0], batch_indices.T[1], batch_indices.T[2]] = (
#        interp.normalize_distances(val_and_dist[1], dist_kernel_order=1)
#    )

In [None]:
sliced_array = interpolated_array[x_ex[0]:x_ex[1]][:, y_ex[0]:y_ex[1]][:, :, z_ex[0]:z_ex[1]]
sliced_img = nib.Nifti1Image(sliced_array, affine=mni_grid.affine)

sliced_img.orthoview()
#fig, axes = plt.subplots(nrows=len(sliced_array), figsize=(10, 5*len(sliced_array)))
#for i, myslice in enumerate(sliced_array):
#    axes[i].imshow(myslice, cmap="binary_r")