In [46]:
import nibabel as nib
import numpy as np
import pandas as pd
import os

class Atlas:

    def __init__(self, filepath, csv_key_path=None):
        """Init function for Atlas class, enhanced to handle .csv, .nii, and .nii.gz files.
        
        Args:
            filepath (str): Path to the .nii/.nii.gz/.csv file of the atlas.
            csv_key_path (str, optional): Path to the .csv file containing the key for the atlas. 
            If not provided, the class tries to find a .csv file matching the root name of the provided file.
        """
        
        # Determine the root file name without extension
        root_filepath = os.path.splitext(filepath)[0]
        if root_filepath.endswith('.nii'):
            root_filepath = os.path.splitext(root_filepath)[0]

        # Set up file paths for .csv and .nii/.nii.gz files
        if not csv_key_path:
            csv_key_path = root_filepath + ".csv"

        nii_file_path = root_filepath + ".nii.gz"
        if not os.path.exists(nii_file_path):
            nii_file_path = root_filepath + ".nii"
            if not os.path.exists(nii_file_path):
                raise FileNotFoundError(f"Could not find a .nii or .nii.gz file for {root_filepath}.")

        if not os.path.exists(csv_key_path):
            raise FileNotFoundError(f"Could not find the csv file {csv_key_path}.")

        # Load the atlas and the key
        self.csv_key_path = csv_key_path
        self.key = pd.read_csv(self.csv_key_path)
        self.atlas = nib.load(nii_file_path)
        self.labels = self.key['name'].tolist()
    
    def name_at_index(self, index=[48, 94, 35]):
        """Returns the name of the region at the given index.
        Args:
            index (list): List of three integers representing the x, y, z coordinates of the index (in voxel space).
        Returns:
            str: Name of the region at the given index."""
        
        value_at_index = self.atlas.get_fdata()[tuple(index)]
        matched_row = self.key[self.key['value'] == value_at_index]
        if len(matched_row) == 1:
            return matched_row['name'].iloc[0]
        elif len(matched_row) > 1:
            return matched_row['name'].tolist()
        else:
            return "No matching region found"

class AtlasLabeler:

    def __init__(self, nifti, atlas="atlases/HarvardOxford-cort-maxprob-thr0-2mm.nii.gz", min_threshold=1, max_threshold=1):
        """Init function for NiftiLabeler class.
        Args:
            atlas (Atlas): An Atlas object, or the path to the .nii.gz file of the atlas.
            nifti (str): Nifti obj, or path to the .nii.gz file of the nifti volume to be labeled.
            min_threshold (int): Minimum value to be considered as a label in the nifti volume.
            max_threshold (int): Maximum value to be considered as a label in the nifti volume.
        """

        if type(atlas) != Atlas:
            print(f"Looking for atlas {atlas}...")
            try:
                atlas = Atlas(atlas)
            except:
                print(f"""Could not find atlas {atlas}. Make sure the path is correct and try again. 
                      \n Example: atlases/HarvardOxford-cort-maxprob-thr0-2mm.nii.gz""")
                return
        self.atlas = atlas
        self.labels = atlas.labels

        if type(nifti) != nib.nifti1.Nifti1Image:
            print(f"Looking for nifti {nifti}...")
            try:
                nifti = nib.load(nifti)
            except:
                print(f"""Could not find nifti {nifti}. Make sure the path is correct and try again. 
                      \n Example: volumes/30lechowiczglogowska2019.nii.gz""")
                return
        self.nifti = nifti
        self.volume_data = nifti.get_fdata()
        self.labeled_data = None
        self.voxel_counts = None
        self.unique_labels = None
        self.min_threshold = min_threshold
        self.max_threshold = max_threshold
    
    def label_volume(self):
        """Labels the nifti volume with the names of the regions from the atlas."""
        
        if self.atlas.atlas.shape != self.nifti.shape:
            print(f"The shape of the atlas ({self.atlas.atlas.shape}) and the nifti volume ({({self.nifti.shape})}) do not match. Please provide a nifti volume with the same shape as the atlas.")
            return

        # Find indices where the volume data is within the specified threshold range
        masked_indices = np.where(np.logical_and(self.volume_data >= self.min_threshold, 
                                         self.volume_data <= self.max_threshold))
        # Prepare a dictionary to store the results
        results = {'index': [], 'atlas_label': []}

        # Iterate over the masked indices
        for i, j, k in zip(*masked_indices):
            label = self.atlas.name_at_index([i, j, k])

            # Store the results
            results['index'].append((i, j, k))
            results['atlas_label'].append(label)

        # Convert the results to a DataFrame
        results_df = pd.DataFrame(results)
        self.labeled_data = results_df
        self.voxel_counts = results_df['atlas_label'].value_counts().to_dict()
        self.unique_labels = results_df['atlas_label'].unique().tolist()
        return self

