# Alexander Xu (2024)

In [None]:
# Import libraries

import os
import shutil
import math
import random
import cv2

import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt

from skimage import io, color, exposure, filters, morphology
from skimage.feature import peak_local_max
from skimage.segmentation import watershed

from scipy import ndimage as ndi
from scipy.ndimage import label, generate_binary_structure

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.optimizers import SGD

!pip install albumentations
import albumentations as A

random.seed(7)

**Utility functions**

In [None]:
# Augment all images in a folder
def augment_folder(input_image_folder,input_mask_folder,output_image_folder,output_mask_folder):
    os.mkdir(output_image_folder, exist_ok=True)
    os.mkdir(output_mask_folder, exist_ok=True)
    
    # Augmentation settings via Albumentations library🐐🐐
    transform = A.Compose([
            A.CLAHE(p=0.2,clip_limit=6),
            A.Blur(p=0.2,blur_limit=10),
            A.ColorJitter(brightness=(0.5, 1), contrast=(0.5, 1), saturation=(0.5, 1), hue=(-0.5, 0.5), p=0.2),
            A.Sharpen(alpha=(0.2, 1.0), lightness=(0.2, 1.0), p=0.2),
            A.RGBShift(r_shift_limit=(-20, 20), g_shift_limit=(-20, 20), b_shift_limit=(-20, 20), p=0.2),
            A.Defocus(radius=(1, 20),alias_blur=(0.1, 1.0),p=0.2),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=20, val_shift_limit=20, p=0.2),
            A.VerticalFlip(p=0.2),
            A.RandomRotate90(p=0.2),
            A.RandomGridShuffle(grid=(3, 3), p=0.2),
            A.Flip(p=0.2),
            A.HorizontalFlip(p=0.2),
            A.ToGray(p=0.2),
            A.ChannelShuffle(p=0.2),
            A.ChannelDropout(channel_drop_range=(1, 1), fill_value=0, p=0.2),
            A.MultiplicativeNoise(multiplier=[0.5, 1.5], elementwise=True, p=0.2),
            A.OneOf([A.OpticalDistortion(p=0.2),A.GridDistortion(p=0.2)], p=0.2),
            A.PixelDropout(dropout_prob=0.02,drop_value=0,p=0.2)])
    
    for filename in tqdm(os.listdir(input_image_folder)):
        image = cv2.imread(os.path.join(input_image_folder, filename))
        mask = cv2.imread(os.path.join(input_mask_folder, filename))
        
        # copy original images
        cv2.imwrite(os.path.join(output_image_folder,f'aug_{filename}'),image)
        cv2.imwrite(os.path.join(output_mask_folder,f'aug_{filename}'),mask)
        
        for i in range(AUG_REPS):
            augmented = transform(image=image, mask=mask)        
            cv2.imwrite(os.path.join(output_image_folder,f'aug{i}_{filename}'),augmented['image'])
            cv2.imwrite(os.path.join(output_mask_folder,f'aug{i}_{filename}'),augmented['mask'])

