In [224]:
import numpy as np

class Atlas:
    """
    Holds a single atlas and its metadata.

    Parameters
    ----------
    name : str
        Human-readable identifier of the atlas (e.g., "aal" or "brodmann").
    vol : np.ndarray, shape (I, J, K)
        The 3D integer/float array representing the atlas, where each voxel 
        corresponds to a labeled region index (e.g., 0, 1, 2, ...).
    hdr : np.ndarray, shape (4, 4)
        The affine transform for voxel->world coordinates.
    labels : array-like or dict, optional
        A structure mapping region indices to region names. 
        Could be a list, a numpy array, or a dict {index: label}.
        By default, None if not provided.
    index : np.ndarray or list, optional
        Explicit numeric indices that correspond to `labels`.
        For example, if `index = [1,2,3]` and `labels = ["Area1","Area2","Area3"]`,
        then a voxel labeled '2' in `vol` => "Area2".
        If you supply a dict to `labels` keyed by these index values, you may leave
        `index` as None. Default is None.
    system : {"mni", "tal", "unknown"}, default "mni"
        The anatomical coordinate space or reference system.
        Commonly 'mni' or 'tal'. Use "unknown" if not sure.

    Attributes
    ----------
    name : str
    vol : np.ndarray
    hdr : np.ndarray
    labels : array-like or dict, optional
    index : np.ndarray or list, optional
    system : str
    offset : int or float
    shape : tuple of int
        The shape (I, J, K) of the atlas data.

    Methods
    -------
    get_label(value):
        Returns the label corresponding to a given index (int) value in `vol`.
    """

    def __init__(self,
                 name,
                 vol,
                 hdr,
                 labels=None,
                 index=None,
                 system='mni'):
        self.name = name
        self.vol = np.asarray(vol)
        self.hdr = hdr
        self.labels = labels
        self.index = index
        self.system = system

        # Basic checks
        if not isinstance(vol, np.ndarray) or vol.ndim != 3:
            raise ValueError("`vol` must be a 3D numpy array.")
        if not isinstance(hdr, np.ndarray) or hdr.shape != (4, 4):
            raise ValueError("`hdr` must be a 4x4 transform matrix.")

        self.shape = vol.shape  # convenience
        if isinstance(self.labels, dict):
            self._label2index = {v: k for k, v in self.labels.items()}
        else:
            self._label2index = None

    def _get_region_name(self, value):
        """
        Return the label corresponding to the integer `value` in the volume.

        If `labels` is:
          - a dict {region_index: region_name}, we use `labels.get(value, 'Unknown')`.
          - a list or np.ndarray, we find where `index == value`.
          - None, returns 'Unknown'.

        Returns
        -------
        str
            The label (region name), or 'Unknown' if not found.
        """
        value = str(value) #TODO: Check if this is necessary/could be problematic
        if isinstance(self.labels, dict):
            return self.labels.get(value, "Unknown")

        # Otherwise, if we have an array-like `labels` plus a separate `index` array,
        if self.index is not None and len(self.index) == len(self.labels):
            try:
                idx_pos = self.index.index(value) if isinstance(self.index, list) else np.where(self.index == value)[0][0]
                return self.labels[idx_pos]
            except (ValueError, IndexError):
                return "Unknown"
        return "Unknown"
    
    def get_region_name(self, value):
        """
        Return the clean region name for a given index value in the volume.
        """
        #TODO: Implement this method, for now just return the raw label
        return self._get_region_name(value)

    def get_region_index(self, label):
        """
        Return the index corresponding to the label in the volume.

        If `labels` is:
          - a dict {region_index: region_name}, we use `labels.get(value, 'Unknown')`.
          - a list or np.ndarray, we find where `index == value`.
          - None, returns 'Unknown'.

        Returns
        -------
        str
            The index (region index), or 'Unknown' if not found.
        """
        if self._label2index is not None:
            return self._label2index.get(label, "Unknown")
        
        
        if self.index is not None and len(self.index) == len(self.labels):
            try:
                idx_pos = self.labels.index(label) if isinstance(self.labels, list) else np.where(self.labels == label)[0][0]
                return self.index[idx_pos]
            except (ValueError, IndexError):
                return "Unknown"
            
        return "Unknown"
    
    def _get_hemisphere(self, region):
        """
        Return the hemisphere (left/right) of the region name/index.
        """
        if not isinstance(region, str):
            region = self.get_region_name(region)
        if region is None or region == "Unknown":
            return None
        region_lower = region.lower()
        if region_lower.endswith('_l'):
            return 'L'
        elif region_lower.endswith('_r'):
            return 'R'
        return None
    
    def get_list_of_regions(self):
        """
        Return a list of all unique region names in the atlas.
        """
        #TODO: Implement this method
        return None
    
    def pos_to_source(self, pos):
        """
        Return the source indices (i, j, k) for a given MNI coordinate using hdr.
        """
        pos = np.asarray(pos)
        if pos.shape != (3,):
            raise ValueError("pos must be a 3-element coordinate (x,y,z).")
        xyz = np.linalg.inv(self.hdr) @ np.array([*pos, 1])
        return tuple(map(int, np.round(xyz)[:3]))
    
    def pos_to_index(self, pos):
        """
        Return the region index for a given MNI coordinate using hdr.
        """
        ijk = self.pos_to_source(pos)
        # Check bounds
        if any(i < 0 or i >= s for i, s in zip(ijk, self.shape)):
            return "Unknown"  # or None
        return int(self.vol[ijk])
    
    def pos_to_region(self, pos):
        """
        Return the region name for a given MNI coordinate using hdr.
        """
        index = self.pos_to_index(pos)
        if index == "Unknown":
            return "Unknown"
        return self.get_region_name(index)
    
    def source_to_pos(self, source):
        """
        Return the MNI coordinate for a given source indices (i, j, k) using hdr.
        """
        source = np.atleast_2d(source)  # Ensure shape is (N, 3) even if (3,)
        source = np.hstack([source, np.ones((source.shape[0], 1))])  # (N, 4)
        transformed = source @ self.hdr.T 
        xyz = transformed[:, :3] / transformed[:, 3, np.newaxis]
        return xyz if len(source) > 1 else xyz[0]  
        
    def index_to_pos(self, index):
        """
        Return the MNI coordinate for a given region index.
        """
        index = int(index) #TODO: Check if this is necessary/could be problematic
        coords = np.argwhere(self.vol == index)  # shape (N,3)
        if coords.size == 0:
            return np.empty((0, 3))  # empty array if none found
        return self.source_to_pos(coords)
    
    def region_to_pos(self, region):
        """
        Return the MNI coordinate for a given region name.
        """
        index = self.get_region_index(region)
        if index == "Unknown":
            return np.empty((0, 3))
        return self.index_to_pos(index)


