## Import

In [2]:
from keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Conv2DTranspose, Dropout
from keras.models import Model
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras import regularizers 
from tensorflow.keras.layers import BatchNormalization as bn
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras.models import model_from_json

In [3]:
import keras.backend as K
import tensorflow as tf

In [4]:
import nibabel as nib
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
from random import shuffle
import re

In [5]:
# Download from https://www.kaggle.com/datasets/andrewmvd/liver-tumor-segmentation

img_path = glob("data/volume_pt*/volume-*.nii")
mask_path = glob("data/segmentations/segmentation-*.nii")

if len(img_path) == 0 or len(mask_path) == 0 or len(img_path) > len(mask_path):
    raise Exception("Incorrect rata found ({} volumes and {} segmentations)!".format(len(img_path), len(mask_path)))

print("Number of images:", len(img_path))

Number of images: 51


In [6]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [atoi(c) for c in re.split(r'(\d+)', text)]

In [7]:
img_path.sort(key=natural_keys)
mask_path.sort(key=natural_keys)

## Utils

In [8]:
def weighted_binary_crossentropy(y_true, y_pred):
    y_pred = tf.clip_by_value(y_pred, 10e-8, 1.-10e-8)
    loss = - (y_true * K.log(y_pred) * 0.90 + (1 - y_true) * K.log(1 - y_pred) * 0.10)
    
    return K.mean(loss)

In [9]:
# Reference : https://github.com/dk67604/LITS-Challenge-Liver-Segmentation/blob/master/experiments/keras_realtime_train.ipynb

smooth = 1.
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

## Model

In [10]:
input_shape = [64, 64, 1]
dropout_rate = 0.3
l2_lambda = 0.0002

