# Preprocessing all at once
This notebook is meant to run the preprocessing code found in the scripts directory all at once\
If you'd rather run the code step by step go to the `02_preprocessing_step_by_step.ipynb` file\
The preprocesing steps include
1. convert dicom to nifti using [Dcm2niix](https://github.com/rordenlab/dcm2niix)
2. reorient images
3. extract brain using [HD-BET](https://github.com/MIC-DKFZ/HD-BET)
4. N4 bias correction
5. Coregister images
6. Resample images
7. Z-score normalize images

## Import necessary libraries

In [None]:
import sys
sys.path.append(r"/Users/LennartPhilipp/Desktop/Uni/Prowiss/Code/Brain_Mets_Classification")
sys.path.append(r"/Users/LennartPhilipp/Desktop/Uni/Prowiss/Code/HD-BET/HD_BET")
import brain_mets_classification.config as config
import brain_mets_classification.custom_funcs as funcs

import tqdm

import pandas as pd
import os
import pathlib
import ants
from typing import Union, List, Tuple
import multiprocessing
import SimpleITK as sitk
from nipype.interfaces.dcm2nii import Dcm2niix
import numpy as np
#from HD_BET.run import run_hd_bet
#import HD_BET
from nipype.interfaces import fsl
from intensity_normalization.normalize.zscore import ZScoreNormalize

## Helper Functions

In [9]:
N_PROC = multiprocessing.cpu_count() - 1

def dcm_to_nifti_conversion(
        path_to_folder: Union[str, pathlib.Path],
        path_to_output: Union[str, pathlib.Path],
        out_filename: str):

    # new sequence name: {patientID}_{sequence}_{preprocessingStep}
    converter = Dcm2niix()
    converter.inputs.source_dir = path_to_folder
    converter.inputs.compress = "y" # uses compression, "y" = yes
    converter.inputs.merge_imgs = True
    # converter.inputs.compression = 5
    converter.inputs.out_filename = out_filename
    converter.inputs.output_dir = path_to_output
    converter.run()
    

def extract_brain(
    path_to_input_image: Union[str, pathlib.Path],
    path_to_output_image: Union[str, pathlib.Path],
):
    """
    applies FSL.Reorient2Std() to input brain scan
    and returns brain extracted image

    Keyword Arguments:
    path_to_input_image: Union[str, pathlib.Path] = file path to input image (brain scan)
    path_to_output_image: Union[str, pathlib.Path] = location to store brain extracted image
    """

    # Alternative: torchio tocanonical
    reorient = fsl.Reorient2Std()
    reorient.inputs.in_file = path_to_input_image
    reorient.inputs.out_file = path_to_output_image
    reorient.run()

    run_hd_bet(mri_fnames=path_to_output_image, output_fnames=path_to_output_image)


def fill_holes(
    binary_image: sitk.Image,
    radius: int = 3,
) -> sitk.Image:
    """
    Fills holes in binary segmentation

    Keyword Arguments:
    - binary_image: sitk.Image = binary brain segmentation
    - radius: int = kernel radius

    Returns:
    - closed_image: sitk.Image = binary brain segmentation with holes filled
    """

    closing_filter = sitk.BinaryMorphologicalClosingImageFilter()
    closing_filter.SetKernelRadius(radius)
    closed_image = closing_filter.Execute(binary_image)

    return closed_image


def binary_segment_brain(
    image: sitk.Image,
) -> sitk.Image:
    """
    Returns binary segmentation of brain from brain-extracted scan via otsu thresholding

    Keyword Arguments:
    - image: sitk.Image = brain-extracted scan

    Returns:
    - sitk.Image = binary segmentation of brain scan with filled holes
    """

    otsu_filter = sitk.OtsuThresholdImageFilter()
    otsu_filter.SetInsideValue(0)
    otsu_filter.SetOutsideValue(1)
    binary_mask = otsu_filter.Execute(image)

    return fill_holes(binary_mask)


def get_bounding_box(
    image: sitk.Image,
) -> Tuple[int]:
    """
    Returns bounding box of brain-extracted scan

    Keyword Arguments:
    - image: sitk.Image = brain-extracted scan

    Returns
    - bounding_box: Tuple(int) = bounding box (startX, startY, startZ, sizeX, sizeY, sizeZ)
    """

    mask_image = binary_segment_brain(image)

    lsif = sitk.LabelShapeStatisticsImageFilter()
    lsif.Execute(mask_image)
    bounding_box = np.array(lsif.GetBoundingBox(1))

    return bounding_box


def apply_bounding_box(
    image: sitk.Image,
    bounding_box: Tuple[int],
) -> sitk.Image:
    """
    Returns image, cropped to bounding box

    Keyword Arguments:
    - image: sitk.Image = image
    - bounding_box: Tuple(ing) = bounding box of kind (startX, startY, startZ, sizeX, sizeY, sizeZ)

    Returns
    - cropped_image: sitk.Image = cropped image
    """

    cropped_image = image[
        bounding_box[0] : bounding_box[3] + bounding_box[0],
        bounding_box[1] : bounding_box[4] + bounding_box[1],
        bounding_box[2] : bounding_box[5] + bounding_box[2],
    ]

    return cropped_image


def apply_bias_correction(
    image: sitk.Image,
) -> sitk.Image:
    """applies N4 bias field correction to image but keeps background at zero

    Keyword Arguments:
    image: sitk.Image = image to apply bias correction to

    Returns:
    image_corrected_masked: sitk.Image = N4 bias field corrected image
    """

    mask_image = binary_segment_brain(image)
    corrector = sitk.N4BiasFieldCorrectionImageFilter()
    image_corrected = corrector.Execute(image, mask_image)

    mask_filter = sitk.MaskImageFilter()
    mask_filter.SetOutsideValue(0)
    image_corrected_masked = mask_filter.Execute(image_corrected, mask_image)

    return image_corrected_masked


def coregister_antspy(
    fixed_path: Union[str, pathlib.Path],
    moving_path: Union[str, pathlib.Path],
    out_path: Union[str, pathlib.Path],
    num_threads=N_PROC,
) -> ants.core.ants_image.ANTsImage:
    """
    Coregister moving image to fixed image. Return warped image and save to disk.

    Keyword Arguments:
    fixed_path: path to fixed image
    moving_path: path to moving image
    out_path: path to save warped image to
    num_threads: number of threads
    """

    os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = str(num_threads)

    res = ants.registration(
        fixed=ants.image_read(fixed_path),
        moving=ants.image_read(moving_path),
        type_of_transform="antsRegistrationSyNQuick[s]",  # or "SyNRA"
        initial_transform=None,
        outprefix="",
        mask=None,
        moving_mask=None,
        mask_all_stages=False,
        grad_step=0.2,
        flow_sigma=3,
        total_sigma=0,
        aff_metric="mattes",
        aff_sampling=32,
        aff_random_sampling_rate=0.2,
        syn_metric="mattes",
        syn_sampling=32,
        reg_iterations=(40, 20, 0),
        aff_iterations=(2100, 1200, 1200, 10),
        aff_shrink_factors=(6, 4, 2, 1),
        aff_smoothing_sigmas=(3, 2, 1, 0),
        write_composite_transform=False,
        random_seed=None,
        verbose=False,
        multivariate_extras=None,
        restrict_transformation=None,
        smoothing_in_mm=False,
    )

    warped_moving = res["warpedmovout"]

    ants.image_write(warped_moving, out_path)

    return warped_moving


def resample(
    itk_image: sitk.Image,
    out_spacing: Tuple[float, ...],
    is_mask: bool,
) -> sitk.Image:
    """
    Resamples sitk image to expected output spacing

    Keyword Arguments:
    itk_image: sitk.Image
    out_spacing: Tuple
    is_mask: bool = True if input image is label mask -> NN-interpolation

    Returns
    output_image: sitk.Image = image resampled to out_spacing
    """

    original_spacing = itk_image.GetSpacing()
    original_size = itk_image.GetSize()

    out_size = [
        int(round(osz * osp / nsp))
        for osz, osp, nsp in zip(original_size, original_spacing, out_spacing)
    ]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(0)

    if is_mask:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)

    else:
        resample.SetInterpolator(
            sitk.sitkBSpline
        )  # sitk.sitkLinear sitk.sitkNearestNeighbor

    output_image = resample.Execute(itk_image)

    return output_image


def zscore_normalize(image: sitk.Image) -> sitk.Image:
    """
    Applies z score normalization to brain scan using a brain mask

    Keyword Arguments:
    image: sitk.Image = input brain scan

    Returns:
    normalized_brain_image: sitk.Image = normalized brain scan
    """

    brain_mask = binary_segment_brain(image)

    normalizer = ZScoreNormalize()
    normalized_brain_array = normalizer(
        sitk.GetArrayFromImage(image),
        sitk.GetArrayFromImage(brain_mask),
    )

    normalized_brain_image = sitk.GetImageFromArray(normalized_brain_array)
    normalized_brain_image.CopyInformation(image)

    return normalized_brain_image


## Run preprocessing
Input arguments:\
path_to_patients: str,\
path_to_output: str,

The `path_to_patients` variable should lead to a directory, where all the patients directories can be found, in which all the sequences are storted as directories full of dicom files

In [6]:
path_to_patients = ""
path_to_output = ""

N_PROC = multiprocessing.cpu_count() - 1

In [None]:
# create folder at path to output called Rgb_Brain_Mets_preprocessed
path_to_preprocessed_files = f"{path_to_output}/Rgb_Brain_Mets_preprocessed"
os.mkdir(path_to_preprocessed_files)

# gets only the folders at path and puts them in an array 
patient_folders = [
    folder for folder in os.listdir(path_to_patients) if os.path.isdir(os.path.join(path_to_patients, folder))
]

for patient in tqdm(patient_folders):

     # ignores the ds_folders
    if config.dsStore in patient:
        continue
    
    patientID = patient

    # get the different sequences (stored in folders) for each patient and put them in an array
    dicomSequences = [
        sequenceFolder for sequenceFolder in os.listdir(os.path.join(path_to_patients, patient)) if os.path.isdir(os.path.join(path_to_patients, patient, sequenceFolder))
    ]

    print("Starting nifti conversion...")
    # loop through the dicom sequences
    for dicomSequenceFolder in dicomSequences:
        # turn each sequence into nifti file
        # save the nifti file in each patient folder
        path_to_folder = os.path.join(path_to_patients, patient, dicomSequenceFolder)
        path_to_output = os.path.join(path_to_patients, patient)
        sequenceType = dicomSequenceFolder.split("_")[1]
        new_file_name = f"{patientID}_{sequenceType}"
        dcm_to_nifti_conversion(path_to_folder = path_to_folder,
                                path_to_output = "",
                                out_filename = new_file_name)
        
    print("Finished nifti conversion")

    # get the newly converted nifti files
    niftiSequences = [
        sequence for sequence in os.listdir(os.path.join(path_to_patients, patient)) if (".nii.gz" in sequence)
    ]

    if len(niftiSequences) == 0:
        print("Warning: no nifti files found")
    

    # loop through the nifit sequences
    for niftiSequence in niftiSequences:
        print("Starting reorientation")
        # reorient images
        print("Finished reorientation")
        print("Starting brain extraction")
        # use brain extraction
        print("Finished brain extraction")
    
    # get the brain extracted files
    brainExtractedFiles = [
        sequence for sequence in os.listdir(os.path.join(path_to_patients, patient)) if ("brainextracted" in sequence)
    ]

    referenceSequenceForCoregistration = ""

    # loop through the nifit sequences
    for brainExtractedSequence in brainExtractedFiles:
        print("Starting bounding box")
        # get and apply a bounding box both to the brain as well as the mask
        print("Finished bounding box")
        print("Starting n4 bias correction")
        # n4 bias correction
        print("Finished n4 bias correction")
        print("Starting coregistration of sequences")
        # coregister images
        print("Finished coregistration of sequences")
        print("Stating resampling of images")
        # resample images
        print("Finished resampling of images")
        print("Starting z score normalization of sequence")
        # z score normalize images
        print("Starting z score normalization of sequence")
    
    # clean up created files (nifti files and brain extracted images)
    print("Cleaned up unnecessary files")

print("Finished preprocessing images")
