In [1]:
## Boiler plate code common to many notebooks.  See the TestFilesCommonCode.ipynb for details
from __future__ import print_function
%run TestFilesCommonCode.ipynb

SimpleITK Version: 0.9.1
Compiled: Sep 28 2015 10:07:41



In [2]:
import os
import glob
import sys

#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/
#####################################################################################
#     Prepend the shell environment search paths
PROGRAM_PATHS = '/scratch/NAMICExternalProjects/release-20160523/bin'
#PROGRAM_PATHS = '/scratch/BS/release-BSR/bin'
PROGRAM_PATHS = PROGRAM_PATHS.split(':')
PROGRAM_PATHS.extend(os.environ['PATH'].split(':'))
os.environ['PATH'] = ':'.join(PROGRAM_PATHS)

CUSTOM_ENVIRONMENT=dict()

# Platform specific information
#     Prepend the python search paths
#PYTHON_AUX_PATHS = '/scratch/BS/BRAINSTools/AutoWorkup'
PYTHON_AUX_PATHS = '/scratch/SuperResolution/BRAINSSuperResolution/HCPWorkflows/:/scratch/wmql/tract_querier/tract_querier/nipype/'
PYTHON_AUX_PATHS = PYTHON_AUX_PATHS.split(':')
PYTHON_AUX_PATHS.extend(sys.path)
sys.path = PYTHON_AUX_PATHS

import SimpleITK as sitk
import nipype
from nipype.interfaces.base import CommandLine, CommandLineInputSpec, TraitedSpec, File, Directory
from nipype.interfaces.base import traits, isdefined, BaseInterface
from nipype.interfaces.utility import Merge, Split, Function, Rename, IdentityInterface
import nipype.interfaces.io as nio   # Data i/oS
import nipype.pipeline.engine as pe  # pypeline engine
from nipype.interfaces import ants
from nipype.interfaces.semtools import *

In [3]:
input_dwi='/raid0/homes/aghayoor/Desktop/invicro_eyeTrackingMR/ANON0FEFVB1Q0/dwi_nrrd/ANON0FEFVB1Q0_DWI.nrrd'
input_t1='/raid0/homes/aghayoor/Desktop/invicro_eyeTrackingMR/ANON0FEFVB1Q0/structure_aligned/ANON0FEFVB1Q0_t13D_aligned.mhd'
input_t2='/raid0/homes/aghayoor/Desktop/invicro_eyeTrackingMR/ANON0FEFVB1Q0/structure_aligned/ANON0FEFVB1Q0_t2_aligned.mhd'
brainstem_box_template_mask='/raid0/homes/aghayoor/Desktop/invicro_eyeTrackingMR/ANON0FEFVB1Q0/template_image/template_brainstem_box_labelmap.mhd'
atlasDefinition='/Shared/sinapse/CACHE/20160502_AliHCP_BAW_base_CACHE/Atlas/ExtendedAtlasDefinition.xml'

In [4]:
BASE_DIR='/raid0/homes/aghayoor/Desktop/invicro_eyeTrackingMR/ANON0FEFVB1Q0/workflow_output'
WFname = 'CorrectionWF'

In [8]:
###### UTILITY FUNCTIONS #######
#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/#
# 
def MakeResamplerInFileList(inputT2, inputLabelMap):
    imagesList = [inputT2, inputLabelMap]
    return imagesList

# Create registration mask for ANTs from resampled label map image
def CreateAntsRegistrationMask(brainMask):
    import os
    import SimpleITK as sitk
    assert os.path.exists(brainMask), "File not found: %s" % brainMask
    labelsMap = sitk.ReadImage(brainMask)
    label_mask = labelsMap>0
    # dilate the label mask
    dilateFilter = sitk.BinaryDilateImageFilter()
    dilateFilter.SetKernelRadius(6)
    dilated_mask = dilateFilter.Execute( label_mask )
    regMask = dilated_mask
    registrationMask = os.path.realpath('registrationMask.nrrd')
    sitk.WriteImage(regMask, registrationMask)
    return registrationMask

# This function helps to pick desirable output from the output list
def pickFromList(inlist,item):
    return inlist[item]

# Save direction cosine for the input volume
def SaveDirectionCosineToMatrix(inputVolume):
    import os
    import SimpleITK as sitk
    assert os.path.exists(inputVolume), "File not found: %s" % inputVolume
    t2 = sitk.ReadImage(inputVolume)
    directionCosine = t2.GetDirection()
    return directionCosine

