# Initialize

## libraries

In [None]:
import os
import glob

import pandas as pd
import pickle as pkl
import numpy as np
import datetime as dt

import matplotlib.pyplot as plt
import matplotlib as mpl

from tqdm import tqdm
from time import sleep

import albumentations as A
from skimage.io import imread

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau
from tensorflow.keras.layers import Conv2D, MaxPooling2D, concatenate, Conv2DTranspose, Concatenate, Multiply, Resizing, DepthwiseConv2D, Lambda, Softmax
from tensorflow.keras.layers import BatchNormalization, Input, Activation, Add, GlobalAveragePooling2D, Reshape, Dense, multiply, Permute, maximum, Layer
from keras.initializers import he_normal
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow import keras
from keras.optimizers import Adam

## paths

In [None]:
ROOT = '/content/drive/MyDrive/xBD_dataset'

test = 'xView2_test/test/'
train = 'xView2_train/train/'
hold = 'xView2_hold/hold/'
tier3 = 'xView2_tier3/tier3/'

SUB = ['images', 'labels', 'targets']

## database

In [None]:
with open(f'{ROOT}/All_Data_Props.pkl', 'rb') as f:
    db = pkl.load(f)

## data generator

### Generator for localization
if LOCALIZATION is going to be done, use this data generator

