# Train U-Net ++ with MoNuSeg dataset

In [None]:
!nvidia-smi

In [None]:
!cat /proc/cpuinfo

In [None]:
!python --version

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd drive/MyDrive/nuclei_segmentation

In [None]:
# https://github.com/dovahcrow/patchify.py
!pip install patchify

## Make validation set

In [None]:
import os
import glob
import random
import shutil
import numpy as np

In [None]:
def create_path(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
train_dir = "dataset/monuseg/stain_normalized/train"
val_dir = "dataset/monuseg/stain_normalized/validation"

In [None]:
def make_validation_set(val_dir, train_dir, val_size=6, seed=42, fold=None):

    create_path(val_dir)
    create_path(os.path.join(val_dir, "tissue_images"))
    create_path(os.path.join(val_dir, "binary_masks"))
    create_path(os.path.join(val_dir, "instance_masks"))
    create_path(os.path.join(val_dir, "modified_masks"))
    
    for j in sorted(glob.glob(os.path.join(val_dir, "tissue_images", "*"))):
        try:
            shutil.move(j, os.path.join(train_dir, "tissue_images"))
            shutil.move(j.replace("tissue_images", "binary_masks").replace("tif", "png"), 
                        os.path.join(train_dir, "binary_masks"))
            shutil.move(j.replace("tissue_images", "instance_masks").replace("tif", "npy"), 
                        os.path.join(train_dir, "instance_masks"))
            shutil.move(j.replace("tissue_images", "modified_masks").replace("tif", "png"), 
                        os.path.join(train_dir, "modified_masks"))
        except:
            continue

    images_lst = sorted(glob.glob(os.path.join(train_dir, "tissue_images", "*")))
    np.random.seed(seed)
    np.random.shuffle(images_lst)
    if fold is None:
        random.seed(seed)
        val_lst = random.sample(images_lst, val_size)
    else:
        val_lst = images_lst[(fold*val_size)-val_size: fold*val_size]
        

    for i in val_lst:
        shutil.move(i, os.path.join(val_dir, "tissue_images"))
        shutil.move(i.replace("tissue_images", "binary_masks").replace("tif", "png"), 
                    os.path.join(val_dir, "binary_masks"))
        shutil.move(i.replace("tissue_images", "instance_masks").replace("tif", "npy"), 
                    os.path.join(val_dir, "instance_masks"))
        shutil.move(i.replace("tissue_images", "modified_masks").replace("tif", "png"), 
                    os.path.join(val_dir, "modified_masks"))
        
    print(f"Validation list: {[os.path.basename(i) for i in val_lst]}")

In [None]:
make_validation_set(val_dir, train_dir, val_size=6, fold=1)

## read the tissue images & GTs

In [None]:
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from patchify import patchify, unpatchify

In [None]:
print(cv2.__version__)

In [None]:
train_dir = "dataset/monuseg/stain_normalized/train/tissue_images"
train_maskdir = "dataset/monuseg/stain_normalized/train/binary_masks"
train_mask2dir = "dataset/monuseg/stain_normalized/train/modified_masks"

val_dir   = "dataset/monuseg/stain_normalized/validation/tissue_images"
val_maskdir = "dataset/monuseg/stain_normalized/validation/binary_masks"
val_mask2dir = "dataset/monuseg/stain_normalized/validation/modified_masks"

In [None]:
W = 1024
H = 1024

patch_size = (256, 256, 3)
all_img_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(train_dir, "*"))), total=len(os.listdir(train_dir))):
    single_img = cv2.imread(x, cv2.IMREAD_COLOR)
    single_img = cv2.cvtColor(single_img, cv2.COLOR_BGR2RGB)
    single_img = cv2.resize(single_img, (W, H), interpolation=cv2.INTER_LINEAR)
    # patchify
    single_img_patches = patchify(single_img, patch_size=patch_size, step=128)
    # squeeze
    single_img_patches = np.squeeze(single_img_patches)

    for i in range(single_img_patches.shape[0]):
        for j in range(single_img_patches.shape[1]):
            all_img_patches.append(single_img_patches[i, j])
    
train_images = np.array(all_img_patches)

In [None]:
train_images.shape