# Force DC to ID
def ForceDCtoID(inputVolume):
    import os
    import SimpleITK as sitk
    inImage = sitk.ReadImage(inputVolume)
    inImage.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0))
    outputVolume = os.path.realpath('IDDC_'+ os.path.basename(inputVolume))
    sitk.WriteImage(inImage, outputVolume)
    return outputVolume

def RestoreDCFromSavedMatrix(inputVolume, inputDirectionCosine):
    import os
    import SimpleITK as sitk
    inImage = sitk.ReadImage(inputVolume)
    inImage.SetDirection(inputDirectionCosine)
    outputVolume = os.path.realpath('CorrectedDWI.nrrd')
    sitk.WriteImage(inImage, outputVolume)
    return outputVolume

def GetRigidTransformInverse(inputTransform):
    import os
    import SimpleITK as sitk
    inputTx = sitk.ReadTransform(inputTransform)
    versorRigidTx = sitk.VersorRigid3DTransform()
    versorRigidTx.SetFixedParameters(inputTx.GetFixedParameters())
    versorRigidTx.SetParameters(inputTx.GetParameters())
    invTx = versorRigidTx.GetInverse()
    inverseTransform = os.path.realpath('Inverse_'+ os.path.basename(inputTransform))
    sitk.WriteTransform(invTx, inverseTransform)
    return inverseTransform

def MakeForceDCFilesList(inputB0, inputT2, inputLabelMap):
    import os
    assert os.path.exists(inputB0), "File not found: %s" % inputB0
    assert os.path.exists(inputT2), "File not found: %s" % inputT2
    assert os.path.exists(inputLabelMap), "File not found: %s" % inputLabelMap
    imagesList = [inputB0, inputT2, inputLabelMap]
    return imagesList

def DownsampleStructralMR(inputMR):
    import os
    import SimpleITK as sitk
    assert os.path.exists(inputMR), "File not found: %s" % inputMR
    mrimage = sitk.ReadImage(inputMR)
    mrimage_lr = sitk.Shrink(mrimage,[2,2,2]) # downsampling
    mrimage_lr = sitk.DiscreteGaussian(mrimage_lr,0.98) # smoothing
    mrimage_lr_fn = os.path.realpath('structuralMR_lr.nrrd')
    sitk.WriteImage(mrimage_lr,mrimage_lr_fn)
    downsampledMR = [mrimage_lr_fn]
    return downsampledMR

def CreateBrainMaskFromLabels(inputLabelMap):
    import os
    import SimpleITK as sitk
    assert os.path.exists(inputLabelMap), "File not found: %s" % inputLabelMap
    brainLabels = sitk.ReadImage(inputLabelMap)
    brainMask = brainLabels > 0
    output_brainMask = os.path.realpath('structuralMR_brainMask.nrrd')
    sitk.WriteImage(brainMask,output_brainMask)
    return output_brainMask

# remove the skull from the structral MR volume
def ExtractBRAINFromHead(RawScan, BrainLabels):
    import os
    import SimpleITK as sitk
    # Remove skull from the head scan
    assert os.path.exists(RawScan), "File not found: %s" % RawScan
    assert os.path.exists(BrainLabels), "File not found: %s" % BrainLabels
    headImage = sitk.ReadImage(RawScan)
    labelsMap = sitk.ReadImage(BrainLabels)
    label_mask = labelsMap>0
    dilateFilter = sitk.BinaryDilateImageFilter()
    dilateFilter.SetKernelRadius(4)
    label_mask = dilateFilter.Execute( label_mask )
    brainImage = sitk.Cast(headImage,sitk.sitkInt16) * sitk.Cast(label_mask,sitk.sitkInt16)
    outputVolume = os.path.realpath('structral_MR_stripped.nrrd')
    sitk.WriteImage(brainImage, outputVolume)
    return outputVolume
#################################
#\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\#

CorrectionWF = pe.Workflow(name=WFname)
CorrectionWF.base_dir = BASE_DIR

inputsSpec = pe.Node(interface=IdentityInterface(fields=['inputT1',
                                                         'inputT2',
                                                         'inputDWI',
                                                         'inputBrainStemMask',
                                                         'atlasDefinition'
                                                        ]),
                         name='inputsSpec')
