<img src="https://miro.medium.com/max/2652/1*eTkBMyqdg9JodNcG_O4-Kw.jpeg" width="100%">
[Image Source](https://medium.com/stanford-ai-for-healthcare/its-a-no-brainer-deep-learning-for-brain-mr-images-f60116397472)

# Brain Tumor Auto-Segmentation for Magnetic Resonance Imaging (MRI)

Welcome to the final part of the "Artificial Intelligence for Medicine" course 1!

You will learn how to build a neural network to automatically segment tumor regions in brain, using [MRI (Magnetic Resonance Imaging)](https://en.wikipedia.org/wiki/Magnetic_resonance_imaging) scans.

The MRI scan is one of the most common image modalities that we encounter in the radiology field.  Other data modalities include: 
- [Computer Tomography (CT)](https://en.wikipedia.org/wiki/CT_scan), 
- [Ultrasound](https://en.wikipedia.org/wiki/Ultrasound)
- [X-Rays](https://en.wikipedia.org/wiki/X-ray). 

In this assignment we will be focusing on MRIs but many of our learnings applies to other mentioned modalities as well.  We'll walk you through some of the steps of training a deep learning model for segmentation.

**You will learn:**
- What is in an MR image
- Standard data preparation techniques for MRI datasets
- Metrics and loss functions for segmentation
- Visualizing and evaluating segmentation models

## Table of Contents
- [0. Packages](#0)
- [1. Dataset](#1)
  - [1.1 What is an MRI?](#1-1)
  - [1.2 MRI Data Processing](#1-2)
  - [1.3 Exploring the Dataset](#1-3)
  - [1.4 Data Preprocessing using Patches](#1-4)
    - [Exercise 1 - get_sub_volume](#ex-1)
    - [Exercise 2 - standardization](#ex-2)
- [2. 3D U-Net Model](#2)
- [3. Metrics](#3)
  - [3.1 Dice Coefficient](#3-1)
    - [Exercise 3 - single_class_dice_coefficient](#ex-3)
    - [3.1.1 Dice Coefficient for Multiple Classes](#3-1-1)
      - [Exercise 4 - dice_coefficient](#ex-4)
  - [3.2 Soft Dice Loss](#3-2)
    - [3.2.1 Multi-Class Soft Dice Loss](#3-2-1)
      - [Exercise 5 - soft_dice_loss](#ex-5)
- [4. Create and Train the Model](#4)
  - [4.1 Training on a Large Dataset](#4-1)
  - [4.2 Loading a Pre-Trained Model](#4-2)
- [5. Evaluation](#5)
  - [5.1 Overall Performance](#5-1)
  - [5.2 Patch-level Predictions](#5-2)
    - [5.2.1 Sensitivity and Specificity](#5-2-1)
      - [Exercise 6 - compute_class_sens_spec](#ex-6)
  - [5.3 Running on Entire Scans](#5-3)

## 0. Packages <a name="0"></a>
We'll use keras, nibabel, numpy, pandas, matplotlib, tensorflow.keras.backend (K), and some provided utilities.


In [None]:
import keras
import json
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K 

import util
from public_tests import *
from test_utils import *

import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

## 1. Dataset <a name="1"></a>
### 1.1 What is an MRI? <a name="1-1"></a>
MRI is a 3D imaging technique. Each voxel can have multiple sequences. We'll segment edemas, non-enhancing tumors, and enhancing tumors.

In [None]:
HOME_DIR = "data/BraTS-Data/"
DATA_DIR = HOME_DIR

def load_case(image_nifty_file, label_nifty_file):
    image = np.array(nib.load(image_nifty_file).get_fdata())
    label = np.array(nib.load(label_nifty_file).get_fdata())
    return image, label

In [None]:
# Visualize a case
image, label = load_case(DATA_DIR + "imagesTr/BRATS_003.nii.gz", DATA_DIR + "labelsTr/BRATS_003.nii.gz")
image = util.get_labeled_image(image, label)
util.plot_image_grid(image)

In [None]:
# Animated gif of MRI
image, label = load_case(DATA_DIR + "imagesTr/BRATS_003.nii.gz", DATA_DIR + "labelsTr/BRATS_003.nii.gz")
util.visualize_data_gif(util.get_labeled_image(image, label))

### 1.4 Data Preprocessing using Patches <a name="1-4"></a>
We'll extract random patches and do standardization.

In [None]:
# UNQ_C1
def get_sub_volume(image, label, orig_x=240, orig_y=240, orig_z=155, output_x=160, output_y=160, output_z=16, num_classes=4, max_tries=1000, background_threshold=0.95):
    X = None
    y = None
    tries = 0
    while tries < max_tries:
        start_x = np.random.randint(0, orig_x - output_x + 1)
        start_y = np.random.randint(0, orig_y - output_y + 1)
        start_z = np.random.randint(0, orig_z - output_z + 1)
        y = label[start_x:start_x+output_x, start_y:start_y+output_y, start_z:start_z+output_z]
        y = keras.utils.to_categorical(y, num_classes=num_classes)
        bgrd_ratio = np.sum(y[:, :, :, 0]) / (output_x * output_y * output_z)
        tries += 1
        if bgrd_ratio < background_threshold:
            X = np.copy(image[start_x:start_x+output_x, start_y:start_y+output_y, start_z:start_z+output_z, :])
            X = np.transpose(X, (3, 0, 1, 2))
            y = np.transpose(y, (3, 0, 1, 2))
            y = y[1:, :, :, :]
            return X, y
    print(f"Tried {tries} times to find a sub-volume. Giving up...")

In [None]:
get_sub_volume_test(get_sub_volume)

In [None]:
image, label = load_case(DATA_DIR + "imagesTr/BRATS_001.nii.gz", DATA_DIR + "labelsTr/BRATS_001.nii.gz")
X, y = get_sub_volume(image, label)
util.visualize_patch(X[0, :, :, :], y[2])

In [None]:
# UNQ_C2
def standardize(image):
    standardized_image = np.zeros_like(image)
    for c in range(image.shape[0]):
        for z in range(image.shape[3]):
            image_slice = image[c, :, :, z]
            centered = image_slice - np.mean(image_slice)
            if np.std(centered) != 0:
                centered_scaled = centered / np.std(centered)
            else:
                centered_scaled = centered
            standardized_image[c, :, :, z] = centered_scaled
    return standardized_image

In [None]:
standardize_test(standardize, X)

In [None]:
X_norm = standardize(X)
util.visualize_patch(X_norm[0, :, :, :], y[2])

## 2. 3D U-Net Model <a name="2"></a>
We'll use util.unet_model_3d(loss_function) to build the model.

## 3. Metrics <a name="3"></a>
### 3.1 Dice Similarity Coefficient <a name="3-1"></a>
#### Exercise 3 - single_class_dice_coefficient <a name="ex-3"></a>

In [None]:
# UNQ_C3
def single_class_dice_coefficient(y_true, y_pred, axis=(0, 1, 2), epsilon=0.00001):
    dice_numerator = 2 * K.sum(y_true * y_pred, axis=axis) + epsilon
    dice_denominator = K.sum(y_true, axis=axis) + K.sum(y_pred, axis=axis) + epsilon
    dice_coefficient = dice_numerator / dice_denominator
    return dice_coefficient

In [None]:
epsilon = 1
sess = K.get_session()
single_class_dice_coefficient_test(single_class_dice_coefficient, epsilon, sess)

#### 3.1.1 Dice Coefficient for Multiple Classes <a name="3-1-1"></a>
#### Exercise 4 - dice_coefficient <a name="ex-4"></a>

In [None]:
# UNQ_C4
def dice_coefficient(y_true, y_pred, axis=(1, 2, 3), epsilon=0.00001):
    dice_numerator = 2 * K.sum(y_true * y_pred, axis=axis) + epsilon
    dice_denominator = K.sum(y_true, axis=axis) + K.sum(y_pred, axis=axis) + epsilon
    dice_coefficient = K.mean(dice_numerator / dice_denominator)
    return dice_coefficient

In [None]:
epsilon = 1
sess = K.get_session()
dice_coefficient_test(dice_coefficient, epsilon, sess)

## 3.2 Soft Dice Loss <a name="3-2"></a>
#### Exercise 5 - soft_dice_loss <a name="ex-5"></a>

In [None]:
# UNQ_C5
def soft_dice_loss(y_true, y_pred, axis=(1, 2, 3), epsilon=0.00001):
    dice_numerator = 2 * K.sum(y_true * y_pred, axis=axis) + epsilon
    dice_denominator = K.sum(y_true ** 2, axis=axis) + K.sum(y_pred ** 2, axis=axis) + epsilon
    dice_loss = 1 - K.mean(dice_numerator / dice_denominator)
    return dice_loss

In [None]:
epsilon = 1
sess = K.get_session()
soft_dice_loss_test(soft_dice_loss, epsilon, sess)

## 4. Create and Train the Model <a name="4"></a>
We'll use util.unet_model_3d(loss_function=soft_dice_loss, metrics=[dice_coefficient])

In [None]:
model = util.unet_model_3d(loss_function=soft_dice_loss, metrics=[dice_coefficient])

In [None]:
base_dir = HOME_DIR + "processed/"
with open(base_dir + "config.json") as json_file:
    config = json.load(json_file)
train_generator = util.VolumeDataGenerator(config["train"], base_dir + "train/", batch_size=3, dim=(160, 160, 16), verbose=0)
valid_generator = util.VolumeDataGenerator(config["valid"], base_dir + "valid/", batch_size=3, dim=(160, 160, 16), verbose=0)

In [None]:
model.load_weights(HOME_DIR + "model_pretrained.hdf5")

In [None]:
model.summary()

## 5. Evaluation <a name="5"></a>
### 5.2 Patch-level Predictions <a name="5-2"></a>

In [None]:
util.visualize_patch(X_norm[0, :, :, :], y[2])

In [None]:
X_norm_with_batch_dimension = np.expand_dims(X_norm, axis=0)
patch_pred = model.predict(X_norm_with_batch_dimension)

In [None]:
threshold = 0.5
patch_pred[patch_pred > threshold] = 1.0
patch_pred[patch_pred <= threshold] = 0.0

In [None]:
print("Patch and ground truth")
util.visualize_patch(X_norm[0, :, :, :], y[2])
plt.show()
print("Patch and prediction")
util.visualize_patch(X_norm[0, :, :, :], patch_pred[0, 2, :, :, :])
plt.show()

#### 5.2.1 Sensitivity and Specificity <a name="5-2-1"></a>
##### Exercise 6 - compute_class_sens_spec <a name="ex-6"></a>

In [None]:
# UNQ_C6
def compute_class_sens_spec(pred, label, class_num):
    class_pred = pred[class_num]
    class_label = label[class_num]
    tp = np.sum((class_label == 1) & (class_pred == 1))
    tn = np.sum((class_label == 0) & (class_pred == 0))
    fp = np.sum((class_label == 0) & (class_pred == 1))
    fn = np.sum((class_label == 1) & (class_pred == 0))
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    return sensitivity, specificity

In [None]:
compute_class_sens_spec_test(compute_class_sens_spec)

In [None]:
sensitivity, specificity = compute_class_sens_spec(patch_pred[0], y, 2)
print(f"Sensitivity: {sensitivity:.4f}")
print(f"Specificity: {specificity:.4f}")

In [None]:
def get_sens_spec_df(pred, label):
    patch_metrics = pd.DataFrame(
        columns=['Edema', 'Non-Enhancing Tumor', 'Enhancing Tumor'],
        index=['Sensitivity', 'Specificity'])
    for i, class_name in enumerate(patch_metrics.columns):
        sens, spec = compute_class_sens_spec(pred, label, i)
        patch_metrics.loc['Sensitivity', class_name] = round(sens, 4)
        patch_metrics.loc['Specificity', class_name] = round(spec, 4)
    return patch_metrics

In [None]:
df = get_sens_spec_df(patch_pred[0], y)
print(df)

### 5.3 Running on Entire Scans <a name="5-3"></a>
To run on whole scans, see util.predict_and_viz.

# That's all for now!
Congratulations on finishing this challenging assignment! You now know all the basics for building a neural auto-segmentation model for MRI images.