# Drosophila Heart OCM Image Segmentation (FlyNet 3.0) training code
Author: Xiangping Ouyang

Date: March 6th, 2024

Description: This notebook contains training code to build the FlyNet3.0 model. A complete description of the model can be found within the "An Attention LSTM U-Net model for Drosophila melanogaster heart tube segmentation in optical coherence microscopy images" manuscript. Data directory paths need to be updated before running.

Requirements:

Python 3.9
Libraries: cudatoolkit=11.8.0 cudnn=8.9.2 tensorflow=2.10.1 scikit-image opencv-python

In [None]:
# Import necessary Tensorflow and Keras libraries 
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import ConvLSTM2D, TimeDistributed, Conv2D, Conv3D, BatchNormalization, \
    Activation, MaxPooling2D, Input, LeakyReLU, Conv2DTranspose, Concatenate, multiply, add
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras.losses import binary_crossentropy
import loss_funcs as lf

# Import numpy and data management
import numpy as np
from numpy.random import choice
import pandas as pd
import os
import pickle
from random import shuffle
from random import randint
import skimage.io
import cv2

# Import tensorboard for data visualization (optional)
%load_ext tensorboard

# Set constants for training model 
type = "full"   # type of file that will be read (full: full size images)
# Directory to the input data files (need to change)
database_path = "F:/Xiangping/flynet/final_tunning/training_01102024.csv"
# Name of your model (subject to change)
model_name = "flynet3.0"
# Directory to save training log files and output models (need to change)
log_directory = "F:/Xiangping/flynet/final_tunning/log/" 


In [None]:
# Helper functions for reading and processing input files 
def read_database(path):
    """
    Get the path of the input file folder directory and split them to training and validation datasets 

    Parameters:
    path (str): path to the csv file containing the input file folder directory 

    Returns: 
    data (np.array): array of path to all the samples 
    train_ids (np.array): array of indices for the samples in the training dataset
    val_ids (np.array): array of indices for the samples in the testing dataset 
    """
    df = pd.read_csv(path)

    # Filter out all the rows that have a weight less than or equal to 0
    df = df[df['weight'] > 0]

    # Save this dataframe as a csv file (file name subject to change)
    df.to_csv("testdatabase.csv", index=False)

    data = df['local_dir'].tolist()
    weights = df['weight'].tolist()
    data=np.array(data)
    weights=np.array(weights)
    
    available_ids = np.array(range(len(data)))

    # Come up with a random permutation of the data
    # Then apply the same permutation to the weights and the available ids
    permutation = np.random.permutation(len(available_ids))
    available_ids = available_ids[permutation]
    weights = weights[permutation]

    # adjust training and validation percentage according to the total number of input files  
    if(len(available_ids)>=70):
        multiplier = 0.95
    else:
        multiplier = 0.9
    final_train_id = int(len(available_ids)*multiplier)
    train_ids = available_ids[:final_train_id]
    val_ids = available_ids[final_train_id:]

    weights = weights[:final_train_id]
    weights = weights/np.sum(weights)

    return data, train_ids, val_ids, weights

def getImg(path):
    """
    Get the path of the full image file directory 

    Parameters:
    path (str): path to the input file folder directory 

    Returns: 
    file_path (str): path to the full image file
    """
    global type
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "img" in the name
    search_string = type + "_img.tiff"

    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print(path)
    print("No resize file found")

def getMask(path):
    """
    Get the path of the full mask file directory 

    Parameters:
    path (str): path to the input file folder directory 

    Returns: 
    file_path (str): path to the full mask file
    """
    global type
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "mask" in the name
    search_string = type + "_mask.tiff"
    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No mask file found")

def getResizeImg(path):
    """
    Get the path of the resize image file directory 

    Parameters:
    path (str): path to the input file folder directory 

    Returns: 
    file_path (str): path to the resized image file
    """
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "resize_img" in the name
    search_string = "resize_img.tiff"

    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No resize file found")

