In [1]:
####################################################
# Variables in this cell shall be assigned by user #
####################################################

# MODEL - name of the neural network architecture to be used
#       - Segnet or Xunet
# WORKING_DIR - folder where the unlabeled data will be processed and results will be found
# IMAGES_DIR  - folder where images to be processed can be found (doesn't have to be in working dir)
# WEIGHTS     - file containing pre-trained weights of neural network (-||-)
#
# OUTPUT_COLLAPSED_CARD     - set to True if crumbled cells should be processed and cropped out as well
# OUTPUT_HALFCOLLAPSED_CARD - set to True if half-crumbled cells should be processed and cropped out as well

MODEL = 'Segnet'  # Segnet or Xunet

WORKING_DIR = '/content/drive/MyDrive/Colab Notebooks/FinalSeg/'
IMAGES_DIR = WORKING_DIR + 'Images/'
WEIGHTS = WORKING_DIR + 'pretrained_weights/segnet_aug.h5'

OUTPUT_HALFCOLLAPSED_CARD = False # True or False

In [2]:
###################################
#  Importing necessary libraries  #
###################################

# from google.colab import drive
# drive.mount('/content/drive')
import os
import imageio
import cv2
import shutil
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
%pip install image_slicer
from image_slicer import slice, save_tiles
from scipy.ndimage import binary_closing

import tensorflow as tf
from tensorflow.keras.utils import normalize
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, \
    Dropout, Lambda, ZeroPadding2D, LeakyReLU
from keras.applications.xception import Xception



In [3]:
###################################
# Definitions of custom functions #
###################################

def load_images(orig_images_dir):
    path = orig_images_dir
    images = {}
    for filename in os.listdir(path):
        img = cv2.imread(os.path.join(path, filename))
        if img is not None:
            images[filename] = img
    return images

def get_model():
    if MODEL is 'Segnet':
      return Segnet()
    elif MODEL is 'Xunet':
      return Unet_Xception_ResNetBlock()
    else:
      print("Non existant variable MODEL, default SeNet architecture was chosen instead.")
      return Segnet()

def Segnet(nClasses=4, input_height=512, input_width=512):
    inputs = Input(shape=(input_height, input_width, 1))

    #Encoder
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Decode
    up7 = UpSampling2D(size=(2, 2))(pool4)
    conv7 = Conv2D(512, (3, 3), activation='relu', padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)

    up8 = UpSampling2D(size=(2, 2))(conv7)
    conv8 = Conv2D(256, (3, 3), activation='relu', padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)

    up9 = UpSampling2D(size=(2, 2))(conv8)
    conv9 = Conv2D(128, (3, 3), activation='relu', padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)

    up10 = UpSampling2D(size=(2, 2))(conv9)
    conv10 = Conv2D(64, (3, 3), activation='relu', padding='same')(up10)
    conv10 = BatchNormalization()(conv10)
    conv10 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv10)
    conv10 = BatchNormalization()(conv10)
    outputs = Conv2D(nClasses, (1, 1), padding='same', activation='softmax')(conv10)

    model = Model(inputs, outputs)
    
    return model


def convolution_block(x, filters, size, strides=(1,1), padding='same', activation=True):
    x = Conv2D(filters, size, strides=strides, padding=padding)(x)
    x = BatchNormalization()(x)
    if activation == True:
        x = LeakyReLU(alpha=0.1)(x)
    return x

def residual_block(blockInput, num_filters=16):
    x = LeakyReLU(alpha=0.1)(blockInput)
    x = BatchNormalization()(x)
    blockInput = BatchNormalization()(blockInput)
    x = convolution_block(x, num_filters, (3,3) )
    x = convolution_block(x, num_filters, (3,3), activation=False)
    x = Add()([x, blockInput])
    return x


