# OpenNeuro Data Loader
A data loader for open neuro MRI datasets https://openneuro.org/

Getting usable data from open neuro was more difficult than it should be. I aim to create a 3 part system to expedite this process.

The architecture is as follows:
1. Given a dataset ID (ds#######) download the dataset to a specified folder and extract it using datalad
1. A 'patient' class to hold data relevant to model training as well as data related to the patient
1. A dataset class that has various dataset-related methods (preprocessing, train-val-test splits or stratified k-fold cross validation, ect)

## Todos
1. Using datalad and git, download dataset
1. Figure out memory measuring tool
1. Load batch of n scans based on available memory
1. Create generator of m batches of n scans which load on demand

## Install Packages

In [2]:
!pip install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting nipy
  Downloading nipy-0.6.1-cp39-cp39-win_amd64.whl (3.0 MB)
     ---------------------------------------- 3.0/3.0 MB 8.4 MB/s eta 0:00:00
Collecting nibabel
  Downloading nibabel-5.3.3-py3-none-any.whl (3.3 MB)
     ---------------------------------------- 3.3/3.3 MB 69.9 MB/s eta 0:00:00
Collecting nilearn
  Downloading nilearn-0.12.1-py3-none-any.whl (12.7 MB)
     --------------------------------------- 12.7/12.7 MB 50.4 MB/s eta 0:00:00
Collecting scipy
  Downloading scipy-1.13.1-cp39-cp39-win_amd64.whl (46.2 MB)
     --------------------------------------- 46.2/46.2 MB 65.6 MB/s eta 0:00:00
Collecting matplotlib
  Downloading matplotlib-3.9.4-cp39-cp39-win_amd64.whl (7.8 MB)
     ---------------------------------------- 7.8/7.8 MB 50.1 MB/s eta 0:00:00
Collecting transforms3d
  Downloading transforms3d-0.4.2-py3-none-any.whl (1.4 MB)
     ---------------------------------------- 1.4/1.4 MB 91.1 

You should consider upgrading via the 'D:\Side_Projects\MRI_Project\env_mri\Scripts\python.exe -m pip install --upgrade pip' command.


In [2]:
import time
import nibabel as nib
import numpy as np
import os
import json
import random
import SimpleITK as sitk
import psutil


In [7]:
class patient:
    '''
    Struct for holding patient information and scan data
    '''
    def __init__(self,path):
        self.info = {} #data-metadata pairs using pre-extension name
        self.folder_path = path
        self.date_loaded = time.time()
        self.parse_and_assign_filenames(self.folder_path)
        
    def __str__(self):
        return f'{len(self.info.keys())} scans from {self.folder_path}'
        
    def parse_and_assign_filenames(self,path):
        patient_scans=[]
        for root,dirs,files in os.walk(path):
            # compressed_files = [file for file in files if file.split('.')[-2] == 'nii' and file.split('.')[-1] == 'gz']
            compressed_files = [file for file in files if file.split('.')[-1] == 'gz']
            for file in compressed_files:
                self.info[file.split('.')[0]] = {
                    'scan':os.path.join(root,file),
                    'metadata':os.path.join(root,file.split('.')[0]+'.json') if os.path.exists(os.path.join(root,file.split('.')[0]+'.json')) else None,
                }
        
    def load(self):
        #return 4D set of values [(H,W,Scans(Depth),N),metadata]
        def load_json(path):
            with open(path) as f:
                out = json.load(f)
            return out
        def load_scan(path):
            # replace with datalad
            img = nib.load(path)
            data = np.asarray(img.dataobj)
            return sitk.GetImageFromArray(data)
        return {
            'data':[load_scan(v['scan']) for k,v in self.info.items()], 
            'metadata':[[k,load_json(v['metadata'])] for k,v in self.info.items()]
            }
    def unload(self):
        #use datalad to unload scan
        pass
class patient_dataset:
    '''
    Responsible for organizing and grouping scans + metadata per patient
    Passes path to patient class 
    Also responsible for image preprocessing methods
    '''
    def __init__(self,path,standard_size=(256,256,200)):
        #where path is the path to the dataset (should end in ds007045 or similar)
        self.path = path
        self.standard_size = standard_size
        self.patients = []
        for folder in os.listdir(self.path):
            if self._is_folder(folder) == False:
                continue
            p = patient(os.path.join(self.path,folder))
            if len(p.info) != 0: #filter non-patient folders
                self.patients.append(p)
        print('length patients', len(self.patients))
        self.length = len(self.patients)
    
    def _is_folder(self,folder):
        is_folder = True
        if 'sub' not in folder.split('-'): #temp fix for picking up non-patient folders
            is_folder = False
        if os.path.isdir(os.path.join(self.path,folder)) == False:
            is_folder = False
        return is_folder
    
    def __iter__(self):
        """
        Stream samples one-by-one without holding everything in memory.
        """
        for file_id in range(self.length):
            yield self.get(file_id)
    
    def __getitem__(self, file_id):
        if isinstance(file_id, slice):
            start, stop, step = file_id.indices(self.length)
            return [self.get(i) for i in range(start, stop, step)]
        elif isinstance(file_id, list):
            return [self.get(i) for i in file_id]
        elif isinstance(file_id, int):
            if file_id < 0 or file_id >= self.length:
                raise IndexError("patient index out of range")
            return self.get(file_id)
        else:
            raise TypeError("Indices must be integers, slices, or a list")
    
    def get(self,file_id):
        return self.patients[file_id].load()
    
    def sample(self):
        #get one random patient obj and call get method
        random_idx = random.randint(0,self.length)
        return self.get(random_idx)
    
    def resample_to_shape(
        self,
        images, #list of sitk images
        out_size,
        interpolator=sitk.sitkLinear
    ):
        resampled_images = []
        for img in images:
            original_size = img.GetSize()
            original_spacing = [1.0,1.0,1.0] #change to grabbing this from metadata
            # original_spacing = self. #change to grabbing this from metadata
        
            new_spacing = [
                (original_size[i] * original_spacing[i]) / out_size[i]
                for i in range(3)
            ]
            
            resampler = sitk.ResampleImageFilter()
            
            resampler.SetSize(out_size)
            resampler.SetOutputSpacing(new_spacing)
            resampler.SetInterpolator(interpolator)
            resampler.SetOutputDirection(img.GetDirection())
            resampler.SetOutputOrigin(img.GetOrigin())
            resampled_images.append(resampler.Execute(img))
        return resampled_images
    
    def preprocess(self,idx,count):
        #standardize size
        scan_sets = self.patients[idx:idx+count]
        patient_scan_sets = [p['data'] for p in scan_sets]
        resized_patient_scans = [self.resample_to_shape(patient_scans,self.standard_size) for patient_scans in patient_scan_sets]
    
    def generate_folds(self,k=10):
        #Create an array from 0 to self.length, shuffle, and make k-1 even cuts 
        assignments = [i for i in range(self.length)]
        random.shuffle(assignments)
        fold_size = self.length//k #last fold will have extra items from excluded by rounding
        self.folds = {}
        for foldnum in range(k-2):
            self.folds[foldnum] = assignments[fold_size*foldnum:fold_size*(foldnum+1)]
        self.folds[k-1] = assignments[fold_size*(foldnum+1):]

    def get_fold(self,fold_num):
        assert len(self.folds.keys()) > 0
        return self.__getitem__(self.folds[fold_num])#what if this ALSO returned a generator??
# dataset = patient_dataset('ds007045')
dataset = patient_dataset('ds007156')


length patients 14


In [8]:
start = time.time()
dataset.generate_folds(15)
fold = dataset.get_fold(2)
end = time.time()
(end-start)/60, "Minutes for ",len(fold)," Scans"

(0.0, 'Minutes for ', 0, ' Scans')

In [52]:
len(fold)

6

In [34]:
dataset.get(10)['data']


[<SimpleITK.SimpleITK.Image; proxy of <Swig Object of type 'itk::simple::Image *' at 0x000001E2A601B990> >,
 <SimpleITK.SimpleITK.Image; proxy of <Swig Object of type 'itk::simple::Image *' at 0x000001E2A601B300> >,
 <SimpleITK.SimpleITK.Image; proxy of <Swig Object of type 'itk::simple::Image *' at 0x000001E2A601BF60> >,
 <SimpleITK.SimpleITK.Image; proxy of <Swig Object of type 'itk::simple::Image *' at 0x000001E2A601B960> >]

In [30]:
start = time.time()
print(len([d['data'] for d in dataset[0:100]]))
end = time.time()
(end-start)/60,'Minutes for ',dataset.length,' images' #2.7min for scans and metadata

100


(2.7974769433339435, 'Minutes for ', 337, ' images')