inputsSpec.inputs.inputT1 = input_t1 # T1 is used for creating brain mask
inputsSpec.inputs.inputT2 = input_t2 # T2 is used for artifact correction
inputsSpec.inputs.inputDWI = input_dwi
inputsSpec.inputs.inputBrainStemMask = brainstem_box_template_mask
inputsSpec.inputs.atlasDefinition = atlasDefinition

outputsSpec = pe.Node(interface=IdentityInterface(fields=['CorrectedDWI_in_StructralMRSpace',
                                                          'DWIBrainMask',
                                                          'DWIBrainstemMask',
                                                          'ukfTracks',
                                                          'tensor_image',
                                                          'DWI_B0',
                                                          'FAImage',
                                                          'MDImage',
                                                          'ADImage',
                                                          'RDImage',
                                                          'FrobeniusNormImage',
                                                          'Lambda2Image',
                                                          'Lambda3Image'
                                                         ]),
                      name='outputsSpec')

# First run BABC on downsampled structral MR (res ~1 mm^3)
DownsampleStructralMRNode = pe.Node(Function(function=DownsampleStructralMR,
                                                input_names=['inputMR'],
                                                output_names=['downsampledMR']),
                                       name="DownsampleStructralMRNode")
CorrectionWF.connect([(inputsSpec,DownsampleStructralMRNode,[('inputT1','inputMR')])])

# run BABC to create brainmask
BABCext = pe.Node(interface=BRAINSABC(), name="BABC")
CorrectionWF.connect(DownsampleStructralMRNode, 'downsampledMR', BABCext, 'inputVolumes')
BABCext.inputs.inputVolumeTypes = ['T1']
BABCext.inputs.outputVolumes = ['t1_corrected.nii.gz']
BABCext.inputs.debuglevel = 0
BABCext.inputs.useKNN = False
BABCext.inputs.maxIterations = 2
BABCext.inputs.maxBiasDegree = 0
BABCext.inputs.filterIteration = 3
BABCext.inputs.filterMethod = 'None'
BABCext.inputs.atlasToSubjectTransformType = 'SyN'
BABCext.inputs.outputFormat = "NIFTI"
BABCext.inputs.outputLabels = "brain_label_seg.nii.gz"
BABCext.inputs.outputDirtyLabels = "volume_label_seg.nii.gz"
BABCext.inputs.posteriorTemplate = "POSTERIOR_%s.nii.gz"
BABCext.inputs.atlasToSubjectTransform = "atlas_to_subject.h5"
BABCext.inputs.interpolationMode = 'Linear'
BABCext.inputs.outputDir = './'
BABCext.inputs.implicitOutputs = ['t1_average_BRAINSABC.nii.gz']
CorrectionWF.connect(inputsSpec, 'atlasDefinition', BABCext, 'atlasDefinition')

#
# CreateBrainMaskFromLabelsNode = pe.Node(Function(function=CreateBrainMaskFromLabels,
#                                                 input_names=['inputLabelMap'],
#                                                 output_names=['output_brainMask']),
#                                        name="CreateBrainMaskFromLabels")
# CorrectionWF.connect([(BABCext,CreateBrainMaskFromLabelsNode,[('outputLabels','inputLabelMap')])])

#### NOW CORRECTION STEPS #####
# remove the skull from the structral MR volume
ExtractBRAINFromHeadNode = pe.Node(interface=Function(function = ExtractBRAINFromHead,
                                                      input_names=['RawScan','BrainLabels'],
                                                      output_names=['outputVolume']),
                                   name="ExtractBRAINFromHead")
#CorrectionWF.connect(BABCext, 'outputT1AverageImage', ExtractBRAINFromHeadNode, 'RawScan')
CorrectionWF.connect(BABCext, 'implicitOutputs', ExtractBRAINFromHeadNode, 'RawScan')
CorrectionWF.connect(BABCext, 'outputLabels', ExtractBRAINFromHeadNode, 'BrainLabels')


# extract b0 from DWI
EXTRACT_B0 = pe.Node(interface=extractNrrdVectorIndex(),name="EXTRACT_B0")
EXTRACT_B0.inputs.vectorIndex = 0
EXTRACT_B0.inputs.outputVolume = 'B0_Image.nrrd'
CorrectionWF.connect(inputsSpec,'inputDWI',EXTRACT_B0,'inputVolume')

