In [None]:
import numpy as np

import os
import os.path as op

import matplotlib.pyplot as plt
from  nilearn import plotting

import nibabel as nib

import importlib

import interpolation as interp

In [None]:
importlib.reload(interp)

path_to_der = "/home/acionca/Documents/data/hcph-template/multivar-v00/derivatives/"
der_name = "histomatch"

path_to_dir = op.join(path_to_der, der_name)

resolution = 0.4
mni_grid = interp.generate_MNI_grid(resolution)#, save_path="/home/acionca/Documents/data/hcph-template/mni_template")
#one_img_grid = ImageGrid(t1w_files_wDir[0])

n_coords_to_print = 4
print(f"First {n_coords_to_print} indices are:\n {mni_grid.ndindex[:, :n_coords_to_print].T}")
print(f"First {n_coords_to_print} coordinates are:\n {mni_grid.ndcoords[:, :n_coords_to_print].T}")

#transform_dir = "rigid"
transform_dir = "ANTs_iteration_0"
path_to_transforms = op.join(path_to_dir, transform_dir)

transforms_files = [op.join(path_to_transforms, file) for file in os.listdir(path_to_transforms) if "Affine" in file and "template" not in file]
transforms_files = sorted(transforms_files)

print(transforms_files[:3])

#anat_files = interp.get_anat_filenames(path_to_dir)
anat_files = interp.get_anat_filenames(op.join(path_to_der, "smooth1.2"))

#affine_transforms = interp.get_transforms(transforms_files, anat_files)
affine_transforms = interp.get_transforms(transforms_files, [mni_grid]*len(transforms_files))
print(f"{len(affine_transforms)} transforms found")
print(f"{len(anat_files)} anat_files found")

#np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})

print(affine_transforms[0])
print(affine_transforms[0].reference)

In [None]:
importlib.reload(interp)

batch_limit = None

n_batches = 50
n_sub = -1
weight = True
dist_kernel_order = 1

if n_sub > 0:
    affine_transforms_sub = affine_transforms[:n_sub]
    anat_files_sub = anat_files[:n_sub]
else:
    affine_transforms_sub = affine_transforms
    anat_files_sub = anat_files

suffix = f"N{len(anat_files_sub)}"
suffix += weight*f"DisWei{dist_kernel_order}"
suffix += "AllInRes_"

interpolated_map = interp.distance_weighted_interpolation(
    mni_grid,
    anat_files_sub,
    affine_transforms_sub,
    n_batches=n_batches,
    weight=weight,
    normalize=True,
    interpolate=True,
    batch_limit=batch_limit,
    dist_kernel_order=dist_kernel_order,
    spline_order=3,
    n_jobs=10
)

In [None]:
n_slices = 2
slices = np.linspace(0, interpolated_map.shape[-1], n_slices+3, dtype=int)

if resolution == 1:
    ZOOM1 = slice(50,100)
    ZOOM2 = slice(100,150)
if resolution == 0.8:
    ZOOM1 = slice(120,180)
    ZOOM2 = slice(120,180)
if resolution == 0.6:
    ZOOM1 = slice(150,250)
    ZOOM2 = slice(100,200)
if resolution == 0.4:
    ZOOM1 = slice(150,250)
    ZOOM2 = slice(100,200)
if resolution == 0.3:
    ZOOM1 = slice(200,300)
    ZOOM2 = slice(200,300)

fig, axes = plt.subplots(nrows=1, ncols=n_slices+1, figsize=(18, 18))

for ax, i in zip(axes[:-1], slices[1:-2]):
    ax.set_title(f"z = {i}")
    ax.imshow(interpolated_map[..., i], cmap="binary_r")#, vmax=120)

axes[-1].set_title(f"z = {i} - zoom")
axes[-1].imshow(interpolated_map[..., i][ZOOM1][:, ZOOM2], cmap="binary_r")#, vmax=120)

