# Train U-Net 3+ with MoNuSeg dataset

In [None]:
!nvidia-smi

In [None]:
!cat /proc/cpuinfo

In [None]:
!python --version

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd drive/MyDrive/nuclei_segmentation

In [None]:
# https://github.com/dovahcrow/patchify.py
!pip install patchify

## Make validation set

In [None]:
import os
import glob
import random
import shutil
import numpy as np

In [None]:
def create_path(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
train_dir = "dataset/monuseg/stain_normalized/train"
val_dir = "dataset/monuseg/stain_normalized/validation"

In [None]:
def make_validation_set(val_dir, train_dir, val_size=6, seed=42, fold=None):

    create_path(val_dir)
    create_path(os.path.join(val_dir, "tissue_images"))
    create_path(os.path.join(val_dir, "binary_masks"))
    create_path(os.path.join(val_dir, "instance_masks"))
    create_path(os.path.join(val_dir, "modified_masks"))
    
    for j in sorted(glob.glob(os.path.join(val_dir, "tissue_images", "*"))):
        try:
            shutil.move(j, os.path.join(train_dir, "tissue_images"))
            shutil.move(j.replace("tissue_images", "binary_masks").replace("tif", "png"), 
                        os.path.join(train_dir, "binary_masks"))
            shutil.move(j.replace("tissue_images", "instance_masks").replace("tif", "npy"), 
                        os.path.join(train_dir, "instance_masks"))
            shutil.move(j.replace("tissue_images", "modified_masks").replace("tif", "png"), 
                        os.path.join(train_dir, "modified_masks"))
        except:
            continue

    images_lst = sorted(glob.glob(os.path.join(train_dir, "tissue_images", "*")))
    np.random.seed(seed)
    np.random.shuffle(images_lst)
    if fold is None:
        random.seed(seed)
        val_lst = random.sample(images_lst, val_size)
    else:
        val_lst = images_lst[(fold*val_size)-val_size: fold*val_size]
        

    for i in val_lst:
        shutil.move(i, os.path.join(val_dir, "tissue_images"))
        shutil.move(i.replace("tissue_images", "binary_masks").replace("tif", "png"), 
                    os.path.join(val_dir, "binary_masks"))
        shutil.move(i.replace("tissue_images", "instance_masks").replace("tif", "npy"), 
                    os.path.join(val_dir, "instance_masks"))
        shutil.move(i.replace("tissue_images", "modified_masks").replace("tif", "png"), 
                    os.path.join(val_dir, "modified_masks"))
        
    print(f"Validation list: {[os.path.basename(i) for i in val_lst]}")

In [None]:
make_validation_set(val_dir, train_dir, val_size=6, fold=1)

## read the tissue images & GTs

In [None]:
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from patchify import patchify, unpatchify

In [None]:
print(cv2.__version__)

In [None]:
train_dir = "dataset/monuseg/stain_normalized/train/tissue_images"
train_maskdir = "dataset/monuseg/stain_normalized/train/binary_masks"
train_mask2dir = "dataset/monuseg/stain_normalized/train/modified_masks"

val_dir   = "dataset/monuseg/stain_normalized/validation/tissue_images"
val_maskdir = "dataset/monuseg/stain_normalized/validation/binary_masks"
val_mask2dir = "dataset/monuseg/stain_normalized/validation/modified_masks"

In [None]:
W = 1024
H = 1024

patch_size = (256, 256, 3)
all_img_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(train_dir, "*"))), total=len(os.listdir(train_dir))):
    single_img = cv2.imread(x, cv2.IMREAD_COLOR)
    single_img = cv2.cvtColor(single_img, cv2.COLOR_BGR2RGB)
    single_img = cv2.resize(single_img, (W, H), interpolation=cv2.INTER_LINEAR)
    # patchify
    single_img_patches = patchify(single_img, patch_size=patch_size, step=128)
    # squeeze
    single_img_patches = np.squeeze(single_img_patches)

    for i in range(single_img_patches.shape[0]):
        for j in range(single_img_patches.shape[1]):
            all_img_patches.append(single_img_patches[i, j])
    
train_images = np.array(all_img_patches)

In [None]:
train_images.shape