# Step1: Register T2 to B0 space using BRAINSFit
BFit_T2toB0 = pe.Node(interface=BRAINSFit(), name="BFit_T2toB0")
BFit_T2toB0.inputs.costMetric = "MMI"
BFit_T2toB0.inputs.numberOfSamples = 100000
BFit_T2toB0.inputs.numberOfIterations = [1500]
BFit_T2toB0.inputs.numberOfHistogramBins = 50
BFit_T2toB0.inputs.maximumStepLength = 0.2
BFit_T2toB0.inputs.minimumStepLength = [0.00005]
BFit_T2toB0.inputs.useRigid = True
BFit_T2toB0.inputs.useAffine = True
BFit_T2toB0.inputs.maskInferiorCutOffFromCenter = 65
BFit_T2toB0.inputs.maskProcessingMode = "ROIAUTO"
BFit_T2toB0.inputs.ROIAutoDilateSize = 13
BFit_T2toB0.inputs.backgroundFillValue = 0.0
BFit_T2toB0.inputs.initializeTransformMode = 'useCenterOfHeadAlign'
BFit_T2toB0.inputs.strippedOutputTransform = "T2ToB0_RigidTransform.h5"
BFit_T2toB0.inputs.writeOutputTransformInFloat = True
CorrectionWF.connect(EXTRACT_B0, 'outputVolume', BFit_T2toB0, 'fixedVolume') # B0 image
#CorrectionWF.connect(ExtractBRAINFromHeadNode, 'outputVolume', BFit_T2toB0, 'movingVolume')
#CorrectionWF.connect(BABCext, 'implicitOutputs', BFit_T2toB0, 'movingVolume') # structral MR
CorrectionWF.connect(inputsSpec, 'inputT2', BFit_T2toB0, 'movingVolume') # structral MR

# Step2: Use T_rigid to "resample" T2 and label map images to B0 image space
MakeResamplerInFilesListNode = pe.Node(Function(function=MakeResamplerInFileList,
                                                input_names=['inputT2','inputLabelMap'],
                                                output_names=['imagesList']),
                                       name="MakeResamplerInFilesListNode")
#CorrectionWF.connect(ExtractBRAINFromHeadNode, 'outputVolume', MakeResamplerInFilesListNode, 'inputT2')
CorrectionWF.connect(inputsSpec, 'inputT2', MakeResamplerInFilesListNode, 'inputT2')
#CorrectionWF.connect(CreateBrainMaskFromLabelsNode, 'output_brainMask', MakeResamplerInFilesListNode, 'inputLabelMap')
CorrectionWF.connect(BABCext, 'outputLabels', MakeResamplerInFilesListNode, 'inputLabelMap')

#
ResampleToB0Space = pe.MapNode(interface=BRAINSResample(), name="ResampleToB0Space",
                               iterfield=['inputVolume', 'pixelType', 'outputVolume'])
ResampleToB0Space.inputs.interpolationMode = 'Linear'
ResampleToB0Space.inputs.outputVolume = ['T2toB0.nrrd','T2MaskToB0.nrrd']
ResampleToB0Space.inputs.pixelType = ['ushort','binary']
CorrectionWF.connect(BFit_T2toB0,'strippedOutputTransform',ResampleToB0Space,'warpTransform')
CorrectionWF.connect(EXTRACT_B0,'outputVolume',ResampleToB0Space,'referenceVolume')
CorrectionWF.connect(MakeResamplerInFilesListNode,'imagesList',ResampleToB0Space,'inputVolume')

# Step3: Create registration mask from resampled label map image
CreateRegistrationMask = pe.Node(interface=Function(function = CreateAntsRegistrationMask,
                                                    input_names=['brainMask'],
                                                    output_names=['registrationMask']),
                                 name="CreateAntsRegistrationMask")
CorrectionWF.connect(ResampleToB0Space, ('outputVolume', pickFromList, 1),
                    CreateRegistrationMask, 'brainMask')

# Step4: Save direction cosine for the resampled T2 image
SaveDirectionCosineToMatrixNode = pe.Node(interface=Function(function = SaveDirectionCosineToMatrix,
                                                             input_names=['inputVolume'],
                                                             output_names=['directionCosine']),
                                          name="SaveDirectionCosineToMatrix")
CorrectionWF.connect(ResampleToB0Space, ('outputVolume', pickFromList, 0),
                     SaveDirectionCosineToMatrixNode, 'inputVolume')

# Step5: Force DC to ID
MakeForceDCFilesListNode = pe.Node(Function(function=MakeForceDCFilesList,
                                            input_names=['inputB0','inputT2','inputLabelMap'],
                                            output_names=['imagesList']),
                                   name="MakeForceDCFilesListNode")