In [5]:
from sourcelocalizer.load_atlas_data import fetch_atlas
from sourcelocalizer import Atlas
atlas = fetch_atlas('aal')
atlas = Atlas(name='aal', vol=atlas['vol'], hdr=atlas['hdr'], labels=atlas['labels'], system='mni')

[get_dataset_dir] Dataset found in /Users/hamzaabdelhedi/Projects/olfaction_local/aal_SPM12


In [3]:
import xml.etree.ElementTree as ET
path_xml = "/Users/hamzaabdelhedi/nilearn_data/aal_SPM12/aal/ROI_MNI_V4.xml"
# Parse the XML file
tree = ET.parse(path_xml)  # Replace "file.xml" with your actual file path
root = tree.getroot()  # Get the root element

# Print the root tag
print(root.tag)

# Extract labels inside <data> section
regions = []

for label in root.find("data").findall("label"):
    index = label.find("index").text
    name = label.find("name").text
    regions.append((index, name))

# Print first 5 brain regions
for idx, name in regions[:5]:
    print(f"Index: {idx}, Name: {name}")

# Optional: Convert to dictionary
region_dict = {idx: name for idx, name in regions}

# Optional: Save as CSV
import pandas as pd
df = pd.DataFrame(regions, columns=["Index", "Region Name"])
df.to_csv("aal_brain_regions.csv", index=False)

