In [None]:
%%bash
# Go to folder containing the source code
cd /app/mialsuperresolutiontoolkit/
# Install the pymialsrtk package inside the python/conda environment
python setup.py install --force

In [None]:
import os
import json
import shutil

# Imports from nipype
from nipype.interfaces.io import BIDSDataGrabber,DataGrabber, DataSink
from nipype.pipeline import Node, Workflow

# Import the implemented interface from pymialsrtk
import pymialsrtk.interfaces.preprocess as preprocess
import pymialsrtk.interfaces.reconstruction as reconstruction

import pymialsrtk.interfaces.postprocess as postprocess

# Cpoy result files
from shutil import copyfile
import glob

from nipype import config, logging

In [None]:
# Set different variables (defined in cell 2) such that we do not have to rerun cell 2
bids_dir = os.path.join('/fetaldata')

subject = 'sub-01'
session = None
stacksOrder = [1,3,5,2,4,6]
    

In [None]:

import os

def run(command, env={}, cwd=os.getcwd()):
    import subprocess
    merged_env = os.environ
    merged_env.update(env)
    process = subprocess.run(command, shell=True,
                             env=merged_env, cwd=cwd, capture_output=True)
    return process


In [None]:
## Node linkage
def create_workflow(bids_dir, process_dir, subject, p_stacksOrder, session=None, deltatTV = 0.01, lambdaTV = 0.75, primal_dual_loops=10):
#     wf_base_dir = os.path.join("{}".format(output_dir),"superres-mri","sub-{}".format(subject),"nipype")
    
    if session is None:
        wf_base_dir = os.path.join(process_dir, subject)
        process_dir = os.path.join(process_dir, subject)
    else:
        wf_base_dir = os.path.join(process_dir, subject, session)
        process_dir = os.path.join(process_dir, subject, session)

    if not os.path.exists(process_dir):
        os.makedirs(process_dir)
    print("Process directory: {}".format(wf_base_dir))

    wf = Workflow(name="srr_nipype",base_dir=wf_base_dir)
    srr_nipype_dir = os.path.join(wf.base_dir, wf.name )
    
    
    # Initialization
    if os.path.isfile(os.path.join(process_dir,"pypeline_"+subject+".log")):
        os.unlink(os.path.join(process_dir,"pypeline_"+subject+".log"))