# Tile all images in a folder
def tile_folder(image_folder,output_folder):
    os.mkdir(output_folder, exist_ok=True)
    
    for filename in tqdm(os.listdir(image_folder)):
        image_path = os.path.join(image_folder, filename)
        image = cv2.imread(image_path)
        img_shape = image.shape

        for i in range(img_shape[0] // TILE_SIZE):
            for j in range(img_shape[1] // TILE_SIZE):
                # Crop tile size section of image using row and col
                tiled_img = image[TILE_SIZE * i:min(TILE_SIZE * (i + 1), img_shape[0]), TILE_SIZE * j:min(TILE_SIZE * (j + 1), img_shape[1])]
                cv2.imwrite(os.path.join(output_folder, f'{i}_{j}_{filename}'), tiled_img)

# Load all training images and masks
def load_images_and_masks(image_folder,mask_folder):
    pos_images, pos_masks, neg_images, neg_masks = [],[],[],[]
    
    for filename in os.listdir(image_folder):
        image = cv2.imread(os.path.join(image_folder, filename))
        mask = cv2.imread(os.path.join(mask_folder, filename))
        
        if np.all(mask == 0):
            neg_images.append(image)
            neg_masks.append(mask)
        else:
            pos_images.append(image)
            pos_masks.append(mask)
    
    return pos_images, pos_masks, neg_images, neg_masks

**Data Preprocessing**

In [None]:
# Train Dataset Generation Parameters
TILE_SIZE = 64 # Divisble by 2^4 for U-Net architecture
AUG_REPS = 5

In [None]:
tile_folder('/data/images','/tiled_images')
tile_folder('/data/masks','/tiled_masks')
augment_folder('/tiled_images','/tiled_masks','/aug_images','/aug_masks')

pos_images, pos_masks, neg_images, neg_masks = load_images_and_masks('/aug_images','/aug_masks')
matched_indices = random.sample(range(len(neg_images)), len(pos_images)) # Equal number of 'positive' (tiles with droplets) as 'negatives' (tiles without droplets)
lipid_images = pos_images + [neg_images[i] for i in matched_indices]
lipid_masks = pos_masks + [neg_masks[i] for i in matched_indices]

# both train and val are augmented due to dataset size
X = np.zeros((len(lipid_images), TILE_SIZE, TILE_SIZE, 3))
Y = np.zeros((len(lipid_masks), TILE_SIZE, TILE_SIZE, 1))

for idx,image in enumerate(lipid_images):
    X[idx]=image

for idx,mask in enumerate(lipid_masks):
    Y[idx]=mask[:, :, :1]

**U-Net Model training**

In [None]:
def downsample_block(n_filters,prev_layer):
    c1 = tf.keras.layers.Conv2D(n_filters, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(prev_layer)
    c1 = tf.keras.layers.Dropout(0.1)(c1)
    c1 = tf.keras.layers.Conv2D(n_filters, (3, 3), activation='relu',kernel_initializer='he_normal', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)
    return c1,p1

def upsample_block(n_filters,prev_layer,skip_layer):
    u1 = tf.keras.layers.Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(prev_layer)
    u1 = tf.keras.layers.concatenate([u1, skip_layer])
    c1 = tf.keras.layers.Conv2D(n_filters, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u1)
    c1 = tf.keras.layers.Dropout(0.1)(c1)
    c1 = tf.keras.layers.Conv2D(n_filters, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    return c1,u1

# Creates U-Net model
def build_unet_model():
    # Input layer
    inputs = tf.keras.layers.Input((TILE_SIZE, TILE_SIZE, 3))
    
    # Normalization
    s = tf.keras.layers.Lambda(lambda x: x / 255)(inputs)

    # Contraction Path
    c1,p1=downsample_block(16,s)
    c2,p2=downsample_block(32,p1)
    c3,p3=downsample_block(64,p2)
    c4,p4=downsample_block(128,p3)
    
    # Bottleneck layer
    c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = tf.keras.layers.Dropout(0.2)(c5)
    c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

    # Expansion Path
    c6,u6=upsample_block(128,c5,c4)
    c7,u7=upsample_block(64,c6,c3)
    c8,u8=upsample_block(32,c7,c2)
    c9,u9=upsample_block(16,c8,c1)

    # Output layer for density map
    outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)
    
    return inputs,outputs

# Custom loss function
def FocalTverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, gamma=GAMMA, smooth=1e-6):
    #flatten label and prediction tensors
    inputs = Flatten()(inputs)
    targets = Flatten()(targets)

    TP = K.sum((inputs * targets))
    FP = K.sum(((1-targets) * inputs))
    FN = K.sum((targets * (1-inputs)))

    Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
    FocalTversky = K.pow((1 - Tversky), gamma)

    return FocalTversky

In [None]:
# Model Loss Parameters
ALPHA = 0.5
BETA = 0.5
GAMMA = 1

# U-Net Parameters
EPOCHS = 40
BATCH_SIZE = 32
VAL_SPLIT = 0.2
LEARNING_RATE = 5e-4
PATIENCE = 3

In [None]:
# Model initialization
inputs, outputs=build_unet_model()
model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),loss='binary_crossentropy',metrics=["accuracy"])

# checkpointer = tf.keras.callbacks.ModelCheckpoint(' /lipid_model.keras',verbose=1, save_best_only=True)
callbacks = [tf.keras.callbacks.TensorBoard(log_dir='logs'),tf.keras.callbacks.EarlyStopping(patience=PATIENCE, monitor='val_loss')]

# Model train
history = model.fit(X, Y, validation_split=VAL_SPLIT, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks)

In [None]:
# Train/val accuracy plot
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

# Train/val loss plot
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