In [None]:
all_img_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(val_dir, "*"))), total=len(os.listdir(val_dir))):
    single_img = cv2.imread(x, cv2.IMREAD_COLOR)
    single_img = cv2.cvtColor(single_img, cv2.COLOR_BGR2RGB)
    single_img = cv2.resize(single_img, (W, H), interpolation=cv2.INTER_LINEAR)
    # patchify
    single_img_patches = patchify(single_img, patch_size=patch_size, step=128)
    # squeeze
    single_img_patches = np.squeeze(single_img_patches)

    for i in range(single_img_patches.shape[0]):
        for j in range(single_img_patches.shape[1]):
            all_img_patches.append(single_img_patches[i, j])

val_images = np.array(all_img_patches)

In [None]:
val_images.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(train_maskdir, "*"))), total=len(os.listdir(train_maskdir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            all_mask_patches.append(single_mask_patches[i, j])

train_masks = np.array(all_mask_patches)

In [None]:
train_masks.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(val_maskdir, "*"))), total=len(os.listdir(val_maskdir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            all_mask_patches.append(single_mask_patches[i, j])

val_masks = np.array(all_mask_patches)

In [None]:
val_masks.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(train_mask2dir, "*"))), total=len(os.listdir(train_mask2dir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            # mask_with_boarders = generate_boarder(single_mask_patches[i, j])
            all_mask_patches.append(single_mask_patches[i, j])

train_masks2 = np.array(all_mask_patches)

In [None]:
train_masks2.shape

In [None]:
all_mask_patches = []
for x in tqdm(sorted(glob.glob(os.path.join(val_mask2dir, "*"))), total=len(os.listdir(val_mask2dir))):
    single_mask = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
    single_mask = cv2.resize(single_mask, (W, H), interpolation=cv2.INTER_NEAREST)
    # patchify
    single_mask_patches = patchify(single_mask, patch_size=(256, 256), step=128)
    # squeeze
    single_mask_patches = np.squeeze(single_mask_patches)

    for i in range(single_mask_patches.shape[0]):
        for j in range(single_mask_patches.shape[1]):
            # mask_with_boarders = generate_boarder(single_mask_patches[i, j])
            all_mask_patches.append(single_mask_patches[i, j])

val_masks2 = np.array(all_mask_patches)

In [None]:
val_masks2.shape

In [None]:
# sanity check
rnd = np.random.randint(len(train_images))
# rnd = 222

fig, ax = plt.subplots(1, 3, figsize=(12, 6))
[axi.set_axis_off() for axi in ax.ravel()]

ax[0].imshow(train_images[rnd])
ax[0].set_title("Tissue Image")

ax[1].imshow(train_masks[rnd])
ax[1].set_title("Mask")

# ax[2].imshow(train_images[rnd])
ax[2].imshow(train_masks2[rnd])
ax[2].set_title("Mask2")

plt.tight_layout()
plt.show()

## One-hot encoding the modified masks

In [None]:
# train masks
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical

label_encoder = LabelEncoder()
n, h, w = train_masks.shape
train_masks_reshaped = train_masks2.reshape(-1,)
train_masks_reshaped_encoded = label_encoder.fit_transform(train_masks_reshaped)
train_masks_encoded_original_shape = train_masks_reshaped_encoded.reshape(n, h, w)

n_classes = 3
train_masks_cat = to_categorical(train_masks_encoded_original_shape, num_classes=n_classes)
train_masks_cat.shape

In [None]:
# val masks
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical

label_encoder = LabelEncoder()
n, h, w = val_masks.shape
val_masks_reshaped = val_masks2.reshape(-1,)
val_masks_reshaped_encoded = label_encoder.fit_transform(val_masks_reshaped)
val_masks_encoded_original_shape = val_masks_reshaped_encoded.reshape(n, h, w)

n_classes = 3
val_masks_cat = to_categorical(val_masks_encoded_original_shape, num_classes=n_classes)
val_masks_cat.shape

## Data augmentation using albumentations library

In [None]:
import albumentations as A
from keras.utils import Sequence

In [None]:
print(A.__version__)

In [None]:
class DataGenerator(Sequence):
    'Generates data for Keras'
    def __init__(self, images, masks, masks_cat, augmentations=None, batch_size=8, img_size=256, n_channels=3, shuffle=True):
        'Initialization'
        self.batch_size = batch_size
        
        self.images = images
        self.masks = masks
        self.masks_cat = masks_cat

        self.img_size = img_size
        
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.augment = augmentations
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.images) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indices of the batch
        indices = self.indices[index * self.batch_size: min((index + 1) * self.batch_size, len(self.images))]

        # Generate data
        X, y = self.data_generation(indices)
        y1 = y[0]
        y2 = y[1]

        if self.augment is None:
            return X, [np.array(y1), np.array(y2)]
        else:            
            im, mask1, mask2 = [], [], []   
            for x, y1, y2 in zip(X, y1, y2):
                augmented = self.augment(image=x, mask1=y1, mask2=y2)
                im.append(augmented['image'])
                mask1.append(augmented['mask1'])
                mask2.append(augmented['mask2'])

            return np.array(im), [np.array(mask1), np.array(mask1), np.array(mask1), np.array(mask1),
                                  np.array(mask2), np.array(mask2), np.array(mask2), np.array(mask2)]

    def on_epoch_end(self):
        'Updates indices after each epoch'
        self.indices = np.arange(len(self.images))
        if self.shuffle == True:
            np.random.shuffle(self.indices)

    def data_generation(self, indices):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((len(indices), self.img_size, self.img_size, self.n_channels))
        y1 = np.empty((len(indices), self.img_size, self.img_size, 1))
        y2 = np.empty((len(indices), self.img_size, self.img_size, 3)) # 3 classes (Nuclei, Border, Background)
        # Generate data
        for n, i in enumerate(indices):
            X[n] = self.images[i]
            y1[n] = (self.masks[i]/255.)[..., np.newaxis]
            y2[n] = self.masks_cat[i]

        return np.uint8(X), [np.float32(y1), np.float32(y2)]

In [None]:
AUGMENTATIONS_TRAIN = A.Compose([
    A.Rotate(limit=360, p=0.5),
    A.OneOf([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        ], p=0.5),
    A.OneOf([
        A.RandomBrightnessContrast(),
        A.RandomGamma(),
        A.GaussNoise()
         ], p=0.3),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        # A.ElasticTransform(alpha=3, sigma=150, alpha_affine=150),
        # A.ElasticTransform(),
        A.Affine(translate_percent=0.2, shear=30, mode=cv2.BORDER_CONSTANT),
        A.GridDistortion(),
        A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
        ], p=0.3),
    A.OneOf([
        A.RGBShift(r_shift_limit=40, g_shift_limit=40,  b_shift_limit=40),
        A.ColorJitter(hue=0.1),
        A.Blur(blur_limit=3)
        ], p=0.3),
    A.ToFloat(max_value=255)
], p=1,
additional_targets={'image': 'image', 'mask1': 'mask', 'mask2':'mask'})

