In [None]:
import numpy as np

import nibabel as nib
from nilearn import plotting, surface

import skimage

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import os
import os.path as op
import sys

from tqdm.notebook import tqdm

sys.path.append('../interpolation')

import interpolation as interp

In [None]:
figure_loc = op.join('..', 'figures')

os.makedirs(figure_loc, exist_ok=True)

# Histogram matching

In [None]:
path_to_data = "/Users/acionca/data"
derivative_name = "ants-t1N4bfcorr-ss"

#path_to_data = "/home/acionca/Documents/data"
#derivative_name = "ants-t1N4bfcorr-b80-noSkull"

path_to_imgs = op.join(path_to_data,"hcph-template/multivar-v00/derivatives"
                f"/{derivative_name}")

anat_files_list = interp.get_anat_filenames(path_to_imgs, pattern="corrden",
                                            exclude=["in0048", "rerun"])

mask_files_list = interp.get_anat_filenames(path_to_imgs, pattern="brainmask",
                                            exclude=["in0048", "rerun"])

path_to_mni = op.join(path_to_data, "hcph-template/mni_template/mni_template-res0.8mm.nii.gz")
ref_img = nib.load(path_to_mni)
ref_data = ref_img.get_fdata()

vmax = 255
scaled_ref_data = ref_data * vmax

path_to_normalized = op.join(path_to_data, "hcph-template/multivar-v00/derivatives"
                      f"/{derivative_name}")

norm_files_list = interp.get_anat_filenames(path_to_normalized, exclude=["in0048", "rerun"])

first_img = nib.load(anat_files_list[0])
#first_img = nib.load(norm_files_list[0])
first_data = first_img.get_fdata()
first_mask = nib.load(mask_files_list[0]).get_fdata()
first_data_masked = first_data * first_mask

In [None]:
anat_files_list_red = anat_files_list[:5]
mask_files_list_red = mask_files_list[:5]

first_img = nib.load(anat_files_list_red[0])
first_data = first_img.get_fdata()

masked = []
matched = []
img_list = []
for i, file in enumerate(tqdm(anat_files_list_red)):
    image = nib.load(file)
    mask = nib.load(mask_files_list_red[i]).get_fdata()

    masked_data = mask * image.get_fdata()
    masked.append(masked_data)
    matched_data = skimage.exposure.match_histograms(masked_data, first_data_masked)

    matched_array = np.zeros_like(masked_data)
    matched_array[mask > 0] = matched_data[mask > 0]

    matched.append(matched_array)

    matched_img = nib.Nifti1Image(matched_array, affine=image.affine, dtype="uint8")
    img_list.append(matched_img)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(12, 10))
#fig, axes = plt.subplots(nrows=len(anat_files_list_red), ncols=1, figsize=(10, 5*len(anat_files_list_red)))

percentile_to_show = 99
img_percentile = np.percentile(first_data.flatten(), percentile_to_show, method="nearest")+20
nbins=(first_data.max()/2).astype(int)

img_id = 2

# Plot reference image
hist_values, hist_bins = np.histogram(first_data_masked.flatten(), bins=nbins, range=(10, img_percentile))
axes.plot(hist_bins[:-1], hist_values, lw=10, color="silver", label="Reference")
axes.fill_betweenx(hist_values, hist_bins[:-1], color="silver")

# Plot original image
hist_values, hist_bins = np.histogram(masked[img_id].flatten(), bins=nbins, range=(10, img_percentile))
axes.plot(hist_bins[:-1], hist_values, lw=10, alpha=.6, color="red", label="Original")

# Plot matched image
hist_values, hist_bins = np.histogram(matched[img_id].flatten(), bins=nbins, range=(10, img_percentile))
axes.plot(hist_bins[:-1], hist_values, lw=10, alpha=.6, color="green", label="Matched")


axes.spines[:].set_visible(False)
axes.spines[:].set_linewidth(2)

