# Skull Stripping Segmentation NFBS Data

My plan is to use the Neurofeedback Skull-stripped (NFBS) dataset to train a 3D UNet model on that. Then use that trained 3D UNet model to do automated skull stripping on the ATLAS 2.0 dataset later in our lesion segmentation notebook and source code. Skull stripping is a preprocessing step before lesion segmentation.

Reference: https://www.analyticsvidhya.com/blog/2021/06/introduction-to-skull-stripping-image-segmentation-on-3d-mri-images/

MRI Brain Dataset with skull stripped labels: http://preprocessed-connectomes-project.org/NFB_skullstripped/

## NFBS Dataset

- 125 participants 21 to 45 years old with a variety of clinical and sub-clinical psychiatric symptoms.
- Structural T1-weighted anonymized (de-faced) images) with a single channel.
- Brain mask is the ground truth obtained using Beast (Brain extraction based on nonlocal segmentation) method and applying manual edits by domain experts to remove non-brain tissue.
- Skull-stripped image is part of the brain stripped from T1 weighted image. It is similar to overlaying masks to actual images.
- The resolution is 1 mm3 and each file is NiFTI (.nii.gz) format.

## Outline

- Prepare NFBS Data
- 3D UNet Model Architecture
- Train 3D UNet Model
- Deploy 3D UNet for Skull Stripping

## Prepare NFBS Data

In [10]:
import nibabel as nib
import numpy as np
import pandas as pd
import os
from matplotlib import pyplot as plt

In [3]:
nfbs_path = "/media/james/My Passport/NFBS_Dataset"
# voxel is a 3D MRI object can be split into MRI 2D slice images
voxel = nib.load("{}/A00028185/sub-A00028185_ses-NFB3_T1w.nii.gz".format(nfbs_path))

print("Shape of voxel=", voxel.shape)

Shape of voxel= (256, 256, 192)


We can see our 3D MRI voxel also can be thought of as a 3D object containing 192 MRI 2D slice images where each image is 256*256. You can think of these images stacked on top of each other.

We will create a data frame that contains the location of images and their corresponding masks and skull-stripped images

In [4]:
# store the address of 3 types of files
brain_mask = list()
brain = list()
raw = list()

for subdir, dirs, files in os.walk(nfbs_path):
    for file in files:
        # print(os.path.join(subdir, file))
        filepath = subdir + os.sep + file
        
        if filepath.endswith(".gz"):
            if "_brainmask." in filepath:
                brain_mask.append(filepath)
            elif "_brain." in filepath:
                brain.append(filepath)
            else:
                raw.append(filepath)

In [11]:
nfbs_df = pd.DataFrame(
    {"brain_mask": brain_mask,
     "brain": brain,
     "raw": raw
    }
)

In [13]:
nfbs_df.head(5)

Unnamed: 0,brain_mask,brain,raw
0,/media/james/My Passport/NFBS_Dataset/A0004372...,/media/james/My Passport/NFBS_Dataset/A0004372...,/media/james/My Passport/NFBS_Dataset/A0004372...
1,/media/james/My Passport/NFBS_Dataset/A0005609...,/media/james/My Passport/NFBS_Dataset/A0005609...,/media/james/My Passport/NFBS_Dataset/A0005609...
2,/media/james/My Passport/NFBS_Dataset/A0006025...,/media/james/My Passport/NFBS_Dataset/A0006025...,/media/james/My Passport/NFBS_Dataset/A0006025...
3,/media/james/My Passport/NFBS_Dataset/A0002818...,/media/james/My Passport/NFBS_Dataset/A0002818...,/media/james/My Passport/NFBS_Dataset/A0002818...
4,/media/james/My Passport/NFBS_Dataset/A0002835...,/media/james/My Passport/NFBS_Dataset/A0002835...,/media/james/My Passport/NFBS_Dataset/A0002835...


The next preparation steps we'll perform are:

- Bias field correction: bias field signal is a low-frequency and smooth signal that corrupts MRI images especially those produced by old MRI machines. Thus, bias field correction is done on MRIs prior to pushing them into image processing algorithms: segmentation, texture analysis, classification.
- Cropping and resizing: due to computational limits of fitting image to model, we reduce MRI image size from (`256*256*192`) to (`96*128*160`). The target size is chosen in a way where most of the skull is captured after cropping and resizing it has a centering effect on images.
- Intensity normalization: shifts and scales an image so the pixels have a zero mean and unit variance. This helps the model converge faster by removing scale in-variance.
- Data Generator to Feed Data to Model

In [14]:
class preprocessing():
    def __init__(self, df):
        self.data = df
        self.raw_index = list()
        self.mask_index = list()
    def bias_correction(self):
        bias_dir = "bias_correction"
        bias_path = os.path.join(nfbs_path, bias_dir)
        if not os.path.exists(bias_path):
            os.makedirs(bias_path)
        n4 = N4BiasFieldCorrection()
        n4.inputs.dimensions = 3
        n4.inputs.shrink_factor = 3
        n4.inputs.n_iterations = [20, 10, 10, 5]
        index_corr = list()
        for i in tqdm(range(len(self.data))):
            n4.inputs.input_image = self.data.raw.iloc[i]
            n4.inputs.output_image = bias_dir + os.sep + str(i) + ".nii.gz"
            index_corr.append(bias_dir + os.sep + str(i) + ".nii.gz")
            res = n4.run()
        index_corr = [bias_dir + os.sep + str(i) + ".nii.gz" for i in range(125)]
        data["bias_corr"] = index_corr
        print("bias corrected voxels stored at : {}/".format(bias_dir))
    def resize_crop(self):
        # reducing the size of image due to memory constraints
        self.rcrop_dir = "resized"
        rcrop_path = os.path.join(nfbs_path, rcrop_dir)
        #reducing size of image from 256*256*192 to 96*128*160
        target_shape = np.array((96,128,160))
        new_resolution = [2,]*3
        new_affine = np.zeros((4,4))
        new_affine[:3,:3] = np.diag(new_resolution)
        # putting point 0,0,0 in the middle of the new volume - this
        # could be refined in the future
        new_affine[:3,3] = target_shape*new_resolution/2.*-1
        new_affine[3,3] = 1.
        raw_index = list()
        mask_index = list()
        # resizing both image and mask and storing in folder
        for i in range(len(data)):
            downsampled_and_cropped_nii = resample_img(
                self.data.data_corr.iloc[i], target_affine=new_affine,
                target_shape=target_shape, interpolation="nearest")
            downsampled_and_cropped_nii.to_filename(
                rcrop_dir + os.sep + "raw" + str(i) + ".nii.gz")
            self.raw_index.append(rcrop_dir + os.sep + "raw" + str(i) + ".nii.gz")
            
            downsampled_and_cropped_nii = resample_img(
                self.data.brain_mask.iloc[i], target_affine=new_affine,
                target_shape=target_shape, interpolation="nearest")
            downsampled_and_cropped_nii.to_filename(
                rcrop_dir + os.sep + "mask" + str(i) + ".nii.gz")
            self.mask_index.append(rcrop_dir + os.sep + "mask" + str(i) + ".nii.gz")
        return self.raw_index, self.mask_index
    
    def intensity_normalization(self):
        for i in self.raw_index:
            image = sitk.ReadImage(i)
            resacleFilter = sitk.RescaleIntensityImageFilter()
            resacleFilter.SetOutputMaximum(255)
            resacleFilter.SetOutputMinimum(0)
            image = resacleFilter.Execute(image)
            sitk.WriteImage(image, i)
        print("Normalization done. Voxels stored at: {}/".format(self.rcrop_dir))

## 3D UNet Model Architecture

We've built our preprocessing class, we can begin modeling. First we do a train and test split. Then we use a custom data generator to feed the input MRIs into the model.