def Unet_Xception_ResNetBlock(nClasses=4, input_height=512, input_width=512):
    
    backbone = Xception(input_shape=(input_height, input_width, 1), weights=None, include_top=False)
    
    inputs = backbone.input

    conv4 = backbone.layers[121].output
    conv4 = LeakyReLU(alpha=0.1)(conv4)
    pool4 = MaxPooling2D((2, 2))(conv4)
    pool4 = Dropout(0.1)(pool4)
    
     # Middle
    convm = Conv2D(16*32, (3, 3), activation=None, padding="same")(pool4)
    convm = residual_block(convm, 16*32)
    convm = residual_block(convm, 16*32)
    convm = LeakyReLU(alpha=0.1)(convm)
    
    # 8 -> 16
    deconv4 = Conv2DTranspose(16*16, (3, 3), strides=(2, 2), padding="same")(convm)
    uconv4 = concatenate([deconv4, conv4])
    uconv4 = Dropout(0.1)(uconv4)
    
    uconv4 = Conv2D(16*16, (3, 3), activation=None, padding="same")(uconv4)
    uconv4 = residual_block(uconv4, 16 * 16)
    uconv4 = residual_block(uconv4, 16*16)
    uconv4 = LeakyReLU(alpha=0.1)(uconv4)
    
    # 16 -> 32
    deconv3 = Conv2DTranspose(16*8, (3, 3), strides=(2, 2), padding="same")(uconv4)
    conv3 = backbone.layers[31].output
    uconv3 = concatenate([deconv3, conv3])    
    uconv3 = Dropout(0.1)(uconv3)
    
    uconv3 = Conv2D(16*8, (3, 3), activation=None, padding="same")(uconv3)
    uconv3 = residual_block(uconv3, 16*8)
    uconv3 = residual_block(uconv3, 16*8)
    uconv3 = LeakyReLU(alpha=0.1)(uconv3)

    # 32 -> 64
    deconv2 = Conv2DTranspose(16*4, (3, 3), strides=(2, 2), padding="same")(uconv3)
    conv2 = backbone.layers[21].output
    conv2 = ZeroPadding2D(((1,0),(1,0)))(conv2)
    uconv2 = concatenate([deconv2, conv2])
        
    uconv2 = Dropout(0.1)(uconv2)
    uconv2 = Conv2D(16*4, (3, 3), activation=None, padding="same")(uconv2)
    uconv2 = residual_block(uconv2, 16*4)
    uconv2 = residual_block(uconv2, 16*4)
    uconv2 = LeakyReLU(alpha=0.1)(uconv2)
    
    # 64 -> 128
    deconv1 = Conv2DTranspose(16*2, (3, 3), strides=(2, 2), padding="same")(uconv2)
    conv1 = backbone.layers[11].output
    conv1 = ZeroPadding2D(((3,0),(3,0)))(conv1)
    uconv1 = concatenate([deconv1, conv1])
    
    uconv1 = Dropout(0.1)(uconv1)
    uconv1 = Conv2D(16*2, (3, 3), activation=None, padding="same")(uconv1)
    uconv1 = residual_block(uconv1, 16*2)
    uconv1 = residual_block(uconv1, 16*2)
    uconv1 = LeakyReLU(alpha=0.1)(uconv1)
    
    # 128 -> 256
    uconv0 = Conv2DTranspose(16*1, (3, 3), strides=(2, 2), padding="same")(uconv1)   
    uconv0 = Dropout(0.1)(uconv0)
    uconv0 = Conv2D(16*1, (3, 3), activation=None, padding="same")(uconv0)
    uconv0 = residual_block(uconv0, 16*1)
    uconv0 = residual_block(uconv0, 16*1)
    uconv0 = LeakyReLU(alpha=0.1)(uconv0)
    
    uconv0 = Dropout(0.1/2)(uconv0)

    outputs = Conv2D(nClasses, (1, 1), padding='same', activation='softmax')(uconv0)
    model = Model(inputs, outputs)
    return model

In [4]:
############################################################ 
# Load images, create temp cut dir, cut them into 64 tiles #
############################################################

orig_images_dir = IMAGES_DIR
cut_images_dir = WORKING_DIR + 'TEMP_cut_images/'
if os.path.isdir(cut_images_dir):
    shutil.rmtree(cut_images_dir)
    os.mkdir(cut_images_dir)
else:
    os.mkdir(cut_images_dir)

########### CHOOSE A SAMPLE IMG ###########
# orig_img = '210324_Sasa_Pavel_Image014_ch00.tif'
# orig_img = '201006b_Mouse_obj20zoom1_ch00.tif'
# orig_img = '210324_Sasa_Pavel_Image023_ch00.tif'

# Get names of all original images from IMAGES_DIR folder
# and cut each of them into 64 tiles
list_of_origs = []
for f in os.listdir(IMAGES_DIR):
    if os.path.isfile(IMAGES_DIR+f) and (f.endswith(".tif") or f.endswith(".png")):
        list_of_origs.append(f)
        os.mkdir(cut_images_dir+f+'/')
        sliced_img = slice(IMAGES_DIR+f, 64, save=False)
        save_tiles(tiles=sliced_img, directory=cut_images_dir+f, prefix=f[:-4], format='png')


