# Stroke Segmentation ATLAS2.0 DL Pipeline

We will design and develop a Deep Learning Pipeline to prepare the ATLAS 2.0 stroke dataset to be used to train lesion segmentation DNN models, in particular our own design and implementation of DAGMNet, and then deploy the model to perform lesion segmentation on stroke 3D MRI scans from the testing set that it has not seen. Each section in the DL Pipeline is based on the design methodology section in Dr. Faria's publication: "Deep learning-based detection and segmentation of diffusion abnormalities in acute ischemic stroke". However, since we do not have access to the clinical 3D & 4D stroke dataset they used as it is not publicly available on ICPSR yet, we may adapt their design to work with ATLAS 2.0 stroke dataset.

C.-F. Liu, J. Hsu, X. Xu, S. Ramachandran, V. Wang, M. I. Miller, A. E. Hillis, A. V. Faria, M. et al., "Deep learning-based detection and segmentation of diffusion abnormalities in acute ischemic stroke," Communications Medicine, vol. 1, no. 1, 2021. 

## Outline

- Prepare ATLAS 2.0 Stroke Data 
- Stroke Segmentation Model Architecture
- Train Stroke Segmentation ML/DL Models
- Evaluate Stroke Segmentation ML/DL Models Quantitatively
- Evaluate Stroke Segmentation ML/DL Models Qualitatively
- Deploy Stroke Segmentation DL Model for Inference

## Prepare ATLAS Stroke Dataset

- 1\. Resample DWI, BO, ADC into 1x1x1 mm^3
    - In our ATLAS2.0, we have T1 weighted MRIs.
- 2\. Skull-stripped with an in-house "UNet BrainMask Network"
- 3\. Used "in-plan" (IP) linear transformations to map the images to the standard MNI (Montreal Neurological Institute) template
- 4\. Normalized the DWI intensity to reduce the variability and increase the comparability among subjects
- 5\. "Down-sampled" (DS) images to reduce the memory resource requirement in the next steps (e.g. DL networks)

Note: The gold standards brain masks were generated as follows: first, all DWI and B0 images were resampled into 1 × 1 × 1 mm3 and skull striped by a level-set algorithm (available under ROIStudio, in MRIStudio38), with
W5 = 1.2 and 4, respectively (see explanation about the choice of
parameters in MRIstudio website38). [...] To train our "UNet BrainMask Network", all images are mapped to MNI and downsampled to 4 × 4 × 4 mm3. The final brain mask inferenced by the network was then post-processed by the closing and the “binary_fill_holes” function from Python scipy module, upsampled to 1 × 1 × 1 mm3, and dilated by one voxel with image smoothing.

Note: **BO** (image in absence of diffusion gradient)

Note: **UNet BrainMask Network's** purpose is to do skull stripping in 19seconds
compared to the gold standard level-set algorithm used in ROIStudio (in MRIStudio)
that does it in 4.3 minutes. Thus, UNet making automated skull stripping suitable
for large-scale, fast processing

Note: **DWI intensity normalization** uses bimodal Gaussian function to fit the intensity histogram of DWI and cluster two groups of voxels: "brain tissue" (highest peak) and "non-brain tissue" (lowest peak). DWI intensities are computed to make "brain tissue" intensity with zero mean and one standard deviation. After normalization, the lesion contrast should be presevered with a minor peak at high intensities in the brain with ischemic lesion. Additionally, the intensity differences between magnetic fields and scan manufacturers are improved.

In [None]:
# prepare_atlas.py

# Purpose of UNet BrainMask Network is to do skull stripping in 19seconds
# compared to the gold standard level-set algorithm used in ROIStudio (in MRIStudio)
# that does it in 4.3 minutes. Thus, UNet making automated skull stripping suitable
# for large-scale, fast processing

## Stroke Segmentation Model Architecture

Unsupervised Lesion Segmentation

DL DAGMNet Lesion Segmentation Architecture:

- UNet3+ intraskip connections
- Fused multiscale contextual information block
- Deep supervision
- L1 regularization on final predicts
- Dual attention gate (DAG)
    - Spatial attention gate (sAG)
    - Channel attention gate (cAG)
- Self-normalized activation (SeLU)
- Batch normalization

DAGMNet as MNet is designed to capture semantical features directly from input images at different receptive scale levels.

DAGMNet with deep supervision was also designed to segment lesions of various volumes with consistent efficiency at each level.

UNet3+'s interskip connections between layers help the model share and re-utilize features between different receptive scale levels with lower computational complexity than DenseUNet and UNet++.