CorrectionWF.connect([(EXTRACT_B0,MakeForceDCFilesListNode,[('outputVolume','inputB0')]),
                      (ResampleToB0Space,MakeForceDCFilesListNode,[(('outputVolume', pickFromList, 0),'inputT2')]),
                      (CreateRegistrationMask,MakeForceDCFilesListNode,[('registrationMask','inputLabelMap')])])

ForceDCtoIDNode = pe.MapNode(interface=Function(function = ForceDCtoID,
                                                input_names=['inputVolume'],
                                                output_names=['outputVolume']),
                             name="ForceDCtoID",
                             iterfield=['inputVolume'])
CorrectionWF.connect(MakeForceDCFilesListNode, 'imagesList', ForceDCtoIDNode, 'inputVolume')

# Step6: Run antsRegistration in one direction
antsReg_B0ToTransformedT2 = pe.Node(interface=ants.Registration(), name="antsReg_B0ToTransformedT2")
antsReg_B0ToTransformedT2.inputs.interpolation = 'Linear'
antsReg_B0ToTransformedT2.inputs.dimension = 3
antsReg_B0ToTransformedT2.inputs.transforms = ["SyN"]
antsReg_B0ToTransformedT2.inputs.transform_parameters = [(0.25, 3.0, 0.0)]
antsReg_B0ToTransformedT2.inputs.metric = ['MI']
antsReg_B0ToTransformedT2.inputs.sampling_strategy = [None]
antsReg_B0ToTransformedT2.inputs.sampling_percentage = [1.0]
antsReg_B0ToTransformedT2.inputs.metric_weight = [1.0]
antsReg_B0ToTransformedT2.inputs.radius_or_number_of_bins = [32]
antsReg_B0ToTransformedT2.inputs.number_of_iterations = [[70, 50, 40]]
antsReg_B0ToTransformedT2.inputs.convergence_threshold = [1e-6]
antsReg_B0ToTransformedT2.inputs.convergence_window_size = [10]
antsReg_B0ToTransformedT2.inputs.use_histogram_matching = [True]
antsReg_B0ToTransformedT2.inputs.shrink_factors = [[3, 2, 1]]
antsReg_B0ToTransformedT2.inputs.smoothing_sigmas = [[2, 1, 0]]
antsReg_B0ToTransformedT2.inputs.sigma_units = ["vox"]
antsReg_B0ToTransformedT2.inputs.use_estimate_learning_rate_once = [False]
antsReg_B0ToTransformedT2.inputs.write_composite_transform = True
antsReg_B0ToTransformedT2.inputs.collapse_output_transforms = False
antsReg_B0ToTransformedT2.inputs.initialize_transforms_per_stage = False
antsReg_B0ToTransformedT2.inputs.output_transform_prefix = 'Tsyn'
antsReg_B0ToTransformedT2.inputs.winsorize_lower_quantile = 0.01
antsReg_B0ToTransformedT2.inputs.winsorize_upper_quantile = 0.99
antsReg_B0ToTransformedT2.inputs.float = True
antsReg_B0ToTransformedT2.inputs.num_threads = -1
antsReg_B0ToTransformedT2.inputs.args = '--restrict-deformation 0x1x0'
CorrectionWF.connect(ForceDCtoIDNode, ('outputVolume', pickFromList, 1), antsReg_B0ToTransformedT2, 'fixed_image') # structral MR
CorrectionWF.connect(ForceDCtoIDNode, ('outputVolume', pickFromList, 2), antsReg_B0ToTransformedT2, 'fixed_image_mask')
CorrectionWF.connect(ForceDCtoIDNode, ('outputVolume', pickFromList, 0), antsReg_B0ToTransformedT2, 'moving_image') # DWI B0

# Step7: Now, all necessary transforms are acquired. It's a time to
#        transform input DWI image into T2 image space
# {DWI} --> ForceDCtoID --> gtractResampleDWIInPlace(using SyN transfrom)
# --> Restore DirectionCosine From Saved Matrix --> gtractResampleDWIInPlace(inverse of T_rigid from BFit)
# --> {CorrectedDW_in_StructralMRSpace}
DWI_ForceDCtoIDNode = pe.Node(interface=Function(function = ForceDCtoID,
                                                 input_names=['inputVolume'],
                                                 output_names=['outputVolume']),
                              name='DWI_ForceDCtoIDNode')
