In [1]:
from IPhyton_import import NotebookFinder
import sys
sys.meta_path.append(NotebookFinder())

#### how to store the information which atlas to use(;
"""
Datagenerator for Voxelmorph
"""

import numpy as np
import os
import sys
from myTransforms import Pad, AdjustBrightness, LimitRange, MyNormalize
import myTransforms as mt
from torchvision.transforms import ToTensor, Resize, Compose
from scipy.ndimage.filters import median_filter



def load_example_by_name(current_vol, atlas_dir, pad_size): 
   
    """
    load a specific volume (used in test case)
    :param current_vol: string name of vol
    :param atlas_dir: location of atlas
    :param pad_size: size for padding the 3D volumes square
    
    """
    #-----------------load the correct atlas for the volume----------------------------------
    if current_vol[-9:-4] == 'L1-L3': 
        ## sometimes files are called nii.gz sometimes nii
        try:
            A = load_volfile(os.path.join(atlas_dir,'atlas_L1-L3.nii.gz'))
        except FileNotFoundError:
            A = load_volfile(os.path.join(atlas_dir,'atlas_L1-L3.nii'))
        
    elif current_vol[-9:-4] == 'L2-L4' :
        try:
            A = load_volfile(os.path.join(atlas_dir,'atlas_L2-L4.nii.gz'))
        except FileNotFoundError:
            A = load_volfile(os.path.join(atlas_dir,'atlas_L2-L4.nii'))
    elif(current_vol[-9:-4] == 'L3-L5'):
        try:
            A = load_volfile(os.path.join(atlas_dir,'atlas_L3-L5.nii.gz'))
        except FileNotFoundError:
            A = load_volfile(os.path.join(atlas_dir,'atlas_L3-L5.nii'))
    else: 
        print('wrong name:',current_vol[-9:-4] )
    
    
    #--------------apply median filter, nomalize and limit the range of the atlas to [-1,1]-------
    A = median_filter(A, size=(3,3,3))
    composed_atlas = Compose([LimitRange(), MyNormalize()])
    #composed_atlas = LimitRange()
    A = composed_atlas(A)
    A_mean = np.mean(A)
    A_std = np.std(A)
    
    #---------apply median filter, nomalize and limit the range of the current volume to [-1,1]----
    X = load_volfile(current_vol)
    X = median_filter(X, size=(3,3,3))
    composed = LimitRange()
    X = composed(X)
    X = mt.adjust_brightness(X,A_mean,A_std)
    
    #--------------pad the volumes and add axes----------------------------------------
    pad = Pad(pad_size)
    X = pad(X)
    A = pad(A)
    X = X[np.newaxis, ..., np.newaxis] 
    A = A[np.newaxis, ..., np.newaxis]
    
    
    return_vals = [X, A]

    #if(seg_name):
    #    X_seg = np.load(seg_name)['vol_data']
    #    X_seg = np.reshape(X_seg, (1,) + X_seg.shape + (1,))
    #    return_vals.append(X_seg)

    return tuple(return_vals)


def load_volfile(datafile):
    """
    load volume file
    formats: nii, nii.gz, mgz, npz
    if it's a npz (compressed numpy), assume variable names 'vol_data'
    """
    assert datafile.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % datafile

    if datafile.endswith(('.nii', '.nii.gz', '.mgz')):
        # import nibabel
        if 'nib' not in sys.modules:
            try:
                import nibabel as nib
            except:
                print('Failed to import nibabel. need nibabel library for these data file types.')

        X = nib.load(datafile)
        X = nib.as_closest_canonical(X)
        X = X.get_data()

    else:  # npz
        X = np.load(datafile)['vol_data']

    return X