In [None]:
dw_interp = nib.Nifti1Image(interpolated_map.astype(np.int16), affine=mni_grid.affine)
#dw_interp.header.set_data_dtype("float32")
dw_interp.header.set_data_dtype("int16")
fname = f"distance_weighted_template_res-{resolution}_desc-{suffix}smth_T1w.nii.gz"
print(fname)
saveloc = "/home/acionca/Documents/data/hcph-template/multivar-v00/derivatives/diswe_interpolation"
dw_interp.to_filename(op.join(saveloc, fname))

In [None]:
## CONCAT IMAGES
from nilearn.image import concat_imgs, smooth_img
from skimage.filters import gaussian


fwhm = 1.2
sigma = fwhm / np.sqrt(8 * np.log(2))

path_to_smooth = f"/home/acionca/Documents/data/hcph-template/multivar-v00/derivatives/smooth{fwhm}"
os.makedirs(path_to_smooth, exist_ok=True)

for file in anat_files:
    img = nib.load(file)
    # nilearn_smoothed = smooth_img(moving_img, fwhm=fwhm)
    gauss_smoothed = gaussian(img.get_fdata(), sigma=sigma, mode="constant")

    smth_img = nib.Nifti1Image(gauss_smoothed.astype(np.int16), affine=img.affine)
    smth_img.to_filename(op.join(path_to_smooth, op.basename(file)))


#for modality in ["T1w", "T2w"]:
#    files = interp.get_anat_filenames(path_to_dir, pattern="WarpedToTemplate",
#                                      modality_filter=[modality], template_prefix=".mat")
#    movie_name = f"template_input_res-{resolution}_concat_{modality}.nii.gz"
#    concat_imgs(files).to_filename(op.join(path_to_movie, movie_name))


In [None]:
import matplotlib.animation as animation
from nilearn.plotting import plot_anat

path_to_der = "/home/acionca/Documents/data/hcph-template/multivar-v00/derivatives/"
der_name = "histomatch"

path_to_dir = op.join(path_to_der, der_name)

path_to_movie = "/home/acionca/Documents/data/hcph-template"

filest1w = interp.get_anat_filenames(path_to_dir, pattern="WarpedToTemplate",
                                      modality_filter=["T1w"], template_prefix=".mat")
filest2w = interp.get_anat_filenames(path_to_dir, pattern="WarpedToTemplate",
                                      modality_filter=["T2w"], template_prefix=".mat")

imgt1w = [nib.load(op.join(path_to_dir, file)) for file in filest1w]
imgt2w = [nib.load(op.join(path_to_dir, file)) for file in filest2w]

n_frames = len(imgt1w)

CUT_COORDS = (15, -17, 8)

fig, axes = plt.subplots(nrows=2, figsize=(18, 12))
def animate(i):
    #plot_anat(img.slicer[..., i], cut_coords=CUT_COORDS, axes=ax, vmin=0, vmax=300,
    #          title=f"img no {i+1:02d}")

    plot_anat(imgt1w[i], cut_coords=CUT_COORDS, axes=axes[0], vmin=0, vmax=300,
              title=op.basename(filest1w[i]))
    
    plot_anat(imgt2w[i], cut_coords=CUT_COORDS, axes=axes[1], vmin=0, vmax=300,
              title=op.basename(filest2w[i]))

ani = animation.FuncAnimation(fig, animate, repeat=False, frames=n_frames, interval=500)

## To save the animation using Pillow as a gif
writer = animation.PillowWriter(fps=2, bitrate=128000)
ani.save(op.join(path_to_movie, "warped_movie_inref.gif"), writer=writer)

In [None]:
importlib.reload(interp)

batch_limit = None

n_batches = 50
n_sub = -1
weight = True

if n_sub > 0:
    affine_transforms_sub = affine_transforms[:n_sub]
    anat_files_sub = anat_files[:n_sub]
else:
    affine_transforms_sub = affine_transforms
    anat_files_sub = anat_files

