<a href="https://colab.research.google.com/github/IAmSuyogJadhav/3d-mri-brain-tumor-segmentation-using-autoencoder-regularization/blob/master/NVDLMED_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3D MRI Brain Tumor Segmentation using autoencoder regularization

# Colab Stuff

In [3]:
from google.colab import drive
drive.mount('/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /gdrive


In [0]:
import zipfile
zf = zipfile.ZipFile('/gdrive/My Drive/MICCAI_BraTS_2018_Data_Training.zip')
zf.extractall()

In [5]:
# Implementation og GroupNorm from https://github.com/titu1994/Keras-Group-Normalization/blob/master/group_norm.py
!wget https://raw.githubusercontent.com/titu1994/Keras-Group-Normalization/master/group_norm.py

--2019-04-09 08:44:01--  https://raw.githubusercontent.com/titu1994/Keras-Group-Normalization/master/group_norm.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7952 (7.8K) [text/plain]
Saving to: ‘group_norm.py’


2019-04-09 08:44:02 (86.5 MB/s) - ‘group_norm.py’ saved [7952/7952]



# Imports and helper functions

In [6]:
import keras
from keras.layers import *
from keras.models import Model, Sequential
from group_norm import GroupNormalization

Using TensorFlow backend.


In [0]:
def green_block(inp, filters, name=None):
    
  inp_res = Conv3D(filters=filters, kernel_size=(1, 1, 1), strides=1, data_format='channels_first', name=f'Res_{name}' if name else None)(inp)
  
  # axis=1, because we have channels_first data format
  # No. of groups = 8, as given in the paper
  x = GroupNormalization(groups=8, axis=1, name=f'GroupNorm_1_{name}' if name else None)(inp)
  x = Activation('relu', name=f'Relu_1_{name}' if name else None)(x)
  x = Conv3D(filters=filters, kernel_size=(3, 3, 3), strides=1, padding='same', data_format='channels_first', name=f'Conv3D_1_{name}' if name else None)(x)
  
  x = GroupNormalization(groups=8, axis=1, name=f'GroupNorm_2_{name}' if name else None)(x)
  x = Activation('relu', name=f'Relu_2_{name}' if name else None)(x)
  x = Conv3D(filters=filters, kernel_size=(3, 3, 3), strides=1, padding='same', data_format='channels_first', name=f'Conv3D_2_{name}' if name else None)(x)
  
  out = Add(name=f'Out_{name}' if name else None)([x, inp_res])
  
  return out

In [0]:
# From keras/examples/variational_autoencoder.py
def sampling(args):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.
    # Arguments
        args (tensor): mean and log of variance of Q(z|X)
    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

# Model

## TODO: 
- Decoder- VAE Part (Lower Branch)
- Loss function - <br />
L = L<sub>dice</sub> + 0.1 ∗ L<sub>L2</sub> + 0.1 ∗ L<sub>KL</sub> 

![The model architecture](https://suyogjadhav.com/images/misc/brats2018_sota_model.png)
<center><b>The Model Architecture</b><br />Source: https://arxiv.org/pdf/1810.11654.pdf</center>

<hr />
![The Decoder Structure](https://dev.suyogjadhav.com/images/misc/brats2018_sota_decoder.png)
<center><b>The Decoder Structure</b><br />Source: https://arxiv.org/pdf/1810.11654.pdf</center>

In [29]:
input_shape = (4, 160, 192, 128)
# ----------------------------------------------------------------------------

# Encoder

# ----------------------------------------------------------------------------

## Input
inp = Input(input_shape)

## Initial Blue Block
x = Conv3D(
    filters=32,
    kernel_size=(3, 3, 3),
    strides=1,
    padding='same',
    data_format='channels_first',
    name='Input_x1')(inp)

## Dropout (0.2)
x = Dropout(0.2)(x)

## Green Blocks x1 (output filters = 32)
x1 = green_block(x, 32, name='x1')
x = Conv3D(
    filters=32,
    kernel_size=(3, 3, 3),
    strides=2,
    padding='same',
    data_format='channels_first',
    name='Enc_DownSample_32')(x1)

## Green Blocks x2 (output filters = 64)
x = green_block(x, 64, name='Enc_64_1')
x2 = green_block(x, 64, name='x2')
x = Conv3D(
    filters=64,
    kernel_size=(3, 3, 3),
    strides=2,
    padding='same',
    data_format='channels_first',
    name='Enc_DownSample_64')(x2)

## Green Blocks x2 (output filters = 128)
x = green_block(x, 128, name='Enc_128_1')
x3 = green_block(x, 128, name='x3')
x = Conv3D(
    filters=128,
    kernel_size=(3, 3, 3),
    strides=2,
    padding='same',
    data_format='channels_first',
    name='Enc_DownSample_128')(x3)

## Green Blocks x4 (output filters = 256)
x = green_block(x, 256, name='Enc_256_1')
x = green_block(x, 256, name='Enc_256_2')
x = green_block(x, 256, name='Enc_256_3')
x4 = green_block(x, 256, name='x4')

# ----------------------------------------------------------------------------

# Decoder

# ----------------------------------------------------------------------------

## GT (Groud Truth) Part
# ----------------------------------------------------------------------------

### Green Block x1 (output filters=128)
x = Conv3D(
    filters=128,
    kernel_size=(1, 1, 1),
    strides=1,
    data_format='channels_first',
    name='Dec_GT_ReduceDepth_128')(x4)
x = UpSampling3D(
    size=2,
    data_format='channels_first',
    name='Dec_GT_UpSample_128')(x)
x = Add(name='Input_Dec_GT_128')([x, x3])
x = green_block(x, 128, name='Dec_GT_128')

### Green Block x1 (output filters=64)
x = Conv3D(
    filters=64,
    kernel_size=(1, 1, 1),
    strides=1,
    data_format='channels_first',
    name='Dec_GT_ReduceDepth_64')(x)
x = UpSampling3D(
    size=2,
    data_format='channels_first',
    name='Dec_GT_UpSample_64')(x)
x = Add(name='Input_Dec_GT_64')([x, x2])
x = green_block(x, 64, name='Dec_GT_64')

### Green Block x1 (output filters=32)
x = Conv3D(
    filters=32,
    kernel_size=(1, 1, 1),
    strides=1,
    data_format='channels_first',
    name='Dec_GT_ReduceDepth_32')(x)
x = UpSampling3D(
    size=2,
    data_format='channels_first',
    name='Dec_GT_UpSample_32')(x)
x = Add(name='Input_Dec_GT_32')([x, x1])
x = green_block(x, 32, name='Dec_GT_32')

### Blue Block x1 (output filters=32)
x = Conv3D(
    filters=32,
    kernel_size=(3, 3, 3),
    strides=1,
    padding='same',
    data_format='channels_first',
    name='Input_Dec_GT_Output')(x)

### Output Block
out1 = Conv3D(
    filters=3,  # No. of tumor classes is 3
    kernel_size=(1, 1, 1),
    strides=1,
    data_format='channels_first',
    activation='sigmoid',
    name='Dec_GT_Output')(x)

# ----------------------------------------------------------------------------

## VAE (Variational Auto Encoder) Part
# ----------------------------------------------------------------------------

### VD Block
x = GroupNormalization(groups=8, axis=1, name='Dec_VAE_VD')(x4)
x = Activation('relu')(x)
x = Conv3D(
    filters=16,
    kernel_size=(3, 3, 3),
    strides=2,
    padding='same',
    data_format='channels_first')(x)

### Something missing here. The author has been contacted, awaiting reply...

x = Dense(256)(x)

### VDraw Block (Sampling)
z_mean = Dense(128)

out = [out1, x]
Model(inp, out).summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           (None, 4, 160, 192,  0                                            
__________________________________________________________________________________________________
Input_x1 (Conv3D)               (None, 32, 160, 192, 3488        input_11[0][0]                   
__________________________________________________________________________________________________
dropout_11 (Dropout)            (None, 32, 160, 192, 0           Input_x1[0][0]                   
__________________________________________________________________________________________________
GroupNorm_1_x1 (GroupNormalizat (None, 32, 160, 192, 64          dropout_11[0][0]                 
__________________________________________________________________________________________________
Relu_1_x1 

In [30]:
out[1].shape

TensorShape([Dimension(None), Dimension(16), Dimension(10), Dimension(12), Dimension(256)])