In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import glob
import cv2 as cv2
from tqdm import tqdm

from sklearn.model_selection import train_test_split

from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Dropout
from tensorflow.keras.models import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam

# Constants

In [2]:
BATCH_SIZE = 32
DIM = 256
CONV_BASE_FILTERS = 16

# Data processing

In [3]:
path = '/kaggle/input/lgg-mri-segmentation/kaggle_3m/'

In [4]:
masks = glob.glob(path + '*/*_mask.tif*')

images = []
abnormalities = []

for mask in tqdm(masks):
    images.append(mask[:-9]+'.tif')
    if np.max(cv2.imread(mask)) > 0:
        abnormalities.append(1)
    else:
        abnormalities.append(0)

print('\n Check if the lists are correctly built ...\n')
print('Amount of \n - images : {}\n - masks: {}\n - abnormalities: {} \n'.format(len(images), len(masks), len(abnormalities)))
print('Type in the lists of \n - images : {}\n - masks: {}\n - abnormalities: {} \n'.format(type(images[0]), type(masks[0]), type(abnormalities[0])))

In [5]:
paths = pd.DataFrame({'image': images,'mask': masks,'abnormality': abnormalities})

paths.head()

# Data analysis

In [6]:
train, test = train_test_split(paths, stratify=paths['abnormality'], test_size=0.15, random_state=0)
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

train, validation = train_test_split(paths, stratify=paths['abnormality'], test_size=0.1, random_state=0)
train = train.reset_index(drop=True)
validation = validation.reset_index(drop=True)

print('Amount of samples used for : \n - training : {}\n - validation : {}\n - testing : {}'.format(train.shape[0], validation.shape[0], test.shape[0]))

In [7]:
SAMPLES_TO_OBSERVE = 2

random_with_abnormality = paths[paths['abnormality']==1].sample(SAMPLES_TO_OBSERVE)

f, ax = plt.subplots(SAMPLES_TO_OBSERVE,3, figsize=(15,10))
[axi.set_axis_off() for axi in ax.ravel()]
for i in range(SAMPLES_TO_OBSERVE):
    img = cv2.imread(random_with_abnormality.iloc[i]['image'])
    msk = cv2.imread(random_with_abnormality.iloc[i]['mask'])
    blend = cv2.addWeighted(img, 0.5, msk, 0.5, 0.)
    ax[i,0].imshow(img)
    ax[i,1].imshow(msk)
    ax[i,2].imshow(blend)
    if i==0:
        ax[i,0].set_title('MRI')
        ax[i,1].set_title('Mask of the tumor')
        ax[i,2].set_title('Tumor on MRI')

# Data Generator

In [8]:
class DataGenerator(Sequence):
    'Generates data for Keras'
    
    def __init__(self, dataframe, batch_size, dim, shuffle=True):
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.dim = dim
        self.indices = self.dataframe.index.tolist()
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return len(self.indices) // self.batch_size

    def __getitem__(self, index):
        index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
        batch = [self.indices[k] for k in index]
        X, y = self.__get_data(batch)
        return X, y

    def on_epoch_end(self):
        self.index = np.arange(len(self.indices))
        if self.shuffle == True:
            np.random.shuffle(self.index)

    def __get_data(self, list_IDs_temp):
        batch_images = np.empty((self.batch_size, self.dim, self.dim, 3))
        batch_masks = np.empty((self.batch_size, self.dim, self.dim, 1))
        for i,_id in enumerate(list_IDs_temp):
            img = cv2.imread(self.dataframe.iloc[_id]['image'])
            img = img/255.
            if self.dim!=img.shape[1]:
                img = cv2.resize(img, (self.dim, self.dim), interpolation= cv2.INTER_LINEAR)
            batch_images[i]=img
            
            msk = cv2.imread(self.dataframe.iloc[_id]['mask'], cv2.IMREAD_GRAYSCALE)
            msk = msk/255.
            if self.dim!=msk.shape[1]:
                msk = cv2.resize(msk, (self.dim, self.dim), interpolation= cv2.INTER_LINEAR)
            msk = np.expand_dims(msk, axis=-1)
            batch_masks[i]=msk
        return batch_images,batch_masks

In [9]:
train_generator = DataGenerator(train, BATCH_SIZE, DIM)
validation_generator = DataGenerator(validation, BATCH_SIZE, DIM)