In [5]:
################################################
# Load cut images, create segmaps folder, make #
# predictions with trained model, save segmaps #
################################################
predictions_dir = WORKING_DIR + 'TEMP_predictions/'
if os.path.isdir(predictions_dir):
    shutil.rmtree(predictions_dir)
    os.mkdir(predictions_dir)
else:
    os.mkdir(predictions_dir)

model = get_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=[tf.keras.metrics.MeanIoU(num_classes=4)])
model.load_weights(WEIGHTS)

# for the number of original images
for i in range(len(list_of_origs)):
    print(f'\nSegmenting {i+1}th image:\n')
    cut_images_dict = load_images(cut_images_dir+list_of_origs[i]+'/')
    os.mkdir(predictions_dir+list_of_origs[i]+ '/')

# segment all 64 tiles of the original image
    counter = 1
    for img in cut_images_dict:
        current_image = cv2.imread(cut_images_dir+list_of_origs[i]+'/' + img,0)
        x_batch = []
        x_batch.append(current_image)
        x_batch = np.array(x_batch, np.float32) / 255.

        x_batch = np.expand_dims(x_batch, axis=3)
        x_batch = normalize(x_batch, axis=1)

        prediction = (model.predict(x_batch))
        predicted_img = np.argmax(prediction, axis=3)[0, :, :]

        cv2.imwrite(predictions_dir + list_of_origs[i] + '/' + img, predicted_img)
        print(f'\rSegmented {counter}/64 tiles', end='')
        counter += 1


Segmenting 1th image:

Segmented 64/64 tiles
Segmenting 2th image:

Segmented 64/64 tiles
Segmenting 3th image:

Segmented 64/64 tiles
Segmenting 4th image:

Segmented 64/64 tiles
Segmenting 5th image:

Segmented 64/64 tiles
Segmenting 6th image:

Segmented 64/64 tiles
Segmenting 7th image:

Segmented 64/64 tiles
Segmenting 8th image:

Segmented 64/64 tiles
Segmenting 9th image:

Segmented 64/64 tiles
Segmenting 10th image:

Segmented 64/64 tiles

In [6]:
################################################
# Join segmaps into final segmap of orig. dim. #
################################################

results_dir = WORKING_DIR + 'RESULT/'
if os.path.isdir(results_dir):
    shutil.rmtree(results_dir)
    os.mkdir(results_dir)
else:
    os.mkdir(results_dir)

for i in range(len(list_of_origs)):
    os.mkdir(results_dir+list_of_origs[i]+'/')

    grid = Image.new('RGBA',size=(4096, 4096), color=(153, 153, 255))
    prediction_images_dict = load_images(predictions_dir+list_of_origs[i]+'/')
    horizontal = 0
    vertical = 0
    counter = 1
    for img in prediction_images_dict:
        curr_img = Image.open(predictions_dir + list_of_origs[i] + '/' + img)
        grid.paste(curr_img,box=(horizontal,vertical))

        if counter%8 == 0:
            horizontal = 0
            vertical += 512
        else:
            horizontal += 512

        counter += 1

    mask_first_channel = grid.getchannel(0)
    mask_first_channel.save(results_dir+list_of_origs[i]+"/final_mask.png")

In [7]:
########################################
# Plot final segmap and original image #
########################################