axes.set_xticks([])
axes.set_yticks([])
axes.axis([100, 700, 0, 20000])
axes.legend(fontsize=44)
#fig.legend(loc="upper right", bbox_to_anchor=(0.9, 0.9))

#fig.savefig(op.join(figure_loc, "histogram_matching.png"), dpi=300, bbox_inches="tight")

## Some more definitions

In [None]:
path_to_images = "/Users/acionca/data/hcph-template"

original_img = nib.load(op.join(path_to_images, "multivar-v00/sub-001_ses-001_T1w.nii.gz"))
skullstrip = nib.load(op.join(path_to_images, "multivar-v00/derivatives/ants-t1N4bfcorr-ss/sub-001_ses-001_desc-noSkull_T1w.nii.gz"))
brainmask = nib.load(op.join(path_to_images, "multivar-v00/derivatives/ants-t1N4bfcorr-ss/sub-001_ses-001_desc-brainmask_T1w.nii.gz"))
n4corrden = nib.load(op.join(path_to_images, "multivar-v00/derivatives/allInRef-noskull/sub-001_ses-001_space-inRef_desc-N4corrdenhist_T1w.nii.gz"))
aligned_T2w = nib.load(op.join(path_to_images, "multivar-v00/derivatives/allInRef-noskull/sub-001_ses-001_space-inRef_desc-N4corrdenhist_T2w.nii.gz"))

path_to_surfaces = op.join(path_to_images, "template-surf/template")

In [None]:
mask_data = brainmask.get_fdata()

from scipy.ndimage import binary_dilation, binary_erosion, binary_closing

z_coords = 115
max_int = 1000

all_imgs = [original_img, skullstrip, n4corrden, aligned_T2w]

my_gray = plt.get_cmap("gray")

n_alpha = 1
my_gray._init()
alpha_vect = [0] * n_alpha + [1] * (my_gray.N + 3 - n_alpha)

print(my_gray._lut.shape)
my_gray._lut[:, -1] = alpha_vect

# Initial images

In [None]:
original_images = [nib.load(op.join(path_to_images, f"multivar-v00/sub-001_ses-001_T{i}w.nii.gz")) for i in [1, 2]]

fig, axes = plt.subplots(nrows=1, ncols=len(original_images), figsize=(12*len(original_images), 10))

for ax in axes:
    ax.axis("off")

threshs = [20, 10]
z_coords_list = [90, 205]

for i, (ax, image) in enumerate(zip(axes, original_images)):
    data = image.get_fdata()[:, :, z_coords_list[i]]

    mask_from_data = (data > threshs[i])
    #mask_from_data = binary_dilation(mask_from_data, iterations=7)
    mask_from_data = binary_closing(mask_from_data, iterations=10)
    
    #data[mask_data[:, :, z_coords] == 0] = max_int
    ax.imshow(np.abs(data)*mask_from_data, interpolation="none", cmap=my_gray, vmin=0)#, vmax=max_int)

#fig.savefig(op.join(figure_loc, "original_imgs.png"), dpi=300, bbox_inches="tight", transparent=True)

# Brain plots

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=len(all_imgs), figsize=(12*len(all_imgs), 10))

for ax in axes:
    ax.axis("off")

original_dilate = original_img.get_fdata()[:, :, 90]

mask_from_data = (original_dilate > 40)
#mask_from_data = binary_dilation(mask_from_data, iterations=7)
mask_from_data = binary_closing(mask_from_data, iterations=9)
original_dilate += 10*mask_from_data

axes[0].imshow(original_dilate*mask_from_data, cmap=my_gray, vmin=0, vmax=max_int)

for ax, image in zip(axes[1:], all_imgs[1:]):
    data = image.get_fdata()[:, :, z_coords]
    #data[mask_data[:, :, z_coords] == 0] = max_int
    ax.imshow(np.abs(data)+10*mask_data[:, :, z_coords], interpolation="none", cmap=my_gray, vmin=0, vmax=max_int)