class MultiAtlasLabeler:
    """Wrapper class for labeling a nifti volume with multiple atlases."""
    def __init__(self, nifti_path, atlas_list='atlases/harvoxf_atlas_list.csv'):
        
        if type(atlas_list) == str:
            print(f"Looking for {atlas_list}")
            try:
                atlas_list = pd.read_csv(atlas_list)
            except:
                print(f"""Could not find csv {atlas_list}. Make sure the path is correct and try again. 
                      \n Example: atlases/harvoxf_atlas_list.csv""")
                return
        elif type(atlas_list) ==list or type(atlas_list) == dict:
            atlas_list = pd.DataFrame(atlas_list)
            if 'atlas_path' not in atlas_list.columns:
                atlas_list.columns = ['atlas_path']
            atlas_list['atlas'] = atlas_list['atlas_path'].apply(lambda x: Atlas(x))
            atlas_list['labels'] = atlas_list['atlas'].apply(lambda x: x.labels)
        
        self.atlas_list = atlas_list
        self.labels = list(set(label for sublist in atlas_list['labels'] for label in sublist))

        if type(nifti_path) != str and type(nifti_path) != nib.nifti1.Nifti1Image:
            print(f"Could not find nifti {nifti_path}. Make sure the path is correct and try again. \n Example: volumes/30lechowiczglogowska2019.nii.gz")
            return
        
        self.nifti_path = nifti_path
        self.voxel_counts = None
        self.unique_labels = None
    
    def label_volume(self):
        """Labels a nifti volume using multiple atlases and consolidates the results."""

        def safe_label_volume(atlas, nifti_path):
            try:
                labeler = AtlasLabeler(nifti_path, atlas)
                labeler.label_volume()
                return labeler
            except Exception as e:
                print(f"Error processing {atlas}: {e}")
                return None

        # List to hold the results for each atlas
        results_list = []

        # Iterate through the atlas list and label the volume for each atlas
        for _, row in self.atlas_list.iterrows():
            labeler = safe_label_volume(row['atlas'], self.nifti_path)
            if labeler and labeler.labeled_data is not None:
                voxel_counts = labeler.voxel_counts
                unique_labels = labeler.unique_labels
                results_list.append({'voxel_counts': voxel_counts, 'unique_labels': unique_labels})

        # Convert the list of results to a DataFrame using pd.concat
        if results_list:
            # Aggregate voxel counts
            combined_voxel_counts = {}
            for result in results_list:
                for label, count in result['voxel_counts'].items():
                    combined_voxel_counts[label] = combined_voxel_counts.get(label, 0) + count
            self.voxel_counts = combined_voxel_counts

            # Aggregate unique labels
            all_labels = set()
            for result in results_list:
                all_labels.update(result['unique_labels'])
            self.unique_labels = list(all_labels)
        else:
            print("No valid results were generated from the atlases.")
            self.voxel_counts = {}
            self.unique_labels = []

        return self

