<a href="https://colab.research.google.com/github/IAmSuyogJadhav/3d-mri-brain-tumor-segmentation-using-autoencoder-regularization/blob/master/NVDLMED_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3D MRI Brain Tumor Segmentation using autoencoder regularization

# Colab Stuff

In [0]:
from google.colab import drive
drive.mount('/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /gdrive


In [0]:
import zipfile
zf = zipfile.ZipFile('/gdrive/My Drive/MICCAI_BraTS_2018_Data_Training.zip')
zf.extractall()

In [0]:
# Implementation og GroupNorm from https://github.com/titu1994/Keras-Group-Normalization/blob/master/group_norm.py
!wget https://raw.githubusercontent.com/titu1994/Keras-Group-Normalization/master/group_norm.py

--2019-04-08 13:39:00--  https://raw.githubusercontent.com/titu1994/Keras-Group-Normalization/master/group_norm.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7952 (7.8K) [text/plain]
Saving to: ‘group_norm.py’


2019-04-08 13:39:00 (113 MB/s) - ‘group_norm.py’ saved [7952/7952]



# Imports and helper functions

In [0]:
import keras
from keras.layers import *
from keras.models import Model, Sequential
from group_norm import GroupNormalization

In [0]:
def green_block(inp, filters, name=None):
    
  inp_res = Conv3D(filters=filters, kernel_size=(1, 1, 1), strides=1, data_format='channels_first', name=f'Res_{name}' if name else None)(inp)
  
  # axis=1, because we have channels_first data format
  # No. of groups is unknown, going with 32 (default)
  x = GroupNormalization(axis=1, name=f'GroupNorm_1_{name}' if name else None)(inp)
  x = Activation('relu', name=f'Relu_1_{name}' if name else None)(x)
  x = Conv3D(filters=filters, kernel_size=(3, 3, 3), strides=1, padding='same', data_format='channels_first', name=f'Conv3D_1_{name}' if name else None)(x)
  
  x = GroupNormalization(axis=1, name=f'GroupNorm_2_{name}' if name else None)(x)
  x = Activation('relu', name=f'Relu_2_{name}' if name else None)(x)
  x = Conv3D(filters=filters, kernel_size=(3, 3, 3), strides=1, padding='same', data_format='channels_first', name=f'Conv3D_2_{name}' if name else None)(x)
  
  out = Add(name=f'Out_{name}' if name else None)([x, inp_res])
  
  return out

# Model

## TODO: 
- Decoder- VAE Part (Lower Branch)
- Loss function - <br />
L = L<sub>dice</sub> + 0.1 ∗ L<sub>L2</sub> + 0.1 ∗ L<sub>KL</sub> 

![The model architecture](https://suyogjadhav.com/images/misc/brats2018_sota_model.png)
<center><b>The Model Architecture</b><br />Source: https://arxiv.org/pdf/1810.11654.pdf</center>

<hr />
![The Decoder Structure](https://dev.suyogjadhav.com/images/misc/brats2018_sota_decoder.png)
<center><b>The Decoder Structure</b><br />Source: https://arxiv.org/pdf/1810.11654.pdf</center>

In [86]:
input_shape = (4, 160, 192, 128)
# ----------------------------------------------------------------------------

# Encoder

# ----------------------------------------------------------------------------

## Input
inp = Input(input_shape)

## Initial Blue Block
x = Conv3D(
    filters=32,
    kernel_size=(3, 3, 3),
    strides=1,
    padding='same',
    data_format='channels_first',
    name='Input_x1')(inp)  # Stride Unknown

## Green Blocks x1 (output filters = 32)
x1 = green_block(x, 32, name='x1')
x = Conv3D(
    filters=32,
    kernel_size=(3, 3, 3),
    strides=2,
    padding='same',
    data_format='channels_first',
    name='Enc_DownSample_32')(x1)

## Green Blocks x2 (output filters = 64)
x = green_block(x, 64, name='Enc_64_1')
x2 = green_block(x, 64, name='x2')
x = Conv3D(
    filters=64,
    kernel_size=(3, 3, 3),
    strides=2,
    padding='same',
    data_format='channels_first',
    name='Enc_DownSample_64')(x2)

## Green Blocks x2 (output filters = 128)
x = green_block(x, 128, name='Enc_128_1')
x3 = green_block(x, 128, name='x3')
x = Conv3D(
    filters=128,
    kernel_size=(3, 3, 3),
    strides=2,
    padding='same',
    data_format='channels_first',
    name='Enc_DownSample_128')(x3)

## Green Blocks x4 (output filters = 256)
x = green_block(x, 256, name='Enc_256_1')
x = green_block(x, 256, name='Enc_256_2')
x = green_block(x, 256, name='Enc_256_3')
x4 = green_block(x, 256, name='x4')

# ----------------------------------------------------------------------------

# Decoder

# ----------------------------------------------------------------------------

## GT (Groud Truth) Part
# ----------------------------------------------------------------------------

### Green Block x1 (output filters=128)
x = Conv3D(
    filters=128,
    kernel_size=(1, 1, 1),
    strides=1,
    data_format='channels_first',
    name='Dec_GT_ReduceDepth_128')(x4)
x = UpSampling3D(
    size=2,
    data_format='channels_first',
    name='Dec_GT_UpSample_128')(x)
x = Add(name='Input_Dec_GT_128')([x, x3])
x = green_block(x, 128, name='Dec_GT_128')

### Green Block x1 (output filters=64)
x = Conv3D(
    filters=64,
    kernel_size=(1, 1, 1),
    strides=1,
    data_format='channels_first',
    name='Dec_GT_ReduceDepth_64')(x)
x = UpSampling3D(
    size=2,
    data_format='channels_first',
    name='Dec_GT_UpSample_64')(x)
x = Add(name='Input_Dec_GT_64')([x, x2])
x = green_block(x, 64, name='Dec_GT_64')

### Green Block x1 (output filters=32)
x = Conv3D(
    filters=32,
    kernel_size=(1, 1, 1),
    strides=1,
    data_format='channels_first',
    name='Dec_GT_ReduceDepth_32')(x)
x = UpSampling3D(
    size=2,
    data_format='channels_first',
    name='Dec_GT_UpSample_32')(x)
x = Add(name='Input_Dec_GT_32')([x, x1])
x = green_block(x, 32, name='Dec_GT_32')

### Output Block
out1 = Conv3D(
    filters=4,
    kernel_size=(1, 1, 1),
    strides=1,
    data_format='channels_first', name='Output_1')(x)

# ----------------------------------------------------------------------------

## VAE (Variational Auto Encoder Part)
# ----------------------------------------------------------------------------

# TODO



out = out1
Model(inp, out).summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_51 (InputLayer)           (None, 4, 160, 192,  0                                            
__________________________________________________________________________________________________
Input_x1 (Conv3D)               (None, 32, 160, 192, 3488        input_51[0][0]                   
__________________________________________________________________________________________________
GroupNorm_1_x1 (GroupNormalizat (None, 32, 160, 192, 64          Input_x1[0][0]                   
__________________________________________________________________________________________________
Relu_1_x1 (Activation)          (None, 32, 160, 192, 0           GroupNorm_1_x1[0][0]             
__________________________________________________________________________________________________
Conv3D_1_x