In [None]:
class xBD_DataGenerator_(keras.utils.Sequence):
    def __init__(self, path_to_jsons, batch_size=5, patch_size=256, shuffle=True):

        # Initialization -->
        self.path_to_jsons = path_to_jsons
        self.pre_img_paths = []
        self.loc_target_paths = []

        for i in range(len(self.path_to_jsons)):
            # pre-disaster images
            self.pre_img_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/images/').replace('_post_', '_pre_') + '.png')
            if 'tier3' not in self.path_to_jsons[i]:
              # localization targets
              self.loc_target_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/targets/').replace('_post_', '_pre_') + '_target.png')
            else:
              # localization targets
              self.loc_target_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/targets/').replace('_post_', '_pre_') + '.png')

        self.batch_size = batch_size
        self.patch_size = patch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.pre_img_paths))
        if self.shuffle:
            np.random.shuffle(self.indices)

        self.on_epoch_end()

    def __len__(self):
        # Denotes the number of batches per epoch (steps per epoch) [N_batch]
        return int(len(self.pre_img_paths) // self.batch_size)

    def __getitem__(self, idx):
        # Returns image indices of a batch
        # Generate one batch of data
        batch_indices = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
        pre_image_batch = list(np.array(self.pre_img_paths)[batch_indices])
        loc_target_batch = list(np.array(self.loc_target_paths)[batch_indices])

        # Generates data containing batch_size samples # X : (n_samples, *dim, n_channels)
        # Initialization
        x_pre_batch = np.empty((self.batch_size,) + (self.patch_size, self.patch_size) + (3,), dtype='float32')
        y_target_batch = np.empty((self.batch_size,) + (self.patch_size, self.patch_size) + (1,), dtype='float32')
        # Generate data
        for b in range(self.batch_size):
            pre_image = imread(pre_image_batch[b])
            target_image = imread(loc_target_batch[b])

            # Perform a random cropping and put images into batch array.
            r_start, c_start = np.random.randint(low=0, high=(pre_image.shape[0] - self.patch_size), size=2)
            x_pre_batch[b] = pre_image[r_start:r_start+self.patch_size, c_start:c_start+self.patch_size, :]
            y_target_batch[b, :, :, 0] = target_image[r_start:r_start+self.patch_size, c_start:c_start+self.patch_size]
            # y_target_batch[b, :, :, 1] = 1 - target_image[r_start:r_start+self.patch_size, c_start:c_start+self.patch_size]

        return x_pre_batch, y_target_batch

    def on_epoch_end(self):
        # Updates indices after each epoch
        self.indices = np.arange(len(self.pre_img_paths))
        if self.shuffle:
            np.random.shuffle(self.indices)

### Generator for classification
If CLASSIFICATION is gonna be done, use this one. It returns both PRE and POST images; also, the MASK is in a one-hot format (each class in one channel).

In [None]:
class xBD_DataGenerator(keras.utils.Sequence):
    def __init__(self, path_to_jsons, batch_size=5, patch_size=256,
                 shuffle=True, classification=True):

        # Initialization -->
        self.path_to_jsons = path_to_jsons
        self.pre_img_paths = []
        self.post_img_paths = []
        self.loc_target_paths = []
        self.cls_target_paths = []

        for i in range(len(self.path_to_jsons)):
            # pre-disaster images
            self.pre_img_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/images/').replace('_post_', '_pre_') + '.png')
            # post-disaster images
            self.post_img_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/images/') + '.png')
            if 'tier3' not in self.path_to_jsons[i]:
              # localization targets
              self.loc_target_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/targets/').replace('_post_', '_pre_') + '_target.png')
              # classification targets
              self.cls_target_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/targets/') + '_target.png')
            else:
              # localization targets
              self.loc_target_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/targets/').replace('_post_', '_pre_') + '.png')
              # classification targets
              self.cls_target_paths.append(self.path_to_jsons[i].replace('Shareddrives/UnlimitedDrive/Thesis', 'MyDrive').replace('/labels/', '/targets/') + '.png')

        self.batch_size = batch_size
        self.patch_size = patch_size
        self.shuffle = shuffle
        self.classification = classification    # Whether you want to perform classification or localization.
        self.indices = np.arange(len(self.pre_img_paths))
        if self.shuffle:
            np.random.shuffle(self.indices)

        self.on_epoch_end()

    def __len__(self):
        # Denotes the number of batches per epoch (steps per epoch) [N_batch]
        return int(len(self.pre_img_paths) // self.batch_size)

    def __getitem__(self, idx):
        # Returns image indices of a batch
        # Generate one batch of data
        batch_indices = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
        pre_image_batch = list(np.array(self.pre_img_paths)[batch_indices])
        post_image_batch = list(np.array(self.post_img_paths)[batch_indices])
        loc_target_batch = list(np.array(self.loc_target_paths)[batch_indices])
        cls_target_batch = list(np.array(self.cls_target_paths)[batch_indices])

        # Generates data containing batch_size samples # X : (n_samples, *dim, n_channels)
        # Initialization
        x_pre_batch = np.empty((self.batch_size,) + (self.patch_size, self.patch_size) + (3,), dtype='float32')
        x_post_batch = np.empty((self.batch_size,) + (self.patch_size, self.patch_size) + (3,), dtype='float32')
        y_target_batch = np.empty((self.batch_size,) + (self.patch_size, self.patch_size) + (5,), dtype='float32')
        # Generate data
        for b in range(self.batch_size):
            pre_image = imread(pre_image_batch[b])
            post_image = imread(post_image_batch[b])
            if self.classification:
                target_image = imread(cls_target_batch[b])
                # Treat "unclassified" as no-damage
                target_image[target_image == 5] = 1

                # for CATEGORICAL ONE-HOT classes
                categorical = np.zeros((target_image.shape[0], target_image.shape[1], 5))
                categorical[:, :, 0] = target_image == 0
                categorical[:, :, 1] = target_image == 1
                categorical[:, :, 2] = target_image == 2
                categorical[:, :, 3] = target_image == 3
                categorical[:, :, 4] = target_image == 4
                target_image = categorical

            else:
                target_image = imread(loc_target_batch[b])

            # AUGMENTATION TO BE ADDED #

            # Perform a random cropping and put images into batch array.
            r_start, c_start = np.random.randint(low=0, high=(pre_image.shape[0] - self.patch_size), size=2)
            x_pre_batch[b] = pre_image[r_start:r_start+self.patch_size, c_start:c_start+self.patch_size, :]
            x_post_batch[b] = post_image[r_start:r_start+self.patch_size, c_start:c_start+self.patch_size, :]
            y_target_batch[b] = target_image[r_start:r_start+self.patch_size, c_start:c_start+self.patch_size, :]

        return (x_pre_batch, x_post_batch), y_target_batch

    def on_epoch_end(self):
        # Updates indices after each epoch
        self.indices = np.arange(len(self.pre_img_paths))
        if self.shuffle:
            np.random.shuffle(self.indices)

## data split

In [None]:
# Training data generation
training_files = []
cond00 = db['Group'] == 'Train'                # use images in Train folder
cond01 = db['Group'] == 'Tier3'                # use images in Train folder
cond1 = db['buildings#'] > 20                 # ensure there are buildings
cond2 = db['Pre_Post'] == 'post'              # choose from pre or post
cond3 = db['destroyed#'] + db['minor-damage#'] + db['major-damage#'] > 10
training_files = list(db[(cond00 | cond01) & cond2 & cond3]['img_name'])
# training_files = list(db[cond00 & cond2 & cond3]['img_name'])
print(len(training_files))

# Testing data generation
testing_files = []
cond0 = db['Group'] == 'Test'                 # use images in Train folder
cond1 = db['destroyed#'] > 0                  # ensure there are buildings
cond2 = db['Pre_Post'] == 'post'              # choose from pre or post
cond3 = db['destroyed#'] + db['major-damage#'] + db['minor-damage#'] > 5
testing_files = list(db[cond0 & cond2 & cond3]['img_name'])
print(len(testing_files))

# Validation data generation
validation_files = []
cond0 = db['Group'] == 'Hold'                 # use images in Train folder
cond1 = db['buildings#'] > 30                 # ensure there are buildings
cond2 = db['Pre_Post'] == 'post'              # choose from pre or post
cond3 = db['destroyed#'] + db['minor-damage#'] + db['major-damage#'] > 20
validation_files = list(db[cond0 & cond2 & cond3]['img_name'])
print(len(validation_files))

1310
292
155


# Models

## base UNet

In [None]:
def base_unet(filters, output_channels, width=None, height=None, input_channels=1, conv_layers=2):
    def conv2d(layer_input, filters, conv_layers=2):
        d = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(layer_input)
        d = BatchNormalization()(d)
        d = Activation('relu')(d)

        for i in range(conv_layers - 1):
            d = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
            d = BatchNormalization()(d)
            d = Activation('relu')(d)

        return d

    def deconv2d(layer_input, filters):
        u = Conv2DTranspose(filters, 2, strides=(2, 2), padding='same')(layer_input)
        u = BatchNormalization()(u)
        u = Activation('relu')(u)
        return u

    inputs = Input(shape=(width, height, input_channels))

    conv1 = conv2d(inputs, filters, conv_layers=conv_layers)
    pool1 = MaxPooling2D((2, 2))(conv1)

    conv2 = conv2d(pool1, filters * 2, conv_layers=conv_layers)
    pool2 = MaxPooling2D((2, 2))(conv2)

    conv3 = conv2d(pool2, filters * 4, conv_layers=conv_layers)
    pool3 = MaxPooling2D((2, 2))(conv3)

    conv4 = conv2d(pool3, filters * 8, conv_layers=conv_layers)
    pool4 = MaxPooling2D((2, 2))(conv4)

    conv5 = conv2d(pool4, filters * 16, conv_layers=conv_layers)

    up6 = deconv2d(conv5, filters * 8)
    up6 = Concatenate()([up6, conv4])
    conv6 = conv2d(up6, filters * 8, conv_layers=conv_layers)

    up7 = deconv2d(conv6, filters * 4)
    up7 = Concatenate()([up7, conv3])
    conv7 = conv2d(up7, filters * 4, conv_layers=conv_layers)

    up8 = deconv2d(conv7, filters * 2)
    up8 = Concatenate()([up8, conv2])
    conv8 = conv2d(up8, filters * 2, conv_layers=conv_layers)

    up9 = deconv2d(conv8, filters)
    up9 = Concatenate()([up9, conv1])
    conv9 = conv2d(up9, filters, conv_layers=conv_layers)

    # Changed sigmoid to softmax, also changed output from 1 to 4
    outputs = Conv2D(output_channels, kernel_size=(1, 1), strides=(1, 1), activation='sigmoid')(conv9)  # softmax

    model = Model(inputs=inputs, outputs=outputs)

    return model

## residual-UNet

In [None]:
def residual_unet(filters, output_channels, width=None, height=None, input_channels=1, conv_layers=2):
    def residual_block(x, filters, conv_layers=2):
        x = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        d = x
        for i in range(conv_layers - 1):
            d = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
            d = BatchNormalization()(d)
            d = Activation('relu')(d)

        x = Add()([d, x])

        return x

    def deconv2d(layer_input, filters):
        u = Conv2DTranspose(filters, 2, strides=(2, 2), padding='same')(layer_input)
        u = BatchNormalization()(u)
        u = Activation('relu')(u)
        return u

    inputs = Input(shape=(width, height, input_channels))

    conv1 = residual_block(inputs, filters, conv_layers=conv_layers)
    pool1 = MaxPooling2D((2, 2))(conv1)

    conv2 = residual_block(pool1, filters * 2, conv_layers=conv_layers)
    pool2 = MaxPooling2D((2, 2))(conv2)

    conv3 = residual_block(pool2, filters * 4, conv_layers=conv_layers)
    pool3 = MaxPooling2D((2, 2))(conv3)

    conv4 = residual_block(pool3, filters * 8, conv_layers=conv_layers)
    pool4 = MaxPooling2D((2, 2))(conv4)

    conv5 = residual_block(pool4, filters * 16, conv_layers=conv_layers)

    conv6 = deconv2d(conv5, filters * 8)
    up6 = concatenate([conv6, conv4])
    up6 = residual_block(up6, filters * 8, conv_layers=conv_layers)

    conv7 = deconv2d(up6, filters * 4)
    up7 = concatenate([conv7, conv3])
    up7 = residual_block(up7, filters * 4, conv_layers=conv_layers)

    conv8 = deconv2d(up7, filters * 2)
    up8 = concatenate([conv8, conv2])
    up8 = residual_block(up8, filters * 2, conv_layers=conv_layers)

    conv9 = deconv2d(up8, filters)
    up9 = concatenate([conv9, conv1])
    up9 = residual_block(up9, filters, conv_layers=conv_layers)

    output_layer_noActi = Conv2D(output_channels, (1, 1), padding="same", activation=None)(up9)
    outputs = Activation('sigmoid')(output_layer_noActi)        # softmax

    model = Model(inputs=inputs, outputs=outputs)

    return model

## attention-UNet

In [None]:
def attention_unet(filters, output_channels, width=None, height=None, input_channels=1, conv_layers=2):
    def conv2d(layer_input, filters, conv_layers=2):
        d = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(layer_input)
        d = BatchNormalization()(d)
        d = Activation('relu')(d)

        for i in range(conv_layers - 1):
            d = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(d)
            d = BatchNormalization()(d)
            d = Activation('relu')(d)

        return d

    def deconv2d(layer_input, filters):
        u = Conv2DTranspose(filters, 2, strides=(2, 2), padding='same')(layer_input)
        u = BatchNormalization()(u)
        u = Activation('relu')(u)
        return u

    def attention_block(F_g, F_l, F_int):
        g = Conv2D(F_int, kernel_size=(1, 1), strides=(1, 1), padding='valid')(F_g)
        g = BatchNormalization()(g)
        x = Conv2D(F_int, kernel_size=(1, 1), strides=(1, 1), padding='valid')(F_l)
        x = BatchNormalization()(x)
        psi = Add()([g, x])
        psi = Activation('relu')(psi)

        psi = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding='valid')(psi)
        psi = Activation('sigmoid')(psi)

        return Multiply()([F_l, psi])

    inputs = Input(shape=(width, height, input_channels))

    conv1 = conv2d(inputs, filters, conv_layers=conv_layers)
    pool1 = MaxPooling2D((2, 2))(conv1)

    conv2 = conv2d(pool1, filters * 2, conv_layers=conv_layers)
    pool2 = MaxPooling2D((2, 2))(conv2)

    conv3 = conv2d(pool2, filters * 4, conv_layers=conv_layers)
    pool3 = MaxPooling2D((2, 2))(conv3)

    conv4 = conv2d(pool3, filters * 8, conv_layers=conv_layers)
    pool4 = MaxPooling2D((2, 2))(conv4)

    conv5 = conv2d(pool4, filters * 16, conv_layers=conv_layers)

    up6 = deconv2d(conv5, filters * 8)
    conv6 = attention_block(up6, conv4, filters * 8)
    up6 = Concatenate()([up6, conv6])
    conv6 = conv2d(up6, filters * 8, conv_layers=conv_layers)

    up7 = deconv2d(conv6, filters * 4)
    conv7 = attention_block(up7, conv3, filters * 4)
    up7 = Concatenate()([up7, conv7])
    conv7 = conv2d(up7, filters * 4, conv_layers=conv_layers)

    up8 = deconv2d(conv7, filters * 2)
    conv8 = attention_block(up8, conv2, filters * 2)
    up8 = Concatenate()([up8, conv8])
    conv8 = conv2d(up8, filters * 2, conv_layers=conv_layers)

    up9 = deconv2d(conv8, filters)
    conv9 = attention_block(up9, conv1, filters)
    up9 = Concatenate()([up9, conv9])
    conv9 = conv2d(up9, filters, conv_layers=conv_layers)

    outputs = Conv2D(output_channels, kernel_size=(1, 1), strides=(1, 1), activation='sigmoid')(conv9)  # softmax

    model = Model(inputs=inputs, outputs=outputs)

    return model

# Training

## parameters

In [None]:
NAME = 'UNET_WCCE_G32_K3'   # Put a custom name here
BATCH = 5            # Define batch size for processing
PATCH = 320          # Define patch size for cropping. Patches are square.
EPOCHS = 150         # How many epochs?
# CLASSES = ['buildings']                                                 # Localization
CLASSES = ['no-damage', 'minor-damage', 'major-damage', 'destroyed']    # Classification
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation

if n_classes == 1:
  TASK = 'localization'
else:
  TASK = 'classification'

# Where to save the model ?!
save_name = f'/content/drive/MyDrive/model_{NAME}_{TASK}_{BATCH}_{PATCH}_{dt.datetime.now().strftime("%d%m%Y-%H%M%S")}.h5'

parameters = {'shuffle': True, 'batch_size': BATCH, 'patch_size': PATCH}    # , classification=True}

train_generator = xBD_DataGenerator(path_to_jsons=training_files, **parameters)
test_generator = xBD_DataGenerator(path_to_jsons=testing_files, **parameters)
valid_generator = xBD_DataGenerator(path_to_jsons=validation_files, **parameters)

## loss functions

https://github.com/shruti-jadon/Semantic-Segmentation-Loss-Functions

In [None]:
from keras.losses import binary_crossentropy, BinaryCrossentropy, CategoricalCrossentropy, categorical_crossentropy

beta = 0.25
alpha = 0.25
gamma = 2
epsilon = 1e-5
smooth = 1


class Semantic_loss_functions(object):
    def __init__(self):
        print("semantic loss functions initialized")

    def dice_coef(self, y_true, y_pred):
        y_true_f = K.flatten(y_true)
        y_pred_f = K.flatten(y_pred)
        intersection = K.sum(y_true_f * y_pred_f)
        return (2. * intersection + K.epsilon()) / (K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())

    def dice_loss(self, y_true, y_pred):
        loss = 1 - self.dice_coef(y_true, y_pred)
        return loss

    def focal_loss_with_logits(self, logits, targets, alpha, gamma, y_pred):
        weight_a = alpha * (1 - y_pred) ** gamma * targets
        weight_b = (1 - alpha) * y_pred ** gamma * (1 - targets)
        return (tf.math.log1p(tf.exp(-tf.abs(logits))) + tf.nn.relu(-logits)) * (weight_a + weight_b) + logits * weight_b

    def focal_loss(self, y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
        logits = tf.math.log(y_pred / (1 - y_pred))
        loss = self.focal_loss_with_logits(logits=logits, targets=y_true, alpha=alpha, gamma=gamma, y_pred=y_pred)
        return tf.reduce_mean(loss)

    def tversky_index(self, y_true, y_pred):
        y_true_pos = K.flatten(y_true)
        y_pred_pos = K.flatten(y_pred)
        true_pos = K.sum(y_true_pos * y_pred_pos)
        false_neg = K.sum(y_true_pos * (1 - y_pred_pos))
        false_pos = K.sum((1 - y_true_pos) * y_pred_pos)
        alpha = 0.7
        return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)

    def tversky_loss(self, y_true, y_pred):
        return 1 - self.tversky_index(y_true, y_pred)

    def jacard_similarity(self, y_true, y_pred):
        """ Intersection-Over-Union (IoU), also known as the Jaccard Index ."""
        y_true_f = K.flatten(y_true)
        y_pred_f = K.flatten(y_pred)

        intersection = K.sum(y_true_f * y_pred_f)
        union = K.sum((y_true_f + y_pred_f) - (y_true_f * y_pred_f))
        return intersection / union

    def jacard_loss(self, y_true, y_pred):
        return 1 - self.jacard_similarity(y_true, y_pred)

    def ssim_loss(self, y_true, y_pred):
        """ Structural Similarity Index (SSIM) loss """
        return 1 - tf.image.ssim(y_true, y_pred, max_val=1)

    def unet3p_hybrid_loss(self, y_true, y_pred):
        """
        Hybrid loss proposed in UNET 3+ (https://arxiv.org/ftp/arxiv/papers/2004/2004.08790.pdf)
        Hybrid loss for segmentation in three-level hierarchy – pixel, patch and map-level,
        which is able to capture both large-scale and fine structures with clear boundaries.
        """
        focal_loss = self.focal_loss(y_true, y_pred)
        ms_ssim_loss = self.ssim_loss(y_true, y_pred)
        jacard_loss = self.jacard_loss(y_true, y_pred)
        return focal_loss + ms_ssim_loss + jacard_loss

    def basnet_hybrid_loss(self, y_true, y_pred):
        """
        Hybrid loss proposed in BASNET (https://arxiv.org/pdf/2101.04704.pdf)
        The hybrid loss is a combination of the binary cross entropy, structural similarity
        and intersection-over-union losses, which guide the network to learn
        three-level (i.e., pixel-, patch- and map- level) hierarchy representations.
        """
        bce_loss = BinaryCrossentropy(from_logits=False)
        bce_loss = bce_loss(y_true, y_pred)

        ms_ssim_loss = self.ssim_loss(y_true, y_pred)
        jacard_loss = self.jacard_loss(y_true, y_pred)
        return bce_loss + ms_ssim_loss + jacard_loss


def weighted_categorical_crossentropy(weights):
    """
    A weighted version of keras.objectives.categorical_crossentropy

    Variables:
        weights: numpy array of shape (C,) where C is the number of classes

    Usage:
        weights = np.array([0.5,2,10]) # Class one at 0.5, class 2 twice the normal weights, class 3 10x.
        loss = weighted_categorical_crossentropy(weights)
        model.compile(loss=loss,optimizer='adam')
    """
    weights = K.variable(weights)

    def loss(y_true, y_pred):
        # scale predictions so that the class probas of each sample sum to 1
        y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
        # clip to prevent NaN's and Inf's
        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
        # calc
        loss = y_true * K.log(y_pred) * weights
        loss = -K.sum(loss, -1)
        return loss

    return loss

In [None]:
semantic_loss = Semantic_loss_functions()
# loss=semantic_loss.unet3p_hybrid_loss
# learning_rate_reduction = ReduceLROnPlateau(monitor='val_dice_coef', patience=10, verbose=1, factor=0.5, min_lr=0.0000001)

semantic loss functions initialized


## metrics

In [None]:
# Different metrics
def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

## callbacks

In [None]:
def decay_schedule(epoch, lr):
    # if epoch < 5:
    #     lr = 0.1
    # elif (epoch >= 5) and (epoch < 15):
    #     lr = 0.01
    # elif (epoch >= 15) and (epoch < 30):
    #     lr = 0.001
    # elif epoch >= 30:
    #     lr = 0.0001
    #     # localization_branch.trainable = True

    if epoch < 5:
        # lr = 0.01
        lr = 0.0001
    elif epoch >= 5:
        # lr = 0.0001
        lr = 0.00001
        localization_branch.trainable = True
    return lr

# Various Loss functions can be tested.
# total_loss = 'binary_crossentropy'
# total_loss = 'categorical_crossentropy'
# total_loss = semantic_loss.dice_loss + CategoricalCrossentropy(from_logits=True)
total_loss = weighted_categorical_crossentropy([0.5, 2, 8, 7, 8])
# total_loss = [semantic_loss.focal_loss, semantic_loss.dice_loss]

lr_scheduler = LearningRateScheduler(decay_schedule)
optimizer = keras.optimizers.AdamW(learning_rate=0.001)     # AdamW
metrics = ['accuracy', precision_m, recall_m, f1_m, semantic_loss.jacard_similarity]

my_callbacks = [
    keras.callbacks.ModelCheckpoint(save_name, save_weights_only=True, save_best_only=True, mode='min'),
    keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, start_from_epoch=35, patience=10, restore_best_weights=True),
    # keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.33, patience=5, mode='min'),
    lr_scheduler,
]