DAGMNet's final fuse block combines all-scale semantic features (from small to large lesion volumes) to generate final predict output.

L1-regularization at the end of DAGMNet's fuse block prevents it from making false-positive predictions.

DAGMNet uses DAG to overcome the high variability in lesion volume, shape and location by conditioning the networks to emphasize the most informative components (spatial and channel-wise) from encoders' semantic features at each level prior to decoders, which increases the sensitivity to small lesions or lesions with subtle contrast.

sAG in DAG is used to spatially excite the receptive for the most abnormal voxels like the hyperintensity in DWI or predicts from the third channel ("IS").

cAG was included to excite the most semantical/morphological features associated to ischemic lesions from artifacts.

The classifical method's ("IS") inclusion of information acted as the 3rd channel aiming to help the networks focus on abnormal voxels even if in small clusters (small lesions).

Batch normalization and SeLU activation function help self-normalize the networks, avoiding the gradient vanish problem usually faced in 3D networks.

In [None]:
# It is not sequential because it is complex network
dagmnet = keras.model()

#UNet3+ intraskip connections
#Fused multiscale contextual information block
#Deep supervision
#L1 regularization on final predicts
#Dual attention gate (DAG)
    #Spatial attention gate (sAG)
    #Channel attention gate (cAG)
#Self-normalized activation (SeLU)
#Batch normalization

## Train Stroke Segmentation ML/DL Models

Unsupervised Lesion Segmentation

DL DAGMNet Lesion Segmentation

- Hybrid loss function was used to train DAGMNet to improve the imbalanced voxel classes issue (between the number of lesion and non-lesion voxels) and regularize the false-positive rate predicted by networks

L_final = L_fuse + 4_sum_i=1(L_i_side)

- L_fuse is the loss function supervised at the final output of the fusion block L_fuse
- L_i_side is the loss function supervised at the side output of the decoders at each level X_de_i

L_fuse = w_gds * L_gds + w_bbc * L_bbc + w_r * L_1(p)

L_i_side = w_gds * L_gds + w_bbc * L_bbc

- L_gds: is the generalized dice loss function
- L_bbc: is the balanced cross entropy
- L_1(p): is the L1 regularization on all predicted voxels

L1 = sum_x_y_z = |p_x_y_z|

- p_x_y_z: is the predicts from networks at (x,y,z) coordinates

- ADAM optimizer with `learning rate = 3E-4` optimizes the loss function during training
- Learning rate will be factor by 5 when loss function is on plateu over `5 epochs` with min `learning rate = 1E-5`

The dimension of inputs and predicts during training and inferencing:

- The networks were trained and inferenced in IP-MNI DS space, which is 96x112x48 voxels
- All images (DWI, ADC, IS) in IP-MNI space were downsampled (DS) along x, y and z axis with stride of (2,2,4) into 16 smaller volumes in 96x112x48x3 voxels for 3-channel models.
- The input shape of networks with 3 channels is 96x112x48x3
- During training 1 of 16 downsampled volumes are randomly selected to be the inputs of 3-channel networks for each subject in a selected batch (batch size = 4)

Note: This re-sampling aims to increase the network's robustness to the image's spatial shifting and inhomogeneity

- To make efficient backpropagation for training the networks, the downsampled volumes were standard normalized to zero mean and unit-variant within the brain mask region for DWI and ADC channels

Hyperparameters:

- for searching hyperparameters, such as weights for loss functions and networks structures, 20% of subjects from the training set were randomly selected as validation dataset with the same random states for all experiment models.

- In each experiment, once the loss functions converged in validation set along the training epochs (200 epochs at top, early stops at 100 epochs if training and validation loss function converge early), we selected the best model from the snapshot models at every 10 epochs in the validation set.

- We chose maximum training epoch as 200 since most experiments of bencmark models converged after 80 to 120 epochs.

- For each experiment, we traine dthe same-type networks independently, with different training set and different resampled validation set at least twice to check if similar performance would be achieved and avoid overfitting.

- Once the networks parameters (including weights for loss or regularization, different network layers, depth, etc) were finalized according to their best dice performance in the validation sets, we used the whole training set including the validation set to train the final deployed maodels and make the loss function and dice scores converge to the similar level as the previous experiment.

- This allowed us to fully use all training set and capture the population variation.


In [1]:
# train_stroke_seg.py

## Evaluate Stroke Segmentation ML/DL Models Quantitatively

### Performance Metrics:

- Dice = 2TP/(2TP + FN + FP)
- Precision = TP/(TP + FP)
- Sensitivity = TP/(TP + FN)
- SDR = (num subjects detected with lesions)/(num total subjects in dataset)

### Compare 3D DAGMNet against SoA Models:

- FCN
- UNet
- DeepMedic

"CH3" models utilize 3 channel inputs (DWI+ADC+IS) performed better than "CH2" models (DWI+ADC)

### Compare Models Dice's Probability Density using Lesion Sizes

Compare the **Dice** scores off of the different lesion sizes. We can do that too with ATLAS 2.0 dataset.

- All lesions, n = 459
- Small lesions, n = 152, volume < 1.7ml
- Medium lesions, n = 144, 1.7ml < volume < 14ml
- Large lesions, n = 163, volume >= 14ml

### Compare Models Metrics (Dice, Prec, Sen, FP, FN) using Lesion Sizes

Compare the **Dice** scores off of the different lesion sizes. We can do that too with ATLAS 2.0 dataset.

- All lesions, n = 459
- Small lesions, n = 152, volume < 1.7ml
- Medium lesions, n = 144, 1.7ml < volume < 14ml
- Large lesions, n = 163, volume >= 14ml

### Compare Models Metrics "" "" in Boxplot using Lesion Location

- ACA
- MCA
- PCA
- VB

I will have to check what locations ATLAS provides these locations.

### Compare Models Metrics "" "" in Boxplot using Lesion Hemisphere

- Left
- Right
- Bilateral

### Compare Models Metrics "" "" in Boxplot using Population

- Female
- Male

### Compare Models Metrics "" "" in Boxplot using Demographic

- African American
- Caucasian

### Compare Models Metrics "" "" in Boxplot using MRI Manufacturer

- Siemens
- FE

### Compare Models Metrics "" "" in Boxplot using MRI Magnetic Field

- 1.5T
- 3.0T

### Compare Models Metrics "" "" in Boxplot using Symptom Onset Time

- Onset time < 6h
- Onset time >= 6h


### Compare Models Metrics "" "" in Boxplot using ATLAS Train & Test

We could use ATLAS later as how Dr. Faria uses STIR external dataset to compare with their ICPSR clinical 3D stroke dataset.

- ATLAS 2.0 Train
- ATLAS 2.0 Test

In [None]:
# evaluate_brain_models.py

## Evaluate Stroke Segmentation ML/DL Models Qualitatively

### Compare Models Illustrative Segmentations on Various Stroke Cases

- Col Names: DWI, ADC, Annotation, DeepMedic, DAGMNet_CH3, UNet_CH3, FCN_CH3

Columns from left to right: DWI, ADC, overlays on DWI of manual delineation (red), DeepMedic predicts (yellow), DAGMNet_CH3 (proposed model) predicts (blue), UNet_Ch3 (green), FCN_CH3 (purple).

Each of the rows are different stroke case

- Row A: Typical lesion
- Row B: Case with inhomogeneity between DWI slices; note the high agreement of proposed model with manual annotation
- Row C: Multifocal lesions
- Row D: Small cortical lesion, detected exclusively by DeepMedic and our proposed attention model
- Rows E-G: Typical false positives (arrows) of other models (DeepMedic in particular in areas of
     - E: “physiological" high DWI intensities; 
     - F: DWI artifacts in tissue interfaces; in addition to the cortical areas, this was vastly observed in the basal brain, along the sinuses interfaces, and in the plexus choroids; 
     - G: in possible chronic microvascular white matter lesions
- Rows H, I: Cases in which the retrospective analysis favored the automated prediction, rather than the human evaluation for: 
    - H: lesion delineation 
    - I: lesion prediction (this case was initially categorized by evaluators as “not visible" lesion, but the small lesion predicted by our model was confirmed by follow-up). 
- Row J: Lesion of high-intense core but subtle boundary contrast, which ameliorates the discriminative power of all 3D networks

In [None]:
# evaluate_brain_models.py

## Deploy Stroke Segmentation DL Model for Inference


- During inferencing step, the 16 lesion predicts from networks in IP-MNI DS space were stacked according to the way their inputs volumes were downsampled in the original coordinates, to construct the final predict in IP-MNI space.
- Then the lesion mask was "closed" with connectivity 1 voxel
- The predicted lesion binary mask (predicted_value > 0.5) in IP-MNI space was mapped back to the individual original space.
- We removed the final prediction by removing clusters with <5 pxiels in each slice, which is the smallest size of lesions defined by human evaluators.

In [None]:
# deploy_stroke_seg.py