atlas
Index: 2001, Name: Precentral_L
Index: 2002, Name: Precentral_R
Index: 2101, Name: Frontal_Sup_L
Index: 2102, Name: Frontal_Sup_R
Index: 2111, Name: Frontal_Sup_Orb_L


In [8]:
print(atlas.get_region_name(2001))
print(atlas.get_region_index('Precentral_L'))
print(atlas._get_hemisphere(2001))
print(atlas._get_hemisphere('Precentral_R'))
print(atlas.pos_to_source([-14.7, -31.2, 45.45]))
print(atlas.pos_to_index([-14.7, -31.2, 45.45]))
print(atlas.pos_to_region([-14.7, -31.2, 45.45]))
print(atlas.source_to_pos([52, 47, 59]))
print(atlas.index_to_pos(2001))
print(atlas.region_to_pos('Precentral_L'))

Precentral_L
2001
L
R
(52, 47, 59)
4011
Cingulum_Mid_L
[-14. -32.  46.]
[[-14. -20.  70.]
 [-14. -18.  68.]
 [-14. -18.  70.]
 ...
 [-62.   8.  30.]
 [-62.   8.  32.]
 [-62.   8.  34.]]
[[-14. -20.  70.]
 [-14. -18.  68.]
 [-14. -18.  70.]
 ...
 [-62.   8.  30.]
 [-62.   8.  32.]
 [-62.   8.  34.]]


In [121]:
import nibabel as nib
import numpy as np
import xml.etree.ElementTree as ET
from scipy.spatial import distance

# ---- STEP 1: Load AAL Atlas ---- #
nii_path = "/Users/hamzaabdelhedi/nilearn_data/aal_SPM12/aal/ROI_MNI_V4.nii"
nii_img = nib.load(nii_path)
atlas_data = nii_img.get_fdata()
affine = nii_img.affine  # Transformation matrix

# ---- STEP 2: Load AAL Labels from XML ---- #
xml_path = "/Users/hamzaabdelhedi/nilearn_data/aal_SPM12/aal/ROI_MNI_V4.xml"
tree = ET.parse(xml_path)
root = tree.getroot()

# Create a dictionary: {index -> region name}
region_dict = {}
for label in root.find("data").findall("label"):
    index = int(label.find("index").text)  # Convert index to int
    name = label.find("name").text
    region_dict[index] = name

# ---- STEP 3: Convert MNI Coordinates to Region Names ---- #
def find_closest_region(mni_coord, search_radius=5):
    """
    Convert MNI coordinate to the closest AAL region if an exact match is not found.
    
    :param mni_coord: (X, Y, Z) MNI coordinate
    :param search_radius: Search radius (in mm) to find the closest region
    :return: (Region Name, Distance in mm)
    """
    # Convert MNI to voxel coordinates
    voxel_coord = np.round(nib.affines.apply_affine(np.linalg.inv(affine), mni_coord)).astype(int)

    # Ensure voxel is inside image bounds
    if not (0 <= voxel_coord[0] < atlas_data.shape[0] and
            0 <= voxel_coord[1] < atlas_data.shape[1] and
            0 <= voxel_coord[2] < atlas_data.shape[2]):
        return "Out of bounds", None

    # Check if voxel has a direct region match
    region_index = int(atlas_data[tuple(voxel_coord)])
    if region_index in region_dict:
        return region_dict[region_index], 0  # Distance = 0 if exact match

    # ---- STEP 4: Search for the Nearest Labeled Voxel ---- #
    # Get all nonzero voxel coordinates (labeled regions)
    labeled_voxels = np.array(np.where(atlas_data > 0)).T  # Transpose to (N,3)
    
    # Convert labeled voxels to MNI space
    labeled_mni_coords = nib.affines.apply_affine(affine, labeled_voxels)
    
    # Compute Euclidean distances to the given MNI coordinate
    distances = distance.cdist([mni_coord], labeled_mni_coords)[0]  # Compute distance from the given MNI coord
    
    # Find the closest labeled voxel
    min_index = np.argmin(distances)
    closest_mni = labeled_mni_coords[min_index]
    closest_voxel = labeled_voxels[min_index]
    closest_index = int(atlas_data[tuple(closest_voxel)])
    
    # Get the closest region name
    closest_region = region_dict.get(closest_index, "Unknown Region")
    
    return closest_region, round(distances[min_index], 2)  # Return region name and distance