AUGMENTATIONS_VAL = A.Compose([
    A.ToFloat(max_value=255)
], p=1,
additional_targets={'image': 'image', 'mask1': 'mask', 'mask2':'mask'})

### Testing data generator

In [None]:
# Single tissue image with 256*256 tiles (50% overlap between tiles) without augmentation
a = DataGenerator(train_images, train_masks, train_masks_cat, batch_size=49, augmentations=AUGMENTATIONS_TRAIN, shuffle=False)
images, masks = a.__getitem__(0)

max_images = 49
grid_width = 7
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*2, grid_height*2))

for i,(im, mask1, mask2) in enumerate(zip(images, masks[0], masks[4])):
    ax = axs[int(i / grid_width), i % grid_width]
    ax.imshow(im)
    # ax.imshow(mask1.squeeze(), alpha=0.4)
    ax.imshow(mask2, alpha=0.4)
    ax.axis('off')

print(mask1.shape)
print(mask2.shape)
plt.tight_layout()
plt.show()

## Defining the model

In [None]:
# import tensorflow as tf
# import keras
# print(tf.__version__)
# print(keras.__version__)

In [None]:
# https://github.com/yingkaisha/keras-unet-collection
!pip install keras_unet_collection

In [None]:
# -*- coding: utf-8 -*-
"""unet_plusplus_2d.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1khOW6qBIJD-HY7LDQxJScICMMDzFbCGG
"""