#fig.savefig(op.join(figure_loc, "all_brains.png"), dpi=300, bbox_inches="tight", transparent=True)

# Mosaic plots

In [None]:
path_to_movie = op.join(path_to_images, "concat_img_noskull_N37.nii.gz")

n_frames = 5

movie_data = nib.load(path_to_movie).slicer[..., :n_frames].get_fdata()[:, :, z_coords]
movie_data.shape

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=n_frames, figsize=(12*n_frames, 10))

for ax in axes:
    ax.axis("off")

for i, ax in enumerate(axes):
    ax.imshow(movie_data[..., i], interpolation="none", cmap=my_gray, vmin=0, vmax=max_int)

#fig.savefig(op.join(figure_loc, "aligned_brains.png"), dpi=300, bbox_inches="tight", transparent=True)

# Plotting the first template

In [None]:
path_to_ants = op.join(path_to_images, "multivar-v00/derivatives/allInRef-noskull/intermediateTemplates")

path_to_ants_templates = [op.join(path_to_ants, f"Affine_iteration2_A_tpl_template{i}.nii.gz") for i in range(2)]

fig, axes = plt.subplots(ncols=2, figsize=(12*2, 10))

for ax in axes:
    ax.axis("off")

for i, ax in enumerate(axes):
    first_template_data = nib.load(path_to_ants_templates[i]).get_fdata()[:, :, z_coords]
    #np.abs(data)+10*mask_data[:, :, z_coords]
    ax.imshow((first_template_data+1)*mask_data[:, :, z_coords], interpolation="none", cmap=my_gray, vmin=0, vmax=22)

#fig.savefig(op.join(figure_loc, "ANTs_Templates.png"), dpi=300, bbox_inches="tight", transparent=True)

In [None]:
from skimage.filters import laplace

def laplacian_sharpen(img_array, alpha=1.0):
    laplacian_array = laplace(img_array)
    return img_array + alpha*laplacian_array

# /Users/acionca/data/hcph-template/multivar-v00/derivatives/allInRef-noskull-i4/intermediateTemplates/Affine_iteration0_A_tpl_template0.nii.gz
path_to_i4 = "/Users/acionca/data/hcph-template/multivar-v00/derivatives/allInRef-noskull-i4"
path_to_ants_templates = [op.join(path_to_i4, "intermediateTemplates", f"Affine_iteration{i}_A_tpl_template0.nii.gz") for i in range(4)]

print(path_to_ants_templates)

fig, axes = plt.subplots(ncols=len(path_to_ants_templates)-1, figsize=(12*2, 10))

for ax in axes:
    ax.axis("off")

for i, ax in enumerate(axes):
    #first_template_data = nib.load(path_to_ants_templates[i]).get_fdata()[:, :, z_coords]
    first_template_data = nib.load(path_to_ants_templates[i]).get_fdata()[:, :, z_coords]
    second_template_data = nib.load(path_to_ants_templates[i+1]).get_fdata()[:, :, z_coords]

    diff = second_template_data - first_template_data
    #diff = second_template_data/second_template_data.max() - first_template_data/first_template_data.max()
    
    max_diff = np.abs(diff).max()
    max_diff = .02
    #np.abs(data)+10*mask_data[:, :, z_coords]
    ax.imshow(diff, cmap="bwr", interpolation="none", vmin=-max_diff, vmax=max_diff)

#fig.savefig(op.join(figure_loc, "diff_w_previous_iter.png"), dpi=300, bbox_inches="tight", transparent=True)

# Plotting the final templates

In [None]:
path_to_ants = op.join(path_to_images, "multivar-v00/derivatives/allInRef-noskull-i4/intermediateTemplates")
path_to_diswe = op.join(path_to_images, "multivar-v00/derivatives/diswe-interp-noskull-i4")