# ---- STEP 4: Example Usage ---- #
# Example MNI coordinates
example_coords = [(-14.7, -31.2, 45.45), (-15.50, -31.45, 49.30), (-17.10, -31.75, 56.95)]  # Replace with actual coordinates

for coord in example_coords:
    region, dist = find_closest_region(coord)
    print(f"MNI: {coord} -> Closest Region: {region} (Distance: {dist} mm)")


MNI: (-14.7, -31.2, 45.45) -> Closest Region: Cingulum_Mid_L (Distance: 0 mm)
MNI: (-15.5, -31.45, 49.3) -> Closest Region: Cingulum_Mid_L (Distance: 1.74 mm)
MNI: (-17.1, -31.75, 56.95) -> Closest Region: Postcentral_L (Distance: 3.06 mm)


In [124]:
from nilearn import datasets
from nilearn.image import load_img
import nibabel as nib
import numpy as np

# Step 1: Load the atlas and labels
atlas = datasets.fetch_atlas_harvard_oxford('cort-maxprob-thr25-1mm')
atlas_img = load_img(atlas.maps)  # This is the atlas image
atlas_labels = atlas.labels       # This is a list of region labels

# Step 2: Load the voxel coordinates (example: [x, y, z] in MNI space)
voxel_coords = np.array(example_coords[0])  # Replace with your voxel coordinates

# Step 3: Map voxel to atlas index
# Affine transformation to convert real-world coordinates to voxel indices
affine = atlas_img.affine
voxel_indices = nib.affines.apply_affine(np.linalg.inv(affine), voxel_coords).astype(int)

# Step 4: Find the atlas region corresponding to the voxel
atlas_data = atlas_img.get_fdata()
region_index = atlas_data[tuple(voxel_indices)]

# Step 5: Retrieve the label of the region
if region_index > 0:  # Check if the voxel belongs to a labeled region
    region_name = atlas_labels[int(region_index)]
    print(f"The voxel {voxel_coords} belongs to the region: {region_name}")
else:
    print(f"The voxel {voxel_coords} is not assigned to any region in the atlas.")


[get_dataset_dir] Dataset found in /Users/hamzaabdelhedi/nilearn_data/fsl
The voxel [-14.7  -31.2   45.45] belongs to the region: Precentral Gyrus


In [34]:


###############################################################################
# Coordinate transform stubs (MNI <-> TAL, etc.)
# You can expand or replace as needed.
###############################################################################
def mni_to_tal(xyz):
    """
    Convert MNI coordinates to approximate Talairach coordinates using
    a simple linear transform (the "Brett transform").

    Parameters
    ----------
    xyz : np.ndarray of shape (N, 3) or (3,)
        MNI coordinates. If shape is (3,), we interpret it as a single point.

    Returns
    -------
    tal_xyz : np.ndarray of the same shape
        Approximated Talairach coordinates.
    """
    xyz = np.asarray(xyz, dtype=np.float32)

    # Handle single point or multiple points
    if xyz.ndim == 1 and xyz.size == 3:
        xyz = xyz.reshape(1, 3)
        single_input = True
    else:
        single_input = False

    # Brett transform (one of several variants). 
    # The exact numbers are "best guess" approximations.
    # You can replace them with your own or more official transformations.
    x_tal = 0.9900 * xyz[:, 0]
    y_tal = 0.9688 * xyz[:, 1] + 0.0420
    z_tal = 0.8390 * xyz[:, 2] + 0.1300

    tal_xyz = np.column_stack((x_tal, y_tal, z_tal))

    if single_input:
        tal_xyz = tal_xyz[0]  # return shape (3,)

    return tal_xyz