from keras_unet_collection.layer_utils import *
from keras_unet_collection.activations import GELU, Snake
from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker
from keras_unet_collection._model_unet_2d import UNET_left, UNET_right

from keras.layers import Input
from keras.models import Model

import warnings

def unet_plus_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
                      activation='ReLU', batch_norm=False, pool=True, unpool=True, deep_supervision=False, 
                      backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet'):
    '''
    The base of U-net++ with an optional ImageNet-trained backbone
    
    unet_plus_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
                      activation='ReLU', batch_norm=False, pool=True, unpool=True, deep_supervision=False, 
                      backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet')
    
    ----------
    Zhou, Z., Siddiquee, M.M.R., Tajbakhsh, N. and Liang, J., 2018. Unet++: A nested u-net architecture 
    for medical image segmentation. In Deep Learning in Medical Image Analysis and Multimodal Learning 
    for Clinical Decision Support (pp. 3-11). Springer, Cham.
    
    Input
    ----------
        input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
        filter_num: a list that defines the number of filters for each \
                    down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
                    The depth is expected as `len(filter_num)`.
        stack_num_down: number of convolutional layers per downsampling level/block. 
        stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
        activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
        batch_norm: True for batch normalization.
        pool: True or 'max' for MaxPooling2D.
              'ave' for AveragePooling2D.
              False for strided conv + batch norm + activation.
        unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
                'nearest' for Upsampling2D with nearest interpolation.
                False for Conv2DTranspose + batch norm + activation.   
        deep_supervision: True for a model that supports deep supervision. Details see Zhou et al. (2018).
        name: prefix of the created keras model and its layers.
        
        ---------- (keywords of backbone options) ----------
        backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
                       None (default) means no backbone. 
                       Currently supported backbones are:
                       (1) VGG16, VGG19
                       (2) ResNet50, ResNet101, ResNet152
                       (3) ResNet50V2, ResNet101V2, ResNet152V2
                       (4) DenseNet121, DenseNet169, DenseNet201
                       (5) EfficientNetB[0-7]
        weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 
                 or the path to the weights file to be loaded.
        freeze_backbone: True for a frozen backbone.
        freeze_batch_norm: False for not freezing batch normalization layers.
        
    Output
    ----------
        If deep_supervision = False; Then the output is a tensor.
        If deep_supervision = True; Then the output is a list of tensors
            with the first tensor obtained from the first downsampling level (for checking the input/output shapes only),
            the second to the `depth-1`-th tensors obtained from each intermediate upsampling levels (deep supervision tensors),
            and the last tensor obtained from the end of the base.
    
    '''
    
    activation_func = eval(activation)

    depth_ = len(filter_num)
    # allocate nested lists for collecting output tensors 
    X_nest_skip_1 = [[] for _ in range(depth_)]
    X_nest_skip_2 = [[] for _ in range(depth_)]

    # no backbone cases
    if backbone is None:

        X = input_tensor

        # downsampling blocks (same as in 'unet_2d')
        X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation, 
                       batch_norm=batch_norm, name='{}_down0'.format(name))
        X_nest_skip_1[0].append(X)
        X_nest_skip_2[0].append(X)
        for i, f in enumerate(filter_num[1:]):
            X = UNET_left(X, f, stack_num=stack_num_down, activation=activation, 
                          pool=pool, batch_norm=batch_norm, name='{}_down{}'.format(name, i+1))        
            X_nest_skip_1[0].append(X)
            X_nest_skip_2[0].append(X)

    # backbone cases
    else:        
        # handling VGG16 and VGG19 separately
        if 'VGG' in backbone:
            backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm)
            # collecting backbone feature maps
            X_nest_skip_1[0] += backbone_([input_tensor,])
            X_nest_skip_2[0] += backbone_([input_tensor,])
            depth_encode = len(X_nest_skip_1[0])
            depth_encode = len(X_nest_skip_2[0])

        # for other backbones
        else:
            backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm)
            # collecting backbone feature maps
            X_nest_skip_1[0] += backbone_([input_tensor,])
            X_nest_skip_2[0] += backbone_([input_tensor,])
            depth_encode = len(X_nest_skip_1[0]) + 1
            depth_encode = len(X_nest_skip_2[0]) + 1

        # extra conv2d blocks are applied
        # if downsampling levels of a backbone < user-specified downsampling levels
        if depth_encode < depth_:

            # begins at the deepest available tensor  
            X = X_nest_skip_1[0][-1]
            X = X_nest_skip_2[0][-1]

            # extra downsamplings
            for i in range(depth_-depth_encode):
                i_real = i + depth_encode

                X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool, 
                              batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1))
                X_nest_skip_1[0].append(X)
                X_nest_skip_2[0].append(X)

    for nest_lev in range(1, depth_):

        # depth difference between the deepest nest skip and the current upsampling  
        depth_lev = depth_-nest_lev

        # number of available encoded tensors
        depth_decode = len(X_nest_skip_1[nest_lev-1])

        # loop over individual upsamling levels
        for i in range(1, depth_decode):

            # collecting previous downsampling outputs
            previous_skip = []
            for previous_lev in range(nest_lev):
                previous_skip.append(X_nest_skip_1[previous_lev][i-1])

            # upsamping block that concatenates all available (same feature map size) down-/upsampling outputs
            X_nest_skip_1[nest_lev].append(
                UNET_right(X_nest_skip_1[nest_lev-1][i], previous_skip, filter_num[i-1], 
                           stack_num=stack_num_up, activation=activation, unpool=unpool, 
                           batch_norm=batch_norm, concat=True, name='{}_up{}_1_from{}'.format(name, nest_lev-1, i-1)))

        if depth_decode < depth_lev+1:

            X = X_nest_skip_1[nest_lev-1][-1]

            for j in range(depth_lev-depth_decode+1):
                j_real = j + depth_decode
                X = UNET_right(X, None, filter_num[j_real-1], 
                               stack_num=stack_num_up, activation=activation, unpool=unpool, 
                               batch_norm=batch_norm, concat=True, name='{}_up{}_from{}'.format(name, nest_lev-1, j_real-1))
                X_nest_skip_1[nest_lev].append(X)

    for nest_lev in range(1, depth_):

        # depth difference between the deepest nest skip and the current upsampling  
        depth_lev = depth_-nest_lev

        # number of available encoded tensors
        depth_decode = len(X_nest_skip_2[nest_lev-1])

        # loop over individual upsamling levels
        for i in range(1, depth_decode):

            # collecting previous downsampling outputs
            previous_skip = []
            for previous_lev in range(nest_lev):
                previous_skip.append(X_nest_skip_2[previous_lev][i-1])

            # upsamping block that concatenates all available (same feature map size) down-/upsampling outputs
            X_nest_skip_2[nest_lev].append(
                UNET_right(X_nest_skip_2[nest_lev-1][i], previous_skip, filter_num[i-1], 
                           stack_num=stack_num_up, activation=activation, unpool=unpool, 
                           batch_norm=batch_norm, concat=True, name='{}_up{}_2_from{}'.format(name, nest_lev-1, i-1)))

        if depth_decode < depth_lev+1:

            X = X_nest_skip_2[nest_lev-1][-1]

            for j in range(depth_lev-depth_decode+1):
                j_real = j + depth_decode
                X = UNET_right(X, None, filter_num[j_real-1], 
                               stack_num=stack_num_up, activation=activation, unpool=unpool, 
                               batch_norm=batch_norm, concat=True, name='{}_up{}_2_from{}'.format(name, nest_lev-1, j_real-1))
                X_nest_skip_2[nest_lev].append(X)
            
    # output
    if deep_supervision:
        
        X_list_1 = []
        X_list_2 = []
        
        for i in range(depth_):
            X_list_1.append(X_nest_skip_1[i][0])
            X_list_2.append(X_nest_skip_2[i][0])
        
        return X_list_1, X_list_2
        
    else:
        return X_nest_skip_1[-1][0], X_nest_skip_2[-1][0]

