In [None]:
# https://github.com/qubvel/segmentation_models
!pip install segmentation-models

In [None]:
import segmentation_models as sm
sm.set_framework("tf.keras")
sm.framework()

In [None]:
!nvidia-smi

In [None]:
!cat /proc/cpuinfo

In [None]:
!python --version

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd drive/MyDrive/nuclei_segmentation

In [None]:
# https://github.com/dovahcrow/patchify.py
!pip install patchify

In [None]:
import os
import glob
import random
import shutil
import numpy as np

In [None]:
def create_path(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
train_dir = "dataset/monuseg/stain_normalized/train"
val_dir = "dataset/monuseg/stain_normalized/validation"

In [None]:
def make_validation_set(val_dir, train_dir, val_size=6, seed=42, fold=None):

    create_path(val_dir)
    create_path(os.path.join(val_dir, "tissue_images"))
    create_path(os.path.join(val_dir, "binary_masks"))
    create_path(os.path.join(val_dir, "instance_masks"))
    create_path(os.path.join(val_dir, "modified_masks"))
    
    for j in sorted(glob.glob(os.path.join(val_dir, "tissue_images", "*"))):
        try:
            shutil.move(j, os.path.join(train_dir, "tissue_images"))
            shutil.move(j.replace("tissue_images", "binary_masks").replace("tif", "png"), 
                        os.path.join(train_dir, "binary_masks"))
            shutil.move(j.replace("tissue_images", "instance_masks").replace("tif", "npy"), 
                        os.path.join(train_dir, "instance_masks"))
            shutil.move(j.replace("tissue_images", "modified_masks").replace("tif", "png"), 
                        os.path.join(train_dir, "modified_masks"))
        except:
            continue

    images_lst = sorted(glob.glob(os.path.join(train_dir, "tissue_images", "*")))
    np.random.seed(seed)
    np.random.shuffle(images_lst)
    if fold is None:
        random.seed(seed)
        val_lst = random.sample(images_lst, val_size)
    else:
        val_lst = images_lst[(fold*val_size)-val_size: fold*val_size]
        

    for i in val_lst:
        shutil.move(i, os.path.join(val_dir, "tissue_images"))
        shutil.move(i.replace("tissue_images", "binary_masks").replace("tif", "png"), 
                    os.path.join(val_dir, "binary_masks"))
        shutil.move(i.replace("tissue_images", "instance_masks").replace("tif", "npy"), 
                    os.path.join(val_dir, "instance_masks"))
        shutil.move(i.replace("tissue_images", "modified_masks").replace("tif", "png"), 
                    os.path.join(val_dir, "modified_masks"))
        
    print(f"Validation list: {[os.path.basename(i) for i in val_lst]}")

In [None]:
make_validation_set(val_dir, train_dir, val_size=6, fold=1)

In [None]:
len(os.listdir(os.path.join(val_dir, "modified_masks")))

In [None]:
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from patchify import patchify, unpatchify

In [None]:
train_dir = "dataset/monuseg/stain_normalized/train/tissue_images"
train_maskdir = "dataset/monuseg/stain_normalized/train/binary_masks"

val_dir   = "dataset/monuseg/stain_normalized/validation/tissue_images"
val_maskdir = "dataset/monuseg/stain_normalized/validation/binary_masks"

In [None]:
W = 1024
H = 1024

patch_size = (256, 256, 3)
all_img_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(train_dir, "*"))), total=len(os.listdir(train_dir))):
    single_img = cv2.imread(x, cv2.IMREAD_COLOR)
    single_img = cv2.cvtColor(single_img, cv2.COLOR_BGR2RGB)
    single_img = cv2.resize(single_img, (W, H), interpolation=cv2.INTER_LINEAR)
    # patchify
    single_img_patches = patchify(single_img, patch_size=patch_size, step=128)
    # squeeze
    single_img_patches = np.squeeze(single_img_patches)

    for i in range(single_img_patches.shape[0]):
        for j in range(single_img_patches.shape[1]):
            all_img_patches.append(single_img_patches[i, j])
    
train_images = np.array(all_img_patches)

In [None]:
train_images.shape

In [None]:
all_img_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(val_dir, "*"))), total=len(os.listdir(val_dir))):
    single_img = cv2.imread(x, cv2.IMREAD_COLOR)
    single_img = cv2.cvtColor(single_img, cv2.COLOR_BGR2RGB)
    single_img = cv2.resize(single_img, (W, H), interpolation=cv2.INTER_LINEAR)
    # patchify
    single_img_patches = patchify(single_img, patch_size=patch_size, step=128)
    # squeeze
    single_img_patches = np.squeeze(single_img_patches)

    for i in range(single_img_patches.shape[0]):
        for j in range(single_img_patches.shape[1]):
            all_img_patches.append(single_img_patches[i, j])