distance_map = interp.get_distance_map(
    mni_grid,
    anat_files_sub,
    affine_transforms_sub,
    n_batches=n_batches,
    batch_limit=batch_limit,
    n_jobs=12,
    spline_order=0,
)

In [None]:
MAX_DIST = np.sqrt(3 * (0.5 ** 2))

#n_plots = np.ceil(np.sqrt(distance_map.shape[0])).astype(int)
slice_id = 100

#fig, axes = plt.subplots(nrows=n_plots, ncols=n_plots, figsize=(16, 16))
distance_map_red = distance_map#[:4]
fig, axes = plt.subplots(nrows=len(distance_map_red), ncols=3, figsize=(15, 5*len(distance_map_red)))

for ax, dist in zip(axes[:, 0], distance_map_red):
    ax.imshow(dist[..., slice_id])

for i, kernel_order in enumerate([1, 2]):
    norm_distances = interp.normalize_distances(distance_map_red, dist_kernel_order=kernel_order)
    for ax, dist in zip(axes[:, i+1], norm_distances):
        ax.imshow(dist[..., slice_id])

In [None]:
k_order = [1, 2, 3]
#normalized = [interp.normalize_distances(distance_map, i) for i in k_order]

In [None]:
fig, axes = plt.subplots(nrows=2, figsize=(10, 10), sharex=True)
fig.subplots_adjust(hspace=0)

k_order = [1, 2, 3]
#for dmap, k in zip(normalized, k_order):
for k in k_order:
    dmap = interp.normalize_distances(distance_map, k)
    subsample = np.random.choice(np.arange(len(distance_map.flatten())), 1000)
    axes[0].scatter(distance_map.flatten()[subsample], dmap.flatten()[subsample], alpha=0.3, label=f"Kernel order = {k}")

axes[0].set_ylabel("Weight")
axes[0].legend()

axes[1].hist(distance_map.flatten(), bins=20, label="Distances")
axes[1].set_xticks(np.arange(0, 1, 0.1))
axes[1].set_ylabel("Count")
axes[1].set_xlabel("Distance")

In [None]:
importlib.reload(interp)

batch_limit = None

n_batches = 50
n_sub = -1
weight = True

if n_sub > 0:
    affine_transforms_sub = affine_transforms[:n_sub]
    anat_files_sub = anat_files[:n_sub]
    suffix = f"subsample{n_sub}_"
else:
    affine_transforms_sub = affine_transforms
    anat_files_sub = anat_files
    suffix = ""

suffix += weight*"DisWei_"

d_map, r_map = interp.get_individual_map(
    mni_grid,
    anat_files_sub,
    affine_transforms_sub,
    n_batches=n_batches,
    batch_limit=batch_limit,
    spline_order=0,
    n_jobs=12
)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(12, 6))

norm_dist = interp.normalize_distances(d_map[:-1], dist_kernel_order=2)
weighted_img = r_map[:-1].__mul__(norm_dist)

print(len(weighted_img))

axes[0].imshow(weighted_img.sum(axis=0)[..., slice_id])
axes[1].imshow(weighted_img.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(18, 18))
slice_id = 100

#norm_dist = interp.normalize_distances(d_map[-1], dist_kernel_order=2)
#weighted_img = r_map[-1].__mul__(norm_dist)

axes[0, 0].imshow(r_map[-1][..., slice_id][50:200][:, 50:250])
axes[0, 1].imshow(r_map[-1][..., slice_id][ZOOM1][:, ZOOM2])
axes[1, 0].imshow(d_map[-1][..., slice_id], vmin=0, vmax=1)
axes[1, 1].imshow(d_map[-1][..., slice_id][ZOOM1][:, ZOOM2], vmin=0, vmax=1)
#axes[2].set_title(f"Using {len(weighted_img)} imgs")
#axes[2].imshow(weighted_img.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

In [None]:
fig, axes = plt.subplots(nrows=d_map.shape[0], ncols=3, figsize=(18, 6*d_map.shape[0]))
slice_id = 100