def unet_plus_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2,
                 activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, deep_supervision=False, 
                 backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet'):
    '''
    U-net++ with an optional ImageNet-trained backbone.
    
    unet_plus_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2,
                 activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, deep_supervision=False, 
                 backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet')
    
    ----------
    Zhou, Z., Siddiquee, M.M.R., Tajbakhsh, N. and Liang, J., 2018. Unet++: A nested u-net architecture 
    for medical image segmentation. In Deep Learning in Medical Image Analysis and Multimodal Learning 
    for Clinical Decision Support (pp. 3-11). Springer, Cham.
    
    Input
    ----------
        input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
        filter_num: a list that defines the number of filters for each \
                    down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
                    The depth is expected as `len(filter_num)`.
        n_labels: number of output labels.
        stack_num_down: number of convolutional layers per downsampling level/block. 
        stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
        activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
        output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
                           Default option is 'Softmax'.
                           if None is received, then linear activation is applied.
        batch_norm: True for batch normalization.
        pool: True or 'max' for MaxPooling2D.
              'ave' for AveragePooling2D.
              False for strided conv + batch norm + activation.
        unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
                'nearest' for Upsampling2D with nearest interpolation.
                False for Conv2DTranspose + batch norm + activation.   
        deep_supervision: True for a model that supports deep supervision. Details see Zhou et al. (2018).
        name: prefix of the created keras model and its layers.
        
        ---------- (keywords of backbone options) ----------
        backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
                       None (default) means no backbone. 
                       Currently supported backbones are:
                       (1) VGG16, VGG19
                       (2) ResNet50, ResNet101, ResNet152
                       (3) ResNet50V2, ResNet101V2, ResNet152V2
                       (4) DenseNet121, DenseNet169, DenseNet201
                       (5) EfficientNetB[0-7]
        weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 
                 or the path to the weights file to be loaded.
        freeze_backbone: True for a frozen backbone.
        freeze_batch_norm: False for not freezing batch normalization layers.
        
    Output
    ----------
        model: a keras model.
    
    '''
    
    depth_ = len(filter_num)
    
    if backbone is not None:
        bach_norm_checker(backbone, batch_norm)
    
    IN = Input(input_size)
    # base
    X_1, X_2 = unet_plus_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up,
                          activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, deep_supervision=deep_supervision, 
                          backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, freeze_batch_norm=freeze_batch_norm, name=name)
    
    # output
    if deep_supervision:
        
        if (backbone is not None) and freeze_backbone:
            backbone_warn = '\n\nThe shallowest U-net++ deep supervision branch directly connects to a frozen backbone.\nTesting your configurations on `keras_unet_collection.base.unet_plus_2d_base` is recommended.'
            warnings.warn(backbone_warn);
            
        # model base returns a list of tensors
        X_list_1, X_list_2 = X_1, X_2
        OUT_list = []

        
        print('----------\ndeep_supervision = True\nnames of output tensors are listed as follows ("sup0" is the shallowest supervision layer;\n"final" is the final output layer):\n')
        
        # no backbone or VGG backbones
        # depth_ > 2 is expected (a least two downsampling blocks)
        if (backbone is None) or 'VGG' in backbone:
        
            for i in range(1, depth_-1):
                if output_activation is None:
                    print('\t{}_1_output_sup{}'.format(name, i))
                else:
                    print('\t{}_1_output_sup{}_activation'.format(name, i))
                    
                OUT_list.append(CONV_output(X_list_1[i], 1, kernel_size=1, activation="Sigmoid", 
                                            name='{}_1_output_sup{}'.format(name, i)))
        # other backbones        
        else:
            for i in range(1, depth_-1):
                if output_activation is None:
                    print('\t{}_output_sup{}'.format(name, i-1))
                else:
                    print('\t{}_output_sup{}_activation'.format(name, i-1))
                
                # an extra upsampling for creating full resolution feature maps
                X = decode_layer(X_list_1[i], filter_num[i], 2, unpool, activation=activation, 
                                 batch_norm=batch_norm, name='{}_sup{}_up'.format(name, i-1))
                
                X = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output_sup{}'.format(name, i-1))
                OUT_list.append(X)
                
        if output_activation is None:
            print('\t{}_output_final'.format(name))
        else:
            print('\t{}_1_output_final_activation'.format(name))
            
        OUT_list.append(CONV_output(X_list_1[-1], 1, kernel_size=1, activation="Sigmoid", name='{}_1_output_final'.format(name)))

        # no backbone or VGG backbones
        # depth_ > 2 is expected (a least two downsampling blocks)
        if (backbone is None) or 'VGG' in backbone:
        
            for i in range(1, depth_-1):
                if output_activation is None:
                    print('\t{}_2_output_sup{}'.format(name, i))
                else:
                    print('\t{}_2_output_sup{}_activation'.format(name, i))
                    
                OUT_list.append(CONV_output(X_list_2[i], 3, kernel_size=1, activation="Softmax", 
                                            name='{}_2_output_sup{}'.format(name, i)))
        # other backbones        
        else:
            for i in range(1, depth_-1):
                if output_activation is None:
                    print('\t{}_output_sup{}'.format(name, i-1))
                else:
                    print('\t{}_output_sup{}_activation'.format(name, i-1))
                
                # an extra upsampling for creating full resolution feature maps
                X = decode_layer(X_list_2[i], filter_num[i], 2, unpool, activation=activation, 
                                 batch_norm=batch_norm, name='{}_sup{}_up'.format(name, i-1))
                
                X = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output_sup{}'.format(name, i-1))
                OUT_list.append(X)
                
        if output_activation is None:
            print('\t{}_output_final'.format(name))
        else:
            print('\t{}_2_output_final_activation'.format(name))
            
        OUT_list.append(CONV_output(X_list_2[-1], 3, kernel_size=1, activation="Softmax", name='{}_2_output_final'.format(name)))
        
    else:
        OUT_1 = CONV_output(X_1, 1, kernel_size=1, activation='Sigmoid', name='Binary_segmentation')
        OUT_2 = CONV_output(X_2, 3, kernel_size=1, activation='Softmax', name='MultiClass_segmentation')
        OUT_list = [OUT_1, OUT_2]
        
    # model
    model = Model(inputs=[IN,], outputs=OUT_list, name='{}_model'.format(name))
    
    return model