val_images = np.array(all_img_patches)

In [None]:
val_images.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(train_maskdir, "*"))), total=len(os.listdir(train_maskdir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            all_mask_patches.append(single_mask_patches[i, j])

train_masks = np.array(all_mask_patches)

In [None]:
train_masks.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(val_maskdir, "*"))), total=len(os.listdir(val_maskdir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            all_mask_patches.append(single_mask_patches[i, j])

val_masks = np.array(all_mask_patches)

In [None]:
val_masks.shape

In [None]:
# sanity check
rnd = np.random.randint(len(train_images))
# rnd = 222

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
[axi.set_axis_off() for axi in ax.ravel()]

ax[0].imshow(train_images[rnd])
ax[0].set_title("Tissue Image")

ax[1].imshow(train_masks[rnd])
ax[1].set_title("Mask")

plt.tight_layout()
plt.show()

In [None]:
import albumentations as A
from keras.utils import Sequence

In [None]:
class DataGenerator(Sequence):
    'Generates data for Keras'
    def __init__(self, images, masks, augmentations=None, batch_size=8, img_size=256, n_channels=3, shuffle=True):
        'Initialization'
        self.batch_size = batch_size
        
        self.images = images
        self.masks = masks

        self.img_size = img_size
        
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.augment = augmentations
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.images) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indices of the batch
        indices = self.indices[index * self.batch_size: min((index + 1) * self.batch_size, len(self.images))]

        # Generate data
        X, y = self.data_generation(indices)

        if self.augment is None:
            return np.array(X), np.array(y)
        else:            
            im, mask = [], []
            for x, y in zip(X, y):
                augmented = self.augment(image=x, mask=y)
                im.append(augmented['image'])
                mask.append(augmented['mask'])

            return np.array(im), np.array(mask)

    def on_epoch_end(self):
        'Updates indices after each epoch'
        self.indices = np.arange(len(self.images))
        if self.shuffle == True:
            np.random.shuffle(self.indices)

    def data_generation(self, indices):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((len(indices), self.img_size, self.img_size, self.n_channels))
        y = np.empty((len(indices), self.img_size, self.img_size, 1))
        # Generate data
        for n, i in enumerate(indices):
            X[n] = self.images[i]
            y[n] = (self.masks[i]/255.)[..., np.newaxis]

        return np.uint8(X), np.float32(y)

In [None]:
AUGMENTATIONS_TRAIN = A.Compose([
    A.Rotate(limit=360, p=0.5),
    A.OneOf([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        ], p=0.5),
    A.OneOf([
        A.RandomBrightnessContrast(),
        A.RandomGamma(),
        A.GaussNoise()
         ], p=0.3),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        # A.ElasticTransform(alpha=3, sigma=150, alpha_affine=150),
        # A.ElasticTransform(),
        A.Affine(translate_percent=0.2, shear=30, mode=cv2.BORDER_CONSTANT),
        A.GridDistortion(),
        A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
        ], p=0.3),
    A.OneOf([
        A.RGBShift(r_shift_limit=40, g_shift_limit=40,  b_shift_limit=40),
        A.ColorJitter(hue=0.1),
        A.Blur(blur_limit=3)
        ], p=0.3),
    A.ToFloat(max_value=255)
], p=1)

AUGMENTATIONS_VAL = A.Compose([
    A.ToFloat(max_value=255)
], p=1)

In [None]:
# Single tissue image with 256*256 tiles (50% overlap between tiles) without augmentation
a = DataGenerator(train_images, train_masks, batch_size=49, augmentations=None, shuffle=False)
images, masks = a.__getitem__(0)

max_images = 49
grid_width = 7
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*2, grid_height*2))

for i,(im, mask1) in enumerate(zip(images, masks)):
    ax = axs[int(i / grid_width), i % grid_width]
    ax.imshow(im)
    ax.imshow(mask1.squeeze(), alpha=0.4)
    ax.axis('off')

print(mask1.shape)
plt.tight_layout()
plt.show()

In [None]:
# Same tissue image with augmentations
a = DataGenerator(train_images, train_masks, batch_size=49, augmentations=AUGMENTATIONS_TRAIN, shuffle=False)
images, masks = a.__getitem__(0)

max_images = 49
grid_width = 7
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*2, grid_height*2))

for i,(im, mask1) in enumerate(zip(images, masks)):
    ax = axs[int(i / grid_width), i % grid_width]
    ax.imshow(im)
    ax.imshow(mask1.squeeze(), alpha=0.4)
    ax.axis('off')

print(mask1.shape)
plt.tight_layout()
plt.show()

