# Imports

In [2]:
import tensorflow as tf
import segmentation_models_3D as sm

from skimage import io
from patchify import patchify, unpatchify

import numpy as np
from matplotlib import pyplot as plt

import keras
from keras import backend as K
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

from glob import glob
from tqdm import tqdm

import os
import imagecodecs

Segmentation Models: using `tf.keras` framework.


# Config

In [3]:
print(tf.__version__)
print(keras.__version__)
# Make sure the GPU is available.
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

2.8.0
2.8.0
Found GPU at: /device:GPU:0


In [4]:
# #Load input image and mask
# import imagecodecs
# #Here we load 256x256x256 pixel volume. We will break it into patches of 64x64x64 for training.
# image = io.imread('E:\\3D segmentation\\Fluo-N3DH-SIM+\\01\\t000.tif')
# print(image.shape)
# img_patches = patchify(image, (32, 32, 32), step=32)  #Step=64 for 64 patches means no overlap
# print(img_patches.shape)
# mask = io.imread('E:\\3D segmentation\\Fluo-N3DH-SIM+\\01_GT\\SEG\\man_seg000.tif')
# mask_patches = patchify(mask, (32, 32, 32), step=32)
# print(mask_patches.shape)

In [5]:
def normalize(vol):
    # stretch contrast to 0-255
    max_val = np.amax(vol)
    min_val = np.amin(vol)
    vol = (vol - min_val)/(max_val-min_val+1e-9)*255
    return np.uint8(vol)


# Load data

In [6]:
#Load input images and masks.
#base_dir = "E:\\3D segmentation\\Fluo-N3DH-SIM+\\01\\"
base_dir= "E:\\3D segmentation\\Fluo-N3DH-SIM+\\01_part\\"
images_paths = glob(os.path.join(base_dir, "*.tif"))
images_paths.sort()
mask_dir = "E:\\3D segmentation\\Fluo-N3DH-SIM+\\01_GT_part\\"
masks_paths = glob(os.path.join(mask_dir, "*.tif"))
masks_paths.sort()

images = []
masks = []

for image_path in tqdm(images_paths):
    image = io.imread(image_path)
    print(image.shape)
    image= np.transpose(image, (1, 2, 0))
    print(image.shape)
    image = patchify(image, (128,128,32), step=32)
    images.append(image)
    

for mask_path in tqdm(masks_paths):
    mask = io.imread(mask_path)
    print(mask.shape)
    mask= np.transpose(mask, (1, 2, 0))
    print(mask.shape)
    mask = patchify(mask, (128,128,32), step=32)
    masks.append(mask)
    #mask_path = image_path.replace("\\01", "\\01_GT\\SEG").replace("\\t", "\\man_seg")

images = np.array(images)
masks = np.array(masks)
print(images.shape)
print(masks.shape)

  0%|                                                                                            | 0/5 [00:00<?, ?it/s]

(59, 349, 639)

 40%|█████████████████████████████████▌                                                  | 2/5 [00:00<00:00,  6.67it/s]


(349, 639, 59)
(59, 349, 639)
(349, 639, 59)


 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00,  6.71it/s]

(59, 349, 639)
(349, 639, 59)
(59, 349, 639)
(349, 639, 59)


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.51it/s]


(59, 349, 639)
(349, 639, 59)


 40%|█████████████████████████████████▌                                                  | 2/5 [00:00<00:00, 15.55it/s]

(59, 349, 639)
(349, 639, 59)
(59, 349, 639)
(349, 639, 59)


 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:00<00:00, 14.68it/s]

(59, 349, 639)
(349, 639, 59)
(59, 349, 639)
(349, 639, 59)


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 14.86it/s]

(59, 349, 639)
(349, 639, 59)





(5, 7, 16, 1, 128, 128, 32)
(5, 7, 16, 1, 128, 128, 32)


In [7]:
input_img = np.reshape(images, (-1, images.shape[4], images.shape[5], images.shape[6]))
input_mask = np.reshape(masks, (-1, masks.shape[4], masks.shape[5], masks.shape[6]))
input_mask = np.float32(input_mask)

