In [None]:
######################################


# -*- coding: utf-8 -*-
"""SegNet model for Keras.
# Reference:
- [Segnet: A deep convolutional encoder-decoder architecture for image segmentation](https://arxiv.org/pdf/1511.00561.pdf)
"""

from __future__ import absolute_import
from __future__ import print_function
import os
import numpy as np
from keras.utils import np_utils
from keras.applications import imagenet_utils


########################
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Activation, Reshape
from keras.layers import BatchNormalization
import tensorflow as tf
from torch.nn import MaxUnpool3d
from keras.layers import Conv3D, MaxPooling3D, concatenate, UpSampling3D


def SegNet(input_shape, classes):
    kernel=(3, 3, 3)
    pool_size=(2, 2, 2)
    output_mode="softmax"
    
    img_input = Input(shape=input_shape)
    x = img_input
    # Encoder
    x = Conv3D(64, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    pool_1 = MaxPooling3D(pool_size=pool_size)(x)
    
    x = Conv3D(128, kernel, padding="same")(pool_1)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    pool_2 = MaxPooling3D(pool_size=pool_size)(x)
    
    x = Conv3D(256, kernel, padding="same")(pool_2)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    pool_3 = MaxPooling3D(pool_size=pool_size)(x)
    
    x = Conv3D(512,kernel, padding="same")(pool_3)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    # Decoder
    x = Conv3D(512, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = UpSampling3D(size=pool_size)(x)
    x = Conv3D(256, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = UpSampling3D(size=pool_size)(x)
    x = Conv3D(128, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = UpSampling3D(size=pool_size)(x)
    x = Conv3D(64, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = Conv3D(classes, 1, 1, padding="valid")(x)
    #x = Reshape((input_shape[0]*input_shape[1]*input_shape[2], classes))(x)
    x = Activation("softmax")(x)
    model = Model(img_input, x)


    return model



model = SegNet(input_shape=(128,128,128,3), classes=4)

model.summary()
print(model.input_shape)
print(model.output_shape)



In [2]:
import numpy as np
import nibabel as nib
import glob
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from tifffile import imsave

In [None]:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
TRAIN_DATASET_PATH = '/content/drive/MyDrive/TrainingData/'
####### Custom Data Generation ####################

import os
import numpy as np


def load_img(img_dir, img_list):
    images=[]
    for i, image_name in enumerate(img_list):    
        if (image_name.split('.')[1] == 'npy'):
            
            image = np.load(img_dir+image_name)
                      
            images.append(image)
    images = np.array(images)
    
    return(images)




def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):

    L = len(img_list)

    #keras needs the generator infinite, so we will use while true  
    while True:

        batch_start = 0
        batch_end = batch_size

        while batch_start < L:
            limit = min(batch_end, L)
                       
            X = load_img(img_dir, img_list[batch_start:limit])
            Y = load_img(mask_dir, mask_list[batch_start:limit])

            yield (X,Y) #a tuple with two numpy arrays with batch_size samples     

            batch_start += batch_size   
            batch_end += batch_size

############################################

#Test the generator

from matplotlib import pyplot as plt
import random

train_img_dir = "/content/drive/MyDrive/input_data_validation/train/images/"
train_mask_dir = "/content/drive/MyDrive/input_data_validation/train/masks/"
train_img_list=os.listdir(train_img_dir)
train_mask_list = os.listdir(train_mask_dir)

batch_size = 2

train_img_datagen = imageLoader(train_img_dir, train_img_list, 
                                train_mask_dir, train_mask_list, batch_size)

#Verify generator.... In python 3 next() is renamed as __next__()
img, msk = train_img_datagen.__next__()


img_num = random.randint(0,img.shape[0]-1)
test_img=img[img_num]
test_mask=msk[img_num]
test_mask=np.argmax(test_mask, axis=3)

n_slice=random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

plt.subplot(221)
plt.imshow(test_img[:,:,n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:,:,n_slice, 1], cmap='gray')
plt.title('Image t1ce')
plt.subplot(223)
plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
plt.title('Image t2')
plt.subplot(224)
plt.imshow(test_mask[:,:,n_slice])
plt.title('Mask')
plt.show()


In [5]:
train_img_dir = "/content/drive/MyDrive/input_data_validation/train/images/"
train_mask_dir = "/content/drive/MyDrive/input_data_validation/train/masks/"

val_img_dir = "/content/drive/MyDrive/input_data_validation/val/images/"
val_mask_dir = "/content/drive/MyDrive/input_data_validation/val/masks/"

train_img_list=os.listdir(train_img_dir)
train_mask_list = os.listdir(train_mask_dir)

val_img_list=os.listdir(val_img_dir)
val_mask_list = os.listdir(val_mask_dir)
##################################

########################################################################
batch_size = 1

train_img_datagen = imageLoader(train_img_dir, train_img_list, 
                                train_mask_dir, train_mask_list, batch_size)

val_img_datagen = imageLoader(val_img_dir, val_img_list, 
                                val_mask_dir, val_mask_list, batch_size)


In [6]:
!pip install keras_applications
!pip install classification-models-3D
!pip install efficientnet-3D
!pip install segmentation-models-3D

Collecting keras_applications
  Downloading Keras_Applications-1.0.8-py3-none-any.whl (50 kB)
[?25l[K     |██████▌                         | 10 kB 29.3 MB/s eta 0:00:01[K     |█████████████                   | 20 kB 18.8 MB/s eta 0:00:01[K     |███████████████████▍            | 30 kB 11.3 MB/s eta 0:00:01[K     |█████████████████████████▉      | 40 kB 9.1 MB/s eta 0:00:01[K     |████████████████████████████████| 50 kB 3.1 MB/s 
Installing collected packages: keras-applications
Successfully installed keras-applications-1.0.8
Collecting classification-models-3D
  Downloading classification_models_3D-1.0.2-py3-none-any.whl (45 kB)
[K     |████████████████████████████████| 45 kB 1.7 MB/s 
Installing collected packages: classification-models-3D
Successfully installed classification-models-3D-1.0.2
Collecting efficientnet-3D
  Downloading efficientnet_3D-1.0.1-py3-none-any.whl (14 kB)
Installing collected packages: efficientnet-3D
Successfully installed efficientnet-3D-1.0.1
Coll

In [7]:
import segmentation_models_3D as sm
wt0, wt1, wt2, wt3 = 0.25,0.25,0.25,0.25
#import segmentation_models_3D as sm
dice_loss = sm.losses.DiceLoss(class_weights=np.array([wt0, wt1, wt2, wt3])) 
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5)]

LR = 0.0001

from tensorflow.keras.optimizers import Adam

#optim = tensorflow.keras.optimizers.Adam(LR)

optim = Adam(LR)
#######################################################################
#Fit the model 

steps_per_epoch = len(train_img_list)//batch_size
val_steps_per_epoch = len(val_img_list)//batch_size
val_steps_per_epoch

Segmentation Models: using `keras` framework.


In [14]:
val_steps_per_epoch

7

In [12]:
#gpu_options = tf.GPUOptions(allow_growth=True)
##session = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
tf.compat.v1.GPUOptions(allow_growth=True)

allow_growth: true

In [18]:
model = SegNet(input_shape=(128,128,128,3), classes=4)

model.compile(optimizer = optim, loss=total_loss, metrics=metrics)
#print(model.summary())

#print(model.input_shape)
#print(model.output_shape)

history=model.fit(train_img_datagen,
          epochs=30,
          validation_data=val_img_datagen,
          validation_steps=2,
          )

Epoch 1/30


ResourceExhaustedError: ignored