In [None]:
all_img_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(val_dir, "*"))), total=len(os.listdir(val_dir))):
    single_img = cv2.imread(x, cv2.IMREAD_COLOR)
    single_img = cv2.cvtColor(single_img, cv2.COLOR_BGR2RGB)
    single_img = cv2.resize(single_img, (W, H), interpolation=cv2.INTER_LINEAR)
    # patchify
    single_img_patches = patchify(single_img, patch_size=patch_size, step=128)
    # squeeze
    single_img_patches = np.squeeze(single_img_patches)

    for i in range(single_img_patches.shape[0]):
        for j in range(single_img_patches.shape[1]):
            all_img_patches.append(single_img_patches[i, j])

val_images = np.array(all_img_patches)

In [None]:
val_images.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(train_maskdir, "*"))), total=len(os.listdir(train_maskdir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            all_mask_patches.append(single_mask_patches[i, j])

train_masks = np.array(all_mask_patches)

In [None]:
train_masks.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(val_maskdir, "*"))), total=len(os.listdir(val_maskdir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            all_mask_patches.append(single_mask_patches[i, j])

val_masks = np.array(all_mask_patches)

In [None]:
val_masks.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(train_mask2dir, "*"))), total=len(os.listdir(train_mask2dir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            # mask_with_boarders = generate_boarder(single_mask_patches[i, j])
            all_mask_patches.append(single_mask_patches[i, j])

train_masks2 = np.array(all_mask_patches)

In [None]:
train_masks2.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(val_mask2dir, "*"))), total=len(os.listdir(val_mask2dir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            # mask_with_boarders = generate_boarder(single_mask_patches[i, j])
            all_mask_patches.append(single_mask_patches[i, j])

val_masks2 = np.array(all_mask_patches)

In [None]:
val_masks2.shape

In [None]:
# sanity check
rnd = np.random.randint(len(train_images))
# rnd = 222

fig, ax = plt.subplots(1, 3, figsize=(12, 6))
[axi.set_axis_off() for axi in ax.ravel()]

ax[0].imshow(train_images[rnd])
ax[0].set_title("Tissue Image")

ax[1].imshow(train_masks[rnd])
ax[1].set_title("Mask")

# ax[2].imshow(train_images[rnd])
ax[2].imshow(train_masks2[rnd])
ax[2].set_title("Mask2")

plt.tight_layout()
plt.show()

## One-hot encoding the modified masks

In [None]:
# train masks
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical

label_encoder = LabelEncoder()
n, h, w = train_masks.shape
train_masks_reshaped = train_masks2.reshape(-1,)
train_masks_reshaped_encoded = label_encoder.fit_transform(train_masks_reshaped)
train_masks_encoded_original_shape = train_masks_reshaped_encoded.reshape(n, h, w)

n_classes = 3
train_masks_cat = to_categorical(train_masks_encoded_original_shape, num_classes=n_classes)
train_masks_cat.shape

In [None]:
# val masks
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical

label_encoder = LabelEncoder()
n, h, w = val_masks.shape
val_masks_reshaped = val_masks2.reshape(-1,)
val_masks_reshaped_encoded = label_encoder.fit_transform(val_masks_reshaped)
val_masks_encoded_original_shape = val_masks_reshaped_encoded.reshape(n, h, w)

n_classes = 3
val_masks_cat = to_categorical(val_masks_encoded_original_shape, num_classes=n_classes)
val_masks_cat.shape

## Data augmentation using albumentations library

In [None]:
import albumentations as A
from keras.utils import Sequence

In [None]:
print(A.__version__)

In [None]:
class DataGenerator(Sequence):
    'Generates data for Keras'
    def __init__(self, images, masks, masks_cat, augmentations=None, batch_size=8, img_size=256, n_channels=3, shuffle=True):
        'Initialization'
        self.batch_size = batch_size
        
        self.images = images
        self.masks = masks
        self.masks_cat = masks_cat

        self.img_size = img_size
        
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.augment = augmentations
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.images) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indices of the batch
        indices = self.indices[index * self.batch_size: min((index + 1) * self.batch_size, len(self.images))]

        # Generate data
        X, y = self.data_generation(indices)
        y1 = y[0]
        y2 = y[1]

        if self.augment is None:
            return X, [np.array(y1), np.array(y2)]
        else:            
            im, mask1, mask2 = [], [], []   
            for x, y1, y2 in zip(X, y1, y2):
                augmented = self.augment(image=x, mask1=y1, mask2=y2)
                im.append(augmented['image'])
                mask1.append(augmented['mask1'])
                mask2.append(augmented['mask2'])

            return np.array(im), [np.array(mask1), np.array(mask1), np.array(mask1), np.array(mask1), np.array(mask1),
                                  np.array(mask2), np.array(mask2), np.array(mask2), np.array(mask2), np.array(mask2),]

    def on_epoch_end(self):
        'Updates indices after each epoch'
        self.indices = np.arange(len(self.images))
        if self.shuffle == True:
            np.random.shuffle(self.indices)

    def data_generation(self, indices):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((len(indices), self.img_size, self.img_size, self.n_channels))
        y1 = np.empty((len(indices), self.img_size, self.img_size, 1))
        y2 = np.empty((len(indices), self.img_size, self.img_size, 3)) # 3 classes (Nuclei, Border, Background)
        # Generate data
        for n, i in enumerate(indices):
            X[n] = self.images[i]
            y1[n] = (self.masks[i]/255.)[..., np.newaxis]
            y2[n] = self.masks_cat[i]

        return np.uint8(X), [np.float32(y1), np.float32(y2)]

In [None]:
AUGMENTATIONS_TRAIN = A.Compose([
    A.Rotate(limit=360, p=0.5),
    A.OneOf([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        ], p=0.5),
    A.OneOf([
        A.RandomBrightnessContrast(),
        A.RandomGamma(),
        A.GaussNoise()
         ], p=0.3),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        # A.ElasticTransform(alpha=3, sigma=150, alpha_affine=150),
        # A.ElasticTransform(),
        A.Affine(translate_percent=0.2, shear=30, mode=cv2.BORDER_CONSTANT),
        A.GridDistortion(),
        A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
        ], p=0.3),
    A.OneOf([
        A.RGBShift(r_shift_limit=40, g_shift_limit=40,  b_shift_limit=40),
        A.ColorJitter(hue=0.1),
        A.Blur(blur_limit=3)
        ], p=0.3),
    A.ToFloat(max_value=255)
], p=1,
additional_targets={'image': 'image', 'mask1': 'mask', 'mask2':'mask'})