#         open(os.path.join(process_dir,"pypeline.log"), 'a').close()
        

    config.update_config({'logging': {'log_directory': os.path.join(process_dir), 'log_to_file': True},
                          'execution': {
                              'remove_unnecessary_outputs': False,
                              'stop_on_first_crash': True,
                              'stop_on_first_rerun': False,
                              'crashfile_format': "txt",
                              'write_provenance' : False,},
                          'monitoring': { 'enabled': True }
                        })
    
    logging.update_logging(config)
    iflogger = logging.getLogger('nipype.interface')

    iflogger.info("**** Processing ****")

    
    dg = Node(interface=DataGrabber(outfields = ['T2ws', 'masks']), name='data_grabber')
    
    dg.inputs.base_directory = bids_dir
    dg.inputs.template = '*'
    dg.inputs.raise_on_empty = False
    dg.inputs.sort_filelist=True
    
    dg.inputs.field_template = dict(T2ws=os.path.join(subject, 'anat', subject+'*_run-*_T2w.nii.gz'),
                                   masks=os.path.join('derivatives','manual_masks', subject, 'anat', subject+'*_run-*_*mask.nii.gz'))
    if not (session is None):
        dg.inputs.field_template = dict(T2ws=os.path.join( subject, session, 'anat', '_'.join([subject, session, '*run-*', '*T2w.nii.gz'])),
                                        masks=os.path.join('derivatives','manual_masks', subject, session, 'anat','_'.join([subject, session, '*run-*', '*mask.nii.gz'])))
    
    
        
    nlmDenoise = Node(interface=preprocess.MultipleBtkNLMDenoising(), name='nlmDenoise')
    nlmDenoise.inputs.bids_dir = bids_dir
    nlmDenoise.inputs.stacksOrder = p_stacksOrder

    
    # Sans le mask le premier correct slice intensity...
    srtkCorrectSliceIntensity01_nlm = Node(interface=preprocess.MultipleMialsrtkCorrectSliceIntensity(), name='srtkCorrectSliceIntensity01_nlm')
    srtkCorrectSliceIntensity01_nlm.inputs.bids_dir = bids_dir
    srtkCorrectSliceIntensity01_nlm.inputs.stacksOrder = p_stacksOrder
    srtkCorrectSliceIntensity01_nlm.inputs.out_postfix = '_uni'

    srtkCorrectSliceIntensity01 = Node(interface=preprocess.MultipleMialsrtkCorrectSliceIntensity(), name='srtkCorrectSliceIntensity01')
    srtkCorrectSliceIntensity01.inputs.bids_dir = bids_dir
    srtkCorrectSliceIntensity01.inputs.stacksOrder = p_stacksOrder
    srtkCorrectSliceIntensity01.inputs.out_postfix = '_uni'

    
    
    srtkSliceBySliceN4BiasFieldCorrection = Node(interface=preprocess.MultipleMialsrtkSliceBySliceN4BiasFieldCorrection(), name='srtkSliceBySliceN4BiasFieldCorrection')
    srtkSliceBySliceN4BiasFieldCorrection.inputs.bids_dir = bids_dir
    srtkSliceBySliceN4BiasFieldCorrection.inputs.stacksOrder = p_stacksOrder
    
    srtkSliceBySliceCorrectBiasField = Node(interface=preprocess.MultipleMialsrtkSliceBySliceCorrectBiasField(), name='srtkSliceBySliceCorrectBiasField')
    srtkSliceBySliceCorrectBiasField.inputs.bids_dir = bids_dir
    srtkSliceBySliceCorrectBiasField.inputs.stacksOrder = p_stacksOrder
    
    
    
    srtkCorrectSliceIntensity02_nlm = Node(interface=preprocess.MultipleMialsrtkCorrectSliceIntensity(), name='srtkCorrectSliceIntensity02_nlm')
    srtkCorrectSliceIntensity02_nlm.inputs.bids_dir = bids_dir
    srtkCorrectSliceIntensity02_nlm.inputs.stacksOrder = p_stacksOrder

    srtkCorrectSliceIntensity02 = Node(interface=preprocess.MultipleMialsrtkCorrectSliceIntensity(), name='srtkCorrectSliceIntensity02')
    srtkCorrectSliceIntensity02.inputs.bids_dir = bids_dir
    srtkCorrectSliceIntensity02.inputs.stacksOrder = p_stacksOrder
    
    
    srtkIntensityStandardization01 = Node(interface=preprocess.MialsrtkIntensityStandardization(), name='srtkIntensityStandardization01')
    srtkIntensityStandardization01.inputs.bids_dir = bids_dir
    
    
    srtkIntensityStandardization01_nlm = Node(interface=preprocess.MialsrtkIntensityStandardization(), name='srtkIntensityStandardization01_nlm')
    srtkIntensityStandardization01_nlm.inputs.bids_dir = bids_dir
    
    
    srtkHistogramNormalization = Node(interface=preprocess.MialsrtkHistogramNormalization(), name='srtkHistogramNormalization')
    srtkHistogramNormalization.inputs.bids_dir = bids_dir
    srtkHistogramNormalization.inputs.stacksOrder = p_stacksOrder
    
    srtkHistogramNormalization_nlm = Node(interface=preprocess.MialsrtkHistogramNormalization(), name='srtkHistogramNormalization_nlm')  
    srtkHistogramNormalization_nlm.inputs.bids_dir = bids_dir
    srtkHistogramNormalization_nlm.inputs.stacksOrder = p_stacksOrder
    
    
    srtkIntensityStandardization02 = Node(interface=preprocess.MialsrtkIntensityStandardization(), name='srtkIntensityStandardization02')
    srtkIntensityStandardization02.inputs.bids_dir = bids_dir
    
    
    srtkIntensityStandardization02_nlm = Node(interface=preprocess.MialsrtkIntensityStandardization(), name='srtkIntensityStandardization02_nlm')
    srtkIntensityStandardization02_nlm.inputs.bids_dir = bids_dir
    
    
    srtkMaskImage01 = Node(interface=preprocess.MultipleMialsrtkMaskImage(), name='srtkMaskImage01')
    srtkMaskImage01.inputs.bids_dir = bids_dir
    srtkMaskImage01.inputs.stacksOrder = p_stacksOrder


    srtkImageReconstruction = Node(interface=reconstruction.MialsrtkImageReconstruction(), name='srtkImageReconstruction')  
    srtkImageReconstruction.inputs.bids_dir = bids_dir
    srtkImageReconstruction.inputs.stacksOrder = p_stacksOrder 

    
    sub_ses = subject
    if session != None:
        sub_ses = ''.join([sub_ses, '_', session])
    srtkImageReconstruction.inputs.sub_ses = sub_ses
    
    srtkTVSuperResolution = Node(interface=reconstruction.MialsrtkTVSuperResolution(), name='srtkTVSuperResolution')  
    srtkTVSuperResolution.inputs.bids_dir = bids_dir
    srtkTVSuperResolution.inputs.stacksOrder = p_stacksOrder
    srtkTVSuperResolution.inputs.sub_ses = sub_ses
    srtkTVSuperResolution.inputs.in_loop = primal_dual_loops
    srtkTVSuperResolution.inputs.in_deltat = deltatTV
    srtkTVSuperResolution.inputs.in_lambda = lambdaTV
    
    

    srtkRefineHRMaskByIntersection = Node(interface=postprocess.MialsrtkRefineHRMaskByIntersection(), name='srtkRefineHRMaskByIntersection')
    srtkRefineHRMaskByIntersection.inputs.bids_dir = bids_dir
    srtkRefineHRMaskByIntersection.inputs.stacksOrder = p_stacksOrder
    
    srtkN4BiasFieldCorrection = Node(interface=postprocess.MialsrtkN4BiasFieldCorrection(), name='srtkN4BiasFieldCorrection')
    srtkN4BiasFieldCorrection.inputs.bids_dir = bids_dir
    
    
    srtkMaskImage02 = Node(interface=preprocess.MialsrtkMaskImage(), name='srtkMaskImage02')
    srtkMaskImage02.inputs.bids_dir = bids_dir
    
    datasink = Node(DataSink(), name='sinker')
    output_dir = os.path.join("{}".format(bids_dir),"derivatives","mialsrtk-py")
    datasink.inputs.base_directory = output_dir
    
    #
    ## Nodes ready - Linking now
    
    wf.connect(dg, "T2ws", nlmDenoise, "input_images")