In [11]:
def u_net(input_shape, dropout_rate, l2_lambda):
  
  # Encoder
  input = Input(shape = input_shape, name = "input")
  conv1_1 = Conv2D(32, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv1_1")(input)
  conv1_1 = bn(name = "conv1_1_bn")(conv1_1)
  conv1_2 = Conv2D(32, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv1_2")(conv1_1)
  conv1_2 = bn(name = "conv1_2_bn")(conv1_2)
  pool1 = MaxPooling2D(name = "pool1")(conv1_2)
  drop1 = Dropout(dropout_rate)(pool1)
  
  conv2_1 = Conv2D(64, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv2_1")(pool1)
  conv2_1 = bn(name = "conv2_1_bn")(conv2_1)
  conv2_2 = Conv2D(64, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv2_2")(conv2_1)
  conv2_2 = bn(name = "conv2_2_bn")(conv2_2)
  pool2 = MaxPooling2D(name = "pool2")(conv2_2)
  drop2 = Dropout(dropout_rate)(pool2)
  
  conv3_1 = Conv2D(128, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv3_1")(pool2)
  conv3_1 = bn(name = "conv3_1_bn")(conv3_1)
  conv3_2 = Conv2D(128, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv3_2")(conv3_1)
  conv3_2 = bn(name = "conv3_2_bn")(conv3_2)
  pool3 = MaxPooling2D(name = "pool3")(conv3_2)
  drop3 = Dropout(dropout_rate)(pool3)  

  conv4_1 = Conv2D(256, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv4_1")(pool3)
  conv4_1 = bn(name = "conv4_1_bn")(conv4_1)
  conv4_2 = Conv2D(256, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv4_2")(conv4_1)
  conv4_2 = bn(name = "conv4_2_bn")(conv4_2)
  pool4 = MaxPooling2D(name = "pool4")(conv4_2)
  drop4 = Dropout(dropout_rate)(pool4)  

  conv5_1 = Conv2D(512, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv5_1")(pool4)
  conv5_1 = bn(name = "conv5_1_bn")(conv5_1)
  conv5_2 = Conv2D(512, (3, 3), padding = "same", activation='relu', kernel_regularizer=regularizers.l2(l2_lambda), name = "conv5_2")(conv5_1)
  conv5_2 = bn(name = "conv5_2_bn")(conv5_2)
  
  # Decoder
  upconv6 = Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same')(conv5_2)
  upconv6 = Dropout(dropout_rate)(upconv6)
  concat6 = concatenate([conv4_2, upconv6], name = "concat6")
  conv6_1 = Conv2D(256, (3, 3), padding = "same", kernel_regularizer=regularizers.l2(l2_lambda), name = "conv6_1")(concat6)
  conv6_1 = bn(name = "conv6_1_bn")(conv6_1)
  conv6_2 = Conv2D(256, (3, 3), padding = "same", kernel_regularizer=regularizers.l2(l2_lambda), name = "conv6_2")(conv6_1)
  conv6_2 = bn(name = "conv6_2_bn")(conv6_2)
    
  upconv7 = Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same')(conv6_2)
  upconv7 = Dropout(dropout_rate)(upconv7)
  concat7 = concatenate([conv3_2, upconv7], name = "concat7")
  conv7_1 = Conv2D(128, (3, 3), padding = "same", kernel_regularizer=regularizers.l2(l2_lambda), name = "conv7_1")(concat7)
  conv7_1 = bn(name = "conv7_1_bn")(conv7_1)
  conv7_2 = Conv2D(128, (3, 3), padding = "same", kernel_regularizer=regularizers.l2(l2_lambda), name = "conv7_2")(conv7_1)
  conv7_2 = bn(name = "conv7_2_bn")(conv7_2)

  upconv8 = Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same')(conv7_2)
  upconv8 = Dropout(dropout_rate)(upconv8)
  concat8 = concatenate([conv2_2, upconv8], name = "concat8")
  conv8_1 = Conv2D(64, (3, 3), padding = "same", kernel_regularizer=regularizers.l2(l2_lambda), name = "conv8_1")(concat8)
  conv8_1 = bn(name = "conv8_1_bn")(conv8_1)
  conv8_2 = Conv2D(64, (3, 3), padding = "same", kernel_regularizer=regularizers.l2(l2_lambda), name = "conv8_2")(conv8_1)
  conv8_2 = bn(name = "conv8_2_bn")(conv8_2)

  upconv9 = Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(conv8_2)
  upconv9 = Dropout(dropout_rate)(upconv9)
  concat9 = concatenate([conv1_2, upconv9], name = "concat9")
  conv9_1 = Conv2D(32, (3, 3), padding = "same", kernel_regularizer=regularizers.l2(l2_lambda), name = "conv9_1")(concat9)
  conv9_1 = bn(name = "conv9_1_bn")(conv9_1)
  conv9_2 = Conv2D(32, (3, 3), padding = "same", kernel_regularizer=regularizers.l2(l2_lambda), name = "conv9_2")(conv9_1)
  conv9_2 = bn(name = "conv9_2_bn")(conv9_2)
  dropout = Dropout(dropout_rate)(conv9_2)
  
  conv10 = Conv2D(1, (1, 1), padding = "same", activation = 'sigmoid', name = "conv10")(dropout)

 
  model = Model(input, conv10)
  
  return model

In [12]:
model = u_net(input_shape, dropout_rate, l2_lambda)

2023-05-19 10:32:49.919777: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [13]:

tf.keras.utils.plot_model(
    model,
    to_file='model.png',
    show_shapes=False,
    show_layer_names=True,
    rankdir='TB',
    expand_nested=False,
    dpi=96
)


You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


## Training

### Patch Utils

In [14]:
patch_ratio = []
for i in range(16 + 1):
  patch_ratio.append(32 * i)

def patch_sampling(img, mask, patch_ratio, pos_neg_ratio, threshold):
  
  temp_mask = mask
  
  temp_mask[temp_mask == 1] = 0
  temp_mask[temp_mask == 2] = 1
  
  positive_patch = []
  positive_mask = []
  
  negative_patch = []
  negative_mask = []
  
  negative_set = []
  
  
  for i in range(temp_mask.shape[2]):
    for x_bin in range(2, len(patch_ratio)):
        for y_bin in range(2, len(patch_ratio)):
          img_patch = img[patch_ratio[x_bin-2] : patch_ratio[x_bin], patch_ratio[y_bin - 2] : patch_ratio[y_bin], i]
          mask_patch = temp_mask[patch_ratio[x_bin-2] : patch_ratio[x_bin], patch_ratio[y_bin - 2] : patch_ratio[y_bin], i]
          _, count = np.unique(mask_patch, return_counts = True)
          
          if len(count) == 2:
            mask_percentage = count[1] / sum(count) * 100
          
            if threshold < mask_percentage :
              positive_patch.append(img_patch)
              positive_mask.append(mask_patch)
          
          
          elif len(count) == 1:
            
            temp_list = []
            temp_list.append(img_patch)
            temp_list.append(mask_patch)
            
            negative_set.append(temp_list)
  
  shuffle(negative_set)
  
  negative_set_to_use = negative_set[:len(positive_patch) * pos_neg_ratio]
  for negative_set in negative_set_to_use:
    negative_patch.append(negative_set[0])
    negative_mask.append(negative_set[1])
  
  negative_set_to_use = []
  
  return positive_patch, positive_mask, negative_patch, negative_mask


def getTotals(from_percent, to_percent, unique_suffix=None):
    from_position = int(len(img_path) * from_percent)
    to_position = int(len(img_path) * to_percent)
    total_patch = []
    total_mask = []
    for i in range(from_position, to_position):
        img_3D = nib.load(img_path[i]).get_fdata()
        mask_3D = nib.load(mask_path[i]).get_fdata()

        pos_patch, pos_mask, neg_patch, neg_mask = patch_sampling(img_3D, mask_3D, patch_ratio, 3, 3.0)
        total_patch += (pos_patch + neg_patch)
        total_mask += (pos_mask + neg_mask)

        print("Image {0}/{1}: # of patches = {2} | # of total images = {3}".format(
            format(i+1, '>2'), 
                   len(img_path),
                   format(len(pos_patch) + len(neg_patch), '>5'),
                   format(len(total_patch), '>5')))
    total_patch = np.array(total_patch).reshape((len(total_patch), 64, 64, 1))
    total_mask = np.array(total_mask).reshape((len(total_mask), 64, 64, 1))
    
    if unique_suffix is not None:
        np.save("model/total_patch_{}.npy".format(unique_suffix), total_patch)
        np.save("model/total_mask_{}.npy".format(unique_suffix), total_mask)
    
    return total_patch, total_mask

### Create Adam optimizer

In [15]:
adam = Adam(learning_rate = 0.0001)

### Compile, fit and save to disk

*Note*: 12 volumes is aroung 30 GB RAM, so using 24% is the limit in my case 

In [None]:
total_patch_train, total_mask_train = getTotals(0, 0.24, "train")

In [21]:
total_patch_train = np.load("model/total_patch_train.npy")
total_mask_train = np.load("model/total_mask_train.npy")

In [22]:
total_mask_train.shape

(27096, 64, 64, 1)

In [None]:
model.compile(optimizer = adam, loss = weighted_binary_crossentropy, metrics = [dice_coef])

# Fitting takes a long time, 1-2 hours
model.fit(total_patch, total_mask, batch_size = 512, epochs = 10)

# Save model to JSON
model_json = model.to_json()
with open("model/model.json", "w") as json_file:
    json_file.write(model_json)
    
# Save model weights to HDF5
model.save_weights("model/model_weights.h5")

print("Saved model to disk")

## Evaluations

In [16]:
json_file = open('model/model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
# load weights into new model
loaded_model.load_weights("model/model_weights.h5")
loaded_model.summary()
print("Loaded model from disk")

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input (InputLayer)             [(None, 64, 64, 1)]  0           []                               
                                                                                                  
 conv1_1 (Conv2D)               (None, 64, 64, 32)   320         ['input[0][0]']                  
                                                                                                  
 conv1_1_bn (BatchNormalization  (None, 64, 64, 32)  128         ['conv1_1[0][0]']                
 )                                                                                                
                                                                                                  
 conv1_2 (Conv2D)               (None, 64, 64, 32)   9248        ['conv1_1_bn[0][0]']         

                                                                                                  
 dropout_5 (Dropout)            (None, 16, 16, 128)  0           ['conv2d_transpose_1[0][0]']     
                                                                                                  
 concat7 (Concatenate)          (None, 16, 16, 256)  0           ['conv3_2_bn[0][0]',             
                                                                  'dropout_5[0][0]']              
                                                                                                  
 conv7_1 (Conv2D)               (None, 16, 16, 128)  295040      ['concat7[0][0]']                
                                                                                                  
 conv7_1_bn (BatchNormalization  (None, 16, 16, 128)  512        ['conv7_1[0][0]']                
 )                                                                                                
          

### Conversion utils

In [17]:
def slice_to_patch(slice, patch_ratio):
  
  slice[slice == 1] = 0
  slice[slice == 2] = 1
  
  patch_list = []
  
  for x_bin in range(2, len(patch_ratio)):
    for y_bin in range(2, len(patch_ratio)):
      patch = slice[patch_ratio[x_bin-2] : patch_ratio[x_bin], patch_ratio[y_bin - 2] : patch_ratio[y_bin]]
      patch = patch.reshape(patch.shape + (1,))
      patch_list.append(patch)
  
  return np.array(patch_list)

In [18]:
def patch_to_slice(patch, patch_ratio, input_shape, conf_threshold):
  
  slice = np.zeros((512, 512, 1))
  row_idx = 0
  col_idx = 0
  
  for i in range(len(patch)):
    
    slice[patch_ratio[row_idx]:patch_ratio[row_idx + 2], patch_ratio[col_idx]:patch_ratio[col_idx + 2]][patch[i] > conf_threshold] = 1
    
    col_idx += 1
    
    if i != 0 and (i+1) % 15 == 0:
      row_idx += 1
      col_idx = 0
  
  return slice

### Get test totals,save them to disk

In [None]:
total_test_patch, total_test_mask = getTotals(0.30, 0.45, "test")

### Print model accuracy, MSE, dice coef.

In [19]:
total_test_patch = np.load("model/total_patch_test.npy")
total_test_mask = np.load("model/total_mask_test.npy")

loaded_model.compile(optimizer = adam, loss = weighted_binary_crossentropy, metrics = [dice_coef, "accuracy", "mse", tf.keras.metrics.FalseNegatives(thresholds=0), tf.keras.metrics.FalsePositives()])

evaluations = loaded_model.evaluate(total_test_patch, total_test_mask, batch_size = 512)
print(list(zip(loaded_model.metrics_names, evaluations)))

[('loss', 0.1577272266149521), ('dice_coef', 0.09807351231575012), ('accuracy', 0.9387029409408569), ('mse', 0.078289695084095), ('false_negatives', 0.0), ('false_positives', 1729087.0)]


### Exaple visualization of model predictions

In [None]:
# Find false positive rate.
fpr = fp / (fp + tn)

image_index = 13
img_ex = nib.load(img_path[image_index]).get_fdata()
mask_ex = nib.load(mask_path[image_index]).get_fdata()

mask_ex[mask_ex == 1] = 0

for i in range(mask_ex.shape[2]):
    _, count = np.unique(mask_ex[:, :, i], return_counts=True)
    if len(count) > 1 and count[1] > 300:
        
        patch_ex = slice_to_patch(img_ex[:, :, i], patch_ratio)
        prediction = loaded_model.predict(patch_ex)
        prediction_mask = patch_to_slice(prediction, patch_ratio, input_shape, conf_threshold = 0.97)
        
        fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize = ((15, 15)))
        
        ax1.imshow(np.rot90(img_ex[:, :, i], 3), cmap = 'bone')
        ax1.set_title("Image", fontsize = "x-large")
        ax1.grid(False)
        ax2.imshow(np.rot90(mask_ex[:, :, i], 3), cmap = 'bone')
        ax2.set_title("Mask (Actual)", fontsize = "x-large")
        ax2.grid(False)
        ax3.imshow(np.rot90(prediction_mask.reshape((512, 512)), 3), cmap = 'bone')
        ax3.set_title("Mask (Prediction)", fontsize = "x-large")
        ax3.grid(False)
        plt.show()