AUGMENTATIONS_VAL = A.Compose([
    A.ToFloat(max_value=255)
], p=1,
additional_targets={'image': 'image', 'mask1': 'mask', 'mask2':'mask'})

### Testing data generator

In [None]:
# Single tissue image with 256*256 tiles (50% overlap between tiles) without augmentation
a = DataGenerator(train_images, train_masks, train_masks_cat, batch_size=49, augmentations=None, shuffle=False)
images, masks = a.__getitem__(0)

max_images = 49
grid_width = 7
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*2, grid_height*2))

for i,(im, mask1, mask2) in enumerate(zip(images, masks[0], masks[1])):
    ax = axs[int(i / grid_width), i % grid_width]
    ax.imshow(im)
    # ax.imshow(mask1.squeeze(), alpha=0.4)
    ax.imshow(mask2, alpha=0.4)
    ax.axis('off')

print(mask1.shape)
print(mask2.shape)
plt.tight_layout()
plt.show()

In [None]:
# Same tissue image with augmentations
a = DataGenerator(train_images, train_masks, train_masks_cat, batch_size=49, augmentations=AUGMENTATIONS_TRAIN, shuffle=False)
images, masks = a.__getitem__(0)

max_images = 49
grid_width = 7
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*2, grid_height*2))

for i,(im, mask1, mask2) in enumerate(zip(images, masks[0], masks[5])):
    ax = axs[int(i / grid_width), i % grid_width]
    ax.imshow(im)
    # ax.imshow(mask1.squeeze(), alpha=0.4)
    ax.imshow(mask2, alpha=0.4)
    ax.axis('off')

print(mask1.shape)
print(mask2.shape)
plt.tight_layout()
plt.show()

## Defining the model

In [None]:
from keras.models import Model
from keras.applications import VGG19
from keras.layers import Input, Activation, UpSampling2D, concatenate, MaxPool2D
from keras.layers import Conv2D, BatchNormalization
from keras.regularizers import l2
from keras.initializers import he_normal

