In [1]:
import os
import numpy as np
import scipy.sparse as sp
import scipy.sparse.csgraph as csgraph
import nibabel as nib
import pyvista as pv
from nilearn import plotting
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

In [28]:
SURF_DIR = '../../datasets/NMT_v2.0_sym/NMT_v2.0_sym_surfaces'

level = 'mid' # mid, gray, white
left_inflated = os.path.join(SURF_DIR, f'lh.{level}_surface.inf_300.rsl.gii')
right_inflated = os.path.join(SURF_DIR, f'rh.{level}_surface.inf_300.rsl.gii')

x = nib.load(left_inflated) 
coords = x.agg_data('NIFTI_INTENT_POINTSET')
faces = x.agg_data('NIFTI_INTENT_TRIANGLE')

faces_pv = np.hstack([np.full((faces.shape[0], 1), 3), faces]).astype(np.int64).ravel()
mesh = pv.PolyData(coords, faces_pv)

In [34]:
## plot a slightly less inflated version
## this version is not accurate, but maybe good just for visualization?
# i like how this looks: n_iter=300, relaxation_factor=0.05

left_folded = os.path.join(SURF_DIR, f'lh.{level}_surface.rsl.gii')

x = nib.load(left_folded) 
coords = x.agg_data('NIFTI_INTENT_POINTSET')
faces = x.agg_data('NIFTI_INTENT_TRIANGLE')

faces_pv = np.hstack([np.full((faces.shape[0], 1), 3), faces]).astype(np.int64).ravel()
mesh = pv.PolyData(coords, faces_pv)

# increase n_iter for more inflation
# keep relaxation_factor between [0.01, 0.1]
mesh_s = mesh.copy()
mesh_s = mesh_s.smooth(n_iter=300, relaxation_factor=0.05, feature_smoothing=False)

pv.set_jupyter_backend('trame')

pl = pv.Plotter()
pl.add_mesh(
    mesh_s,
    cmap=['lightgray'],
    smooth_shading=True,
    interpolate_before_map=False,
)
pl.reset_camera()
pl.show()