In [40]:
class SourceLocalizer:
    """
    A brand-new, independent class that locates (x,y,z) source positions
    in one or more atlas volumes.

    Parameters
    ----------
    atlases : list of AtlasVolume or str
        Either a list of AtlasVolume objects or a path/string pointing
        to an atlas. If a string or list of strings is provided, we'll
        try to load from disk (using `load_volume_from_file` stub).
    """

    def __init__(self, atlases, offset=0):
        self._atlas_list = []

        # If user passes a single string or single AtlasVolume, unify to a list
        if isinstance(atlases, (str, Atlas)):
            atlases = [atlases]

        # Build up our list of AtlasVolume
        for atlas in atlases:
            if isinstance(atlas, Atlas):
                self._atlas_list.append(atlas)
            elif isinstance(atlas, str):
                # Try to load from file or interpret atlas name
                vol, hdr, labels, index, system = load_volume_from_file(atlas)
                name = atlas.split('.')[0]  # naive
                av = Atlas(
                    name=name,
                    vol=vol,
                    hdr=hdr,
                    labels=labels,
                    index=index,
                    system=system,
                    offset=offset
                )
                self._atlas_list.append(av)
            else:
                raise TypeError(f"Invalid atlas type: {type(atlas)}")

        self._analysis = None  # Will store last result

    def localize_sources(self,
                         xyz,
                         source_names=None,
                         replace_bad=True,
                         bad_patterns=None,
                         replace_with='Not found',
                         distance=None,
                         keep_only=None):
        """
        Localize each source in xyz to one or more atlas volumes.

        If multiple atlases are in self._atlas_list, we merge their results.

        Parameters
        ----------
        xyz : np.ndarray, shape (N, 3)
            Source coordinates.
        source_names : list of str, optional
            If provided, must match len(xyz). Used as labels in the output.
        replace_bad : bool
            Whether to replace missing or bad pattern values with `replace_with`.
        bad_patterns : list
            Values considered 'bad' (e.g. [-1, None, 'undefined']).
        replace_with : str
            Replace any 'bad' values with this string.
        distance : float or None
            If provided, tries to reassign 'bad' or 'Not found' points
            from the nearest valid-labeled source within this distance.
        keep_only : list of str or None
            If provided, keep only sources that match these labels (in any column).

        Returns
        -------
        df : pd.DataFrame
            The localized results. One row per source coordinate,
            plus columns for each atlas’s label.
        """
        if not isinstance(xyz, np.ndarray) or xyz.ndim != 2 or xyz.shape[1] != 3:
            raise ValueError("xyz must be a (N, 3) array of coordinates.")
        n_sources = xyz.shape[0]

        if source_names is None:
            source_names = [f"s{i}" for i in range(n_sources)]
        elif len(source_names) != n_sources:
            raise ValueError("source_names length must match xyz rows.")

        if bad_patterns is None:
            bad_patterns = [-1, None, 'undefined', 'None']

        # Localize for each atlas
        dataframes = []
        for atlas_vol in self._atlas_list:
            colname = atlas_vol.name  # e.g. "brodmann" or "myAtlas"
            results_col = []
            for i in range(n_sources):
                label = atlas_vol.label_at_coordinate(xyz[i])
                results_col.append(label if label is not None else 'Not found')
            df_atlas = pd.DataFrame({
                'Text': source_names,
                'X': xyz[:, 0],
                'Y': xyz[:, 1],
                'Z': xyz[:, 2],
                'hemisphere': np.where(xyz[:, 0] > 0, 'Right', 'Left'),
                colname: results_col
            })
            dataframes.append(df_atlas)

        # If multiple atlases, merge
        if len(dataframes) == 1:
            df_merged = dataframes[0]
        else:
            # merge on Text, X, Y, Z, hemisphere
            df_merged = dataframes[0]
            for df in dataframes[1:]:
                df_merged = pd.merge(
                    df_merged,
                    df,
                    on=['Text', 'X', 'Y', 'Z', 'hemisphere'],
                    how='outer'
                )

        # Replace bad patterns if requested
        if replace_bad:
            replace_map = {bp: replace_with for bp in bad_patterns}
            df_merged.replace(replace_map, inplace=True, regex=False)

        # Distance-based fix if requested
        if distance is not None:
            self._distance_fix(df_merged, distance, replace_with)

        # Keep only certain labels if requested
        if keep_only is not None:
            df_merged = self._filter_dataframe(df_merged, keep_only)

        self._analysis = df_merged
        return df_merged

    @staticmethod
    def _distance_fix(df, distance, replace_with):
        """
        Reassign 'replace_with' rows to the nearest valid-labeled row
        if within `distance`. Works on a per-atlas basis.

        This is an adaptation of the distance-based fix used in the
        original ROI code.
        """
        coords = df[['X', 'Y', 'Z']].values
        dist_matrix = cdist(coords, coords, metric='euclidean')

        # For each atlas column:
        # if row is "Not found", see if there's a row within `distance`
        # that has a valid label. Then copy that label.
        # This is fairly naive, but consistent with typical usage.

        # Identify non-key columns that contain atlas labels
        key_cols = {'Text', 'X', 'Y', 'Z', 'hemisphere'}
        atlas_cols = [c for c in df.columns if c not in key_cols]

        for c in atlas_cols:
            col_data = df[c].values
            for i in range(len(col_data)):
                if col_data[i] == replace_with:
                    # find any row j within distance that has a valid label
                    # that is not 'replace_with'
                    row_dists = dist_matrix[i, :]
                    valid_candidates = np.where((row_dists <= distance) &
                                                (col_data != replace_with))[0]
                    if len(valid_candidates) > 0:
                        # pick the closest one
                        j = valid_candidates[row_dists[valid_candidates].argmin()]
                        col_data[i] = col_data[j]
            df[c] = col_data

    @staticmethod
    def _filter_dataframe(df, keep_only):
        """
        Keep only rows where at least one atlas label matches any value in keep_only.
        """
        if isinstance(keep_only, str):
            keep_only = [keep_only]

        mask = np.zeros(len(df), dtype=bool)
        key_cols = {'Text', 'X', 'Y', 'Z', 'hemisphere'}
        # Only consider atlas columns
        atlas_cols = [c for c in df.columns if c not in key_cols]

        for c in atlas_cols:
            for pattern in keep_only:
                mask |= (df[c].values == pattern)

        return df[mask]

    @property
    def analysis(self):
        """Return the last DataFrame produced by localize_sources."""
        return self._analysis