class CustomLabeler:

    def __init__(self, nifti_path, name_mask_dict="atlases/joseph_custom_atlas.csv", min_threshold=1, max_threshold=1):
        
        if type(name_mask_dict) == str:
            try:
                name_mask_dict = pd.read_csv(name_mask_dict)
            except:
                print(f"""Could not find csv {name_mask_dict}. Make sure the path is correct and try again. 
                      \n Example: masks/harvoxf_masks.csv""")
                return
        if name_mask_dict is not None and type(name_mask_dict) != dict and type(name_mask_dict) != pd.DataFrame:
            print("""Please provide a dictionary, DataFrame of name-mask pairs.
                  Example: {'hippocampus': harvoxf_hippocampus.nii.gz, 'basal_ganglia': harvoxf_basal_ganglia.nii.gz}""")
            return
        if type(nifti_path) != nib.nifti1.Nifti1Image:
            print(f"Looking for nifti {nifti_path}...")
            try:
                nifti = nib.load(nifti_path)
            except:
                print(f"""Could not find nifti {nifti_path}. Make sure the path is correct and try again. 
                      \n Example: volumes/30lechowiczglogowska2019.nii.gz""")
                return
        if type(name_mask_dict) == dict:
            self.name_mask_df = pd.DataFrame(name_mask_dict)
        elif type(name_mask_dict) == pd.DataFrame:
            self.name_mask_df = name_mask_dict
        
        self.labels = self.name_mask_df['region_name'].tolist()
        self.nifti = nifti
        self.volume_data = nifti.get_fdata()
        self.min_threshold = min_threshold
        self.max_threshold = max_threshold
        self.labeled_data = None
        self.voxel_counts = None
        self.unique_labels = None  
    
    def label_volume(self):
        """Assigns labels to the nifti volume based on the defined regions from the provided masks."""
        # Initialize a dictionary to hold the labeling results
        results = {'index': [], 'atlas_label': []}
        
        # Identify voxels in the nifti volume that meet the threshold criteria
        within_threshold_voxels = set(zip(*np.where((self.volume_data >= self.min_threshold) & 
                                                    (self.volume_data <= self.max_threshold))))

        # Define a function to process each row of the DataFrame
        def process_row(row):
            region_name, mask_path = row['region_name'], row['mask_path']
            # Load the mask data
            mask_volume = nib.load(mask_path).get_fdata()
            # Find voxels in the mask that are non-zero
            mask_active_voxels = set(zip(*np.where(mask_volume > 0)))

            # Find intersection of within-threshold voxels and mask-active voxels
            intersecting_voxels = within_threshold_voxels.intersection(mask_active_voxels)
            for voxel_coords in intersecting_voxels:
                # Save the voxel coordinates and the corresponding label (region name) in the results
                results['index'].append(voxel_coords)
                results['atlas_label'].append(region_name)

        # Apply the function to each row of the DataFrame
        self.name_mask_df.apply(process_row, axis=1)

        # Transform the collected results into a DataFrame
        results_df = pd.DataFrame(results)
        self.labeled_data = results_df
        self.voxel_counts = results_df['atlas_label'].value_counts().to_dict()
        self.unique_labels = results_df['atlas_label'].unique().tolist()
        return self

## Example usage

atlas_cort = r"atlases/HarvardOxford-cort-maxprob-thr0-2mm.nii.gz"

atlas_sub = r"atlases/HarvardOxford-sub-maxprob-thr0-2mm.nii.gz"

custom_atlas = r"atlases/joseph_custom_atlas.csv"

volume_to_label = r'volumes/30lechowiczglogowska2019.nii.gz'

voxel_counts_cort = AtlasLabeler(volume_to_label, atlas_cort).label_volume().voxel_counts

voxel_counts_cort_and_sub = MultiAtlasLabeler(volume_to_label, [atlas_cort, atlas_sub]).label_volume().voxel_counts

voxel_counts_custom = CustomLabeler(volume_to_label, custom_atlas).label_volume().voxel_counts

voxel_counts_cort, voxel_counts_cort_and_sub, voxel_counts_custom

Looking for atlas atlases/HarvardOxford-cort-maxprob-thr0-2mm.nii.gz...
Looking for nifti volumes/30lechowiczglogowska2019.nii.gz...
Looking for nifti volumes/30lechowiczglogowska2019.nii.gz...
Looking for nifti volumes/30lechowiczglogowska2019.nii.gz...
Looking for nifti volumes/30lechowiczglogowska2019.nii.gz...