In [None]:
import tensorflow as tf
import keras
print(tf.__version__)
print(keras.__version__)

In [None]:
backbone = "vgg19"
input_shape = (256, 256, 3)
model = sm.Unet(
             backbone_name=backbone,
             input_shape=input_shape,
             )

In [None]:
model.summary()

In [None]:
from keras.utils import plot_model

In [None]:
plot_model(model, show_shapes=True, dpi=330, to_file="ُStandardUnet.png")

In [None]:
import keras.backend as K
from keras.optimizers import Adam

In [None]:
# https://github.com/yingkaisha/keras-unet-collection
!pip install keras_unet_collection

In [None]:
from keras_unet_collection import losses

def binary_loss(y_true, y_pred):

    loss_focal = losses.focal_tversky(y_true, y_pred, alpha=0.5, gamma=4/3)
    loss_iou = losses.iou_seg(y_true, y_pred)
    
    # (x) 
    # loss_ssim = losses.ms_ssim(y_true, y_pred, max_val=1.0, filter_size=4)
    
    return loss_focal+loss_iou #+loss_ssim

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

In [None]:
model.compile(optimizer=Adam(1e-4),
              loss=binary_loss,
              metrics=dice_coef)

In [None]:
from keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping, Callback, LearningRateScheduler
from tensorflow.keras.optimizers.schedules import PolynomialDecay
import datetime

from skimage import filters
from scipy.ndimage import measurements
from skimage.segmentation import watershed, mark_boundaries

# Visualize training 
class loss_history(Callback):

    def __init__(self, x=4):
        self.x = x
    
    def on_epoch_begin(self, epoch, logs={}):
        fig, ax = plt.subplots(1, 4, figsize=(18, 12))
        [axi.set_axis_off() for axi in ax.ravel()]

        ax[0].imshow(train_images[self.x])
        ax[0].set_title("Tissue Image")

        ax[1].imshow(train_masks[self.x], cmap="gray")
        ax[1].set_title("Ground Truth")

        model_sample_input = train_images[self.x].astype("float32") / 255.
        preds_train = self.model.predict(np.expand_dims(model_sample_input, axis=0), verbose=0)
        preds1 =  np.squeeze(preds_train[0]) >= 0.5
        ax[2].imshow(preds1, cmap="gray")
        ax[2].set_title("Nuclei prediction")

        ax[3].imshow(mark_boundaries(train_images[self.x], preds1, color=(0, 0, 1)))
        ax[3].set_title("Result")

        plt.tight_layout()
        plt.show()

lr_scheduler = PolynomialDecay(
    initial_learning_rate=1e-4,
    end_learning_rate=1e-7,
    decay_steps=40
)

# lr_scheduler = tf.keras.optimizers.schedules.CosineDecay(
#     1e-4, 100, alpha=0.0, name=None
# )

class ADJUSTLR(keras.callbacks.Callback):
    def __init__ (self, model, freq, factor, verbose):
        self.model=model
        self.freq=freq
        self.factor =factor
        self.verbose=verbose
        self.adj_epoch=freq
    def on_epoch_end(self, epoch, logs=None):
        if epoch + 1 == self.adj_epoch: # adjust the learning rate
            lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate
            new_lr=lr * self.factor
            self.adj_epoch +=self.freq
            if self.verbose == 1:
                print('\non epoch ',epoch + 1, ' lr was adjusted from ', lr, ' to ', new_lr)
            tf.keras.backend.set_value(self.model.optimizer.lr, new_lr) # set the learning rate in the optimizer


create_path("logs/Unet")
csv_log = CSVLogger('logs/Unet/SimpleUnet_fold1_log00.csv', separator=',')
model_name = f"SimpleUnet_fold1_v00_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}"
create_path("models/Unet")
path_to_save_model = "models/Unet/" + model_name + ".h5"
checkpointer = ModelCheckpoint(path_to_save_model, verbose=1, save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-7, verbose=1)
# reduce_lr = LearningRateScheduler(schedule=lr_scheduler)
early_stop   = EarlyStopping(monitor="val_loss", patience=20, restore_best_weights=True)

print(model_name)

In [None]:
callbacks = [
             loss_history(), 
             checkpointer, 
             reduce_lr, 
            #  csv_log, 
             early_stop
             ]

In [None]:
batch_size = 8
epochs = 120

# Generators
training_generator = DataGenerator(train_images, train_masks, augmentations=AUGMENTATIONS_TRAIN, batch_size=batch_size)
validation_generator = DataGenerator(val_images, val_masks, augmentations=AUGMENTATIONS_VAL, batch_size=batch_size)

history = model.fit(training_generator,
                    validation_data=validation_generator,                        
                    epochs=epochs,
                    verbose=1,
                    callbacks=callbacks)