Widget(value='<iframe src="http://localhost:50816/index.html?ui=P_0x346355df0_20&reconnect=auto" class="pyvist…

In [16]:
## plot hemi with no ROIs
pv.set_jupyter_backend('trame')

pl = pv.Plotter()
pl.add_mesh(
    mesh,
    cmap=['lightgray'],
    smooth_shading=False,
    interpolate_before_map=False,
)
pl.reset_camera()
pl.show()

Widget(value='<iframe src="http://localhost:50816/index.html?ui=P_0x32c0e29a0_5&reconnect=auto" class="pyvista…

In [4]:
def geodesic_ball(mesh, seed, radius_mm=4.0):
    pts = mesh.points
    f = mesh.faces.reshape(-1, 4)[:, 1:]  # (n_faces, 3)

    # build undirected edges from triangles
    e01 = f[:, [0, 1]]
    e12 = f[:, [1, 2]]
    e20 = f[:, [2, 0]]
    edges = np.vstack([e01, e12, e20])
    edges = np.sort(edges, axis=1)
    edges = np.unique(edges, axis=0)

    # edge weights = euclidean length along the surface edges
    w = np.linalg.norm(pts[edges[:, 0]] - pts[edges[:, 1]], axis=1)

    # sparse adjacency
    n = pts.shape[0]
    A = sp.csr_matrix((w, (edges[:, 0], edges[:, 1])), shape=(n, n))
    A = A + A.T

    dist = csgraph.dijkstra(A, directed=False, indices=seed)
    return (dist <= radius_mm).astype(np.uint8)

In [5]:
# --- user knobs ---
ap_axis = 1                 # usually y in template space; verify once by rotating
ap_targets = {
    # body patches
    'MiddleBody':     6.3,
    'AnteriorBody':  11.3,

    # face patches
    'MiddleFace':     8.0,   # center of ~7–9 mm range
    'AnteriorFace':  13.6,

    # object patches
    'MiddleObject':   9.0,
    'AnteriorObject': 12.8,

    # color patches
    'MiddleColor':    8.3,
    'AnteriorColor': 14.5,

    # scene-selective (parahippocampal / retrosplenial-ish)
    'Scene':          5.3,

    # control / undefined
    'Unknown':        7.3,
}
ap_tol_mm = 0.75            # slice thickness-ish; widen if you get too few verts
restrict_axis = 0           # left/right axis; usually x
restrict_mode = 'abs'       # 'abs' keeps near midline; 'sign' keeps hemisphere side
restrict_val = None         # set e.g. -1 for left hemi, +1 for right hemi if using 'sign'
restrict_tol = 25.0         # mm window if using restrict_mode='abs' (keeps ventral-ish/midline-ish)

In [6]:
points = {}
for name, ap_mm in ap_targets.items():
    # vertices near target ap
    ap = coords[:, ap_axis]
    cand = np.where(np.abs(ap - ap_mm) <= ap_tol_mm)[0]
    if cand.size == 0:
        raise ValueError(f'no vertices near ap={ap_mm} (tol={ap_tol_mm} mm)')

    # left hemi: avoid medial wall by preferring more lateral vertices (more negative x)
    xcoord = coords[cand, 0]
    zcoord = coords[cand, 2]

    # keep lateral-ish subset (20% most lateral for left hemi)
    x_thr = np.quantile(xcoord, 0.2)
    cand2 = cand[xcoord <= x_thr]

    # within that, pick a ventral-ish vertex (small z)
    j = np.argmin(coords[cand2, 2])
    vidx = int(cand2[j])
    points[name] = {'vidx': vidx, 'pt': coords[vidx]}
points

{'MiddleBody': {'vidx': 34943,
  'pt': array([-22.04805  ,   6.997086 ,   5.1638026], dtype=float32)},
 'AnteriorBody': {'vidx': 3363,
  'pt': array([-22.215214 ,  11.854848 ,   4.1056833], dtype=float32)},
 'MiddleFace': {'vidx': 36483,
  'pt': array([-22.08149  ,   8.5379095,   4.6989346], dtype=float32)},
 'AnteriorFace': {'vidx': 13372,
  'pt': array([-22.131983 ,  12.908673 ,   4.1511455], dtype=float32)},
 'MiddleObject': {'vidx': 36454,
  'pt': array([-22.219948,   9.525031,   4.505883], dtype=float32)},
 'AnteriorObject': {'vidx': 13378,
  'pt': array([-22.237059 ,  12.125127 ,   4.1230063], dtype=float32)},
 'MiddleColor': {'vidx': 36455,
  'pt': array([-22.156792 ,   9.041454 ,   4.5968533], dtype=float32)},
 'AnteriorColor': {'vidx': 848,
  'pt': array([-21.992828 ,  13.838565 ,   4.4226294], dtype=float32)},
 'Scene': {'vidx': 34951,
  'pt': array([-22.06677  ,   5.9308047,   5.531305 ], dtype=float32)},
 'Unknown': {'vidx': 36504,
  'pt': array([-21.993958 ,   8.016357 ,  

In [7]:
# roi -> family
roi_family = {
    'MiddleBody':     'body',
    'AnteriorBody':   'body',

    'MiddleFace':     'face',
    'AnteriorFace':   'face',

    'MiddleObject':   'object',
    'AnteriorObject': 'object',

    'MiddleColor':    'color',
    'AnteriorColor':  'color',

    'Scene':          'scene',
    'Unknown':        'unknown',
}

# family -> integer label (0 = background)
family_label = {
    'background': 0,
    'face':       1,
    'body':       2,
    'object':     3,
    'color':      4,
    'scene':      5,
    'unknown':    6,
}

# label -> color (ordered by label index)
cmap = [
    'lightgrey',  # 0 background
    'dodgerblue', # 1 face
    'limegreen',  # 2 body
    'orange',     # 3 object
    'gold',       # 4 color
    'purple',     # 5 scene
    'dimgray',    # 6 unknown
]


In [8]:
roi_radius = 1
labels = np.zeros(mesh.n_points, dtype=np.uint8)

for roi, ap_mm in ap_targets.items():
    fam = roi_family[roi]
    lab = family_label[fam]

    vidx = points[roi]['vidx']  # you already computed these
    mask = geodesic_ball(mesh, vidx, radius_mm=roi_radius).astype(bool)

    # overwrite is intentional: later ROIs of same family just reinforce
    labels[mask] = lab

mesh.point_data['roi_family'] = labels

In [9]:
pv.set_jupyter_backend('trame')

pl = pv.Plotter()
pl.add_mesh(
    mesh,
    scalars='roi_family',
    cmap=cmap,
    clim=[0, len(cmap) - 1],
    categories=True,
    smooth_shading=False,
    interpolate_before_map=False,
)
pl.reset_camera()
pl.show()

Widget(value='<iframe src="http://localhost:50816/index.html?ui=P_0x303f66880_0&reconnect=auto" class="pyvista…

In [None]:
## PLOT BOTH HEMIS
fig = plt.figure(figsize=(10, 4))

ax1 = fig.add_subplot(1, 2, 1, projection='3d')
plotting.plot_surf(
    surf_mesh=str(left_inflated),
    hemi='left',
    view='lateral',
    axes=ax1,
)
ax1.set_title('left inflated')

ax2 = fig.add_subplot(1, 2, 2, projection='3d')
plotting.plot_surf(
    surf_mesh=str(right_inflated),
    hemi='right',
    view='lateral',
    axes=ax2,
)
ax2.set_title('right inflated')

plt.tight_layout()
plt.show()