## localization branch

In [None]:
import tensorflow as tf
tf.compat.v1.reset_default_graph()

localization_model = base_unet(32, 1, width=PATCH, height=PATCH, input_channels=3, conv_layers=4)
localization_model.load_weights('/content/drive/MyDrive/model_BASE-UNET_localization_10_352_25082023-154620.h5')

# localization_model = attention_unet(32, 1, width=PATCH, height=PATCH, input_channels=3, conv_layers=4)
# localization_model.load_weights('/content/drive/MyDrive/model_ATTENTION-UNET_localization_10_352_26082023-212616.h5')

# localization_model = residual_unet(32, 1, width=PATCH, height=PATCH, input_channels=3, conv_layers=4)
# localization_model.load_weights('/content/drive/MyDrive/model_RESIDUAL-UNET_localization_10_352_25082023-174256.h5')

# Define the localization branch
# Check localization_model.summary() and look for the one-to-the-last layer's name. Put it down HERE...
# localization_branch = Model(inputs=localization_model.inputs, outputs=localization_model.get_layer('activation_47').output)
localization_branch = Model(inputs=localization_model.inputs, outputs=localization_model.layers[-2].output)
localization_branch.trainable = False   # SHARED WEIGHTS!!!

## ⚖ Selective Kernel Module (SKConv)