In [None]:
model = unet_plus_2d(input_size=(256, 256, 3),
                     filter_num=[16, 32, 64, 128, 256],
                     n_labels=1,
                     batch_norm=True,
                     deep_supervision=True,
                     backbone="VGG19",
                     freeze_backbone=False,
                     freeze_batch_norm=False)

In [None]:
model.summary()

In [None]:
from keras.utils import plot_model

In [None]:
plot_model(model, show_shapes=True, dpi=330, to_file="U-Net++_DD.png")

### Defining loss function for binary & multi segmentation tasks

In [None]:
import keras.backend as K
from keras.optimizers import Adam

In [None]:
from keras_unet_collection import losses

def binary_loss(y_true, y_pred):

    loss_focal = losses.focal_tversky(y_true, y_pred, alpha=0.5, gamma=4/3)
    loss_iou = losses.iou_seg(y_true, y_pred)
    
    # (x) 
    # loss_ssim = losses.ms_ssim(y_true, y_pred, max_val=1.0, filter_size=4)
    
    return loss_focal+loss_iou #+loss_ssim

def multi_loss(y_true, y_pred):

    loss_focal = losses.focal_tversky(y_true, y_pred, alpha=0.7, gamma=4/3)
    loss_iou = losses.iou_seg(y_true, y_pred)
    
    # (x) 
    # loss_ssim = losses.ms_ssim(y_true, y_pred, max_val=1.0, filter_size=4)
    
    return loss_focal+loss_iou #+loss_ssim

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

