![brain_baner](http://www.mf-data-science.fr/images/projects/brain_baner.jpg)

<h1 style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold">Context</h1>

The goal of this competition, initiated by the **Radiological Society of North America *(RSNA)*** in partnership with the **Medical Image Computing and Computer Assisted Intervention Society *(the MICCAI Society)*** is to predict the methylation of the **MGMT promoter**, which is an important gene biomarker for treatment of brain tumors.

These predictions will be based on a database of **MRI *(magnetic resonance imaging)*** scans of several hundred patients.

<h1 style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold">Data</h1>

Each independent case has a dedicated folder identified by a five-digit number. Within each of these “case” folders, there are four sub-folders, each of them corresponding to each of the structural multi-parametric MRI (mpMRI) scans, in DICOM format. The exact mpMRI scans included are:

- Fluid Attenuated Inversion Recovery (FLAIR)
- T1-weighted pre-contrast (T1w)
- T1-weighted post-contrast (T1Gd)
- T2-weighted (T2)

| ![brain_baner](http://www.mf-data-science.fr/images/projects/brain_tumor_types.png) | 
|:--:| 
| *Examples of the four MR sequence types included in this work* |

<h1 style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold">Acknowledgement</h1>

This Notebook is inspired from *Ammar Alhaj Ali* work :
- [🧠Brain Tumor 3D [Training]](https://www.kaggle.com/ammarnassanalhajali/brain-tumor-3d-training)
- [🧠Brain Tumor 3D [Inference]](https://www.kaggle.com/ammarnassanalhajali/brain-tumor-3d-inference)

<h1 style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold">Summary</h1>

1. [Exploratory data analysis (EDA)](#section_1)      
    1.1. [Submission sample & train.csv](#section_1_1)      
    1.2. [MRI train data](#section_1_2)      
    1.3. [Data cleaning](#section_1_3)      

2. [Preprocessing](#section_2)      
    2.1. [Crop and resize the images](#section_2_1)      
    2.2. [Equalization CLAHE](#section_2_2)      
    2.3. [Denoising filter](#section_2_3)      
    2.4. [Global preprocessing function](#section_2_4)      
    
3. [Development of supervised models](#section_3)      
    3.1. [Multimodal inputs CNN from scratch](#section_3_1)      
    3.2. [Define loaders for images sequences 4 MRI types](#section_3_2)      
    3.3. [Define folds](#section_3_3)      
    3.4. [Keras custom data generator](#section_3_4)      
    3.5. [Define CNN Multi-inputs model](#section_3_5)      
    
4. [Test of trained final model](#section_4)     
5. [Try another approach: Transfer Learning](#section_5)

# <span style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold" id="section_1">Exploratory data analysis (EDA)</span>

First, we have to load the usefull Python libraries :

In [None]:
# Import Python libraries
import os
import glob
import shutil
import re
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import animation, rc
import seaborn as sns
from tqdm.notebook import tqdm
import pydicom  as dicom
import cv2
from PIL import Image

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv3D, MaxPool3D, GlobalAveragePooling3D, Dense, Dropout, BatchNormalization, concatenate
from tensorflow.keras.models import Model
from keras.callbacks import  ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model
from keras.metrics import AUC
from sklearn.model_selection import train_test_split

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_1_1">Submission sample & train.csv</span>

We will already look at the exemple **submission file** to see exactly what we need to predict in the end.

In [None]:
# Paths of the dataset
input_path = "../input/rsna-miccai-brain-tumor-radiogenomic-classification/"
sample_file = "sample_submission.csv"

# Load the submission sample dataset
sample_submission = pd.read_csv(
    input_path + sample_file)
sample_submission.head(5)

All of the subjects in this dataset appear to have a brain tumor. MGMT_Class = 0 refers to people who do not have the MGMT promoter methylation. MGMT_Class = 1 appears to be someone who has the MGMT promoter methylation. 

It is a competition which gives **the probability of it in `MGMT_value`** feature. `BraTS21ID` is the patient's identification.

Now, let's have a look to **train.csv** file :

In [None]:
train_labels_file = "train_labels.csv"
train_labels = pd.read_csv(
    input_path + train_labels_file)
train_labels.head(5)

We can see that the data structure is the same as for the submission file, knowing that here **MGMT_value is indeed equal to 0 or 1 and no longer a probability**.

Let's look at the distribution of the values of this variable in the train set :

In [None]:
sns.set_style("whitegrid")
fig = plt.figure(figsize=(8,6))
# Countplot with Seaborn
ax = sns.countplot(data=train_labels,
                   x="MGMT_value")
# Annotating bars
for p in ax.patches:
    ax.annotate(
        format(p.get_height(), '.0f'), 
               (p.get_x() + p.get_width() / 2., p.get_height()),
        ha = 'center', va = 'center', 
        xytext = (0, 10), 
        textcoords = 'offset points')

sns.despine(left=True, bottom=True)
plt.title("MGMT value distribution in train labels\n",
          fontsize=18, color="#0b0a2d")
plt.show()

The distribution is almost equal between class 0 and class 1. The presence of MGMT is slightly higher. A total of **585 patients are tested in the train set**.

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_1_2">MRI train data</span>

Train data contains one record per patient. For each patient, four sub-files are available *(FLAIR, T1w, T1wCE and T2w)* in which the MRI image sequences are distributed.

![data_structure](http://www.mf-data-science.fr/images/projects/data_structure.jpg)

We are going to take a look at what an MRI image looks like :

In [None]:
def plot_examples(row = 0, cat = 'FLAIR'):
    '''
    This function allows to display the MRI images of a 
    category entered in parameter for a given patient. 
    5 Random images are displayed simultaneously.
    ***************************************************
    PARAMETERS
    ***************************************************
    - row : integer
        Line number corresponding to a patient 
        in the train_labels.csv file
    - cat : string
        MR sequence types to display. Can be in :
        * FLAIR
        * T1w
        * T1wCE
        * T2w
    '''
    folder = str(train_labels.loc[row, 'BraTS21ID']).zfill(5)
    path_file = ''.join([input_path, 'train/', folder, '/', cat, '/'])
    images = os.listdir(path_file)
    
    fig, axs = plt.subplots(1, 5, figsize=(30, 10))
    fig.subplots_adjust(hspace = .2, wspace=.2)
    axs = axs.ravel()
    
    for num in range(5):
        data_file = dicom.dcmread(path_file+images[num])
        img = data_file.pixel_array
        axs[num].imshow(img, cmap='gray')
        axs[num].set_title(cat+' '+images[num])
        axs[num].set_xticklabels([])
        axs[num].set_yticklabels([])
        axs[num].grid(False)
    
    plt.suptitle("MRI "+cat+" Scan for patient "+folder,
                 fontsize=18, color="#0b0a2d",
                 x=.5, y=.8)

In [None]:
# Example of FLAIR scans
row = 3
plot_examples(row = row, cat = 'FLAIR')

In [None]:
# Example of T1wCE scans
row = 12
plot_examples(row = row, cat = 'T1wCE')

To better understand the representation of these MRI images, we can also **create an animation to visualize the sequence of images** of a certain category for a given patient.

In [None]:
rc('animation', html='jshtml')

def create_animation(row = 0, cat = 'FLAIR'):
    '''
    This function returns an animation of the MRI images 
    of a category entered as a parameter for a given patient.
    ***************************************************
    PARAMETERS
    ***************************************************
    - row : integer
        Line number corresponding to a patient 
        in the train_labels.csv file
    - cat : string
        MR sequence types to display. Can be in :
        * FLAIR
        * T1w
        * T1wCE
        * T2w
    '''
    folder = str(train_labels.loc[row, 'BraTS21ID']).zfill(5)
    path_file = ''.join([input_path, 'train/', folder, '/', cat])
    t_paths = sorted(
        glob.glob(os.path.join(path_file, "*")), 
        key=lambda x: int(x[:-4].split("-")[-1]),
    )
    images = []
    for filename in t_paths:
        data_file = dicom.dcmread(filename)
        data = data_file.pixel_array
        data = data - np.min(data)
        if np.max(data) != 0:
            data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
        if data.max() == 0:
            continue
        images.append(data)
    
    fig = plt.figure(figsize=(8, 8))
    plt.axis('off')
    im = plt.imshow(images[0], cmap="gray", animated=True)

    def animate_func(i):
        im.set_array(images[i])
        return [im]

    animated = animation.FuncAnimation(
        fig, animate_func, 
        frames = len(images), 
        interval = 1000//24)
    return animated

In [None]:
ani_1 = create_animation(row = 3, cat = 'FLAIR')

In [None]:
ani_1

We are now going to check the number of DCM files to check if their number is the same for each category and for each patient. For that, we will complete a copy of train.csv with the calculated informations:

In [None]:
scan_categories = ["FLAIR","T1w","T1wCE","T2w"]
train_dataset = train_labels.copy()

for scan in scan_categories:
    train_dataset[scan + "_count"] = [len(os.listdir(input_path + "train/" 
                                                      + str(p).zfill(5) 
                                                      + "/" + scan))
                                       for p in train_dataset.BraTS21ID]

In [None]:
fig = plt.figure(figsize = (25,40))
for i, scan in enumerate(scan_categories):
    ax = plt.subplot(4,1,i+1)
    plt.xticks(rotation=70)
    sns.countplot(x=train_dataset[scan + "_count"], ax=ax)
    ax.set_title("Distribution of number of DCM file in {} scans".format(scan),
             fontsize=18, color="#0b0a2d")
plt.show()

Note that some values for each scan category are over-represented. On the other hand, the span ranges of the counters are important. This may be due, for example, to the use of different X-ray machines ...

However, if we only consider patients with maximum values, the amount of data available may be too low for a complex machine learning algorithm.

In [None]:
train_dataset[(train_dataset["FLAIR_count"] == int(train_dataset["FLAIR_count"].mode()))
              & (train_dataset["T1w_count"] == int(train_dataset["T1w_count"].mode()))
              & (train_dataset["T1wCE_count"] == int(train_dataset["T1wCE_count"].mode()))
              & (train_dataset["T2w_count"] == int(train_dataset["T2w_count"].mode()))]

There are only 151 patients whose scans contain the maximum of DCM images out of the 585 at the start. We therefore keep the entire dataset for the moment. Lets have a look to the **test dataset** :

In [None]:
test_dataset = pd.DataFrame(columns=["BraTS21ID",
                                     "FLAIR_count","T1w_count",
                                     "T1wCE_count","T2w_count"])
test_dataset["BraTS21ID"] = os.listdir(input_path + "test/")
for scan in scan_categories:
    test_dataset[scan + "_count"] = [len(os.listdir(input_path + "test/" 
                                                      + str(p).zfill(5) 
                                                      + "/" + scan))
                                       for p in test_dataset.BraTS21ID]

In [None]:
test_dataset.head(5)

In [None]:
fig = plt.figure(figsize = (25,40))
for i, scan in enumerate(scan_categories):
    ax = plt.subplot(4,1,i+1)
    plt.xticks(rotation=70)
    sns.countplot(x=test_dataset[scan + "_count"], ax=ax)
    ax.set_title("Distribution of number of DCM file in TEST {} scans".format(scan),
             fontsize=18, color="#0b0a2d")
plt.show()

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_1_3">Data cleaning</span>

In the animation of the scans projected above, we notice that a certain number of images at the beginning or at the end of the sequence has a lot of black area.

These images will therefore be useless in our models and may even cause over-training.
**When creating the image sequences, we will therefore start from the central image of each folder and we will then take the same number of images upstream and downstream**.

Let's take an example from a single image :

In [None]:
# Path to sample image
sample_img_path = ''.join([input_path, 'train/00005/FLAIR/Image-80.dcm'])
sample_img = dicom.dcmread(sample_img_path)
sample_img = sample_img.pixel_array
plt.imshow(sample_img)
print("% of colored pixels : {:.2f}".format(
    np.sum(np.where(sample_img!=0,1,0)
           /(sample_img.shape[0]*sample_img.shape[1]))*100))

In this example, only 7% of the image has colored pixels.      
We will now look at the distribution of the **colorization rates of the images in the full FLAIR folder** :

In [None]:
img_colored = []
g_path = ''.join([input_path, 'train/00005/FLAIR/'])
for imgs in os.listdir(g_path):
    img_path = ''.join([g_path, imgs])
    temp_img = dicom.dcmread(img_path)
    temp_img = temp_img.pixel_array
    colored_zone = round(np.sum(np.where(temp_img!=0,1,0)
                           /(temp_img.shape[0]*temp_img.shape[1])),2)
    img_colored.append(colored_zone)

fig = plt.figure(figsize=(12,8))
sns.countplot(x=img_colored)
plt.xlabel("Rate of colored pixels")
plt.xticks(rotation=70)
plt.title("Colored pixel rate distribution in 00005 FLAIR folder\n",
          fontsize=18, color="#0b0a2d")
plt.show()

It can be seen that the majority of the images are completely black.

# <span style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold" id="section_2">Preprocessing</span>
We will be using several preprocessing techniques on our images.
- Crop images to reduce black areas.
- Resize images.
- Application of a denoising filter.

**We will not apply image equalization** as the different types of scans already have voluntary contrast variations.

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_2_1">Crop and resize the images</span>

In [None]:
def crop_resize_img(img, scale=1.0, dim=(244,244)): 
    '''
    Crop function to keep a central part of the 2D image 
    according to a defined scale and resize it to 
    specific dimension.
    ****************************************************
    PARAMETERS
    ****************************************************
    - img : 2D array
        2D array of pixels in the image
    - scale : float
        Desired scale of the cropped image
    - dim : Tuple
        Tuple of integer with desired final width, height
    '''
    # Crop image
    center_x, center_y = img.shape[1] / 2, img.shape[0] / 2
    width_scaled, height_scaled = img.shape[1] * scale, img.shape[0] * scale
    left_x, right_x = center_x - width_scaled / 2, center_x + width_scaled / 2
    top_y, bottom_y = center_y - height_scaled / 2, center_y + height_scaled / 2
    img_cropped = img[int(top_y):int(bottom_y), int(left_x):int(right_x)]
    
    # Resize
    img_cropped = cv2.resize(img_cropped, dim, interpolation = cv2.INTER_AREA)
    return img_cropped

In [None]:
fig = plt.figure(figsize=(12,6))
ax = plt.subplot(1,2,1)
ax.imshow(sample_img)
ax.set_title("Original image")
ax1 = plt.subplot(1,2,2)
img_cropped = crop_resize_img(sample_img, 0.7, (244,244))
ax1.imshow(img_cropped)
ax1.set_title("Cropped (244,244) image with scale = 0.7")
plt.show()

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_2_4">Global preprocessing function</span>

We can now create a global preprocessing function that will be applied to our MRI images before entering the neural network models. This function will take over the various treatments seen previously.

In [None]:
def mri_preprocessor(
    img, scale=.8, 
    dim=(240,240)):
    
    # Apply all preprocess
    x = crop_resize_img(img,scale,dim)
    
    return x

In [None]:
fig = plt.figure(figsize=(12,6))
ax = plt.subplot(1,2,1)
ax.imshow(sample_img)
ax.set_title("Original image")
ax1 = plt.subplot(1,2,2)
img_pre = mri_preprocessor(sample_img, 
                           scale=.8, 
                           dim=(240,240))
ax1.imshow(img_pre)
ax1.set_title("Preprocess image")
plt.show()

# <span style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold" id="section_3">Development of supervised models</span>

We will test several models of neural networks and test the main metrics to determine the best model.

- **CNN** from scratch (Baseline).
- **Learning transfer**.
- **Fine Tuning**.

The metrics tested will be the **accuracy** in train and validation.

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_3_1">Multimodal inputs CNN from scratch</span>

As part of this modeling, it is necessary to develop the models on each of the 4 types of scans available to patients:

- Fluid Attenuated Inversion Recovery (FLAIR).
- T1-weighted pre-contrast (T1w).
- T1-weighted post-contrast (T1Gd).
- T2-weighted (T2).

As a reminder, these models are stored in the variable:
```Python
scan_categories = ["FLAIR", "T1w", "T1wCE", "T2w"]
```

The arrays of the patient images will be loaded into the variable X, the labels of MGMT_value into the variable y. To obtain a binary classification, we will use a global **sigmoid activation layer** to the results of our 4 models *(classifier layer)*.

We will test the models with **a sequence of images** *(as for a video)* by taking the **middle 24 images** *(not entirely black)* of each patient for each type of scan :

In [None]:
# Initial parameters
IMAGE_SIZE = 120
NUM_IMAGES = 24
BATCH_SIZE = 4

num_folds = 5
Selected_fold = 1 #1,2,3,4,5 

In [None]:
train_df = train_labels
train_df['BraTS21ID5'] = [format(x, '05d') for x in train_df.BraTS21ID]
train_df["Fold"]="train"
train_df.head(3)

In [None]:
test = sample_submission
test['BraTS21ID5'] = [format(x, '05d') for x in test.BraTS21ID]
test.head(3)

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_3_2">Define loaders for images sequences 4 MRI types</span>

In [None]:
def load_dicom_image(path, img_size=IMAGE_SIZE, preproc=True):
    data = dicom.read_file(path)
    data = data.pixel_array
        
    if (preproc==True):
        data = mri_preprocessor(
            data,
            scale=.8,
            dim=(img_size,img_size))
            
    else:    
        data = cv2.resize(data, (img_size, img_size))
    
    return data

In [None]:
sample_img = dicom.read_file(
    input_path+"train/00046/FLAIR/Image-90.dcm").pixel_array
preproc_img = load_dicom_image(input_path+"train/00046/FLAIR/Image-90.dcm",
                               preproc=True)

fig = plt.figure(figsize=(12,8))
ax1 = plt.subplot(1,2,1)
ax1.imshow(sample_img, cmap="gray")
ax1.set_title(f"Original image shape = {sample_img.shape}")
ax2 = plt.subplot(1,2,2)
ax2.imshow(preproc_img, cmap="gray")
ax2.set_title(f"Preproc image shape = {preproc_img.shape}")
plt.show()

In [None]:
def load_dicom_images_3d(scan_id, mri_type="FLAIR", num_imgs=NUM_IMAGES, img_size=IMAGE_SIZE, split="train"):
    
        files = sorted(glob.glob(f"{input_path}{split}/{scan_id}/{mri_type}/*.dcm"), 
                   key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

        middle = len(files)//2
        num_imgs2 = num_imgs//2
        p1 = max(0, middle - num_imgs2)
        p2 = min(len(files), middle + num_imgs2)
        img3d = np.stack([load_dicom_image(f) for f in files[p1:p2]], axis=-1) 
        if img3d.shape[-1] < num_imgs:
            n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
            img3d = np.concatenate((img3d,  n_zero), axis = -1)

        if np.min(img3d) < np.max(img3d):
            img3d = img3d - np.min(img3d)
            img3d = img3d / np.max(img3d)
            
        return img3d

In [None]:
a = load_dicom_images_3d("00046")
print(np.array(a).shape)
image = a[:, :, 12]
print("Dimension of the CT scan is:", image.shape)
plt.imshow(np.squeeze(image), cmap="gray")

Thanks to this function, we obtain **4 ordered sequences *(one for each Scan type)* of 24 images of dimension 120 x 120** and the preprocessing has been applied.

To couple our four types of scans, we will use a **multi-modal approach** to create our model. We will integrate 4 different inputs for a single final classifier.

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_3_3">Define folds</span>

In [None]:
from sklearn.model_selection import KFold,StratifiedKFold
sfolder = StratifiedKFold(n_splits=5,random_state=13,shuffle=True)
X = train_df[['BraTS21ID']]
y = train_df[['MGMT_value']]

fold_no = 1
for train, valid in sfolder.split(X,y):
    if fold_no==Selected_fold:
        train_df.loc[valid, "Fold"] = "valid"
    fold_no += 1

In [None]:
df_train=train_df[train_df.Fold=="train"]
df_valid=train_df[train_df.Fold=="valid"].iloc[:-1,:]
print("df_train=",len(df_train),"-- df_valid=",len(df_valid))

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_3_4">Keras custom data generator</span>
Thanks to the Sequence module of the Keras library, we are going to create a personalized image generator. This will prevent us from creating Numpy arrays or Tensors containing all the sequences which would quickly overload the memory.

In [None]:
from keras.utils import Sequence
class Dataset(Sequence):
    def __init__(self,df,is_train=True,batch_size=BATCH_SIZE,shuffle=False):
        self.idx = df["BraTS21ID"].values
        self.paths = df["BraTS21ID5"].values
        self.y =  df["MGMT_value"].values
        self.is_train = is_train
        self.batch_size = batch_size
        self.shuffle = shuffle
    def __len__(self):
        return math.ceil(len(self.idx)/self.batch_size)
   
    def __getitem__(self,ids):
        id_path= self.paths[ids]
        batch_paths = self.paths[ids * self.batch_size:(ids + 1) * self.batch_size]
        
        if self.y is not None:
            batch_y = self.y[ids * self.batch_size: (ids + 1) * self.batch_size]
        
        if self.is_train:
            self.full_X_mri = []
            for mri_type in scan_categories:
                list_x =  [load_dicom_images_3d(x,split="train",mri_type=mri_type) for x in batch_paths]
                batch_X = np.stack(list_x, axis=0)
                self.full_X_mri.append(np.array(batch_X))
            return self.full_X_mri,np.array(batch_y)
        else:
            self.full_X_mri = []
            for mri_type in scan_categories:
                list_x =  load_dicom_images_3d(id_path,split="test",mri_type=mri_type)
                batch_X = np.stack(np.expand_dims(list_x, axis=0))
                self.full_X_mri.append(np.array(batch_X))
            return self.full_X_mri
    
    def on_epoch_end(self):
        if self.shuffle and self.is_train:
            ids_y = list(zip(self.idx, self.y))
            tf.random_shuffle(ids_y)
            self.idx, self.y = list(zip(*ids_y))

In [None]:
train_dataset = Dataset(df_train,batch_size=BATCH_SIZE)
valid_dataset = Dataset(df_valid,batch_size=BATCH_SIZE)

Once the generators are created, we can project an image to check:

In [None]:
fig = plt.figure(figsize=(20,8))

for i in range(1):
    images, label = train_dataset[i]
    print("Number of scan type is:", len(images))
    print("Dimension of the CT FLAIR scan is:", images[1].shape)
    print("label=",label)
    print("\n")
    
    ax1 = plt.subplot(1,4,1)
    ax1.imshow(images[0][1][:,:,15], cmap="gray")
    ax1.set_title("FLAIR, 15th image")
    ax2 = plt.subplot(1,4,2)
    ax2.imshow(images[1][1][:,:,15], cmap="gray")
    ax2.set_title("T1w, 15th image")
    ax3 = plt.subplot(1,4,3)
    ax3.imshow(images[2][1][:,:,15], cmap="gray")
    ax3.set_title("T1wCE, 15th image")
    ax4 = plt.subplot(1,4,4)
    ax4.imshow(images[3][1][:,:,15], cmap="gray")
    ax4.set_title("T2w, 15th image")

## <span style="color:#3c99dc; font-size:18px; text-transform: uppercase; font-weight:bold" id="section_3_5">Define CNN Multi-inputs model</span>

In [None]:
def get_model(inputs):
    """Build a 3D convolutional neural network model."""
     
    x = Conv3D(filters=16, kernel_size=3, activation="relu", padding="same")(inputs)
    x = MaxPool3D(pool_size=2)(x)
    x = BatchNormalization()(x)
    
    x = Conv3D(filters=32, kernel_size=3, activation="relu", padding="same")(x)
    x = MaxPool3D(pool_size=2)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.1)(x)
    
    x = Conv3D(filters=32, kernel_size=3, activation="relu", padding="same")(x)
    x = MaxPool3D(pool_size=2)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.1)(x)
    
    x = Conv3D(filters=64, kernel_size=3, activation="relu", padding="same")(x)
    x = MaxPool3D(pool_size=2)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    return x

In [None]:
# Create the four modal input
FLAIR_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_IMAGES, 1))
FLAIR_model = get_model(FLAIR_input)

T1w_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_IMAGES, 1))
T1w_model = get_model(T1w_input)

T1wCE_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_IMAGES, 1))
T1wCE_model = get_model(T1wCE_input)

T2w_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_IMAGES, 1))
T2w_model = get_model(T2w_input)

# Concat models and add final block
cnn = concatenate(
    [FLAIR_model, T1w_model, T1wCE_model, T2w_model])

cnn = GlobalAveragePooling3D()(cnn)
cnn = Dense(units=256, activation="relu")(cnn)
cnn = Dropout(0.3)(cnn)

# Output layer
output_layer = Dense(units = 1, activation = 'sigmoid')(cnn)

# Final model
cnn_model = Model(inputs=[FLAIR_input, 
                          T1w_input, 
                          T1wCE_input, 
                          T2w_input], 
                  outputs=[output_layer],
                  name="multi3dcnn")

# Compile final model
cnn_model.compile(loss='binary_crossentropy',
                  optimizer=keras.optimizers.Adam(learning_rate=0.001),
                  metrics=["accuracy"])

# Define callbacks.
model_save = ModelCheckpoint(f'Brain_3d_multimodal.h5', 
                             save_best_only = True, 
                             monitor = 'val_loss', 
                             mode = 'min', verbose = 1)
early_stop = EarlyStopping(monitor = 'val_loss', 
                           patience = 10, mode = 'min', verbose = 1,
                           restore_best_weights = True)

In [None]:
# Plot model diagram
plot_model(cnn_model, show_shapes=True, show_layer_names=True)

Then we **train the multi-input model** with the network defined above:

In [None]:
#tf.config.experimental_run_functions_eagerly(True)
epochs = 50
cnn_model.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs=epochs, 
    shuffle=False,
    #verbose=2,
    callbacks=[model_save, early_stop])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 7))
ax = ax.ravel()

for i, metric in enumerate(["accuracy","loss"]):
    ax[i].plot(cnn_model.history.history[metric])
    ax[i].plot(cnn_model.history.history["val_" + metric])
    ax[i].set_title("Model {}".format(metric))
    ax[i].set_xlabel("epochs")
    ax[i].set_ylabel(metric)
    ax[i].legend(["train", "val"])

# <span style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold" id="section_4">Test of trained final model</span>

The best model was saved during training. We will therefore load it and test the predictions on the file and the submission images.

In [None]:
test_df = sample_submission.copy()
test_df['BraTS21ID5'] = [format(x, '05d') for x in test_df.BraTS21ID]

In [None]:
test_dataset = Dataset(test_df,is_train=False,batch_size=1)
print(len(test_dataset))
print(np.array(test_dataset).shape)

In [None]:
for i in range(1):
    images= test_dataset[i]
    print("Number of scan type is:", len(images))
    print("Dimension of the CT FLAIR scan is:", images[1].shape)
    plt.imshow(images[1][0,:,:,15], cmap="gray")
    plt.show()

In [None]:
preds = cnn_model.predict(test_dataset)
preds = preds.reshape(-1)

In [None]:
print(preds.shape)
print(preds)

In [None]:
submission = pd.DataFrame({'BraTS21ID':sample_submission['BraTS21ID'],'MGMT_value':preds})
submission.head(5)

In [None]:
submission.to_csv('submission.csv',index=False)

The submission.csv file will be used for the competition *(evaluation with AUC under ROC curve)*. We can also look at the **distribution of the predicted probabilities** :

In [None]:
plt.figure(figsize=(8, 8))
plt.hist(submission["MGMT_value"])
plt.title("Predicted probabilites distribution on test set", 
          fontsize=18, color="#0b0a2d")
plt.show()

# <span style="color:#0b0a2d; font-size:24px; text-transform: uppercase; font-weight:bold" id="section_5">Try another approach: Transfer Learning</span>

In order to complete these models, we will try another approach using **Transfer Learning methods**. We will use a pre-trained deep model to detect the features *(like EfficientNet ...)* and an LSTM layer for the final classification on the matrices obtained. This approach is available in the Notebook :

<span style="font-size:18px">[🧠Brain Tumor - Transfert Learning MRI - All MRI](https://www.kaggle.com/michaelfumery/brain-tumor-transfert-learning-mri-all-mri/)</span>

<span style="color:red; font-size:18px">Don't forget to **upvote** if this Notebook helped you!</span>