In [41]:
if __name__ == "__main__":
    # Example usage
    xyz_coords = np.array([
        [10.0, -15.0, 35.0],
        [-12.0, 22.0, 48.0],
        [50.0, 50.0, 50.0]  # possibly out of volume => "Not found"
    ])
    # You can pass a single string (path) or multiple:
    #   atlases=['brodmann.nii', 'aal.npz']  # for instance
    # Or you can create an AtlasVolume manually and pass it:
    localizer = SourceLocalizer("my_atlas_file.nii")

    results_df = localizer.localize_sources(
        xyz=xyz_coords,
        source_names=["S1", "S2", "S3"],
        replace_bad=True,
        bad_patterns=[None, -1, "undefined", "Not found"],
        replace_with="Unknown",
        distance=10.0,
        keep_only=None
    )

    print("\n--- Localized Results ---")
    print(results_df)
    """
    Example output:
         Text     X     Y     Z hemisphere   my_atlas_file
    0       S1  10.0 -15.0  35.0       Right        "Area 4"
    1       S2 -12.0  22.0  48.0        Left        "Area 9"
    2       S3  50.0  50.0  50.0       Right       "Unknown"
    """



--- Localized Results ---
  Text     X     Y     Z hemisphere my_atlas_file
0   S1  10.0 -15.0  35.0      Right       Unknown
1   S2 -12.0  22.0  48.0       Left       Unknown
2   S3  50.0  50.0  50.0      Right       Unknown