**Density map postprocessing**

In [None]:
def normalize_image(image):
    return tf.cast(image, tf.float32) / 255.0

def to_binary(image):
    _,binary = cv2.threshold(image, 0.2, 1, cv2.THRESH_BINARY)
    return binary.astype(np.float32)

# Tile input image for inference
def tile_image(image):
    tiles = []
    img_shape = image.shape
    
    for i in range(img_shape[0] // TILE_SIZE):
        for j in range(img_shape[1] // TILE_SIZE):
            tiled_img = image[
                TILE_SIZE * i:min(TILE_SIZE * (i + 1), img_shape[0]),
                TILE_SIZE * j:min(TILE_SIZE * (j + 1), img_shape[1])
            ]
            tiles.append(tiled_img)
    
    return tiles

# Recombine tiles
def stitch_image(tiles,image):
    img_shape = image.shape
    rows = [
        np.concatenate(tiles[row_i * (img_shape[1] // TILE_SIZE):(row_i + 1) * (img_shape[1] // TILE_SIZE)], axis=1)
        for row_i in range(img_shape[0] // TILE_SIZE)
    ]
    
    return np.concatenate(rows,axis=0)

# Recombine centroid predictions
def stitch_centroids(centroids,image):
    img_shape = image.shape
    res = [(
            tile_centroid[0] + (i // (img_shape[1] // TILE_SIZE)) * TILE_SIZE,
            tile_centroid[1] + (i % (img_shape[1] // TILE_SIZE)) * TILE_SIZE)
        for i, tile_centroids in enumerate(centroids)
        for tile_centroid in tile_centroids]
    
    return res

# Extract cell background
def extract_stain(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = exposure.equalize_adapthist(gray, clip_limit=0.03)
    gray = filters.rank.mean_bilateral(gray, morphology.disk(30))
    gray = cv2.fastNlMeansDenoising(gray.astype(np.uint8), h=10)
    gray = cv2.bitwise_not(gray)
    gray = filters.unsharp_mask(gray, radius=1, amount=1)
    
    binary = gray > filters.threshold_otsu(gray)   
    binary = morphology.dilation(binary, morphology.disk(1))
    binary = morphology.remove_small_objects(binary.astype(bool), min_size=200)
    
    return binary.astype(np.float32)

def detect_centroids(feature_map):
    labeled_array, num_features = label(feature_map)
    centers = ndi.center_of_mass(feature_map, labeled_array, range(1, num_features+1))
    return centers

# Exclude centroids not present on stains
def filter_centroids(centroids,stain):
    return [centroid for centroid in centroids if stain[int(centroid[0])][int(centroid[1])]]

# Main function to make predictions on entire folder
def predict_folder(image_folder,model,verbose=False):
    res = {}
    for filename in os.listdir(image_folder):
        n_features = predict_image(os.path.join(image_folder, filename), model, verbose = verbose)
        res[filename] = n_features

        if verbose:
            print(f'{filename}: {n_features}')
    
    return res
        
# Make prediction on a single file
def predict_image(image_path,model,verbose=False):
    image = cv2.imread(image_path, cv2.COLOR_BGR2RGB)
    
    # Image pre processing
    working_image = normalize_image(image)
    stain = extract_stain(image)
    tiles = tile_image(working_image)
    
    # U-Net prediction
    preds = model.predict(np.array(tiles), verbose=False)
#     pred_image=stitch_image(preds,image)    # Predictions can be visualized if needed
    binaries = [to_binary(pred) for pred in preds]
    
    # Centroid identification
    centroids = [detect_centroids(binary) for binary in binaries]
    scaled_centroids = stitch_centroids(centroids,image)
    all_centroids = filter_centroids(scaled_centroids,stain)

    # Debug
    if verbose:
        grid_color = [0,0,0]
        image[:,::TILE_SIZE,:] = grid_color
        image[::TILE_SIZE,:,:] = grid_color
        
        plt.figure(figsize=(16,10))
        plt.grid(False)
        plt.axis('off')
        plt.imshow(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        
        x, y = np.array(all_centroids).T
        plt.scatter(y, x, s = 10, c = 'w')
        plt.show()
            
    return len(all_centroids)

**Pipeline testing**

In [None]:
inputs,outputs=build_unet_model()
model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
model.load_weights('/unetmodel.weights.h5')

predict_folder('/test/images',model,verbose=True)