### Compiling the model

In [None]:
model.compile(optimizer=Adam(1e-4),
              loss={
                     "xnet_1_output_sup1_activation":binary_loss,
	                 "xnet_1_output_sup2_activation":binary_loss,
	                 "xnet_1_output_sup3_activation":binary_loss,
	                 "xnet_1_output_final_activation":binary_loss,
	                 "xnet_2_output_sup1_activation":multi_loss,
	                 "xnet_2_output_sup2_activation":multi_loss,
	                 "xnet_2_output_sup3_activation":multi_loss,
	                 "xnet_2_output_final_activation":multi_loss
                     },
              loss_weights=[0.25, 0.35, 0.45, 1.0, 0.25, 0.35, 0.45, 1.0],
              metrics={
                     "xnet_1_output_sup1_activation":dice_coef,
	                 "xnet_1_output_sup2_activation":dice_coef,
	                 "xnet_1_output_sup3_activation":dice_coef,
	                 "xnet_1_output_final_activation":dice_coef,
	                 "xnet_2_output_sup1_activation":dice_coef,
	                 "xnet_2_output_sup2_activation":dice_coef,
	                 "xnet_2_output_sup3_activation":dice_coef,
	                 "xnet_2_output_final_activation":dice_coef
                       })

## Train the model

### Defining some useful callbacks 