CorrectionWF.connect(inputsSpec,'inputDWI',DWI_ForceDCtoIDNode,'inputVolume')

##
gtractResampleDWI_SyN = pe.Node(interface=gtractResampleDWIInPlace(),
                                name="gtractResampleDWI_SyN")
CorrectionWF.connect(DWI_ForceDCtoIDNode,'outputVolume',
                     gtractResampleDWI_SyN,'inputVolume')
CorrectionWF.connect(antsReg_B0ToTransformedT2,'composite_transform',
                     gtractResampleDWI_SyN,'warpDWITransform')
CorrectionWF.connect(ForceDCtoIDNode,('outputVolume', pickFromList, 1),
                     gtractResampleDWI_SyN,'referenceVolume') # fixed image of antsRegistration
gtractResampleDWI_SyN.inputs.outputVolume = 'IDDC_correctedDWI.nrrd'

##
RestoreDCFromSavedMatrixNode = pe.Node(interface=Function(function = RestoreDCFromSavedMatrix,
                                                          input_names=['inputVolume','inputDirectionCosine'],
                                                          output_names=['outputVolume']),
                                       name='RestoreDCFromSavedMatrix')
CorrectionWF.connect(gtractResampleDWI_SyN,'outputVolume',RestoreDCFromSavedMatrixNode,'inputVolume')
CorrectionWF.connect(SaveDirectionCosineToMatrixNode,'directionCosine',RestoreDCFromSavedMatrixNode,'inputDirectionCosine')
#CorrectionWF.connect(RestoreDCFromSavedMatrixNode,'outputVolume', outputsSpec, 'CorrectedDWI')

##
GetRigidTransformInverseNode = pe.Node(interface=Function(function = GetRigidTransformInverse,
                                                          input_names=['inputTransform'],
                                                          output_names=['inverseTransform']),
                                       name='GetRigidTransformInverse')
CorrectionWF.connect(BFit_T2toB0,'strippedOutputTransform',GetRigidTransformInverseNode,'inputTransform')

##
gtractResampleDWIInPlace_Trigid = pe.Node(interface=gtractResampleDWIInPlace(),
                                          name="gtractResampleDWIInPlace_Trigid")
CorrectionWF.connect(RestoreDCFromSavedMatrixNode,'outputVolume',
                     gtractResampleDWIInPlace_Trigid,'inputVolume')
CorrectionWF.connect(GetRigidTransformInverseNode,'inverseTransform',
                     gtractResampleDWIInPlace_Trigid,'inputTransform') #Inverse of rigid transform from BFit
gtractResampleDWIInPlace_Trigid.inputs.outputVolume = 'CorrectedDWI_in_StructralMRSpace_estimate.nrrd'
gtractResampleDWIInPlace_Trigid.inputs.outputResampledB0 = 'CorrectedDWI_in_StructralMRSpace_estimate_B0.nrrd'

# Setp8: An extra registration step to tune the alignment between the CorrecetedDWI_in_StructralMRSpace image and T2 image.
BFit_TuneRegistration = pe.Node(interface=BRAINSFit(), name="BFit_TuneRegistration")
BFit_TuneRegistration.inputs.costMetric = "MMI"
BFit_TuneRegistration.inputs.numberOfSamples = 100000
BFit_TuneRegistration.inputs.numberOfIterations = [1500]
BFit_TuneRegistration.inputs.numberOfHistogramBins = 50
BFit_TuneRegistration.inputs.maximumStepLength = 0.2
BFit_TuneRegistration.inputs.minimumStepLength = [0.00005]
BFit_TuneRegistration.inputs.useRigid = True
BFit_TuneRegistration.inputs.useAffine = True
BFit_TuneRegistration.inputs.maskInferiorCutOffFromCenter = 65
BFit_TuneRegistration.inputs.maskProcessingMode = "ROIAUTO"
BFit_TuneRegistration.inputs.ROIAutoDilateSize = 13
BFit_TuneRegistration.inputs.backgroundFillValue = 0.0
BFit_TuneRegistration.inputs.initializeTransformMode = 'useCenterOfHeadAlign'
BFit_TuneRegistration.inputs.strippedOutputTransform = "CorrectedB0inStructralMRSpace_to_T2_RigidTransform.h5"
BFit_TuneRegistration.inputs.writeOutputTransformInFloat = True
CorrectionWF.connect(BABCext, 'implicitOutputs', BFit_TuneRegistration, 'fixedVolume') #T2 brain volume
CorrectionWF.connect(gtractResampleDWIInPlace_Trigid, 'outputResampledB0', BFit_TuneRegistration, 'movingVolume') # CorrectedB0_in_StructralMRSpace