#path_to_diswe_02mm = "/Users/acionca/data/hcph-template/multivar-v00/derivatives/diswe-interp/distance_weighted_template_res-0.2_desc-N35T1w.nii.gz"
path_to_diswe_02mm = op.join(path_to_diswe, f"distance_weighted_template_res-0.2_desc-N35DisWei1T1w.nii.gz")
path_to_diswe_template = op.join(path_to_diswe, f"distance_weighted_template_res-0.4_desc-N35DisWei1T1w.nii.gz")
path_to_ants_template = op.join(path_to_ants, f"Affine_iteration3_A_tpl_template0.nii.gz")

templates_to_show = [path_to_diswe_02mm, path_to_diswe_template, path_to_ants_template]

resolution_factors = [4, 2, 1]

z_coords = 115
z_coords_template = [z_coords*res for res in resolution_factors]

#maxvals = [700, 14]
maxvals = [255]*len(templates_to_show)

zooms_x = [110, 150]
zooms_y = [80, 120]

fig, axes = plt.subplots(ncols=len(templates_to_show), figsize=(12*2, 10))

for ax in axes:
    ax.axis("off")

for i, (ax, z_coord_temp) in enumerate(zip(axes, z_coords_template)):
    first_template_data = nib.load(templates_to_show[i]).get_fdata()[:, :, z_coord_temp]

    zoomed_img = first_template_data[resolution_factors[i]*zooms_x[0]: resolution_factors[i]*zooms_x[1], resolution_factors[i]*zooms_y[0]: resolution_factors[i]*zooms_y[1]]
    normalized_img = 255*zoomed_img/zoomed_img.max()
    #first_template_data = nib.load(path_to_diswe_templates[i]).get_fdata()[300]
    #np.abs(data)+10*mask_data[:, :, z_coords]
    ax.imshow(normalized_img, cmap=my_gray, interpolation="none", vmin=0, vmax=maxvals[i])

axes[0].set_title("DisWe 0.2mm", fontsize=24)
axes[1].set_title("DisWe 0.4mm", fontsize=24)
axes[2].set_title("ANTs 0.8mm", fontsize=24)

#fig.savefig(op.join(figure_loc, f"ANTs_vs_DisWe-zoom_x{zooms_x[0]}_{zooms_x[1]}y{zooms_y[0]}_{zooms_y[1]}.png"), dpi=300, bbox_inches="tight", transparent=True)

In [None]:
# distance_weighted_template_res-0.4_desc-N35DisWei1T1w.nii.gz
path_to_diswe = op.join(path_to_images, "multivar-v00/derivatives/diswe-interp-noskull-i4")

res = 0.2

# distance_weighted_template_res-0.4_desc-N35DisWei1to_MNI04mmRigidT1w.nii.gz
path_to_diswe_templates = [op.join(path_to_diswe, f"distance_weighted_template_res-{res:1.1f}_desc-N35DisWei1T{i+1}w.nii.gz") for i in range(2)]
path_to_diswe_templates = [op.join(path_to_diswe, f"distance_weighted_template_res-{res:1.1f}_desc-N35DisWei1to_MNI04mmRigidT{i+1}w.nii.gz") for i in range(2)]

in_mni = False
if "MNI" in path_to_diswe_templates[0]:
    in_mni = True

fig, axes = plt.subplots(ncols=2, figsize=(15*2, 15))