In [48]:
# metadata_example.py (an illustrative example, not real data)

label4mri_metadata = {
    "aal": {
        "coordinate_list": None,   # <-- This will be a NumPy array or nested list shape (3, N)
        "coordinate_label": None,  # <-- This will be a list or array shape (N,)
        "label": None              # <-- This might be a pandas DataFrame or a list of dicts
    },
    "ba": {
        "coordinate_list": None,
        "coordinate_label": None,
        "label": None
    }
}


In [49]:
# helpers.py (or put them inside mni_to_region_name.py, etc.)

import numpy as np
# from .metadata_example import label4mri_metadata

def _mni_to_region_index(x, y, z, distance=True, template=None):
    """
    Python equivalent of mni_to_region_index in R.
    Finds the region index for a given MNI coordinate (x, y, z).
    If exact match is not found and distance=True, finds nearest coordinate.
    """
    if template not in label4mri_metadata:
        raise ValueError(f"Template '{template}' does not exist in metadata.")
    
    data = label4mri_metadata[template]

    # coordinate_list = shape (3, N)
    coords = data["coordinate_list"]  # e.g., a NumPy array
    labels = data["coordinate_label"] # shape (N,)

    # 1) Round the MNI coordinate as in R
    x, y, z = round(x), round(y), round(z)

    # 2) Find all columns in coords that match x, y, z exactly
    #    coords[0,:] -> all x's, coords[1,:] -> all y's, coords[2,:] -> all z's
    match_x = np.where(coords[0, :] == x)[0]
    # refine by y
    match_y = match_x[coords[1, match_x] == y]
    # refine by z
    match_xyz = match_y[coords[2, match_y] == z]

    if len(match_xyz) > 0:
        # found exact match
        region_index = labels[match_xyz[0]]   # take first match if multiple
        region_distance = 0 if distance else None
    else:
        # no exact match
        if distance:
            # compute squared distances to all coordinates
            diff = coords - np.array([[x],[y],[z]])
            # shape of diff is still (3, N), sum of squares across axis=0
            sqdist = np.sum(diff**2, axis=0)  # shape (N,)
            min_idx = np.argmin(sqdist)
            region_index = labels[min_idx]
            region_distance = np.sqrt(sqdist[min_idx])
        else:
            region_index = None
            region_distance = None

    return region_index, region_distance



def _region_index_to_mni(region_index, template=None):
    """
    Python equivalent of region_index_to_mni in R.
    Returns all coordinates for which coordinate_label == region_index.
    """
    if template not in label4mri_metadata:
        raise ValueError(f"Template '{template}' does not exist in metadata.")
    
    coords = label4mri_metadata[template]["coordinate_list"]  # shape (3, N)
    labels = label4mri_metadata[template]["coordinate_label"]  # shape (N,)

    # find all columns with label == region_index
    match_cols = np.where(labels == region_index)[0]
    matched_coords = coords[:, match_cols]  # shape (3, k)
    
    # Return as a Python list of [x, y, z], or as an Nx3 array
    # Let’s return a list of dicts, e.g. [{'x': x, 'y': y, 'z': z}, ...]
    out = []
    for col in range(matched_coords.shape[1]):
        out.append({
            "x": int(matched_coords[0, col]),
            "y": int(matched_coords[1, col]),
            "z": int(matched_coords[2, col])
        })
    return out


In [52]:
# mni_to_region_name.py

# from .helpers import _mni_to_region_index
# from .metadata_example import label4mri_metadata