In [None]:
def SKConv(M=2, r=16, L=32, G=32, name='skconv'):
    def wrapper(inputs):
        b, h, w = inputs.shape[0], inputs.shape[1], inputs.shape[2]
        # inputs_shape = tf.shape(inputs)
        # b, h, w = inputs_shape[0], inputs_shape[1], inputs_shape[2]                 # b: batch
        filters = inputs.get_shape().as_list()[-1]
        d = max(filters//r, L)      # Middle channels

        x = inputs

        xs = []
        for m in range(M):
            if G == 1:
                # _x = Conv2D(filters, kernel_size=3+m*2, dilation_rate=m+1, padding='same', use_bias=False, name=name+'_conv%d'%m)(x)
                _x = Conv2D(filters, kernel_size=3, dilation_rate=m+1, padding='same', use_bias=False, name=name+'_conv%d'%m)(x)
            else:
                c = filters // G
                # _x = DepthwiseConv2D(kernel_size=3+m*2, dilation_rate=m+1, padding='same', use_bias=False, depth_multiplier=c, name=name+'_conv%d'%m)(x)
                _x = DepthwiseConv2D(kernel_size=3, dilation_rate=m+1, padding='same', use_bias=False, depth_multiplier=c, name=name+'_conv%d'%m)(x)

                _x = Reshape([h, w, G, c, c], name=name+'_conv%d_reshape1'%m)(_x)
                _x = Lambda(lambda x: tf.reduce_sum(x, axis=-1), output_shape=[b, h, w, G, c], name=name+'_conv%d_sum'%m)(_x)
                _x = Reshape([h, w, filters], name=name+'_conv%d_reshape2'%m)(_x)


            _x = BatchNormalization(name=name+'_conv%d_bn'%m)(_x)
            _x = Activation('relu', name=name+'_conv%d_relu'%m)(_x)

            xs.append(_x)

        U = Add(name=name+'_add')(xs)
        s = Lambda(lambda x: tf.reduce_mean(x, axis=[1,2], keepdims=True), output_shape=[b, 1, 1, filters], name=name+'_gap')(U)

        z = Conv2D(d, 1, name=name+'_fc_z')(s)
        z = BatchNormalization(name=name+'_fc_z_bn')(z)
        z = Activation('relu', name=name+'_fc_z_relu')(z)

        x = Conv2D(filters*M, 1, name=name+'_fc_x')(z)
        x = Reshape([1, 1, filters, M],name=name+'_reshape')(x)
        scale = Softmax(name=name+'_softmax')(x)

        x = Lambda(lambda x: tf.stack(x, axis=-1), output_shape=[b, h, w, filters, M], name=name+'_stack')(xs) # b, h, w, c, M
        x = Axpby(name=name+'_axpby')([scale, x])

        return x
    return wrapper


class Axpby(Layer):
  def __init__(self, **kwargs):
        super(Axpby, self).__init__(**kwargs)

  def build(self, input_shape):
        super(Axpby, self).build(input_shape)  # Be sure to call this at the end

  def call(self, inputs):
    """ scale: [B, 1, 1, C, M]
        x: [B, H, W, C, M]
    """
    scale, x = inputs
    f = tf.multiply(scale, x, name='product')
    f = tf.reduce_sum(f, axis=-1, name='sum')
    return f

  def compute_output_shape(self, input_shape):
    return input_shape[0:4]

## 🥂 Construct the siamese netowrk + segmentation head

In [None]:
input_pre = Input(shape=(PATCH, PATCH, 3), name="pre_input")
output_pre = localization_branch(input_pre)

input_post = Input(shape=(PATCH, PATCH, 3), name="post_input")
output_post = localization_branch(input_post)


# Segmentation Head can be configured to get different results.
head = Concatenate()([output_pre, output_post])

# Selective Kernel ---------------------- WITHOUT
head = SKConv(M=2, r=16, L=32, G=32)(head)

# new from here -----------------
head = BatchNormalization()(head)
head = Activation("relu")(head)
head = Conv2D(8, (3, 3), padding='same', kernel_initializer=he_normal(), name='middle_conv')(head)
head = BatchNormalization()(head)
head = Activation("relu")(head)
# to here -----------------------
head = Conv2D(n_classes, (3, 3), padding='same', kernel_initializer=he_normal(), name='class_conv')(head)
output = Activation("softmax")(head)    # sigmoid <-- 1st place winner uses sigmoid as the last layer activation.

classification_model = Model([input_pre, input_post], output)                   # CLASSIFICATION MODEL

* If your model’s output classes are NOT mutually exclusive and you can choose many of them at the same time, use a `sigmoid` function on the network’s raw outputs.
* If your model’s output classes are mutually exclusive and you can only choose one, then use a `softmax` function on the network’s raw outputs.

## compile

In [None]:
classification_model.compile(optimizer, loss=total_loss, metrics=metrics)

## Fit

[class_weight](https://datascience.stackexchange.com/questions/13490/how-to-set-class-weights-for-imbalanced-classes-in-keras)

In [None]:
history = classification_model.fit(
    train_generator,
    validation_data=valid_generator,
    use_multiprocessing=True,
    workers=6, epochs=EPOCHS,
    callbacks=my_callbacks
)

## show classification outputs

In [None]:
classification_model.load_weights('/content/drive/MyDrive/model_BASE-UNET_classification_4_256_31082023-215407.h5')

In [None]:
[x_pre, x_post], y = test_generator.__getitem__(20)
print(x_pre.shape, y.shape)

a = classification_model.predict((x_pre, x_post))
print(a.shape)

(5, 320, 320, 3) (5, 320, 320, 5)
(5, 320, 320, 5)


In [None]:
plt.figure(figsize=(20, 20))
for i in range(5):
    y[i, :, :, 0] = 0
    plt.subplot(5, 8, 8*i + 1), plt.imshow(x_pre[i].astype(int) ), plt.xticks([]), plt.yticks([]), plt.title('pre image')  #
    plt.subplot(5, 8, 8*i + 2), plt.imshow(x_post[i].astype(int) ), plt.xticks([]), plt.yticks([]), plt.title('post image')
    # plt.subplot(5, 8, 8*i + 3), plt.imshow(y[i, :, :, 0]), plt.xticks([]), plt.yticks([]), plt.title('GT')
    plt.subplot(5, 8, 8*i + 3), plt.imshow(np.argmax(y[i], axis=-1), vmin=0, vmax=4), plt.xticks([]), plt.yticks([]), plt.title('GT')
    plt.subplot(5, 8, 8*i + 4), plt.imshow(a[i, :, :, 0]), plt.xticks([]), plt.yticks([]), plt.title('--')
    plt.subplot(5, 8, 8*i + 5), plt.imshow(a[i, :, :, 1]), plt.xticks([]), plt.yticks([]), plt.title('--')
    plt.subplot(5, 8, 8*i + 6), plt.imshow(a[i, :, :, 2]), plt.xticks([]), plt.yticks([]), plt.title('--')
    plt.subplot(5, 8, 8*i + 7), plt.imshow(a[i, :, :, 3]), plt.xticks([]), plt.yticks([]), plt.title('--')
    plt.subplot(5, 8, 8*i + 8), plt.imshow(a[i, :, :, 4]), plt.xticks([]), plt.yticks([]), plt.title('--')
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(20, 9))
for i in range(5):
    plt.subplot(5, 7, 7*i + 1), plt.imshow(x_pre[i].astype(int)), plt.xticks([]), plt.yticks([]), plt.title('pre image')
    plt.subplot(5, 7, 7*i + 2), plt.imshow(x_post[i].astype(int)), plt.xticks([]), plt.yticks([]), plt.title('post image')
    plt.subplot(5, 7, 7*i + 3), plt.imshow(np.argmax(y[i], axis=-1)), plt.xticks([]), plt.yticks([]), plt.title('GT')
    plt.subplot(5, 7, 7*i + 4), plt.imshow(np.argmax(a[i], axis=-1)), plt.xticks([]), plt.yticks([]), plt.title('--')
    plt.subplot(5, 7, 7*i + 5), plt.imshow(a[i, :, :, 1]), plt.xticks([]), plt.yticks([]), plt.title('--')
    plt.subplot(5, 7, 7*i + 6), plt.imshow(a[i, :, :, 2]), plt.xticks([]), plt.yticks([]), plt.title('--')
    plt.subplot(5, 7, 7*i + 7), plt.imshow(a[i, :, :, 3]), plt.xticks([]), plt.yticks([]), plt.title('--')
plt.tight_layout()
plt.show()

## show localization outputs
In order to be able to run these cells, the localization data-generator must be defined before.

In [None]:
localization_model_base_unet = base_unet(32, 1, width=PATCH, height=PATCH, input_channels=3, conv_layers=4)
localization_model_residual_unet = residual_unet(32, 1, width=PATCH, height=PATCH, input_channels=3, conv_layers=4)
localization_model_attention_unet = attention_unet(32, 1, width=PATCH, height=PATCH, input_channels=3, conv_layers=4)

localization_model_base_unet.load_weights('/content/drive/MyDrive/model_BASE-UNET_localization_10_352_25082023-154620.h5')
localization_model_residual_unet.load_weights('/content/drive/MyDrive/model_RESIDUAL-UNET_localization_10_352_25082023-174256.h5')
localization_model_attention_unet.load_weights('/content/drive/MyDrive/model_ATTENTION-UNET_localization_10_352_26082023-212616.h5')

In [None]:
x, y = test_generator.__getitem__(23)
print(x.shape, y.shape)

a = localization_model_base_unet.predict(x)
print(a.shape)
b = localization_model_residual_unet.predict(x)
print(b.shape)
c = localization_model_attention_unet.predict(x)
print(c.shape)

In [None]:
plt.figure(figsize=(9, 20))
for i in range(10):
    plt.subplot(10, 5, 5*i + 1), plt.imshow(x[i].astype(int)), plt.xticks([]), plt.yticks([]), plt.title('pre image')
    plt.subplot(10, 5, 5*i + 2), plt.imshow(y[i, :, :, 0]), plt.xticks([]), plt.yticks([]), plt.title('GT')
    plt.subplot(10, 5, 5*i + 3), plt.imshow(a[i, :, :, 0]), plt.xticks([]), plt.yticks([]), plt.title('UNet')
    plt.subplot(10, 5, 5*i + 4), plt.imshow(b[i, :, :, 0]), plt.xticks([]), plt.yticks([]), plt.title('ResUNet')
    plt.subplot(10, 5, 5*i + 5), plt.imshow(c[i, :, :, 0]), plt.xticks([]), plt.yticks([]), plt.title('a-UNet')
plt.tight_layout()
plt.show()