del images
del masks
print(input_img.shape)
print(input_mask.shape)

(560, 128, 128, 32)
(560, 128, 128, 32)


# Split data

In [8]:
#Convert grey image to 3 channels by copying channel 3 times.
#We do this as our unet model expects 3 channel input.

train_img = np.stack((input_img,)*3, axis=-1)
train_mask = np.expand_dims(input_mask, axis=4)
X_train, X_test, y_train, y_test = train_test_split(train_img, train_mask, test_size = 0.15, random_state = 0)
del train_img
del train_mask

# X_train, X_test, y_train, y_test = train_test_split(input_img, input_mask, test_size = 0.15, random_state = 0)
# del input_img
# del input_mask

In [9]:
X_train.shape

(476, 128, 128, 32, 3)

# Loss functions

In [10]:
# Loss Function and coefficients to be used during training:
def dice_coefficient(y_true, y_pred):
    smoothing_factor = 1
    flat_y_true = K.flatten(y_true)
    flat_y_pred = K.flatten(y_pred)
    return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor)

def dice_coefficient_loss(y_true, y_pred):
    return 1 - dice_coefficient(y_true, y_pred)

In [11]:
encoder_weights = 'imagenet'
BACKBONE = 'vgg16'  #Try vgg16, efficientnetb7, inceptionv3, resnet50
activation = 'sigmoid'
patch_size = 32
n_classes = 1
channels=3

LR = 0.0001
optim = tf.keras.optimizers.Adam(LR)

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
# set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
#loss = DiceLoss()

# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

In [12]:
preprocess_input = sm.get_preprocessing(BACKBONE)

In [13]:
#Preprocess input data - otherwise you end up with garbage resutls
# and potentially model that does not converge.
X_train_prep = preprocess_input(X_train)
X_test_prep = preprocess_input(X_test)

In [14]:
del X_train
del X_test

In [15]:
X_train_prep.dtype

dtype('float32')

In [16]:
y_train.dtype

dtype('float32')

In [17]:
#Define the model. Here we use Unet but we can also use other model architectures from the library.
model = sm.Unet(BACKBONE, classes=n_classes,
                input_shape=(128,128,32,channels),
                encoder_weights=encoder_weights,
                activation=activation)

model.compile(optimizer = optim, loss=dice_loss, metrics=metrics)
print(model.summary())

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                2, 3)]                                                            
                                                                                                  
 block1_conv1 (Conv3D)          (None, 128, 128, 32  5248        ['input_1[0][0]']                
                                , 64)                                                             
                                                                                                  
 block1_conv2 (Conv3D)          (None, 128, 128, 32  110656      ['block1_conv1[0][0]']           
                                , 64)                                                         

In [18]:
# Fit the model
history=model.fit(X_train_prep,
          y_train,
          batch_size=2, epochs=20,
          validation_data=(X_test_prep, y_test))

InternalError: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.

In [None]:
#plot the training and validation IoU and loss at each epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['iou_score']
val_acc = history.history['val_iou_score']

plt.plot(epochs, acc, 'y', label='Training IOU')
plt.plot(epochs, val_acc, 'r', label='Validation IOU')
plt.title('Training and validation IOU')
plt.xlabel('Epochs')
plt.ylabel('IOU')
plt.legend()
plt.show()
     

In [None]:
y_pred=model.predict(X_test)

In [None]:
#Test some random images
import random
test_img_number = random.randint(0, len(X_test))
test_img = X_test[test_img_number]
ground_truth=y_test[test_img_number]

test_img_input=np.expand_dims(test_img, 0)
test_img_input1 = preprocess_input(test_img_input)

test_pred1 = model.predict(test_img_input1)
#test_prediction1 = np.argmax(test_pred1, axis=4)[0,:,:,:]
print(test_pred1.shape)

In [None]:
#Plot individual slices from test predictions for verification
slice = 10
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(test_img[slice,:,:,0], cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
plt.imshow(ground_truth[slice,:,:])
plt.subplot(233)
plt.title('Prediction on test image')
plt.imshow(test_pred1[slice,:,:])
plt.show()