def example_gen(vol_names, atlas_dir , pad_size, batch_size=1, return_segs=False, seg_dir=None):
    # here I give all my file names ... 
    """
    generate examples
    Parameters:
        vol_names: a list or tuple of filenames
        atlas_dir: location to folder with the atlases
        pad_size: int that refers to the wanted size of the volumes
        batch_size: the size of the batch (default: 1)
        
        These paramaters are from the original implementation, not implemented here :
        The following are fairly specific to our data structure, please change to your own
        return_segs: logical on whether to return segmentations
        seg_dir: the segmentations directory.
    """
    #--------------load all the atlases, apply median filter, normalize and limit range to [-1,1]-----
    try:
        A1 = load_volfile(os.path.join(atlas_dir,'atlas_L1-L3.nii.gz'))
        A1 = median_filter(A1, size=(3,3,3))
        A2 = load_volfile(os.path.join(atlas_dir,'atlas_L2-L4.nii.gz'))
        A2 = median_filter(A2, size=(3,3,3))
        A3 = load_volfile(os.path.join(atlas_dir,'atlas_L3-L5.nii.gz'))
        A3 = median_filter(A3, size=(3,3,3))
        composed_atlas = Compose([LimitRange(), MyNormalize()])
        # pretty put in one array
        #composed_atlas = LimitRange()
        A1 = composed_atlas(A1)

        A2 = composed_atlas(A2)
        A3 = composed_atlas(A3)
        A1_mean = np.mean(A1)
        A1_std = np.std(A1)
        A2_mean = np.mean(A2)
        A2_std = np.std(A2)
        A3_mean = np.mean(A3)
        A3_std = np.std(A3)
    #------------if nii.gz files are not found then we are using the healthy data which only has one atlas:L1-L3----
    except FileNotFoundError:
        A1 = load_volfile(os.path.join(atlas_dir,'atlas_L1-L3.nii'))
        A1 = median_filter(A1, size=(3,3,3))
        #A2 = load_volfile(os.path.join(atlas_dir,'atlas_L2-L4.nii.gz'))
        #A2 = median_filter(A2, size=(3,3,3))
        #A3 = load_volfile(os.path.join(atlas_dir,'atlas_L3-L5.nii.gz'))
       # A3 = median_filter(A3, size=(3,3,3))
        composed_atlas = Compose([LimitRange(), MyNormalize()])
        # pretty put in one array
        #composed_atlas = LimitRange()
        A1 = composed_atlas(A1)

        #A2 = composed_atlas(A2)
        #A3 = composed_atlas(A3)
        A1_mean = np.mean(A1)
        A1_std = np.std(A1)
        #A2_mean = np.mean(A2)
        #A2_std = np.std(A2)
        #A3_mean = np.mean(A3)
        #A3_std = np.std(A3)
        
    
    #-------for each of the batches, find appropriate atlas, limit range, normalize and stack the data-------
    while True:
        
        idxes = np.random.randint(len(vol_names), size=batch_size)
        
        X_data = []
        A_data = []
        for idx in idxes:
            current_vol = vol_names[idx]
            
            if current_vol[-9:-4] == 'L1-L3': 
                A = A1
                A_mean = A1_mean
                A_std = A1_std
            elif current_vol[-9:-4] == 'L2-L4' :
                A = A2
                A_mean = A2_mean
                A_std = A2_std
            elif(current_vol[-9:-4] == 'L3-L5'):
                A = A3
                A_mean = A3_mean
                A_std = A3_std
            else: 
                print('wrong name:',current_vol[-9:-4] )
            
            #composed = Compose[LimitRange(),AdjustBrightness(A_mean,A_std),Pad(64)]
            
            # load and apply transformations to current volume
            X = load_volfile(current_vol)
            X = median_filter(X, size=(3,3,3))
            composed = LimitRange()
            X = composed(X)
            X = mt.adjust_brightness(X,A_mean,A_std)
           
            #pad the data
            pad = Pad(pad_size)
            X = pad(X)
            A = pad(A)
            X = X[np.newaxis, ..., np.newaxis] 
            A = A[np.newaxis, ..., np.newaxis]
            
            ##stack the data
            X_data.append(X)
            A_data.append(A)
        

        if batch_size > 1:
            return_vals = [np.concatenate(X_data, 0), np.concatenate(A_data, 0) ] 
            #print(return_vals.shape)
            #print(return_vals)
        else:
            return_vals = [X_data[0], A_data[0]]

        # also return segmentations
        #if return_segs:
        #    X_data = []
        #    for idx in idxes:
        #        v = vol_names[idx].replace('norm', 'aseg')
        #        v = v.replace('vols', 'asegs')
        #        X_seg = load_volfile(v)
        #        X_seg = X_seg[np.newaxis, ..., np.newaxis]
        #        X_data.append(X_seg)

        #    if batch_size > 1:
        #        return_vals.append(np.concatenate(X_data, 0))
        #    else:
        #        return_vals.append(X_data[0])

        yield tuple(return_vals)

importing Jupyter notebook from myTransforms.ipynb