def getResizeMask(path):
    """
    Get the path of the resize mask file directory 

    Parameters:
    path (str): path to the input file folder directory 

    Returns: 
    file_path (str): path to the resized mask file
    """
    # Get a list of files in the directory
    files = os.listdir(path)
    # Get the file with "resize_mask" in the name
    search_string = "resize_mask.tiff"
    for file in files:
        if search_string in file:
            file_path = os.path.join(path, file).replace("\\","/")
            return file_path
    print("No mask file found")

def centerVideo(video):
    """
    Center the video by subtracting the mean of all the frames and dividing by the standard deviation

    Parameters: 
    video (np.array): array of video frames

    Returns: 
    centered_video (np.array): array of centered video frames 
    """
    mean = np.mean(video)
    std = np.std(video)
    centered_video = (video - mean) / std
    return centered_video

In [17]:
# Generate training data
def generateData(data, available_ids, batch_size, ws):
    """
    Generate training data 

    Parameters: 
    data (np.array): array of paths to individual samples
    available_ids (np.array): array of available indices
    batch_size (int): number of samples to generate

    Yield:
    outputX (np.array): array of training samples
    outputY (np.array): array of training labels
    """
    augment = True
    while True:
        # Choose two random IDs from the available IDs
        # INCREASE NUMBER OF SAMPLES HERE IF GPU MEMORY ALLOWS
        s = choice(available_ids, size=3, replace=False, p=ws)
        outputX = []
        outputY = []
        for i in s:
            # Read the image at that ID and convert it to a numpy array
            dir_path = data[i]
            mask_path = getMask(dir_path)
            resize_path = getImg(dir_path)
            img = skimage.io.imread(resize_path)
            img = np.array(img)
            img = np.squeeze(img)
           
            # Read the mask file and convert it to a numpy array
            img_mask=skimage.io.imread(mask_path)
            img_mask=np.array(img_mask)
            if(len(img_mask.shape) > 3):
                img_mask = (img_mask[:,:,:,0]>0.5)*1.0
            else:
                img_mask = (img_mask>0.5)*1.0
            
            # Add a singleton dimension so that all images have a color channel
            train = np.array(img)
            train=train[...,np.newaxis]
            y=np.array(img_mask)
            y=y[...,np.newaxis]
            
            # For training samples shorten to 128 frames per step
            last_start = train.shape[0] - batch_size
            start_loc = randint(0, last_start)
            end_loc = start_loc + batch_size
            train = train[start_loc:end_loc]
            y = y[start_loc:end_loc]

            # The shape of y is [frame, height, width, 1]
            # y is a mask with value 0 or 1, find the max and min x and y coordinates of the mask
            max_x = np.max(np.where(y == 1)[2])
            min_x = np.min(np.where(y == 1)[2])
            max_y = np.max(np.where(y == 1)[1])
            min_y = np.min(np.where(y == 1)[1])

            min_x = 3 if min_x <= 15 else min_x - 12
            min_y = 3 if min_y <= 15 else min_y - 12
            max_x = 124 if max_x >= 112 else max_x + 12
            max_y = 596 if max_y >= 584 else max_y + 12

            low = min_y - 30 if min_y - 30 > 1 else 1
            high = max_y + 30 if max_y + 30 < y.shape[1] - 1 else y.shape[1] - 1

            # Now select the boundaries for the crop
            # Make sure the whole mask is in the frame
            crop_x_min = randint(2, min_x)
            crop_x_max = randint(max_x, 125)
            crop_y_min = randint(low, min_y)
            crop_y_max = randint(max_y, high)

            # Crop the image and mask
            train_roi = train[:, crop_y_min:crop_y_max, crop_x_min:crop_x_max]
            y_roi = y[:, crop_y_min:crop_y_max, crop_x_min:crop_x_max]

            # Now interpolate the image and mask to the original size
            train_list = []
            y_list = []
            for i in range(len(train_roi)):
                train_i = cv2.resize(train_roi[i], (128, 128),interpolation=cv2.INTER_CUBIC)
                train_list.append(train_i)
                y_i = cv2.resize(y_roi[i], (128, 128),interpolation=cv2.INTER_CUBIC)
                y_list.append(y_i)

            train = np.array(train_list)
            y = np.array(y_list)
            
            if augment:
                # Apply random rotation/flip augmentation
                aug = randint(0, 2) # Equal chance for each
                if aug==0:
                    aug_x = train
                    aug_y = y
                elif aug==1:
                    aug_x = np.flip(train, 1)
                    aug_y = np.flip(y, 1)
                elif aug==2:
                    aug_x = np.flip(train, 2)
                    aug_y = np.flip(y, 2)
                elif aug==3:
                    aug_x = np.flip(train, 0)
                    aug_y = np.flip(y, 0)
                    
                # Cast to uint8 before yield
                train = aug_x.astype('float32')
                y = aug_y.astype('float32')

                # Normalize the image
                train = centerVideo(train)

            else:
                train = train.astype('float32')
                y = y.astype('float32')

                # Normalize the image
                train = centerVideo(train)

            outputX.append(train)
            outputY.append(y)

        outputX = np.array(outputX)
        outputY = np.array(outputY)

        yield (outputX, outputY)