for i, axes_row in enumerate(axes[:-1]):
    if i > 30:
        axes_row[0].imshow(r_map[:i+1].mean(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
        norm_dist = interp.normalize_distances(d_map[:i+1], dist_kernel_order=2)
        axes_row[1].imshow(norm_dist[i][..., slice_id][ZOOM1][:, ZOOM2], vmin=0, vmax=1)
        weighted_img = r_map[:i+1].__mul__(norm_dist)
        axes_row[2].set_title(f"Using {len(weighted_img)} imgs")
        axes_row[2].imshow(weighted_img.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

diff_img = weighted_img.sum(axis=0) - r_map.mean(axis=0)
axes[-1, 0].imshow(weighted_img.sum(axis=0)[..., slice_id])
axes[-1, 1].imshow(diff_img[..., slice_id], vmin=-10, vmax=10, cmap="coolwarm")
axes[-1, 2].imshow(diff_img[..., slice_id][ZOOM1][:, ZOOM2], vmin=-10, vmax=10, cmap="coolwarm")

del weighted_img

#axes[-1, 0].imshow(r_map.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
#axes[-1, 1].imshow(norm_dist.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
#axes[-1, 2].imshow(np.array(weighted).sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

In [None]:
norm_dist = interp.normalize_distances(d_map, dist_kernel_order=2)

In [None]:
fig, axes = plt.subplots(nrows=d_map.shape[0]+1, ncols=3, figsize=(12, 5*d_map.shape[0]))
slice_id = 100

weighted = []
for i, axes_row in enumerate(axes[:-1]):
    axes_row[0].imshow(r_map[i][..., slice_id][ZOOM1][:, ZOOM2])
    axes_row[1].imshow(d_map[i][..., slice_id][ZOOM1][:, ZOOM2], vmin=0, vmax=1)
    weighted_img = r_map[i].__mul__(norm_dist[i])
    weighted.append(weighted_img)
    axes_row[2].imshow(weighted_img[..., slice_id][ZOOM1][:, ZOOM2])

axes[-1, 0].imshow(r_map.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
d_imshow = axes[-1, 1].imshow(norm_dist.sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
axes[-1, 2].imshow(np.array(weighted).sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])

plt.colorbar(d_imshow)

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(17, 6))

axes[0].imshow(r_map.mean(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
axes[1].imshow(np.array(weighted).sum(axis=0)[..., slice_id][ZOOM1][:, ZOOM2])
diff_map = r_map.mean(axis=0) - np.array(weighted).sum(axis=0)
cbar = axes[2].imshow(diff_map[..., slice_id][ZOOM1][:, ZOOM2], vmin=-10, vmax=10, cmap="coolwarm")
plt.colorbar(cbar)

In [None]:
importlib.reload(interp)
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 12))

np.random.seed(220367)

n_samples = 35

x_data = np.arange(0, 2*np.pi, 0.01)
data_vect = np.cos(x_data)

axes[0].plot(x_data, data_vect, label="Ground truth", linewidth=4)

noise = np.random.normal(np.zeros((n_samples, len(x_data))), .1)

noisy_data = data_vect + noise

axes[0].fill_between(x_data, noisy_data.min(axis=0), noisy_data.max(axis=0), alpha = .5,
                     color="tab:orange", label="Noisy data")
axes[0].plot(x_data, noisy_data.mean(axis=0), label="Average noisy data", linewidth=2, alpha=.8)

distances = np.abs(noise)

norm_distances = interp.normalize_distances(distances, 1, limits=5)

for dist, n_dist in zip(distances, norm_distances):
    axes[1].scatter(dist, n_dist, alpha=.2)

interpolated_data = (noisy_data * norm_distances).sum(axis=0)
axes[0].plot(x_data, interpolated_data, label="Interpolated data", linewidth=2, alpha=.8)

axes[0].legend()