##
gtractResampleDWIInPlace_TuneRigidTx = pe.Node(interface=gtractResampleDWIInPlace(),
                                               name="gtractResampleDWIInPlace_TuneRigidTx")
CorrectionWF.connect(gtractResampleDWIInPlace_Trigid,'outputVolume',gtractResampleDWIInPlace_TuneRigidTx,'inputVolume')
CorrectionWF.connect(BFit_TuneRegistration,'strippedOutputTransform',gtractResampleDWIInPlace_TuneRigidTx,'inputTransform')
gtractResampleDWIInPlace_TuneRigidTx.inputs.outputVolume = 'CorrectedDWI_in_StructralMRSpace.nrrd'
gtractResampleDWIInPlace_TuneRigidTx.inputs.outputResampledB0 = 'CorrectedDWI_in_StructralMRSpace_B0.nrrd'

# Finally we pass the outputs of the gtractResampleDWIInPlace_TuneRigidTx to the outputsSpec
CorrectionWF.connect(gtractResampleDWIInPlace_TuneRigidTx, 'outputVolume', outputsSpec, 'CorrectedDWI_in_StructralMRSpace')

# Step9: Create brain mask from the input labelmap
DWIBRAINMASK = pe.Node(interface=BRAINSResample(), name='DWIBRAINMASK')
DWIBRAINMASK.inputs.interpolationMode = 'Linear'
DWIBRAINMASK.inputs.outputVolume = 'BrainMaskForDWI.nrrd'
DWIBRAINMASK.inputs.pixelType = 'binary'
CorrectionWF.connect(gtractResampleDWIInPlace_TuneRigidTx,'outputResampledB0',DWIBRAINMASK,'referenceVolume')
CorrectionWF.connect(BABCext,'outputLabels',DWIBRAINMASK,'inputVolume')
CorrectionWF.connect(DWIBRAINMASK, 'outputVolume', outputsSpec, 'DWIBrainMask')


##### RUN UKF Processing #######
# make tractography brainstem mask
DWIBrainstemMASK = pe.Node(interface=BRAINSResample(), name='DWIBrainstemMASK')
DWIBrainstemMASK.inputs.interpolationMode = 'Linear'
DWIBrainstemMASK.inputs.outputVolume = 'BrainstemMaskForDWI.nrrd'
DWIBrainstemMASK.inputs.pixelType = 'binary'
CorrectionWF.connect(gtractResampleDWIInPlace_TuneRigidTx,'outputResampledB0',DWIBrainstemMASK,'referenceVolume')
CorrectionWF.connect(inputsSpec,'inputBrainStemMask',DWIBrainstemMASK,'inputVolume')
CorrectionWF.connect(DWIBrainstemMASK, 'outputVolume', outputsSpec, 'DWIBrainstemMask')

#
UKFNode = pe.Node(interface=UKFTractography(), name= "UKFRunRecordStates")
UKFNode.inputs.tracts = "ukfTracts.vtp"
UKFNode.inputs.numTensor = '2'
UKFNode.inputs.freeWater = True ## default False
UKFNode.inputs.minFA = 0.06
UKFNode.inputs.minGA = 0.06
UKFNode.inputs.seedFALimit = 0.06
UKFNode.inputs.Ql = 70
UKFNode.inputs.recordLength = 1
UKFNode.inputs.recordTensors = True
UKFNode.inputs.recordFreeWater = True
UKFNode.inputs.recordFA = True
UKFNode.inputs.recordTrace = True
UKFNode.inputs.seedsPerVoxel = 2
CorrectionWF.connect(gtractResampleDWIInPlace_TuneRigidTx, 'outputVolume', UKFNode, 'dwiFile')
#CorrectionWF.connect(DWIBRAINMASK, 'outputVolume', UKFNode, 'maskFile') # whole brain mask
CorrectionWF.connect(DWIBrainstemMASK, 'outputVolume', UKFNode, 'maskFile') # brainstem mask
CorrectionWF.connect(UKFNode,'tracts',outputsSpec,'ukfTracks')