In [19]:
# Generate validation data 
def readFiles(data, ids):
    """
    Generate validation data 

    Parameters: 
    data (np.array): array of paths to individual samples
    ids (np.array): array of available indices

    Returns:
    image (np.array): array of validation samples
    y (np.array): array of validation labels
    """
    # make an empty array to hold 
    train_list = []
    mask_list = []
    for i in ids:
        dir_path = data[i]
        mask_path = getResizeMask(dir_path)
        resize_path = getResizeImg(dir_path)
        img = skimage.io.imread(resize_path)
        img = np.array(img)
        img = np.squeeze(img)
        
        # Read the mask file and convert it to a numpy array
        img_mask=skimage.io.imread(mask_path)
        img_mask=np.array(img_mask)
        if(len(img_mask.shape) > 3):
            img_mask = (img_mask[:,:,:,0]>0.5)*1.0
        else:
            img_mask = (img_mask>0.5)*1.0
        
        # Add a singleton dimension so that all images have a color channel
        length=len(img)
        image = np.array(img)
        image=image[:length,:,:,np.newaxis]
        y=np.array(img_mask)
        y=y[:length,:,:,np.newaxis]

        startidx = 0
        endidx = 32
        for i in range(length//64):
            train_list.append(centerVideo(image[startidx:endidx, ...].astype('float32')))
            mask_list.append(y[startidx:endidx, ...].astype('float32'))
            startidx += 64
            endidx += 64

    # Convert the lists into numpy arrays combining the first dimension
    image = np.array(train_list)
    y = np.array(mask_list)
    
    return image, y

In [None]:
# Define the different metrics for measuring the performance of the model
def dice_coeff(y_true, y_pred):
    """
    Calculate the dice coefficient between the true and predicted masks

    Parameters:
    y_true (np.array): array of true masks
    y_pred (np.array): array of predicted masks

    Returns:
    score (float): dice coefficient
    """
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    y_true_f = tf.cast(y_true_f, tf.float32)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    """
    Calculate the dice loss between the true and predicted masks

    Parameters:
    y_true (np.array): array of true masks
    y_pred (np.array): array of predicted masks

    Returns:
    loss (float): dice loss
    """
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

def bce_dice_loss(y_true, y_pred):
    """
    Calculate the binary cross entropy dice loss between the true and predicted masks

    Parameters:
    y_true (np.array): array of true masks
    y_pred (np.array): array of predicted masks

    Returns:
    loss (float): binary cross entropy dice loss
    """
    loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

In [None]:
# Define the model structure for each blocks, more details explaination can 
# be found under "2.3 General network structure" section from the manuscript 
def attention_block(input, num_filters, skip_features):
    g1 = Conv2D(num_filters/2, (2,2), padding="same")(input)
    g1 = BatchNormalization()(g1)

    x1 = Conv2D(num_filters/2, (1,1), padding="same")(skip_features)
    x1 = BatchNormalization()(x1)

    psi = LeakyReLU(alpha=0.2)(add([g1, x1]))
    psi = Conv2D(1, 1, padding="same")(psi)
    psi = BatchNormalization()(psi)
    psi = Activation('sigmoid')(psi)

    out = multiply([skip_features, psi])
    return out

def conv_block(input, num_filters):
    x = TimeDistributed(Conv2D(num_filters, 5, padding="same"))(input)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    return x

def init_block(input, num_filters):
    x = ConvLSTM2D(num_filters, 5, padding="same", return_sequences=True)(input)
    x = BatchNormalization()(x)
    return x

def first_block(input, num_filters):
    x = init_block(input, num_filters)
    x = conv_block(x, num_filters)
    return x

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    x = conv_block(x, num_filters)
    return x

def decoder_block(input, num_filters, skip_features):
    x = TimeDistributed(Conv2DTranspose(num_filters, (2, 2), strides=(2,2), padding="same"))(input)
    a = attention_block(x, num_filters, skip_features)
    x = Concatenate()([x, a])
    x = encoder_block(x, num_filters)
    return x

def last_block(input, num_filters, skip_features):
    x = TimeDistributed(Conv2DTranspose(num_filters, (2, 2), strides=(2,2), padding="same"))(input)
    a = attention_block(x, num_filters, skip_features)
    x = Concatenate()([x, a])
    x = first_block(x, num_filters)
    return x

def create_model(input_shape=(None, 128, 128, 1)):
    inputs = Input(shape=input_shape)

    s1 = first_block(inputs, 32)
    p1 = TimeDistributed(MaxPooling2D((2, 2)))(s1)

    s2 = encoder_block(p1, 64)
    p2 = TimeDistributed(MaxPooling2D((2, 2)))(s2)

    s3 = encoder_block(p2, 128)
    p3 = TimeDistributed(MaxPooling2D((2, 2)))(s3)

    m3 = encoder_block(p3, 256)

    d3 = decoder_block(m3, 128, s3)

    d2 = decoder_block(d3, 64, s2)

    d1 = last_block(d2, 32, s1)

    classify = Conv3D(1, (1, 1, 1), padding="same", activation='sigmoid')(d1)

    model = Model(inputs=inputs, outputs=classify)
    return model

In [None]:
# Define the main function that will train the model
def main():
    # read the input files 
    data, train_ids,val_ids, ws = read_database(database_path)
    print(f"Length of training set: {len(train_ids)}, Length of validation set: {len(val_ids)}")
    
    # prepare training
    K.set_image_data_format('channels_last')
    batch_size = 32
    steps_per_epoch = len(train_ids)*10//batch_size
    val, y_val = readFiles(data, val_ids)

    # create a new model for training or load pre-trained model
    model = create_model()
    # model = load_model("pre_trained_model.h5", custom_objects = {'dice_coeff': dice_coeff, 'bce_dice_loss': bce_dice_loss, "focal_tversky": lf.Semantic_loss_functions().focal_tversky, "log_cosh_dice_loss": lf.Semantic_loss_functions().log_cosh_dice_loss})
    
    # set model parameters 
    lr = 1e-4
    loss_function = lf.Semantic_loss_functions().log_cosh_dice_loss
    model.compile(
        loss=loss_function,
        optimizer=Adam(lr),
        metrics=[
            tf.keras.metrics.MeanIoU(num_classes=2),
            tf.keras.metrics.Recall(),
            tf.keras.metrics.Precision(),
            dice_coeff
        ]
    )
    
    # start training 
    model_checkpoint = ModelCheckpoint(log_directory + model_name + "{epoch:02d}.h5", monitor='val_loss', save_best_only=False)
    # uncomment this line to use tensorboard (optional)
    # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_directory, histogram_freq=1)
    model.fit(generateData(data, train_ids, batch_size, ws), steps_per_epoch=steps_per_epoch, epochs=150, verbose=1, validation_data=(val, y_val), callbacks=[model_checkpoint, tensorboard_callback])
    
    return data, val_ids

data, val = main()