In [None]:
from keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping, Callback, LearningRateScheduler
from tensorflow.keras.optimizers.schedules import CosineDecay
import datetime

from skimage import filters
from scipy.ndimage import measurements
from skimage.segmentation import watershed, mark_boundaries

# Visualize training 
class loss_history(Callback):

    def __init__(self, x=4):
        self.x = x
    
    def on_epoch_begin(self, epoch, logs={}):
        fig, ax = plt.subplots(1, 5, figsize=(18, 12))
        [axi.set_axis_off() for axi in ax.ravel()]

        ax[0].imshow(train_images[self.x])
        ax[0].set_title("Tissue Image")

        ax[1].imshow(train_masks[self.x], cmap="gray")
        ax[1].set_title("Ground Truth")

        model_sample_input = train_images[self.x].astype("float32") / 255.
        pred = self.model.predict(np.expand_dims(model_sample_input, axis=0), verbose=0)
        preds_train1 = (pred[0] + pred[1] + pred[2] + pred[3]) / 4
        preds_train2 = (pred[4] + pred[5] + pred[6] + pred[7]) / 4

        preds1 =  np.squeeze(preds_train1[0]) >= 0.5
        ax[2].imshow(preds1, cmap="gray")
        ax[2].set_title("Nuclei prediction")
        preds2 = preds_train2[0][:, :, 2] - preds_train2[0][:, :, 1] >= 0.5
        ax[3].imshow(preds2, cmap="gray")
        ax[3].set_title("Nuclei marker prediction")
        
        grad = filters.scharr(preds1)
        marker = preds1 * preds2
        marker = measurements.label(marker)[0]
        proced_pred = watershed(grad, marker, mask=preds1)
        ax[4].imshow(mark_boundaries(train_images[self.x], proced_pred, color=(0, 0, 1)))
        ax[4].set_title("Result")

        plt.tight_layout()
        plt.show()

# lr_scheduler = CosineDecay(
#     1e-4, 50, alpha=0.0, name=None
# )

create_path("logs/Unet++")
csv_log = CSVLogger('logs/Unet++/Unet++_fold1_log00.csv', separator=',')
model_name = f"Unet++_fold1_v00_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}"
create_path("models/Unet++")
path_to_save_model = "models/Unet++/" + model_name + ".h5"
checkpointer = ModelCheckpoint(path_to_save_model, verbose=1, save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-7, verbose=1)
# reduce_lr = LearningRateScheduler(schedule=lr_scheduler)
early_stop   = EarlyStopping(monitor="val_loss", patience=20, restore_best_weights=False)

print(model_name)

In [None]:
callbacks = [
             loss_history(), 
             checkpointer, 
             reduce_lr, 
            #  csv_log, 
             early_stop
             ]

In [None]:
batch_size = 8
epochs = 200

# Generators
training_generator = DataGenerator(train_images, train_masks, train_masks_cat, augmentations=AUGMENTATIONS_TRAIN, batch_size=batch_size)
validation_generator = DataGenerator(val_images, val_masks, val_masks_cat, augmentations=AUGMENTATIONS_VAL, batch_size=batch_size)

history = model.fit(training_generator,
                    validation_data=validation_generator,                        
                    epochs=epochs,
                    verbose=1,
                    callbacks=callbacks)