for i in range(len(list_of_origs)):
    finalplot_prediction = results_dir + list_of_origs[i] + '/final_mask.png'
    original_img = cv2.imread(IMAGES_DIR+list_of_origs[i],0)

    prediction = imageio.imread(finalplot_prediction)

    plt.figure(figsize=(32, 32))
    plt.subplot(231)
    plt.title('Original')
    plt.imshow(original_img[:, :], cmap='gray')

    plt.subplot(232)
    plt.title('Prediction')
    plt.imshow(prediction[:, :], cmap='jet')

    plt.savefig(results_dir + list_of_origs[i] + '/orig_and_pred.png', bbox_inches='tight')
    plt.show()


    ######################################################
    # Convert segmap into image, where cardiomyocytes    #
    # of interest have 1 and other pixels 0 (binary img) #
    ######################################################
    arr_prediction = np.asarray(prediction)

    # val 2 is card. of interest
    pred_bin = np.zeros((4096,4096),dtype=np.int8)
    pred_bin_half = np.zeros((4096,4096),dtype=np.int8)
    for j in range(len(arr_prediction[1,:])):
        for k in range(len(arr_prediction[:,1])):
            if arr_prediction[j,k] == 2:
                pred_bin[j,k] = 1
                pred_bin_half[j,k] = 0
            elif arr_prediction[j,k] == 3:
                pred_bin[j,k] = 0
                pred_bin_half[j,k] = 1
            else:
                pred_bin[j,k] = 0
                pred_bin_half[j,k] = 0

    ##############################################
    # Perform binary closing on binarized segmap #
    ##############################################

    # print binary stuff
    # plt.imshow(data, interpolation='nearest')
    plt.figure(figsize=(8, 8))
    plt.imshow(pred_bin)
    plt.show()
    # perform CLOSING - Dilation followed by Erosion
    str_el = np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], np.int8)
    closed_pred_bin = binary_closing(pred_bin, iterations=7, structure=str_el)
    # conversion from bool "True, False" to "1, 0"
    closed_pred_bin = 255*closed_pred_bin

    # print closed binary stuff
    plt.figure(figsize=(8, 8))
    plt.imshow(closed_pred_bin)
    plt.show()

    temp_bin_pred_img = results_dir + 'temp_bin_pred_img.png'
    cv2.imwrite(temp_bin_pred_img, closed_pred_bin)
    im = cv2.imread(temp_bin_pred_img, cv2.IMREAD_GRAYSCALE)

    ##################################################
    # Find connected components and their properties.#
    # Then find bounding boxes and prepare for print #
    ##################################################
    
    conn_comp_output = cv2.connectedComponentsWithStats(im)
    # The first cell is the number of labels
    num_labels = conn_comp_output[0]
    # The second cell is the label matrix
    labels = conn_comp_output[1]
    # The third cell is the stat matrix 
        # cv2.CC_STAT_LEFT The leftmost (x) coordinate which is the inclusive start of the bounding box in the horizontal direction.
        # cv2.CC_STAT_TOP The topmost (y) coordinate which is the inclusive start of the bounding box in the vertical direction.
        # cv2.CC_STAT_WIDTH The horizontal size of the bounding box
        # cv2.CC_STAT_HEIGHT The vertical size of the bounding box
        # cv2.CC_STAT_AREA The total area (in pixels) of the connected component
    stats = conn_comp_output[2]
    # The fourth cell is the centroid matrix
    centroids = conn_comp_output[3]

    # select out the background
    max_area = 0
    max_area_index = 0
    for j in range(len(stats)):
        if stats[j][4] >= max_area:
            max_area = stats[j][4]
            max_area_index = j

    # if area of conn comp is bigger than T (=10000) -> save bounding boxes
    saved_bb_x = []
    saved_bb_y = []
    saved_bb_x_width = []
    saved_bb_y_height = []
    area_thresh = 10000
    margin = 30
    for j in range(len(stats)):
        if (stats[j][4] >= area_thresh) and (j is not max_area_index):
            saved_bb_x.append(0 if ((stats[j][0]-margin)<0) else (stats[j][0]-margin))
            saved_bb_y.append(0 if ((stats[j][1]-margin)<0) else (stats[j][1]-margin))
            saved_bb_x_width.append(4096 if ((stats[j][2]+margin)>4096) else (stats[j][2]+margin))
            saved_bb_y_height.append(4096 if ((stats[j][3]+margin)>4096) else (stats[j][3]+margin))

    ################################################
    # Print found bounding boxes over orig. image  #
    ################################################

    plt.figure(figsize=(16, 16))
    im = cv2.imread(IMAGES_DIR + list_of_origs[i])
    for j in range(len(saved_bb_x)):
        cv2.rectangle(im, (saved_bb_x[j], saved_bb_y[j]), (saved_bb_x[j]+saved_bb_x_width[j], saved_bb_y[j]+saved_bb_y_height[j]), (255,70,0), 3)
    cv2.imwrite(results_dir + list_of_origs[i] + '/orig_with_BBs.png',im) 
    implot = plt.imshow(im)

    ##################################################################
    # Crop content of bounding boxes and save them to results folder #
    ##################################################################

    cropped_results_dir = results_dir + list_of_origs[i] + '/crops/'
    os.mkdir(cropped_results_dir)
    orig_img_clean = cv2.imread(IMAGES_DIR+list_of_origs[i])

    num = 1
    for j in range(len(saved_bb_x)):
        crop = orig_img_clean[saved_bb_y[j]:saved_bb_y[j]+saved_bb_y_height[j], saved_bb_x[j]:saved_bb_x[j]+saved_bb_x_width[j]]
        crop_name = f'crop_{num}.png'
        cv2.imwrite(cropped_results_dir + crop_name, crop)
        num += 1

    if OUTPUT_HALFCOLLAPSED_CARD:
        closed_pred_bin_half = binary_closing(pred_bin_half, iterations=7, structure=str_el)
        closed_pred_bin_half = 255*closed_pred_bin_half
        temp_bin_pred_img_half = results_dir + 'temp_bin_pred_img_half.png'
        cv2.imwrite(temp_bin_pred_img_half, closed_pred_bin_half)
        im_half = cv2.imread(temp_bin_pred_img_half, cv2.IMREAD_GRAYSCALE)
        conn_comp_output_half = cv2.connectedComponentsWithStats(im_half)
        num_labels_half = conn_comp_output_half[0]
        labels_half = conn_comp_output_half[1]
        stats_half = conn_comp_output_half[2]
        centroids_half = conn_comp_output_half[3]
        max_area_half = 0
        max_area_index_half = 0
        for j in range(len(stats_half)):
            if stats_half[j][4] >= max_area_half:
                max_area_half = stats_half[j][4]
                max_area_index_half = j
        saved_bb_x_half = []
        saved_bb_y_half = []
        saved_bb_x_width_half = []
        saved_bb_y_height_half = []
        area_thresh_half = 10000
        margin_half = 50
        for j in range(len(stats_half)):
            if (stats_half[j][4] >= area_thresh_half) and (j is not max_area_index_half):
                saved_bb_x_half.append(0 if ((stats_half[j][0]-margin_half)<0) else (stats_half[j][0]-margin_half))
                saved_bb_y_half.append(0 if ((stats_half[j][1]-margin_half)<0) else (stats_half[j][1]-margin_half))
                saved_bb_x_width_half.append(4096 if ((stats_half[j][2]+margin_half)>4096) else (stats_half[j][2]+margin_half))
                saved_bb_y_height_half.append(4096 if ((stats_half[j][3]+margin_half)>4096) else (stats_half[j][3]+margin_half))
        
        im_half = cv2.imread(IMAGES_DIR + list_of_origs[i])
        for j in range(len(saved_bb_x_half)):
            cv2.rectangle(im_half, (saved_bb_x_half[j], saved_bb_y_half[j]), (saved_bb_x_half[j]+saved_bb_x_width_half[j], saved_bb_y_half[j]+saved_bb_y_height_half[j]), (255,70,0), 3)
        cv2.imwrite(results_dir + list_of_origs[i] + '/orig_with_halfcolapsed_BBs.png',im_half) 
        implot = plt.imshow(im)

        cropped_results_dir_half = results_dir + list_of_origs[i] + '/crops_halfcolapsed/'
        os.mkdir(cropped_results_dir_half)
        orig_img_clean_half = cv2.imread(IMAGES_DIR+list_of_origs[i])

        num = 1
        for j in range(len(saved_bb_x_half)):
            crop_half = orig_img_clean_half[saved_bb_y_half[j]:saved_bb_y_half[j]+saved_bb_y_height_half[j], saved_bb_x_half[j]:saved_bb_x_half[j]+saved_bb_x_width_half[j]]
            crop_name_half = f'crop_{num}.png'
            cv2.imwrite(cropped_results_dir_half + crop_name_half, crop_half)
            num += 1
            # Write out info about found centroids


    with open(results_dir + 'Centroids_coords.txt', 'w') as f:
        f.write('Cordinates of found bounding boxes of cardiomyocytes of interest are:\n')
        f.write('\nHorizontal coordinates of upper left corner:')
        f.write(str(saved_bb_x))
        f.write('\nVertical coordinates of upper left corners:')
        f.write(str(saved_bb_y))
        f.write('\nWidths:')
        f.write(str(saved_bb_x_width))
        f.write('\nHeight:')
        f.write(str(saved_bb_y_height))

Output hidden; open in https://colab.research.google.com to view.

In [8]:
# CLEANUP 
shutil.rmtree(predictions_dir)
shutil.rmtree(cut_images_dir)
os.remove(temp_bin_pred_img)
if OUTPUT_HALFCOLLAPSED_CARD:
    os.remove(temp_bin_pred_img_half)