In [None]:
# https://github.com/hamidriasat/UNet-3-Plus/blob/unet3p_lits/models/unet3plus_deep_supervision.py

def conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same',
               is_bn=True, is_relu=True, n=2):
    """ Custom function for conv2d:
        Apply  3*3 convolutions with BN and relu.
    """
    for i in range(1, n + 1):
        x = Conv2D(filters=kernels, kernel_size=kernel_size,
                   padding=padding, strides=strides,
                   kernel_regularizer=l2(1e-4),
                   kernel_initializer=he_normal(seed=5))(x)
        if is_bn:
            x = BatchNormalization()(x)
        if is_relu:
            x = Activation("relu")(x)

    return x

def unet3plus_deepsup(input_shape, deep_supervision=False):
    """ UNet_3Plus with Deep Supervision """
    # filters = [64, 128, 256, 512, 1024]
    filters = [16, 32, 64, 128, 256]

    input_layer = Input(shape=input_shape, name="input_layer")  # 320*320*3

    """ Encoder"""
    skip_connections = []

    model = VGG19(include_top=False, weights="imagenet", input_tensor=input_layer)
    names = ["block1_conv2", "block2_conv2", "block3_conv4", "block4_conv4"]
    for name in names:
        skip_connections.append(model.get_layer(name).output)

    output = model.get_layer("block5_conv4").output

    # block 5
    # bottleneck layer
    e5 = conv_block(output, filters[4])  # 20*20*1024

    """ Decoder """
    cat_channels = filters[0]
    cat_blocks = len(filters)
    upsample_channels = cat_blocks * cat_channels

    """ d4_1 """
    e1_d4_1 = MaxPool2D(pool_size=(8, 8))(skip_connections[0])  # 320*320*64  --> 40*40*64
    e1_d4_1 = conv_block(e1_d4_1, cat_channels, n=1)  # 320*320*64  --> 40*40*64

    e2_d4_1 = MaxPool2D(pool_size=(4, 4))(skip_connections[1])  # 160*160*128 --> 40*40*128
    e2_d4_1 = conv_block(e2_d4_1, cat_channels, n=1)  # 160*160*128 --> 40*40*64

    e3_d4_1 = MaxPool2D(pool_size=(2, 2))(skip_connections[2])  # 80*80*256  --> 40*40*256
    e3_d4_1 = conv_block(e3_d4_1, cat_channels, n=1)  # 80*80*256  --> 40*40*64

    e4_d4_1 = conv_block(skip_connections[3], cat_channels, n=1)  # 40*40*512  --> 40*40*64

    e5_d4_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(e5)  # 80*80*256  --> 40*40*256
    e5_d4_1 = conv_block(e5_d4_1, cat_channels, n=1)  # 20*20*1024  --> 20*20*64

    d4_1 = concatenate([e1_d4_1, e2_d4_1, e3_d4_1, e4_d4_1, e5_d4_1])
    d4_1 = conv_block(d4_1, upsample_channels, n=1)  # 40*40*320  --> 40*40*320

    """ d3_1 """
    e1_d3_1 = MaxPool2D(pool_size=(4, 4))(skip_connections[0])  # 320*320*64 --> 80*80*64
    e1_d3_1 = conv_block(e1_d3_1, cat_channels, n=1)  # 80*80*64 --> 80*80*64

    e2_d3_1 = MaxPool2D(pool_size=(2, 2))(skip_connections[1])  # 160*160*256 --> 80*80*256
    e2_d3_1 = conv_block(e2_d3_1, cat_channels, n=1)  # 80*80*256 --> 80*80*64

    e3_d3_1 = conv_block(skip_connections[2], cat_channels, n=1)  # 80*80*512 --> 80*80*64

    e4_d3_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(d4_1)  # 40*40*320 --> 80*80*320
    e4_d3_1 = conv_block(e4_d3_1, cat_channels, n=1)  # 80*80*320 --> 80*80*64

    e5_d3_1 = UpSampling2D(size=(4, 4), interpolation='bilinear')(e5)  # 20*20*320 --> 80*80*320
    e5_d3_1 = conv_block(e5_d3_1, cat_channels, n=1)  # 80*80*320 --> 80*80*64

    d3_1 = concatenate([e1_d3_1, e2_d3_1, e3_d3_1, e4_d3_1, e5_d3_1])
    d3_1 = conv_block(d3_1, upsample_channels, n=1)  # 80*80*320 --> 80*80*320

    """ d2_1 """
    e1_d2_1 = MaxPool2D(pool_size=(2, 2))(skip_connections[0])  # 320*320*64 --> 160*160*64
    e1_d2_1 = conv_block(e1_d2_1, cat_channels, n=1)  # 160*160*64 --> 160*160*64

    e2_d2_1 = conv_block(skip_connections[1], cat_channels, n=1)  # 160*160*256 --> 160*160*64

    d3_d2_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(d3_1)  # 80*80*320 --> 160*160*320
    d3_d2_1 = conv_block(d3_d2_1, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d4_d2_1 = UpSampling2D(size=(4, 4), interpolation='bilinear')(d4_1)  # 40*40*320 --> 160*160*320
    d4_d2_1 = conv_block(d4_d2_1, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    e5_d2_1 = UpSampling2D(size=(8, 8), interpolation='bilinear')(e5)  # 20*20*320 --> 160*160*320
    e5_d2_1 = conv_block(e5_d2_1, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d2_1 = concatenate([e1_d2_1, e2_d2_1, d3_d2_1, d4_d2_1, e5_d2_1])
    d2_1 = conv_block(d2_1, upsample_channels, n=1)  # 160*160*320 --> 160*160*320

    """ d1_1 """
    e1_d1_1 = conv_block(skip_connections[0], cat_channels, n=1)  # 320*320*64 --> 320*320*64

    d2_d1_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(d2_1)  # 160*160*320 --> 320*320*320
    d2_d1_1 = conv_block(d2_d1_1, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d3_d1_1 = UpSampling2D(size=(4, 4), interpolation='bilinear')(d3_1)  # 80*80*320 --> 320*320*320
    d3_d1_1 = conv_block(d3_d1_1, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    d4_d1_1 = UpSampling2D(size=(8, 8), interpolation='bilinear')(d4_1)  # 40*40*320 --> 320*320*320
    d4_d1_1 = conv_block(d4_d1_1, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    e5_d1_1 = UpSampling2D(size=(16, 16), interpolation='bilinear')(e5)  # 20*20*320 --> 320*320*320
    e5_d1_1 = conv_block(e5_d1_1, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    d1_1 = concatenate([e1_d1_1, d2_d1_1, d3_d1_1, d4_d1_1, e5_d1_1, ])
    d1_1 = conv_block(d1_1, upsample_channels, n=1)  # 320*320*320 --> 320*320*320

    # last layer does not have batch norm and relu
    d1_1 = conv_block(d1_1, 1, n=1, is_bn=False, is_relu=False)
    d1_1 = Activation("sigmoid", name="binary_final")(d1_1)

    """ d4_2 """
    e1_d4_2 = MaxPool2D(pool_size=(8, 8))(skip_connections[0])  # 320*320*64  --> 40*40*64
    e1_d4_2 = conv_block(e1_d4_2, cat_channels, n=1)  # 320*320*64  --> 40*40*64

    e2_d4_2 = MaxPool2D(pool_size=(4, 4))(skip_connections[1])  # 160*160*128 --> 40*40*128
    e2_d4_2 = conv_block(e2_d4_2, cat_channels, n=1)  # 160*160*128 --> 40*40*64

    e3_d4_2 = MaxPool2D(pool_size=(2, 2))(skip_connections[2])  # 80*80*256  --> 40*40*256
    e3_d4_2 = conv_block(e3_d4_2, cat_channels, n=1)  # 80*80*256  --> 40*40*64

    e4_d4_2 = conv_block(skip_connections[3], cat_channels, n=1)  # 40*40*512  --> 40*40*64

    e5_d4_2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(e5)  # 80*80*256  --> 40*40*256
    e5_d4_2 = conv_block(e5_d4_2, cat_channels, n=1)  # 20*20*1024  --> 20*20*64

    d4_2 = concatenate([e1_d4_2, e2_d4_2, e3_d4_2, e4_d4_2, e5_d4_2])
    d4_2 = conv_block(d4_2, upsample_channels, n=1)  # 40*40*320  --> 40*40*320

    """ d3_2 """
    e1_d3_2 = MaxPool2D(pool_size=(4, 4))(skip_connections[0])  # 320*320*64 --> 80*80*64
    e1_d3_2 = conv_block(e1_d3_2, cat_channels, n=1)  # 80*80*64 --> 80*80*64

    e2_d3_2 = MaxPool2D(pool_size=(2, 2))(skip_connections[1])  # 160*160*256 --> 80*80*256
    e2_d3_2 = conv_block(e2_d3_2, cat_channels, n=1)  # 80*80*256 --> 80*80*64

    e3_d3_2 = conv_block(skip_connections[2], cat_channels, n=1)  # 80*80*512 --> 80*80*64

    e4_d3_2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(d4_2)  # 40*40*320 --> 80*80*320
    e4_d3_2 = conv_block(e4_d3_2, cat_channels, n=1)  # 80*80*320 --> 80*80*64

    e5_d3_2 = UpSampling2D(size=(4, 4), interpolation='bilinear')(e5)  # 20*20*320 --> 80*80*320
    e5_d3_2 = conv_block(e5_d3_2, cat_channels, n=1)  # 80*80*320 --> 80*80*64

    d3_2 = concatenate([e1_d3_2, e2_d3_2, e3_d3_2, e4_d3_2, e5_d3_2])
    d3_2 = conv_block(d3_2, upsample_channels, n=1)  # 80*80*320 --> 80*80*320

    """ d2_2 """
    e1_d2_2 = MaxPool2D(pool_size=(2, 2))(skip_connections[0])  # 320*320*64 --> 160*160*64
    e1_d2_2 = conv_block(e1_d2_2, cat_channels, n=1)  # 160*160*64 --> 160*160*64

    e2_d2_2 = conv_block(skip_connections[1], cat_channels, n=1)  # 160*160*256 --> 160*160*64

    d3_d2_2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(d3_2)  # 80*80*320 --> 160*160*320
    d3_d2_2 = conv_block(d3_d2_2, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d4_d2_2 = UpSampling2D(size=(4, 4), interpolation='bilinear')(d4_2)  # 40*40*320 --> 160*160*320
    d4_d2_2 = conv_block(d4_d2_2, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    e5_d2_2 = UpSampling2D(size=(8, 8), interpolation='bilinear')(e5)  # 20*20*320 --> 160*160*320
    e5_d2_2 = conv_block(e5_d2_2, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d2_2 = concatenate([e1_d2_2, e2_d2_2, d3_d2_2, d4_d2_2, e5_d2_2])
    d2_2 = conv_block(d2_2, upsample_channels, n=1)  # 160*160*320 --> 160*160*320

    """ d1_2 """
    e1_d1_2 = conv_block(skip_connections[0], cat_channels, n=1)  # 320*320*64 --> 320*320*64

    d2_d1_2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(d2_2)  # 160*160*320 --> 320*320*320
    d2_d1_2 = conv_block(d2_d1_2, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d3_d1_2 = UpSampling2D(size=(4, 4), interpolation='bilinear')(d3_2)  # 80*80*320 --> 320*320*320
    d3_d1_2 = conv_block(d3_d1_2, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    d4_d1_2 = UpSampling2D(size=(8, 8), interpolation='bilinear')(d4_2)  # 40*40*320 --> 320*320*320
    d4_d1_2 = conv_block(d4_d1_2, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    e5_d1_2 = UpSampling2D(size=(16, 16), interpolation='bilinear')(e5)  # 20*20*320 --> 320*320*320
    e5_d1_2 = conv_block(e5_d1_2, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    d1_2 = concatenate([e1_d1_2, d2_d1_2, d3_d1_2, d4_d1_2, e5_d1_2, ])
    d1_2 = conv_block(d1_2, upsample_channels, n=1)  # 320*320*320 --> 320*320*320

    # last layer does not have batch norm and relu
    d1_2 = conv_block(d1_2, 3, n=1, is_bn=False, is_relu=False)
    d1_2 = Activation("softmax", name="multi_final")(d1_2)

    """ Deep Supervision Part"""
    if deep_supervision:
        # Binary super-vision
        d2_1 = conv_block(d2_1, 1, n=1, is_bn=False, is_relu=False)
        d3_1 = conv_block(d3_1, 1, n=1, is_bn=False, is_relu=False)
        d4_1 = conv_block(d4_1, 1, n=1, is_bn=False, is_relu=False)
        e5_1 = conv_block(e5, 1, n=1, is_bn=False, is_relu=False)

        # d1_1 = no need for up sampling
        d2_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(d2_1)
        d3_1 = UpSampling2D(size=(4, 4), interpolation='bilinear')(d3_1)
        d4_1 = UpSampling2D(size=(8, 8), interpolation='bilinear')(d4_1)
        e5_1 = UpSampling2D(size=(16, 16), interpolation='bilinear')(e5_1)

        d2_1 = Activation("sigmoid", name="binary_sup1")(d2_1)
        d3_1 = Activation("sigmoid", name="binary_sup2")(d3_1)
        d4_1 = Activation("sigmoid", name="binary_sup3")(d4_1)
        e5_1 = Activation("sigmoid", name="binary_sup4")(e5_1)

        # Multi-class super-vision
        d2_2 = conv_block(d2_2, 3, n=1, is_bn=False, is_relu=False)
        d3_2 = conv_block(d3_2, 3, n=1, is_bn=False, is_relu=False)
        d4_2 = conv_block(d4_2, 3, n=1, is_bn=False, is_relu=False)
        e5_2 = conv_block(e5, 3, n=1, is_bn=False, is_relu=False)

        # d1_2 = no need for up sampling
        d2_2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(d2_2)
        d3_2 = UpSampling2D(size=(4, 4), interpolation='bilinear')(d3_2)
        d4_2 = UpSampling2D(size=(8, 8), interpolation='bilinear')(d4_2)
        e5_2 = UpSampling2D(size=(16, 16), interpolation='bilinear')(e5_2)

        d2_2 = Activation("softmax", name="multi_sup1")(d2_2)
        d3_2 = Activation("softmax", name="multi_sup2")(d3_2)
        d4_2 = Activation("softmax", name="multi_sup3")(d4_2)
        e5_2 = Activation("softmax", name="multi_sup4")(e5_2)

    if deep_supervision:
        return Model(inputs=input_layer, outputs=[d1_1, d2_1, d3_1, d4_1, e5_1,
                                                  d1_2, d2_2, d3_2, d4_2, e5_2], name='UNet3Plus_DeepSup')
    else:
        return Model(inputs=input_layer, outputs=[d1_1, d1_2], name='UNet3Plus_DeepSup')

In [None]:
INPUT_SHAPE = [256, 256, 3]

model = unet3plus_deepsup(INPUT_SHAPE, deep_supervision=True)
model.summary()

In [None]:
from keras.utils import plot_model

In [None]:
plot_model(model, show_shapes=True, dpi=330, to_file="U-Net3+_DD.png")

### Defining loss function for binary & multi segmentation tasks

In [None]:
import keras.backend as K
from keras.optimizers import Adam

In [None]:
# https://github.com/yingkaisha/keras-unet-collection
!pip install keras_unet_collection

In [None]:
from keras_unet_collection import losses

def binary_loss(y_true, y_pred):

    loss_focal = losses.focal_tversky(y_true, y_pred, alpha=0.5, gamma=4/3)
    loss_iou = losses.iou_seg(y_true, y_pred)
    
    # (x) 
    # loss_ssim = losses.ms_ssim(y_true, y_pred, max_val=1.0, filter_size=4)
    
    return loss_focal+loss_iou #+loss_ssim

def multi_loss(y_true, y_pred):

    loss_focal = losses.focal_tversky(y_true, y_pred, alpha=0.7, gamma=4/3)
    loss_iou = losses.iou_seg(y_true, y_pred)
    
    # (x) 
    # loss_ssim = losses.ms_ssim(y_true, y_pred, max_val=1.0, filter_size=4)
    
    return loss_focal+loss_iou #+loss_ssim

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

### Compiling the model

In [None]:
model.compile(optimizer=Adam(1e-4),
              loss={
                    "binary_final":binary_loss,
                    "binary_sup1":binary_loss,
                    "binary_sup2":binary_loss,
                    "binary_sup3":binary_loss, 
                    "binary_sup4":binary_loss,
                    "multi_final":multi_loss,
                    "multi_sup1":multi_loss,
                    "multi_sup2":multi_loss,
                    "multi_sup3":multi_loss, 
                    "multi_sup4":multi_loss
                    },
              loss_weights=[1.0, 0.25, 0.25, 0.25, 0.25, 1.0, 0.25, 0.25, 0.25, 0.25],
              metrics={
                       "binary_final":dice_coef,
                       "binary_sup1":dice_coef,
                       "binary_sup2":dice_coef,
                       "binary_sup3":dice_coef, 
                       "binary_sup4":dice_coef,
                       "multi_final":dice_coef,
                       "multi_sup1":dice_coef,
                       "multi_sup2":dice_coef,
                       "multi_sup3":dice_coef, 
                       "multi_sup4":dice_coef
                       })

## Train the model

### Defining some useful callbacks 

In [None]:
from keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping, Callback, LearningRateScheduler
# from tensorflow.keras.optimizers.schedules import CosineDecay
import datetime

from skimage import filters
from scipy.ndimage import measurements
from skimage.segmentation import watershed, mark_boundaries

# Visualize training 
class loss_history(Callback):

    def __init__(self, x=4):
        self.x = x
    
    def on_epoch_begin(self, epoch, logs={}):
        fig, ax = plt.subplots(1, 5, figsize=(18, 12))
        [axi.set_axis_off() for axi in ax.ravel()]

        ax[0].imshow(train_images[self.x])
        ax[0].set_title("Tissue Image")

        ax[1].imshow(train_masks[self.x], cmap="gray")
        ax[1].set_title("Ground Truth")

        model_sample_input = train_images[self.x].astype("float32") / 255.
        pred = self.model.predict(np.expand_dims(model_sample_input, axis=0), verbose=0)
        preds_train1 = pred[0]
        preds_train2 = pred[5]
        preds1 =  np.squeeze(preds_train1[0]) >= 0.5
        ax[2].imshow(preds1, cmap="gray")
        ax[2].set_title("Nuclei prediction")
        preds2 = preds_train2[0][:, :, 2] - preds_train2[0][:, :, 1] >= 0.5
        ax[3].imshow(preds2, cmap="gray")
        ax[3].set_title("Nuclei marker prediction")
        
        grad = filters.scharr(preds1)
        marker = preds1 * preds2
        marker = measurements.label(marker)[0]
        proced_pred = watershed(grad, marker, mask=preds1)
        ax[4].imshow(mark_boundaries(train_images[self.x], proced_pred, color=(0, 0, 1)))
        ax[4].set_title("Result")

        plt.tight_layout()
        plt.show()

# lr_scheduler = CosineDecay(
#     1e-4, 50, alpha=0.0, name=None
# )

create_path("logs/Unet3+")
csv_log = CSVLogger('logs/Unet3+/Unet3+_fold1_log00.csv', separator=',')
model_name = f"Unet3+_fold1_v00_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}"
create_path("models/Unet3+")
path_to_save_model = "models/Unet3+/" + model_name + ".h5"
checkpointer = ModelCheckpoint(path_to_save_model, verbose=1, save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-7, verbose=1)
# reduce_lr = LearningRateScheduler(schedule=lr_scheduler)
early_stop   = EarlyStopping(monitor="val_loss", patience=20, restore_best_weights=False)

print(model_name)

In [None]:
callbacks = [
             loss_history(), 
             checkpointer, 
             reduce_lr, 
            #  csv_log, 
             early_stop
             ]

In [None]:
batch_size = 8
epochs = 200

# Generators
training_generator = DataGenerator(train_images, train_masks, train_masks_cat, augmentations=AUGMENTATIONS_TRAIN, batch_size=batch_size)
validation_generator = DataGenerator(val_images, val_masks, val_masks_cat, augmentations=AUGMENTATIONS_VAL, batch_size=batch_size)

history = model.fit(training_generator,
                    validation_data=validation_generator,                        
                    epochs=epochs,
                    verbose=1,
                    callbacks=callbacks)