### Calculate DTI and RISs #####
# 1: DTI estimation
DTIEstim = pe.Node(interface=dtiestim(), name="DTIEstim")
DTIEstim.inputs.method = 'wls'
DTIEstim.inputs.threshold = 0
DTIEstim.inputs.correction = 'nearest'
DTIEstim.inputs.B0 = 'DWI_B0.nrrd'
DTIEstim.inputs.tensor_output = 'DTI_Output.nrrd'
CorrectionWF.connect(gtractResampleDWIInPlace_TuneRigidTx, 'outputVolume', DTIEstim, 'dwi_image')
CorrectionWF.connect(DWIBRAINMASK, 'outputVolume', DTIEstim, 'brain_mask')
CorrectionWF.connect(DTIEstim, 'tensor_output', outputsSpec, 'tensor_image')
CorrectionWF.connect(DTIEstim, 'B0', outputsSpec, 'DWI_B0')

# 2: Calculate RISs
DTIProcess = pe.Node(interface=dtiprocess(), name='DTIProcess')
DTIProcess.inputs.fa_output = 'FA.nrrd'
DTIProcess.inputs.md_output = 'MD.nrrd'
DTIProcess.inputs.RD_output = 'RD.nrrd'
DTIProcess.inputs.frobenius_norm_output = 'frobenius_norm_output.nrrd'
DTIProcess.inputs.lambda1_output = 'lambda1_output.nrrd'
DTIProcess.inputs.lambda2_output = 'lambda2_output.nrrd'
DTIProcess.inputs.lambda3_output = 'lambda3_output.nrrd'
DTIProcess.inputs.scalar_float = True
DTIProcess.inputs.correction = 'nearest'
CorrectionWF.connect(DTIEstim, 'tensor_output', DTIProcess, 'dti_image')
CorrectionWF.connect(DTIProcess, 'fa_output', outputsSpec, 'FAImage')
CorrectionWF.connect(DTIProcess, 'md_output', outputsSpec, 'MDImage')
CorrectionWF.connect(DTIProcess, 'RD_output', outputsSpec, 'RDImage')
CorrectionWF.connect(DTIProcess, 'frobenius_norm_output', outputsSpec, 'FrobeniusNormImage')
CorrectionWF.connect(DTIProcess, 'lambda1_output', outputsSpec, 'ADImage')
CorrectionWF.connect(DTIProcess, 'lambda2_output', outputsSpec, 'Lambda2Image')
CorrectionWF.connect(DTIProcess, 'lambda3_output', outputsSpec, 'Lambda3Image')


#########
DWIDataSink = pe.Node(interface=nio.DataSink(), name='DWIDataSink')
DWIDataSink.overwrite = True
DWIDataSink.inputs.base_directory = BASE_DIR
CorrectionWF.connect(outputsSpec, 'CorrectedDWI_in_StructralMRSpace', DWIDataSink, 'Outputs.@CorrectedDWI_in_StructralMRSpace')
CorrectionWF.connect(outputsSpec, 'DWIBrainMask', DWIDataSink, 'Outputs.@DWIBrainMask')
CorrectionWF.connect(outputsSpec, 'DWIBrainstemMask', DWIDataSink, 'Outputs.@DWIBrainstemMask')
CorrectionWF.connect(outputsSpec, 'ukfTracks', DWIDataSink, 'Outputs.@ukfTracks')
CorrectionWF.connect(outputsSpec, 'tensor_image', DWIDataSink, 'Outputs.@tensor_image')
CorrectionWF.connect(outputsSpec, 'DWI_B0', DWIDataSink, 'Outputs.@DWI_B0')
CorrectionWF.connect(outputsSpec, 'FAImage', DWIDataSink, 'Outputs.@FAImage')
CorrectionWF.connect(outputsSpec, 'MDImage', DWIDataSink, 'Outputs.@MDImage')
CorrectionWF.connect(outputsSpec, 'RDImage', DWIDataSink, 'Outputs.@RDImage')
CorrectionWF.connect(outputsSpec, 'ADImage', DWIDataSink, 'Outputs.@ADImage')
CorrectionWF.connect(outputsSpec, 'FrobeniusNormImage', DWIDataSink, 'Outputs.@FrobeniusNormImage')
CorrectionWF.connect(outputsSpec, 'Lambda2Image', DWIDataSink, 'Outputs.@Lambda2Image')
CorrectionWF.connect(outputsSpec, 'Lambda3Image', DWIDataSink, 'Outputs.@Lambda3Image')

CorrectionWF.write_graph()
CorrectionWF.run()

<networkx.classes.digraph.DiGraph at 0x124d54e90>