In [10]:
masks = train_generator.__getitem__(1)[1]
for batch in range(BATCH_SIZE):
    print('Values of masks number {} : {}'.format(batch+1, np.unique(masks[i])))
    print('\n ----- \n')

# UNet : construction

In [11]:
def DoubleConv2D(input_tensor, filters, kernel_size):
    
    x = Conv2D(filters = filters, kernel_size = kernel_size, kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(filters = filters, kernel_size = kernel_size, kernel_initializer = 'he_normal', padding = 'same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    return x

def EncoderCell(input_tensor, filters, dropout_rate, kernel_size):

    conv_output = DoubleConv2D(input_tensor, filters, kernel_size)
    pool_output = MaxPool2D((2, 2))(conv_output)
    pool_output = Dropout(dropout_rate)(pool_output)

    return conv_output, pool_output

def DecoderCell(VerticalConvInput, HorizontalConvInput, filters, dropout_rate, kernel_size):

    x = Conv2DTranspose(filters, kernel_size, strides = (2, 2), padding = 'same')(VerticalConvInput)
    x = Concatenate()([x, HorizontalConvInput])
    x = Dropout(dropout_rate)(x)
    VerticalConvOutput = DoubleConv2D(x, filters, kernel_size = 3)

    return VerticalConvOutput
  
def UNet(input_shape, n_filters = 16, dropout_rate = 0.1):

    inputs = Input(input_shape)

    conv_output1, pooling_output1 = EncoderCell(inputs, filters = n_filters * 1, kernel_size = 3, dropout_rate = dropout_rate)

    conv_output2, pooling_output2 = EncoderCell(pooling_output1, filters = n_filters * 2, kernel_size = 3, dropout_rate = dropout_rate)
    
    conv_output3, pooling_output3 = EncoderCell(pooling_output2, filters = n_filters * 4, kernel_size = 3, dropout_rate = dropout_rate)

    conv_output4, pooling_output4 = EncoderCell(pooling_output3, filters = n_filters * 8, kernel_size = 3, dropout_rate = dropout_rate)
    
    conv_output5 = DoubleConv2D(pooling_output4, filters = n_filters * 16, kernel_size = 3)
    
    conv_output6 = DecoderCell(conv_output5, conv_output4, filters = n_filters * 8, kernel_size = 3, dropout_rate = dropout_rate)

    conv_output7 = DecoderCell(conv_output6, conv_output3, filters = n_filters * 4, kernel_size = 3, dropout_rate = dropout_rate)
    
    conv_output8 = DecoderCell(conv_output7, conv_output2, filters = n_filters * 2, kernel_size = 3, dropout_rate = dropout_rate)

    conv_output9 = DecoderCell(conv_output8, conv_output1, filters = n_filters * 1, kernel_size = 3, dropout_rate = dropout_rate)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv_output9)
    model = Model(inputs=[inputs], outputs=[outputs])
    return model

# UNet : training

In [12]:
model = UNet((DIM,DIM,3), n_filters=CONV_BASE_FILTERS, dropout_rate=0.05)
model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])

In [13]:
model.summary()

In [14]:
callbacks = [
    EarlyStopping(patience=10, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
    ModelCheckpoint('unet-brain-mri.h5', verbose=1, save_best_only=True, save_weights_only=True)
]

In [15]:
results = model.fit(train_generator, epochs=50, callbacks=callbacks, validation_data = validation_generator)

# Results analysis

In [18]:
plt.figure(figsize=(8, 8))
plt.title('Learning curve')
plt.plot(results.history['loss'], label='loss')
plt.plot(results.history['val_loss'], label='val_loss')
plt.plot( np.argmin(results.history['val_loss']), np.min(results.history['val_loss']), marker='x', color='r', label='best model')
plt.xlabel('Epochs')
plt.ylabel('log_loss')
plt.legend()

In [24]:
for i in range(30):
    index=np.random.randint(1,len(test.index))
    img = cv2.imread(test['image'].iloc[index])
    img = img / 255.
    print(img.shape)
    img = np.expand_dims(img, axis=0)
    pred=model.predict(img)

    plt.figure(figsize=(12,12))
    plt.subplot(1,3,1)
    plt.imshow(np.squeeze(img))
    plt.title('Image')
    plt.subplot(1,3,2)
    plt.imshow(np.squeeze(cv2.imread(test['mask'].iloc[index])))
    plt.title('Mask')
    plt.subplot(1,3,3)
    plt.imshow(np.squeeze(pred)>.5)
    plt.title('Prediction')
    plt.show()