res_factor = int(0.8//res)


if in_mni:
    xzoom = 40
    yzoom = 30
    z_coords = 103
else:
    xzoom = 55
    yzoom = 50
    z_coords = 115

z_coords_template = [res_factor*z_coords]*2
zooms = [[res_factor*xzoom, -res_factor*xzoom],
         [res_factor*yzoom, -res_factor*yzoom]]

for ax in axes:
    ax.axis("off")

maxvals = [[40, 260], [0, 255]]

for i, (ax, z_coord_temp) in enumerate(zip(axes, z_coords_template)):
    first_template_data = nib.load(path_to_diswe_templates[i]).get_fdata()[:, :, z_coord_temp]

    #zoomed_img = first_template_data[110:-110, 100:-100]
    zoomed_img = first_template_data[zooms[0][0]:zooms[0][1], zooms[1][0]:zooms[1][1]]

    norm_img = 255*zoomed_img/zoomed_img.max()

    #first_template_data = nib.load(path_to_diswe_templates[i]).get_fdata()[300]
    #np.abs(data)+10*mask_data[:, :, z_coords]
    #ax.imshow(np.flip(zoomed_img.T, axis=0), cmap=my_gray, interpolation="none", vmin=0, vmax=600)
    #ax.imshow(np.flip(zoomed_img.T, axis=0), cmap="gray", interpolation="none", vmin=50, vmax=600)
    ax.imshow(np.flip(norm_img.T, axis=0), cmap="gray", interpolation="none", vmin=maxvals[i][0], vmax=maxvals[i][1])

res_str = f"{res*10:02.0f}"
#fig.savefig(op.join(figure_loc, f"Interp_Templates_{res_str}mm-({maxvals})"+in_mni*"-MNI"+".png"), dpi=600, bbox_inches="tight", transparent=True)

# Surface plots

In [None]:
path_to_surf_template = op.join(path_to_images, "template-surf/template/surf")
surface_file = "rh.pial"
#surface_file = "rh.inflated"
bg_file = "rh.sulc"
#bg_file = "rh.curv"
#bg_file = "rh.thickness"

path_to_surf = op.join(path_to_surf_template, surface_file)
path_to_bg = op.join(path_to_surf_template, bg_file)

In [None]:
fs_keys = ["infl_left", "pial_left", "sulc_left",  "white_left", "curv_left", "infl_right", "pial_right", "sulc_right", "white_right", "curv_right"]
hemi_key = {"left":"lh", "right":"rh"}

surf_like_fs = {}
for key in fs_keys:
    label, hemi = key.split("_")
    if "infl" in label:
        label = "inflated"
        surf = surface.load_surf_mesh(op.join(path_to_surf_template, ".".join([hemi_key[hemi], label])))
    elif ("sulc" in label) or ("curv" in label):
        surf = surface.load_surf_data(op.join(path_to_surf_template, ".".join([hemi_key[hemi], label])))
    else:
        surf = surface.load_surf_mesh(op.join(path_to_surf_template, ".".join([hemi_key[hemi], label])))
    
    surf_like_fs.update({key:surf})

# Plot surfaces

In [None]:
fig, axes = plt.subplots(figsize=(10, 10))
#display = plotting.plot_surf(path_to_surf, bg_map=path_to_bg, hemi="right", engine="plotly", darkness=0.1)
#plotting.plot_surf(path_to_surf, bg_map=path_to_bg, hemi="right", darkness=1, figure=fig)
#maxval=0.4
maxval=10

disp = plotting.plot_surf(path_to_surf, surf_map=path_to_bg, hemi="right", cmap="binary", figure=fig,
                          colorbar=True, vmin=-maxval, vmax=maxval//2, avg_method="median")
#fig.savefig(op.join(path_to_fig, "surf_tests", f"surf_{bg_file}_maxval{maxval}.png"), dpi=300)

#plotting.view_surf(path_to_surf, bg_map=path_to_bg, black_bg=True)

#fig = display.show(renderer=None)

In [None]:
fig = plt.figure(layout="constrained", figsize=(12, 10))

gs = GridSpec(2, 3, figure=fig)

# All view: "lateral”, “medial”, “dorsal”, “ventral”, “anterior”, “posterior"

# Define axes
leftLateral = fig.add_subplot(gs[0, 0], projection="3d")
leftMedial = fig.add_subplot(gs[-1, 0], projection="3d")
rightLateral = fig.add_subplot(gs[0, -1], projection="3d")
rightMedial = fig.add_subplot(gs[-1, -1], projection="3d")
dorsal = fig.add_subplot(gs[:, 1], projection="3d")

all_axes = [[leftLateral, leftMedial, dorsal],
            [rightLateral, rightMedial, dorsal]]

# Some plotting utilities
hemis = ["left", "right"]

# Type of surface data to show
#bg_type = "curv"
bg_type = "sulc"

# Clipping of values
#maxval = 0.4
maxval = 10

for i_hemi, (hemi, ax_per_hemi) in enumerate(zip(hemis, all_axes)):

    surface_file = ".".join([hemi_key[hemi], "pial"])
    bg_file = ".".join([hemi_key[hemi], bg_type])

    path_to_surf = op.join(path_to_surf_template, surface_file)
    path_to_bg = op.join(path_to_surf_template, bg_file)

    for view, ax in zip(["lateral", "medial", "dorsal"], ax_per_hemi):

        real_view = view
        real_hemi = hemi
        real_title = hemi_key[hemi].capitalize()
        if view == "medial":
            real_view = "lateral"
            real_hemi = hemis[1 - i_hemi]
        elif view == "dorsal":
            real_title = None

        disp = plotting.plot_surf(path_to_surf, surf_map=path_to_bg, hemi=real_hemi, cmap="binary",
                                  vmin=-maxval, vmax=maxval, axes=ax, view=real_view, title=real_title)

figsize_str = "("+", ".join(fig.get_size_inches().astype(int).astype(str))+")"
surf_filename = f"all_surfs-{bg_type}-pyplot{figsize_str}600dpi.png"
#fig.savefig(op.join(path_to_fig, surf_filename), dpi=600)

In [None]:
fig = plt.figure(figsize=(12, 10))

# All view: "lateral”, “medial”, “dorsal”, “ventral”, “anterior”, “posterior"

# Some plotting utilities
hemis = ["left", "right"]

# Type of surface data to show
bg_type = "curv"
#bg_type = "sulc"

# Clipping of values
maxval = 10

for i_hemi, (hemi, ax_per_hemi) in enumerate(zip(hemis, all_axes)):

    surface_file = ".".join([hemi_key[hemi], "pial"])
    bg_file = ".".join([hemi_key[hemi], bg_type])

    path_to_surf = op.join(path_to_surf_template, surface_file)
    path_to_bg = op.join(path_to_surf_template, bg_file)

    for view, ax in zip(["lateral", "medial", "dorsal"], ax_per_hemi):

        real_view = view
        real_hemi = hemi
        real_title = hemi_key[hemi].capitalize()
        if view == "medial":
            real_view = "lateral"
            real_hemi = hemis[1 - i_hemi]
        elif view == "dorsal":
            real_title = None

        disp = plotting.plot_surf(path_to_surf, surf_map=path_to_bg, hemi=real_hemi, cmap="binary",
                                  vmin=-maxval, vmax=maxval, axes=ax, view=real_view, title=real_title)

        #figsize_str = "("+", ".join(fig.get_size_inches().astype(int).astype(str))+")"
        surf_filename = f"surfs-{bg_type}-{hemi}_{view}-pyplot.png"
        #fig.savefig(op.join(path_to_fig, "surf_tests", surf_filename), dpi=300)

# Plot stat map on background

In [None]:
from nilearn import datasets
import nibabel as nib
import numpy as np

path_to_data = "/Users/acionca/data/hcph-template/multivar-v00/derivatives/diswe-interp-noskull-i4/"

path_to_template = op.join(path_to_data, "distance_weighted_template_res-0.4_desc-N35DisWei1T1w.nii.gz")
path_to_template_mni = op.join(path_to_data, "distance_weighted_template_res-0.4_desc-N35DisWei1to_MNI04mmRigidT1w.nii.gz")
# PROV
path_to_template_mni_warp = op.join(path_to_data, "distance_weighted_template_res-0.4_desc-N35DisWei1T1w_MovedToMNI04mmAffine.nii.gz")

# /Users/acionca/data/hcph-template/multivar-v00/derivatives/diswe-interp-noskull-i4/
# distance_weighted_template_res-0.8_desc-N35DisWei1T1w_MovedToMNI04mmAffine.nii.gz
all_templates = [op.join(path_to_data, f"distance_weighted_template_res-0.{res}_desc-N35DisWei1T1w_MovedToMNI04mmAffine.nii.gz") for res in [8, 4]]

images = datasets.fetch_localizer_button_task()#['tmap']

t1w = images["anat"]
statmap = images["tmap"]

stat_img = datasets.load_sample_motor_activation_image()

statmap = stat_img

all_bgs = [path_to_template_mni, path_to_template_mni_warp]
titles = ["Stat map on HCPh template (Rigid to MNI) overlay",
          "Stat map on HCPh template (Affine to MNI) overlay"]

clip_prop = 0.7

fig, axes = plt.subplots(nrows=3, figsize=(12, 6*3))

for i, (title, bg_map) in enumerate(zip(titles, all_bgs)):
    bg_img = nib.load(bg_map)
    bg_data = bg_img.get_fdata()

    bg_clipped = nib.Nifti1Image(np.clip(bg_data, a_min=0, a_max=clip_prop*bg_data.max()), affine=bg_img.affine)

    plotting.plot_stat_map(nib.load(statmap), bg_img=bg_clipped, display_mode="ortho", threshold=1, black_bg=True,
                           title=title, axes=axes[i], interpolation="none")

plotting.plot_stat_map(nib.load(statmap), display_mode="ortho", threshold=2, black_bg=True,
                       title="Stat map on `MNI152TEMPLATE` overlay", axes=axes[-1], interpolation="none")

#fig.savefig(op.join(path_to_fig, "statmap-trans-v01.png"), dpi=600, facecolor="black")

### Testting multiple resolutions

In [None]:
all_res = [8, 4]

all_bgs = [op.join(path_to_data, f"distance_weighted_template_res-0.{res}_desc-N35DisWei1T1w_MovedToMNI04mmAffine.nii.gz") for res in all_res]
titles = [f"Stat map on HCPh template (Affine to MNI, 0.{res} mm) overlay" for res in all_res]

clip_prop = 0.7

fig, axes = plt.subplots(nrows=3, figsize=(12, 6*3))

for i, (title, bg_map) in enumerate(zip(titles, all_bgs)):
    bg_img = nib.load(bg_map)
    bg_data = bg_img.get_fdata()

    bg_clipped = nib.Nifti1Image(np.clip(bg_data, a_min=0, a_max=clip_prop*bg_data.max()), affine=bg_img.affine)

    plotting.plot_stat_map(nib.load(statmap), bg_img=bg_clipped, display_mode="ortho", threshold=3, black_bg=True,
                           title=title, axes=axes[i], interpolation="none")

plotting.plot_stat_map(nib.load(statmap), display_mode="ortho", threshold=2, black_bg=True,
                       title="Stat map on `MNI152TEMPLATE` overlay", axes=axes[-1], interpolation="none")

#fig.savefig(op.join(path_to_fig, "statmap-res-v01.png"), dpi=600, facecolor="black")

### Surface tentative

In [None]:
vol_to_surf_kwargs = {"depth":[-.5, 0, 1, 1.5]}
vol_to_surf_kwargs = {}

view = plotting.view_img_on_surf(nib.load(statmap), threshold='90%', surf_mesh=surf_like_fs,
                                 vol_to_surf_kwargs=vol_to_surf_kwargs)

view

In [None]:
fig, axes = plotting.plot_img_on_surf(statmap, surf_mesh=surf_like_fs,
                                views=['lateral', 'medial'], hemispheres=['left', 'right'],
                                threshold=2, colorbar=True)

fig.set_size_inches(10, 9)

#fig.savefig(op.join(path_to_fig, "statmap-surf-v01.png"), dpi=600)