({'Superior Frontal Gyrus': 108,
  'Middle Frontal Gyrus': 92,
  'Juxtapositional Lobule Cortex (formerly Supplementary Motor Cortex)': 79,
  'No matching region found': 34,
  'Cingulate Gyrus, anterior division': 6},
 {'Superior Frontal Gyrus': 108,
  'Middle Frontal Gyrus': 92,
  'Juxtapositional Lobule Cortex (formerly Supplementary Motor Cortex)': 79,
  'No matching region found': 34,
  'Cingulate Gyrus, anterior division': 6,
  'Left Cerebral Cortex': 181,
  'Left Cerebral White Matter': 138},
 {'frontal_lobe': 227, 'cerebral_cortex': 172, 'subcortex': 139})

In [52]:
csv_path = """all_takotsubo_mgh_all_symptoms_2d.csv"""
custom_atlas_path = """atlases/joseph_custom_atlas.csv"""

def label_csv(csv_path, atlas)

custom_atlas_df = pd.read_csv(custom_atlas_path)
df = pd.read_csv(csv_path).head(72)
df['labeling_results'] = df['orig_roi_vol'].apply(lambda x: CustomLabeler(x, custom_atlas_df).label_volume())
df['voxel_counts'] = df['labeling_results'].apply(lambda x: x.voxel_counts)
labels = df['labeling_results'].iloc[0].labels
for label in labels:
    df[label] = df['voxel_counts'].apply(lambda x: x.get(label, 0))
df.drop(columns=['voxel_counts', 'labeling_results'], inplace=True)
display(df)

Looking for nifti /data/nimlab/dl_archive/takutsuboLesions_GSP1000_V3/sub-01an2020/roi/sub-01an2020_space-original_lesionMask.nii.gz...
Looking for nifti /data/nimlab/dl_archive/takutsuboLesions_GSP1000_V3/sub-02androdias2017u1/roi/sub-02androdias2017u1_space-original_lesionMask.nii.gz...
Looking for nifti /data/nimlab/dl_archive/takutsuboLesions_GSP1000_V3/sub-03androdias2017u2/roi/sub-03androdias2017u2_space-original_lesionMask.nii.gz...
Looking for nifti /data/nimlab/dl_archive/takutsuboLesions_GSP1000_V3/sub-04androdias2017u3/roi/sub-04androdias2017u3_space-original_lesionMask.nii.gz...
Looking for nifti /data/nimlab/dl_archive/takutsuboLesions_GSP1000_V3/sub-05androdias2017u4/roi/sub-05androdias2017u4_space-original_lesionMask.nii.gz...
Looking for nifti /data/nimlab/dl_archive/takutsuboLesions_GSP1000_V3/sub-06androdias2017u5/roi/sub-06androdias2017u5_space-original_lesionMask.nii.gz...
Looking for nifti /data/nimlab/dl_archive/takutsuboLesions_GSP1000_V3/sub-07banuelos2008/roi/s

Unnamed: 0,dataset,subject,orig_roi_vol,roi_2mm,avgR,avgRFz,t,has_takotsubo,hippocampus_and_amygdala,cerebral_cortex,subcortex,basal_ganglia,brainstem,thalamus,cerebellum,frontal_lobe,parietal_lobe,insular_lobe,occipital_lobe,temporal_lobe
0,takutsuboLesions_GSP1000_V3,01an2020,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,0,130,0,144,0,21,0,0,0,0,0
1,takutsuboLesions_GSP1000_V3,02androdias2017u1,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,0,18,0,23,0,0,0,0,0,0,0
2,takutsuboLesions_GSP1000_V3,03androdias2017u2,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,0,27,0,25,0,0,0,0,0,0,0
3,takutsuboLesions_GSP1000_V3,04androdias2017u3,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,0,34,0,31,0,1,0,0,0,0,0
4,takutsuboLesions_GSP1000_V3,05androdias2017u4,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,0,43,0,44,0,4,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67,takutsuboLesions_GSP1000_V3,68papanikolaou2009,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,1080,598,0,0,0,0,0,0,0,1285,340
68,takutsuboLesions_GSP1000_V3,69shiromoto2012,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,0,200,0,0,0,200,0,0,0,0,0
69,takutsuboLesions_GSP1000_V3,70summers2012,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,775,598,0,0,0,0,0,0,0,1157,118
70,takutsuboLesions_GSP1000_V3,71tempaku2012,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,/data/nimlab/dl_archive/takutsuboLesions_GSP10...,1,0,0,397,0,53,0,386,0,0,0,0,0