#     wf.connect(dg, "masks", nlmDenoise, "input_masks")  ## Comment to match docker process
    
    wf.connect(nlmDenoise, "output_images", srtkCorrectSliceIntensity01_nlm, "input_images")
    wf.connect(dg, "masks", srtkCorrectSliceIntensity01_nlm, "input_masks")
    
    wf.connect(dg, "T2ws", srtkCorrectSliceIntensity01, "input_images")
    wf.connect(dg, "masks", srtkCorrectSliceIntensity01, "input_masks")
    
    wf.connect(srtkCorrectSliceIntensity01_nlm, "output_images", srtkSliceBySliceN4BiasFieldCorrection, "input_images")
    wf.connect(dg, "masks", srtkSliceBySliceN4BiasFieldCorrection, "input_masks")
    
    wf.connect(srtkCorrectSliceIntensity01, "output_images", srtkSliceBySliceCorrectBiasField, "input_images")
    wf.connect(srtkSliceBySliceN4BiasFieldCorrection, "output_fields", srtkSliceBySliceCorrectBiasField, "input_fields")
    wf.connect(dg, "masks", srtkSliceBySliceCorrectBiasField, "input_masks")
    
    wf.connect(srtkSliceBySliceCorrectBiasField, "output_images", srtkCorrectSliceIntensity02, "input_images")
    wf.connect(dg, "masks", srtkCorrectSliceIntensity02, "input_masks")
    
    wf.connect(srtkSliceBySliceN4BiasFieldCorrection, "output_images", srtkCorrectSliceIntensity02_nlm, "input_images")
    wf.connect(dg, "masks", srtkCorrectSliceIntensity02_nlm, "input_masks")
    
    wf.connect(srtkCorrectSliceIntensity02, "output_images", srtkIntensityStandardization01, "input_images")
    
    wf.connect(srtkCorrectSliceIntensity02_nlm, "output_images", srtkIntensityStandardization01_nlm, "input_images")
    
    wf.connect(srtkIntensityStandardization01, "output_images", srtkHistogramNormalization, "input_images")
    wf.connect(dg, "masks", srtkHistogramNormalization, "input_masks")
    
    wf.connect(srtkIntensityStandardization01_nlm, "output_images", srtkHistogramNormalization_nlm, "input_images")
    wf.connect(dg, "masks", srtkHistogramNormalization_nlm, "input_masks")
    
    wf.connect(srtkHistogramNormalization, "output_images", srtkIntensityStandardization02, "input_images")
    
    wf.connect(srtkHistogramNormalization_nlm, "output_images", srtkIntensityStandardization02_nlm, "input_images")
    
    
    wf.connect(srtkIntensityStandardization02_nlm, "output_images", srtkMaskImage01, "input_images")
    wf.connect(dg, "masks", srtkMaskImage01, "input_masks")
    
    
    wf.connect(srtkMaskImage01, "output_images", srtkImageReconstruction, "input_images")
    wf.connect(dg, "masks", srtkImageReconstruction, "input_masks")
    
    wf.connect(srtkIntensityStandardization02, "output_images", srtkTVSuperResolution, "input_images")
    wf.connect(srtkImageReconstruction, "output_transforms", srtkTVSuperResolution, "input_transforms")
    wf.connect(dg, "masks", srtkTVSuperResolution, "input_masks")
    wf.connect(srtkImageReconstruction, "output_sdi", srtkTVSuperResolution, "input_sdi")
    
    
    wf.connect(srtkIntensityStandardization02, "output_images", srtkRefineHRMaskByIntersection, "input_images")
    wf.connect(dg, "masks", srtkRefineHRMaskByIntersection, "input_masks")
    wf.connect(srtkImageReconstruction, "output_transforms", srtkRefineHRMaskByIntersection, "input_transforms")
    wf.connect(srtkTVSuperResolution, "output_sr", srtkRefineHRMaskByIntersection, "input_sr")
    
    wf.connect(srtkTVSuperResolution, "output_sr", srtkN4BiasFieldCorrection, "input_image")
    wf.connect(srtkRefineHRMaskByIntersection, "output_SRmask", srtkN4BiasFieldCorrection, "input_mask")
    
    wf.connect(srtkTVSuperResolution, "output_sr", srtkMaskImage02, "in_file")
    wf.connect(srtkRefineHRMaskByIntersection, "output_SRmask", srtkMaskImage02, "in_mask")
    
    
    
    #
    ### - Saving files
    
    
    substitutions = []
    for stack in stacksOrder:
    
        print( sub_ses+'_run-'+str(stack)+'_T2w_nlm_uni_bcorr_histnorm.nii.gz', '    --->     ',sub_ses+'_run-'+str(stack)+'_T2w_preproc.nii.gz')
        substitutions.append( ( sub_ses+'_run-'+str(stack)+'_T2w_nlm_uni_bcorr_histnorm.nii.gz', sub_ses+'_run-'+str(stack)+'_T2w_preproc.nii.gz') )
        
        print( sub_ses+'_run-'+str(stack)+'_T2w_nlm_uni_bcorr_histnorm_transform_'+str(len(stacksOrder))+'V.txt', '    --->     ', sub_ses+'_run-'+str(stack)+'_T2w_from-origin_to-SDI_mode-image_xfm.txt')
        substitutions.append( ( sub_ses+'_run-'+str(stack)+'_T2w_nlm_uni_bcorr_histnorm_transform_'+str(len(stacksOrder))+'V.txt', sub_ses+'_run-'+str(stack)+'_T2w_from-origin_to-SDI_mode-image_xfm.txt') )
        
        print( sub_ses+'_run-'+str(stack)+'_T2w_uni_bcorr_histnorm_LRmask.nii.gz', '    --->     ', sub_ses+'_run-'+str(stack)+'_T2w_desc-LRmask.nii.gz')
        substitutions.append( ( sub_ses+'_run-'+str(stack)+'_T2w_uni_bcorr_histnorm_LRmask.nii.gz', sub_ses+'_run-'+str(stack)+'_T2w_desc-LRmask.nii.gz') )

        
    print( 'SDI_'+sub_ses+'_'+str(len(stacksOrder))+'V_rad1.nii.gz', '    --->     ', sub_ses+'_rec-SDI_T2w.nii.gz')
    substitutions.append( ( 'SDI_'+sub_ses+'_'+str(len(stacksOrder))+'V_rad1.nii.gz', sub_ses+'_rec-SDI_T2w.nii.gz') )

    print( 'SRTV_'+sub_ses+'_'+str(len(stacksOrder))+'V_rad1_gbcorr.nii.gz', '    --->     ', sub_ses+'_rec-SR_T2w.nii.gz')
    substitutions.append( ( 'SRTV_'+sub_ses+'_'+str(len(stacksOrder))+'V_rad1_gbcorr.nii.gz', sub_ses+'_rec-SR_T2w.nii.gz') )
    

    print( sub_ses+'_T2w_uni_bcorr_histnorm_srMask.nii.gz', '    --->     ', sub_ses+'_rec-SR_T2w_desc-brain_mask.nii.gz')
    substitutions.append( ( sub_ses+'_T2w_uni_bcorr_histnorm_srMask.nii.gz', sub_ses+'_rec-SR_T2w_desc-SRmask.nii.gz') )

    
        
    datasink.inputs.substitutions = substitutions
    
    wf.connect(srtkMaskImage01, "output_images", datasink, 'preproc')
    wf.connect(srtkImageReconstruction, "output_transforms", datasink, 'xfm')
    wf.connect(srtkRefineHRMaskByIntersection, "output_LRmasks", datasink, 'postproc')
    
    wf.connect(srtkImageReconstruction, "output_sdi", datasink, 'anat')
    wf.connect(srtkN4BiasFieldCorrection, "output_image", datasink, 'anat.@SR')
    wf.connect(srtkRefineHRMaskByIntersection, "output_SRmask", datasink, 'postproc.@SRmask')
    
    
    # JSON file SRTV
    output_dict = {}

    output_dict["Description"] = "Isotropic high-resolution image reconstructed using the Total-Variation Super-Resolution algorithm provided by MIALSRTK"
    # output_dict["Sources"] = sources
    output_dict["Input sources run order"] = stacksOrder
    output_dict["CustomMetaData"] = {}
    output_dict["CustomMetaData"]["Number of scans used"] = str(len(p_stacksOrder))
    output_dict["CustomMetaData"]["TV regularization weight lambda"] = lambdaTV
    output_dict["CustomMetaData"]["Optimization time step"] = deltatTV
    output_dict["CustomMetaData"]["Primal/dual loops"] = primal_dual_loops

    output_json = os.path.join(output_dir, 'anat', ''.join([subject, '_rec-SR.json']))
    with open(output_json, 'w+', encoding='utf8') as outfile:
        json.dump(output_dict, outfile, indent=4)
        
    return wf

In [None]:
m_wf = create_workflow(bids_dir, process_dir='/fetaldata/derivatives/tmp_proc', subject=subject, p_stacksOrder=stacksOrder, session = session)
m_wf.write_graph()
aa = m_wf.run()