def mni_to_region_name(x, y, z, distance=True, template=["aal","ba"]):
    """
    Python equivalent of mni_to_region_name in R.
    """
    # Validate template
    not_found = [t for t in template if t not in label4mri_metadata]
    if len(not_found) > 0:
        raise ValueError(f"Template(s) {not_found} do not exist in the metadata.")

    # For each template, find region index + distance, then map index -> label
    output = {}
    for t in template:
        region_index, dist_val = _mni_to_region_index(x, y, z, distance=distance, template=t)

        # Now find the region_name from region_index
        if region_index is not None:
            # look up region_name from label
            label_df = label4mri_metadata[t]["label"]
            # get region name
            region_name = "NULL"
            for row in label_df:
                if row["Region_index"] == region_index:
                    region_name = row["Region_name"]
                    break
        else:
            region_name = "NULL"

        # Build keys that mimic the R structure: e.g. aal.distance, aal.label
        output[f"{t}.distance"] = dist_val if dist_val is not None else "NULL"
        output[f"{t}.label"]    = region_name

    return output


In [53]:
# region_name_to_mni.py

# from .helpers import _region_name_to_index, _region_index_to_mni
# from .metadata_example import label4mri_metadata

def region_name_to_mni(region_names, template="aal"):
    """
    Python equivalent of region_name_to_mni in R.
    region_names: list or single string of region names
    template: "aal" or "ba"
    """
    if isinstance(region_names, str):
        region_names = [region_names]  # handle single string

    if template not in label4mri_metadata:
        raise ValueError(f"Template '{template}' does not exist in the metadata.")

    # Validate each region name
    # We'll gather them all in a dict
    results = {}
    for name in region_names:
        # get index
        r_index = _region_name_to_index(name, template=template)
        if r_index is None:
            raise ValueError(f"Region '{name}' does not exist in template '{template}'.")
        # get MNI coordinates
        coords = _region_index_to_mni(r_index, template=template)
        results[f"{template}.{name}"] = coords

    return results


In [54]:
# show_cluster_composition.py

import numpy as np
from collections import Counter
# from .mni_to_region_name import mni_to_region_name
# from .metadata_example import label4mri_metadata

def show_cluster_composition(coordinate_matrix, template=["aal","ba"]):
    """
    Python equivalent of show_cluster_composition in R.
    coordinate_matrix: shape (3, N) (x in row 0, y in row 1, z in row 2).
    template: list of templates (e.g. ["aal","ba"])
    """
    # Validate template
    not_found = [t for t in template if t not in label4mri_metadata]
    if len(not_found) > 0:
        raise ValueError(f"Template(s) {not_found} do not exist in the metadata.")

    # coordinate_matrix can be a list of lists, so let's ensure it's an array
    coords = np.array(coordinate_matrix)
    if coords.shape[0] != 3:
        raise ValueError("coordinate_matrix must have shape (3, N).")

    n_coords = coords.shape[1]

    results = {}
    for t in template:
        # We'll collect the label for each coordinate
        labels_for_all_coords = []
        for col in range(n_coords):
            x, y, z = coords[0, col], coords[1, col], coords[2, col]
            out = mni_to_region_name(x, y, z, distance=False, template=[t])
            # out might look like {"aal.distance": None, "aal.label": "Putamen_R"}
            label_key = f"{t}.label"
            region_label = out[label_key]
            labels_for_all_coords.append(region_label)
        
        # Now we have a list of region labels (some might be "NULL").
        freq_counter = Counter(labels_for_all_coords)  # e.g., {"NULL": 6, "Caudate_R": 2, ...}
        
        # Create a table-like structure: each row = (region_label, count, percentage)
        table = []
        for region_label, count in freq_counter.most_common():
            percentage = round((count / n_coords)*100, 1)
            table.append({
                "region": region_label,
                "Number of coordinates": count,
                "Percentage (%)": percentage
            })
        
        results[f"{t}.cluster.composition"] = table

    return results
