Gamper et al. (2019) [1], the PanNuke dataset is an open pan-cancer histology dataset for nuclei instance segmentation and classification.

Gamper et al. (2020) [2], the PanNuke dataset has been extended with additional annotations and features to support more advanced applications of nuclei classification and segmentation.

---

## References

[1] J. Gamper, N. Alemi Koohbanani, K. Benes, A. Khurram, and N. Rajpoot, "PanNuke: an open pan-cancer histology dataset for nuclei instance segmentation and classification," in *European Congress on Digital Pathology*, Springer, 2019, pp. 11-19. doi: 10.1007/978-3-030-25970-8_2

[2] J. Gamper, N. Alemi Koohbanani, S. Graham, M. Jahanifar, S. A. Khurram, A. Azam, K. Hewitt, and N. Rajpoot, "PanNuke Dataset Extension, Insights and Baselines," *arXiv preprint arXiv:2003.10778*, 2020.


In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
# get the data

#!unzip "/content/drive/MyDrive/pannuke_processed.zip"
#!unzip "/content/drive/MyDrive/pannuke_over_times_3.zip" -d "/content/pannuke_over_times_3"
!unzip "/content/drive/MyDrive/pannuke_instances.zip" -d "/content/pannuke_over_times_3"

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.transform import resize
import cv2
import os
from tqdm import tqdm
from keras.utils import to_categorical
from PIL import Image


#---------------------------------------------------------------#
# Same Code as in hrnet_v2_ocr_pannuke_multi                    #
# but with different training methods and data/ data preperation#
#---------------------------------------------------------------#
class Data:
    def load_npy(path, r, size):
        data_ = np.load(path)
        data = data_[:r]
        resized_data = []
        for d in data:
            resized_d = cv2.resize(d, (size, size))
            resized_data.append(resized_d)
        del data_
        return np.array(resized_data, dtype=np.float32)

    def generator(path, r, size, batch_size):
        start = 0
        end = batch_size
        while True:
            data_ = np.load(path)
            data = data_[start:end]
            resized_data = []
            for d in data:
                resized_d = cv2.resize(d, (size, size))
                resized_data.append(resized_d)
            yield np.array(resized_data)
            start = end
            end += batch_size
            if end > r:
                break

    def load_numpys(directory_path):
        # Initialize empty list to hold data arrays
        data = []

        # Loop through files in directory
        for filename in os.listdir(directory_path):
            if filename.endswith(".npy"):
                # Load data from .npy file and append to list
                data.append(np.load(os.path.join(directory_path, filename)))

        data = np.array(data, dtype=np.float32)
        return data

    def pipe(x, y, size, num_classes=6, r=None, normalize=True):
        fold_1_im = np.load(x)
        fold_1_ma = np.load(y)
        if r is not None:
            fold_1_im[:r]
            fold_1_ma[:r]
        input_width, input_height = size, size
        preprocessed_images = []
        preprocessed_masks = []
        for i in tqdm(range(len(fold_1_im)), colour="#d13516"):
            # Resize the image and mask to the input size of your model
            resized_image = cv2.resize(fold_1_im[i], (input_width, input_height))
            resized_mask = cv2.resize(fold_1_ma[i], (input_width, input_height))

            # One-hot encode the mask
            one_hot_mask = np.zeros(
                (input_width, input_height, num_classes), dtype=np.float32
            )
            for c in range(num_classes):
                one_hot_mask[:, :, c] = (resized_mask[:, :, 0] == c).astype(np.float32)

            # Normalize the pixel values of the image
            normalized_image = resized_image / 255.0

            # Append the preprocessed image and mask to the list
            preprocessed_images.append(normalized_image)
            preprocessed_masks.append(one_hot_mask)

        # Convert the preprocessed images and masks to numpy arrays
        preprocessed_images = np.array(preprocessed_images)
        preprocessed_masks = np.array(preprocessed_masks)
        del fold_1_im
        del fold_1_ma
        # Print the dimensions of the preprocessed data
        print("Preprocessed images shape:", preprocessed_images.shape)
        print("Preprocessed masks shape:", preprocessed_masks.shape)
        return preprocessed_images, preprocessed_masks

    def normalize(array):
        array = array / 255.0
        return array

    def masking(masks):
        for i in range(0, len(masks)):
            masks[i] = np.where(masks[i] > 0, 1, masks[i])

    def display(img, mask, de_norm=True):
        if de_norm:
            img *= 255.0
        img = img.astype(np.uint8)
        mask = mask.astype(np.uint8)
        mask = mask[:, :, :3]
        plt.subplot(1, 2, 1)
        plt.imshow(img)
        plt.title("image")
        plt.subplot(1, 2, 2)
        plt.imshow(mask)
        plt.title("mask")
        plt.show()

    def split(x, y):
        X_train, X_val, y_train, y_val = train_test_split(
            x, y, test_size=0.2, randotm_state=42
        )
        del x
        del y
        return X_train, X_val, y_train, y_val

    def load_data(path, r=0):
        data = [
            np.array(Image.open(os.path.join(path, file)))
            for file in tqdm(os.listdir(path)[r:])
        ]
        data = np.array(data, dtype=np.float32)
        return data

    def load_resize(path, resize=None, r=0):
        if resize is not None:
            shape = (resize, resize)
        else:
            shape = None

        data = []
        for file in tqdm(os.listdir(path)[r:]):
            image = Image.open(os.path.join(path, file))
            if shape is not None:
                image = image.resize(shape)
            data.append(np.array(image))

        data = np.array(data, dtype=np.float32)
        return data

    def resize_data(data, shape=(150, 150)):
        resized_data = []
        for d in data:
            resized_data.append(resize(d, shape, anti_aliasing=True))
        return np.array(resized_data)

    # computationally inefficient
    def one_hot(y, size, load=True, r=None):
        # Load the dataset into memory
        if load:
            dataset = np.load(y)
        else:
            dataset = y
        if r is not None:
            dataset = dataset[:r]
        # Set the desired output size
        output_size = (size, size)

        # Resize each mask in the dataset
        resized_masks = []
        for mask in dataset:
            # Resize the mask using Lanczos interpolation
            resized_mask = resize(mask, output_size, order=1, anti_aliasing=False)

            # One-hot encode the mask
            one_hot_mask = tf.one_hot(np.argmax(resized_mask, axis=-1), depth=6)

            # Append the one-hot encoded mask to the list
            resized_masks.append(one_hot_mask)

        # Convert the list of resized masks to a NumPy array
        fold_ma = np.array(resized_masks)
        del resized_masks
        del dataset
        return fold_ma

    def check(x, y):
        print(f"shapes: x: {x.shape}, y: {y.shape}")
        print(f"norms: x: {np.min(x), np.max(x)}, y: {np.min(y), np.max(y)}")
        print(f"types: x: {type(x)}, y; {type(y)}")

    def load_masks(path, r=0):
        # Load the masks
        mask_files = sorted(os.listdir(path))[r:]
        masks = []
        for file in tqdm(mask_files):
            mask = np.array(Image.open(os.path.join(path, file)), dtype=np.uint8)
            # Perform one-hot encoding using Keras
            one_hot_mask = to_categorical(mask, num_classes=6)
            masks.append(one_hot_mask)
        masks = np.array(masks, dtype=np.float32)
        #  masks /= 255.0
        return masks

    def load_data(path, r=0):
        data = [
            np.array(Image.open(os.path.join(path, file)))
            for file in tqdm(os.listdir(path)[r:])
        ]
        data = np.array(data, dtype=np.float32)
        return data


png_paths = {
    "fold_1_images": "/content/pannuke_processed/Fold 1/images",
    "fold_1_inst": "/content/pannuke_processed/Fold 1/inst_masks",
    "fold_1_seg": "/content/pannuke_processed/Fold 1/sem_masks",
    "fold_2_images": "/content/pannuke_processed/Fold 2/images",
    "fold_2_inst": "/content/pannuke_processed/Fold 2/inst_masks",
    "fold_2_seg": "/content/pannuke_processed/Fold 2/sem_masks",
    "fold_3_images": "/content/pannuke_processed/Fold 3/images",
    "fold_3_inst": "/content/pannuke_processed/Fold 2/inst_masks",
    "fold_3_seg": "/content/pannuke_processed/Fold 2/sem_masks",
}

In [None]:
import os
import numpy as np
import cv2
import albumentations as A
from PIL import Image

# Custom data generator class following the keras.utils.Sequence structure
class CustomDataGenerator(tf.keras.utils.Sequence):
    def __init__(
        self,
        x_path,
        y_path,
        batch_size,
        num_classes=6,
        shuffle=True,
        augment=False,
        target_size=256,
    ):
        self.x_path = x_path
        self.y_path = y_path
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.image_files = sorted(os.listdir(x_path))
        self.mask_files = sorted(os.listdir(y_path))
        self.on_epoch_end()
        self.augment = augment
        self.target_size = target_size

    def __len__(self):
        return int(np.ceil(len(self.image_files) / float(self.batch_size)))


    # load and preprocess the data
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]

        x = []
        y = []
        for i in indexes:
            img = Image.open(os.path.join(self.x_path, self.image_files[i])).convert(
                "RGB"
            )
            mask = Image.open(os.path.join(self.y_path, self.mask_files[i]))
            img = img.resize((self.target_size, self.target_size))
            mask = mask.resize(
                (self.target_size, self.target_size), resample=Image.NEAREST
            )  
            img = np.array(img)
            mask = np.array(mask)
            mask = np.where(mask > 0, 1, 0)
            # after normalizing the masks pixels, further steps can also be neglacted
            mask_onehot = np.zeros((mask.shape[0], mask.shape[1], self.num_classes))
            for j in range(self.num_classes):
                mask_onehot[:, :, j] = (mask == j).astype(int)
            if self.augment:
                data = {"image": img, "mask": mask_onehot}
                data = self.augmentation_transform(**data)
                img = data["image"]
                mask_onehot = data["mask"]
            x.append(img)
            y.append(mask_onehot)
        x = np.array(x) / 255.0
        y = np.array(y)
        y[:, :, :, [0, 5]] = y[:, :, :, [5, 0]]  # swap channel 0 and channel 5
        return x, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.image_files))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def set_augmentation_transform(self, augmentation_transform):
        self.augmentation_transform = augmentation_transform


# Define the augmentation pipeline
augmentation_transform = A.Compose(
    [
        A.RandomRotate90(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.25, brightness_limit=0.1, contrast_limit=0.1),
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),
        #   A.CoarseDropout(max_holes=2, max_height=16, max_width=16, min_holes=1, p=0.1, mask_fill_value=0),
        #   A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=20, p=0.5),
        #   A.ColorJitter(p=0.1, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        #    A.RandomCrop(height=256, width=256, p=0.5),
    ]
)
batch_size = 8

# creating all generators at once
def create_gens(
    fold_x,
    fold_y,
    fold_x_2,
    fold_y_2,
    fold_x_3,
    fold_y_3,
    batch_size=batch_size,
    mode="VAL",
    shuffle=True,
    target_size=256,
):
    if mode not in ["VAL", "TRAIN"]:
        raise Exception(f"Provide either VAL or TRAIN as mode argument, not {mode}")
    elif mode == "VAL":
        augment = False
        val_generator_f_1 = CustomDataGenerator(
            fold_x,
            fold_y,
            batch_size,
            augment=augment,
            shuffle=shuffle,
            target_size=target_size,
        )
        val_generator_f_2 = CustomDataGenerator(
            fold_x_2,
            fold_y_2,
            batch_size,
            augment=augment,
            shuffle=shuffle,
            target_size=target_size,
        )
        val_generator_f_3 = CustomDataGenerator(
            fold_x_3,
            fold_y_3,
            batch_size,
            augment=augment,
            shuffle=shuffle,
            target_size=target_size,
        )
        return val_generator_f_1, val_generator_f_2, val_generator_f_3
    elif mode == "TRAIN":
        augment = True
        train_generator_f_1 = CustomDataGenerator(
            fold_x,
            fold_y,
            batch_size,
            augment=augment,
            shuffle=shuffle,
            target_size=target_size,
        )
        train_generator_f_2 = CustomDataGenerator(
            fold_x_2,
            fold_y_2,
            batch_size,
            augment=augment,
            shuffle=shuffle,
            target_size=target_size,
        )
        train_generator_f_3 = CustomDataGenerator(
            fold_x_3,
            fold_y_3,
            batch_size,
            augment=augment,
            shuffle=shuffle,
            target_size=target_size,
        )
        train_generator_f_1.set_augmentation_transform(augmentation_transform)
        train_generator_f_2.set_augmentation_transform(augmentation_transform)
        train_generator_f_3.set_augmentation_transform(augmentation_transform)
        return train_generator_f_1, train_generator_f_2, train_generator_f_3


# upsampled
""" image_path = "/content/pannuke_over/Fold 1/images"
mask_path = "/content/pannuke_over/Fold 1/sem_masks"
image_path1 = "/content/pannuke_over/Fold 2/images"
mask_path1 = "/content/pannuke_over/Fold 2/sem_masks"
image_path2 = "/content/pannuke_over/Fold 3/images"
mask_path2 = "/content/pannuke_over/Fold 3/sem_masks" """
# normal
""" image_path = "/content/pannuke_processed/Fold 1/images"
mask_path = "/content/pannuke_processed/Fold 1/sem_masks"
image_path1 = "/content/pannuke_processed/Fold 2/images"
mask_path1 = "/content/pannuke_processed/Fold 2/sem_masks"
image_path2 = "/content/pannuke_processed/Fold 3/images"
mask_path2 = "/content/pannuke_processed/Fold 3/sem_masks" """

""" image_path = "/content/pannuke_over_times_3/pannuke_over/Fold 1/images"
mask_path = "/content/pannuke_over_times_3/pannuke_over/Fold 1/sem_masks"
image_path1 = "/content/pannuke_over_times_3/pannuke_over/Fold 2/images"
mask_path1 = "/content/pannuke_over_times_3/pannuke_over/Fold 2/sem_masks"
image_path2 = "/content/pannuke_over_times_3/pannuke_over/Fold 3/images"
mask_path2 = "/content/pannuke_over_times_3/pannuke_over/Fold 3/sem_masks"  """

image_path = "/content/pannuke_over_times_3/pannuke_instances/Fold 1/images"
mask_path = "/content/pannuke_over_times_3/pannuke_instances/Fold 1/inst_masks"
image_path1 = "/content/pannuke_over_times_3/pannuke_instances/Fold 2/images"
mask_path1 = "/content/pannuke_over_times_3/pannuke_instances/Fold 2/inst_masks"
image_path2 = "/content/pannuke_over_times_3/pannuke_instances/Fold 3/images"
mask_path2 = "/content/pannuke_over_times_3/pannuke_instances/Fold 3/inst_masks"

vg_1, vg_2, vg_3 = create_gens(
    image_path, mask_path, image_path1, mask_path1, image_path2, mask_path2, mode="VAL"
)
tg_1, tg_2, tg_3 = create_gens(
    image_path,
    mask_path,
    image_path1,
    mask_path1,
    image_path2,
    mask_path2,
    mode="TRAIN",
)
# non shuffle, only for quick test
vg_1_shuffle, vg_2_shuffle, vg_3_shuffle = create_gens(
    image_path,
    mask_path,
    image_path1,
    mask_path1,
    image_path2,
    mask_path2,
    mode="VAL",
    shuffle=False,
)

In [None]:
import matplotlib.pyplot as plt

# Get a batch of images and masks
batch_x, batch_y = tg_1.__getitem__(0)
batch_x1, batch_y1 = vg_1.__getitem__(0)
batch_x2, batch_y2 = vg_1_shuffle.__getitem__(0)
Data.check(batch_x, batch_y)
print(len(tg_1) * batch_size, len(tg_2) * batch_size, len(tg_3) * batch_size)
Data.check(batch_x1, batch_y1)
Data.check(batch_x2, batch_y2)

# display some test images
def dis_gen(x, y):
    for i in range(len(x)):
        plt.subplot(1, 2, 1)
        plt.imshow(x[i])
        plt.axis(False)
        plt.title("image")
        plt.subplot(1, 2, 2)
        plt.imshow(y[i][:, :, :3])
        plt.title("mask")
        plt.axis(False)
        plt.show()


dis_gen(batch_x, batch_y)
dis_gen(batch_x1, batch_y1)
dis_gen(batch_x2, batch_y2)

In [None]:
import keras.backend as K
import tensorflow as tf
import numpy as np
from sklearn.metrics import roc_auc_score
import keras.backend as K
from tabulate import tabulate


def dice_loss(y_true, y_pred):
    smooth = 1.0
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return 1.0 - score

# metrics 
def dice_score(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.0 * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)


def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall


def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision


def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2 * ((precision * recall) / (precision + recall + K.epsilon()))


def jaccard_index(y_true, y_pred, smooth=1):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou


def cross_entropy_balanced(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)

    count_neg = tf.reduce_sum(1.0 - y_true)
    count_pos = tf.reduce_sum(y_true)

    beta = count_neg / (count_pos + count_neg)

    pos_weight = beta / (1 - beta)

    cost = tf.nn.weighted_cross_entropy_with_logits(
        logits=y_pred, labels=y_true, pos_weight=pos_weight
    )

    cost = tf.reduce_mean(cost * (1 - beta))

    return tf.where(tf.equal(count_pos, 0.0), 0.0, cost)


def pixel_error(y_true, y_pred):
    pred = tf.cast(tf.greater(y_pred, 0.5), tf.int32)
    error = tf.cast(tf.not_equal(pred, tf.cast(y_true, tf.int32)), tf.float32)
    return tf.reduce_mean(error)


def get_fast_pq(y_true, y_pred):
    y_true = tf.cast(y_true, dtype=tf.int32)
    y_pred = tf.argmax(y_pred, axis=-1)
    y_pred = tf.cast(y_pred, dtype=tf.int32)
    n_classes = y_pred.shape[-1]
    tp = tf.Variable([0.0] * n_classes, dtype=tf.float32)
    fp = tf.Variable([0.0] * n_classes, dtype=tf.float32)
    fn = tf.Variable([0.0] * n_classes, dtype=tf.float32)
    for i in range(n_classes):
        tp[i].assign_add(
            tf.reduce_sum(tf.cast((y_true == i) & (y_pred == i), tf.float32))
        )
        fp[i].assign_add(
            tf.reduce_sum(tf.cast((y_true != i) & (y_pred == i), tf.float32))
        )
        fn[i].assign_add(
            tf.reduce_sum(tf.cast((y_true == i) & (y_pred != i), tf.float32))
        )
    pq = tf.reduce_mean(tp / (tp + 0.5 * fp + 0.5 * fn))
    return pq


class PanopticQuality(tf.keras.metrics.Metric):
    def __init__(self, name="pq", **kwargs):
        super().__init__(name=name, **kwargs)
        self.pq_score = self.add_weight(name="pq_score", initializer="zeros")
        self.num_samples = self.add_weight(name="num_samples", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.convert_to_tensor(y_true)
        y_pred = tf.convert_to_tensor(y_pred)
        pq = get_fast_pq(y_true, y_pred)
        self.pq_score.assign_add(tf.reduce_sum(pq))
        self.num_samples.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))

    def result(self):
        return self.pq_score / self.num_samples

    def reset_states(self):
        self.pq_score.assign(0)
        self.num_samples.assign(0)

In [None]:
import keras.backend as K
from keras.models import Model
from keras.layers import *
import keras.layers as kl
import tensorflow as tf
from tensorflow.keras.utils import get_custom_objects

def attention_block(input_tensor, filters, name, dr=0.0):
    x = Conv2D(
        filters,
        (3, 3),
        strides=(1, 1),
        padding="same",
        kernel_initializer="he_normal",
        name=f"conv_att{name}_1",
    )(input_tensor)
    x = BatchNormalization(name=f"batch_att{name}_1")(x)
    x = Activation("Mish", name=f"ac_att{name}_1")(x)
    x = Conv2D(
        filters,
        (3, 3),
        strides=(1, 1),
        padding="same",
        kernel_initializer="he_normal",
        name=f"conv_att{name}_2",
    )(x)
    x = BatchNormalization(name=f"batch_att{name}_2")(x)
    x = Activation("sigmoid", name=f"ac_att{name}_2")(x)
    x = Dropout(dr)(x)
    x = Multiply()([input_tensor, x])
    return x

# using Mish as activation instead of relu or other
class Mish(Activation):
    """
    Mish Activation Function.
    .. math::
        mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
    Shape:
        - Input: Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a model.
        - Output: Same shape as the input.
    Examples:
        >>> X = Activation('Mish', name="conv1_act")(X_input)
    """

    def __init__(self, activation, **kwargs):
        super(Mish, self).__init__(activation, **kwargs)
        self.__name__ = "Mish"


def mish(inputs):
    return inputs * tf.math.tanh(tf.math.softplus(inputs))


get_custom_objects().update({"Mish": Mish(mish)})

GNOISE = 0.001


def conv3x3(x, out_filters, strides=(1, 1)):
    x = Conv2D(
        out_filters,
        3,
        padding="same",
        strides=strides,
        use_bias=False,
        kernel_initializer="he_normal",
    )(x)
    return x


def basic_Block(
    input, out_filters, strides=(1, 1), with_conv_shortcut=False, dropout=0.0
):
    x = conv3x3(input, out_filters, strides)
    x = BatchNormalization(axis=3)(x)
    x = kl.GaussianNoise(GNOISE)(x)  # NCFC
    x = Activation("Mish")(x)

    x = conv3x3(x, out_filters)
    x = BatchNormalization(axis=3)(x)
    x = kl.GaussianNoise(GNOISE)(x)  # NCFC

    if with_conv_shortcut:
        residual = Conv2D(
            out_filters,
            1,
            strides=strides,
            use_bias=False,
            kernel_initializer="he_normal",
        )(input)
        residual = BatchNormalization(axis=3)(residual)
        x = add([x, residual])
    else:
        x = add([x, input])

    x = Activation("Mish")(x)
    if dropout > 0.0:
        x = Dropout(dropout)(x)
    return x


def bottleneck_Block(
    input, out_filters, strides=(1, 1), with_conv_shortcut=False, dropout=0.0
):
    expansion = 4
    de_filters = int(out_filters / expansion)

    x = Conv2D(de_filters, 1, use_bias=False, kernel_initializer="he_normal")(input)
    x = BatchNormalization(axis=3)(x)
    x = kl.GaussianNoise(GNOISE)(x)
    x = Activation("Mish")(x)

    x = Conv2D(
        de_filters,
        3,
        strides=strides,
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x)
    x = BatchNormalization(axis=3)(x)
    x = kl.GaussianNoise(GNOISE)(x)
    x = Activation("Mish")(x)

    x = Conv2D(out_filters, 1, use_bias=False, kernel_initializer="he_normal")(x)
    x = BatchNormalization(axis=3)(x)
    x = kl.GaussianNoise(GNOISE)(x)

    if with_conv_shortcut:
        residual = Conv2D(
            out_filters,
            1,
            strides=strides,
            use_bias=False,
            kernel_initializer="he_normal",
        )(input)
        residual = BatchNormalization(axis=3)(residual)
        x = add([x, residual])
    else:
        x = add([x, input])

    x = Activation("Mish")(x)
    if dropout > 0.0:
        x = Dropout(dropout)(x)
    return x


def stem_net(input):
    x = Conv2D(
        64,
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(input)
    x = BatchNormalization(axis=3)(x)
    x = kl.GaussianNoise(GNOISE)(x)
    x = Activation("Mish")(x)

    x = bottleneck_Block(x, 256, with_conv_shortcut=True)  # changed to false
    x = bottleneck_Block(x, 256, with_conv_shortcut=False)
    x = bottleneck_Block(
        x, 256, with_conv_shortcut=False
    )  # chang some shortcuts to true, maybe integrate more params inside
    x = bottleneck_Block(x, 256, with_conv_shortcut=False)

    return x


def transition_layer1(x, out_filters_list=[32, 64], dropout=0.0):
    x0 = Conv2D(
        out_filters_list[0],
        3,
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x)
    x0 = BatchNormalization(axis=3)(x0)
    x0 = Activation("Mish")(x0)

    x1 = Conv2D(
        out_filters_list[1],
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x)
    x1 = BatchNormalization(axis=3)(x1)
    x1 = Activation("Mish")(x1)

    if dropout > 0.0:
        x = Dropout(dropout)(x)

    return [x0, x1]


def make_branch1_0(x, out_filters=32):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch1_1(x, out_filters=64):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def fuse_layer1(x):
    x0_0 = x[0]
    x0_1 = Conv2D(32, 1, use_bias=False, kernel_initializer="he_normal")(x[1])
    x0_1 = BatchNormalization(axis=3)(x0_1)
    x0_1 = UpSampling2D(size=(2, 2))(x0_1)
    x0 = add([x0_0, x0_1])

    x1_0 = Conv2D(
        64,
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[0])
    x1_0 = BatchNormalization(axis=3)(x1_0)
    x1_1 = x[1]
    x1 = add([x1_0, x1_1])
    return [x0, x1]


def transition_layer2(x, out_filters_list=[32, 64, 128], dropout=0.0):
    x0 = Conv2D(
        out_filters_list[0],
        3,
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[0])
    x0 = BatchNormalization(axis=3)(x0)
    x0 = Activation("Mish")(x0)

    x1 = Conv2D(
        out_filters_list[1],
        3,
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[1])
    x1 = BatchNormalization(axis=3)(x1)
    x1 = Activation("Mish")(x1)

    x2 = Conv2D(
        out_filters_list[2],
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[1])
    x2 = BatchNormalization(axis=3)(x2)
    x2 = Activation("Mish")(x2)

    if dropout > 0.0:
        x = Dropout(dropout)(x)

    return [x0, x1, x2]


def make_branch2_0(x, out_filters=32):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch2_1(x, out_filters=64):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch2_2(x, out_filters=128):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def fuse_layer2(x):
    x0_0 = x[0]
    x0_1 = Conv2D(32, 1, use_bias=False, kernel_initializer="he_normal")(x[1])
    x0_1 = BatchNormalization(axis=3)(x0_1)
    x0_1 = UpSampling2D(size=(2, 2))(x0_1)
    x0_2 = Conv2D(32, 1, use_bias=False, kernel_initializer="he_normal")(x[2])
    x0_2 = BatchNormalization(axis=3)(x0_2)
    x0_2 = UpSampling2D(size=(4, 4))(x0_2)
    x0 = add([x0_0, x0_1, x0_2])

    x1_0 = Conv2D(
        64,
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[0])
    x1_0 = BatchNormalization(axis=3)(x1_0)
    x1_1 = x[1]
    x1_2 = Conv2D(64, 1, use_bias=False, kernel_initializer="he_normal")(x[2])
    x1_2 = BatchNormalization(axis=3)(x1_2)
    x1_2 = UpSampling2D(size=(2, 2))(x1_2)
    x1 = add([x1_0, x1_1, x1_2])

    x2_0 = Conv2D(
        32,
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[0])
    x2_0 = BatchNormalization(axis=3)(x2_0)
    x2_0 = Activation("Mish")(x2_0)
    x2_0 = Conv2D(
        128,
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x2_0)
    x2_0 = BatchNormalization(axis=3)(x2_0)
    x2_1 = Conv2D(
        128,
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[1])
    x2_1 = BatchNormalization(axis=3)(x2_1)
    x2_2 = x[2]
    x2 = add([x2_0, x2_1, x2_2])
    return [x0, x1, x2]


def transition_layer3(x, out_filters_list=[32, 64, 128, 256], dropout=0.0):
    x0 = Conv2D(
        out_filters_list[0],
        3,
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[0])
    x0 = BatchNormalization(axis=3)(x0)
    x0 = Activation("Mish")(x0)

    x1 = Conv2D(
        out_filters_list[1],
        3,
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[1])
    x1 = BatchNormalization(axis=3)(x1)
    x1 = Activation("Mish")(x1)

    x2 = Conv2D(
        out_filters_list[2],
        3,
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[2])
    x2 = BatchNormalization(axis=3)(x2)
    x2 = Activation("Mish")(x2)

    x3 = Conv2D(
        out_filters_list[3],
        3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        kernel_initializer="he_normal",
    )(x[2])
    x3 = BatchNormalization(axis=3)(x3)
    x3 = Activation("Mish")(x3)
    if dropout > 0.0:
        x = Dropout(dropout)(x)

    return [x0, x1, x2, x3]


def make_branch3_0(x, out_filters=32):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch3_1(x, out_filters=64):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch3_2(x, out_filters=128):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch3_3(x, out_filters=256):
    x = basic_Block(x, out_filters, with_conv_shortcut=True)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def fuse_layer3(x):
    x0_0 = x[0]
    x0_1 = Conv2D(32, 1, use_bias=False, kernel_initializer="he_normal")(x[1])
    x0_1 = BatchNormalization(axis=3)(x0_1)
    x0_1 = UpSampling2D(size=(2, 2))(x0_1)
    x0_2 = Conv2D(32, 1, use_bias=False, kernel_initializer="he_normal")(x[2])
    x0_2 = BatchNormalization(axis=3)(x0_2)
    x0_2 = UpSampling2D(size=(4, 4))(x0_2)
    x0_3 = Conv2D(32, 1, use_bias=False, kernel_initializer="he_normal")(x[3])
    x0_3 = BatchNormalization(axis=3)(x0_3)
    x0_3 = UpSampling2D(size=(8, 8))(x0_3)
    x0 = concatenate([x0_0, x0_1, x0_2, x0_3], axis=-1)
    return x0


# NCFC: Added naming convention to support finetuning
def final_layer(x, classes=1, layernameprefix="model", activation="softmax"):
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(
        classes,
        1,
        use_bias=False,
        kernel_initializer="he_normal",
        name=layernameprefix + "_conv2d",
    )(x)
    x = BatchNormalization(axis=3, name=layernameprefix + "_bnclass")(x)
    x = Activation(activation, name=layernameprefix + "_classification")(x)
    return x


def object_attention_context(feature, prob, scale=1):
    feature = tf.keras.layers.Reshape([-1, tf.keras.backend.int_shape(feature)[-1]])(
        feature
    )
    feature = tf.keras.layers.Permute([2, 1])(feature)
    prob = tf.keras.layers.Reshape([-1, tf.keras.backend.int_shape(prob)[-1]])(prob)
    prob = tf.keras.activations.softmax(prob * scale, axis=-2)
    context = tf.keras.layers.Dot([2, 1])(
        [feature, prob]
    )  # batch x featrue ch x prob ch
    context = tf.keras.backend.expand_dims(
        context, axis=-2
    )  # batch x featrue ch x 1 x prob ch
    return context


def object_attention(feature, prob, n_feature=256, scale=1, **kwargs):
    if 1 < scale:
        feature = tf.keras.layers.MaxPooling2D((scale, scale), padding="same")(feature)
    proxy_context = object_attention_context(feature, prob, scale)

    query = tf.keras.layers.Conv2D(
        n_feature,
        1,
        use_bias=False,
        kernel_initializer="he_normal",
        bias_initializer="zeros",
    )(feature)
    query = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(
        query
    )
    query = tf.keras.layers.Activation("Mish")(query)
    query = tf.keras.layers.Conv2D(
        n_feature,
        1,
        use_bias=False,
        kernel_initializer="he_normal",
        bias_initializer="zeros",
    )(query)
    query = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(
        query
    )
    query = tf.keras.layers.Activation("Mish")(query)
    query = tf.keras.layers.Reshape([-1, tf.keras.backend.int_shape(query)[-1]])(query)

    key = tf.keras.layers.Conv2D(
        n_feature,
        1,
        use_bias=False,
        kernel_initializer="he_normal",
        bias_initializer="zeros",
    )(proxy_context)
    key = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(key)
    key = tf.keras.layers.Activation("Mish")(key)
    key = tf.keras.layers.Conv2D(
        n_feature,
        1,
        use_bias=False,
        kernel_initializer="he_normal",
        bias_initializer="zeros",
    )(key)
    key = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(key)
    key = tf.keras.layers.Activation("Mish")(key)
    key = tf.keras.layers.Reshape([-1, tf.keras.backend.int_shape(key)[-1]])(key)
    key = tf.keras.layers.Permute([2, 1])(key)

    value = tf.keras.layers.Conv2D(
        n_feature,
        1,
        use_bias=False,
        kernel_initializer="he_normal",
        bias_initializer="zeros",
    )(proxy_context)
    value = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(
        value
    )
    value = tf.keras.layers.Activation("Mish")(value)
    value = tf.keras.layers.Reshape([-1, tf.keras.backend.int_shape(value)[-1]])(value)

    sim = tf.keras.layers.Dot([2, 1])([query, key])
    sim = sim * (n_feature**-0.5)
    sim = tf.keras.activations.softmax(sim)

    context = tf.keras.layers.Dot([2, 1])([sim, value])
    context = tf.keras.layers.Reshape(
        tf.keras.backend.int_shape(feature)[-3:-1] + (n_feature,)
    )(context)
    context = tf.keras.layers.Conv2D(
        tf.keras.backend.int_shape(feature)[-1],
        1,
        use_bias=False,
        kernel_initializer="he_normal",
        bias_initializer="zeros",
    )(context)
    context = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(
        context
    )
    context = tf.keras.layers.Activation("Mish")(context)

    if 1 < scale:
        context = tf.keras.layers.UpSampling2D((scale, scale))(context)
    return context


def ocr_module(
    feature,
    prob,
    n_feature=512,
    n_attention_feature=256,
    dropout_rate=0.05,
    scale=1,
    **kwargs,
):
    context = object_attention(feature, prob, n_attention_feature, scale, **kwargs)
    out = tf.keras.layers.Concatenate(axis=-1)([feature, context])
    out = tf.keras.layers.Conv2D(
        n_feature,
        1,
        use_bias=False,
        kernel_initializer="he_normal",
        bias_initializer="zeros",
    )(out)
    out = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(out)
    out = tf.keras.layers.Activation("Mish")(out)
    out = tf.keras.layers.Dropout(dropout_rate)(out)
    return out


def seg_hrnet(
    height,
    width,
    channel,
    classes,
    layername="model",
    last_activation="softmax",
    dropout_rate=0.1,
    dropout=True,
    mode="RES",
    **kwargs,
):
    """
    @param mode: RES meaning the end operation of the OCR module is computed using residual style operations with other operations following that
    @param mode: UP meaning the end operation of the the OCR module is upsampling2d, less computationally expensive
    """
    assert mode in ["RES", "UP"]
    inputs = Input(shape=(height, width, channel))  # NCFC: Removed fixed batch size

    x = stem_net(inputs)
    if dropout:
        x = Dropout(dropout_rate)(x)

    x = transition_layer1(x)
    x0 = make_branch1_0(x[0])
    x1 = make_branch1_1(x[1])

    x = fuse_layer1([x0, x1])

    x = transition_layer2(x)
    x0 = make_branch2_0(x[0])
    x1 = make_branch2_1(x[1])
    x2 = make_branch2_2(x[2])

    x = fuse_layer2([x0, x1, x2])

    # Add OCR module branch here
    ocr_input = x[
        -1
    ]  # use output of last layer in HRNet backbone as input to OCR module
    out = ocr_module(ocr_input, prob=x[-1])  # apply OCR module
    if mode == "RES":
        erosion_size = 2
        dilation_size = 2
        out = Conv2D(filters=128, kernel_size=3, padding="same")(out)
        out = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(out)

        # Define the residual block
        residual = Conv2D(filters=128, kernel_size=3, padding="same")(out)
        residual = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(residual)
        residual = Activation("Mish")(residual)

        # Add "erosion" layer
        residual = MaxPooling2D(pool_size=(erosion_size, erosion_size), padding="same")(
            residual
        )

        # Add dilation layer
        residual = Conv2D(
            filters=128, kernel_size=3, padding="same", dilation_rate=dilation_size
        )(residual)
        residual = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(residual)
        residual = Activation("Mish")(residual)

        # Add "erosion" layer
        residual = MaxPooling2D(pool_size=(erosion_size, erosion_size), padding="same")(
            residual
        )

        # Add dilation layer
        residual = Conv2D(
            filters=128, kernel_size=3, padding="same", dilation_rate=dilation_size
        )(residual)
        residual = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(residual)

        # normal convolution
        # out = Conv2D(filters=128, kernel_size=3, padding='same')(out)
        # out = BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(out)

        # Add skip connection
        residual = UpSampling2D(size=(4, 4))(residual)
        out = Add()([out, residual])
        out = attention_block(out, 128, name="ocr_attention")
        out = Activation("Mish")(out)

        # if dropout:
        #     cr_output = Dropout(dropout_rate//4)(out)

        ocr_output = Conv2DTranspose(
            filters=128, kernel_size=4, strides=4, padding="same"
        )(out)

    else:
        ocr_output = UpSampling2D(size=(4, 4))(out)

    x = transition_layer3(x)
    x0 = make_branch3_0(x[0])
    x1 = make_branch3_1(x[1])
    x2 = make_branch3_2(x[2])
    x3 = make_branch3_3(x[3])

    x = fuse_layer3([x0, x1, x2, x3])

    # Merge outputs of OCR module and main backbone
    x = Concatenate()([x, ocr_output])
    if dropout:
        x = Dropout(dropout_rate)(x)
    out = final_layer(
        x, classes=classes, layernameprefix=layername, activation=last_activation
    )

    model = Model(inputs=inputs, outputs=out)

    return model

In [None]:
!pip install git+https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git
from tensorflow.keras.losses import Loss, categorical_crossentropy
import tensorflow_advanced_segmentation_models as tasm
import tensorflow as tf
from keras.optimizers import optimizer
import numpy as np
from typing import Tuple, Optional, Callable
from tensorflow.keras import optimizers

""" # number of nuclei in each class
nuclei_counts = np.array([32276, 50585, 77403, 26572, 2908])

# total number of nuclei
total_nuclei = np.sum(nuclei_counts)

# inverse frequency of each class
class_weights = total_nuclei / (len(nuclei_counts) * nuclei_counts)

# normalize the class weights
class_weights = class_weights / np.sum(class_weights)

print(class_weights) """

# different loss functions for different use cases
CatFocDiceLoss = (
    tasm.losses.CategoricalFocalLoss(alpha=0.25, gamma=2.0) + tasm.losses.DiceLoss()
)
TverskyFocalLoss = tasm.losses.FocalTverskyLoss()
CELoss = tasm.losses.CategoricalCELoss()
tf_CELoss = tf.keras.losses.CategoricalCrossentropy()
CatFocLoss = tasm.losses.CategoricalFocalLoss()
TverskyDice = tasm.losses.DiceLoss() + tasm.losses.FocalTverskyLoss()
CeDice = tasm.losses.CategoricalCELoss() + tasm.losses.DiceLoss()
JacardLoss = tasm.losses.JaccardLoss()


def average(x, class_weights=None):
    if class_weights is not None:
        x = x * class_weights
    return K.mean(x)


def gather_channels(*xs):
    return xs


def round_if_needed(x, threshold):
    if threshold is not None:
        x = K.greater(x, threshold)
        x = K.cast(x, K.floatx())
    return x


def dice_coefficient(
    y_true, y_pred, beta=1.0, class_weights=1.0, smooth=1e-5, threshold=None
):
    # print(y_pred)
    y_true, y_pred = gather_channels(y_true, y_pred)
    y_pred = round_if_needed(y_pred, threshold)
    axes = [1, 2] if K.image_data_format() == "channels_last" else [2, 3]

    tp = K.sum(y_true * y_pred, axis=axes)
    fp = K.sum(y_pred, axis=axes) - tp
    fn = K.sum(y_true, axis=axes) - tp

    score = ((1.0 + beta) * tp + smooth) / (
        (1.0 + beta) * tp + (beta**2.0) * fn + fp + smooth
    )
    # print("Score, wo avg: " + str(score))
    score = average(score, class_weights)
    # print("Score: " + str(score))

    return score


# class_weights = [2.8309667, 15.249543, 7.3735404, 417.3982, 11.119026, 2.8309667]
#  normalize the class weights
# 0.06959552 0.04440575 0.02902039 0.08453504 0.77244329
# class_weights /= np.sum(class_weights)
# print(class_weights)
class DiceLoss(Loss):
    def __init__(self, beta=1.0, class_weights=None, smooth=1e-5):
        super().__init__(name="dice_loss")
        self.beta = beta
        self.class_weights = class_weights if class_weights is not None else 1.0
        self.smooth = smooth

    def __call__(self, y_true, y_pred):
        # print(y_pred)
        return 1.0 - dice_coefficient(
            y_true,
            y_pred,
            beta=self.beta,
            class_weights=self.class_weights,
            smooth=self.smooth,
            threshold=None,
        )

    def get_config(self):
        return {
            "beta": self.beta,
            "class_weights": self.class_weights,
            "smooth": self.smooth,
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)


# combination loss with dice and cross entropy incorporating class weights
class CombinedLoss(Loss):
    def __init__(self, class_weights=None, **kwargs):
        super().__init__()
        self.cross_entropy = categorical_crossentropy
        self.class_weights = class_weights

    def call(self, y_true, y_pred):
        dice_loss = DiceLoss(class_weights=self.class_weights)(y_true, y_pred)
        cross_entropy = self.cross_entropy(y_true, y_pred)
        return cross_entropy + dice_loss

    def get_config(self):
        return {
            "cross_entropy": self.cross_entropy,
            "class_weights": self.class_weights,
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)

# functions


def exists(val):
    return val is not None


# update functions


@tf.function
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
    # stepweight decay

    p.assign(p * (1 - lr * wd))

    # weight update

    update = (
        tf.raw_ops.LinSpace(start=1.0, stop=0.0, num=1, name=None)[0] * exp_avg
        + (1 - tf.raw_ops.LinSpace(start=1.0, stop=0.0, num=1, name=None)[0]) * grad
    )
    p.assign_add(tf.sign(update) * -lr)

    # decay the momentum running average coefficient

    exp_avg.assign(exp_avg * beta2 + grad * (1 - beta2))


def lerp(start, end, weight):
    return start + weight * (end - start)


def sparse_lerp(start, end, weight):
    # Mathematically equivalent, but you can't subtract a dense Tensor from sparse
    # IndexedSlices, so we have to flip it around.
    return start + weight * -(start - end)


class Lion(optimizer.Optimizer):
    """Optimizer that implements the Lion algorithm.
    Lion was published in the paper "Symbolic Discovery of Optimization Algorithms"
    which is available at https://arxiv.org/abs/2302.06675
    Args:
        learning_rate: A `tf.Tensor`, floating point value, a schedule that is a
        `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
        that takes no arguments and returns the actual value to use. The
        learning rate. Defaults to 1e-4.
        beta_1: A float value or a constant float tensor, or a callable
        that takes no arguments and returns the actual value to use. Factor
        used to interpolate the current gradient and the momentum. Defaults to 0.9.
        beta_2: A float value or a constant float tensor, or a callable
        that takes no arguments and returns the actual value to use. The
        exponential decay rate for the momentum. Defaults to 0.99.
    Notes:
    The sparse implementation of this algorithm (used when the gradient is an
    IndexedSlices object, typically because of `tf.gather` or an embedding
    lookup in the forward pass) does apply momentum to variable slices even if
    they were not used in the forward pass (meaning they have a gradient equal
    to zero). Momentum decay (beta2) is also applied to the entire momentum
    accumulator. This means that the sparse behavior is equivalent to the dense
    behavior (in contrast to some momentum implementations which ignore momentum
    unless a variable slice was actually used).
    """

    def __init__(
        self,
        learning_rate=1e-4,
        beta_1=0.9,
        beta_2=0.99,
        weight_decay=None,
        clipnorm=None,
        clipvalue=None,
        global_clipnorm=None,
        jit_compile=True,
        name="Lion",
        **kwargs
    ):
        super().__init__(
            name=name,
            weight_decay=weight_decay,
            clipnorm=clipnorm,
            clipvalue=clipvalue,
            global_clipnorm=global_clipnorm,
            jit_compile=jit_compile,
            **kwargs
        )
        self._learning_rate = self._build_learning_rate(learning_rate)
        self.beta_1 = beta_1
        self.beta_2 = beta_2

    def build(self, var_list):
        """Initialize optimizer variables.
        var_list: list of model variables to build Lion variables on.
        """
        super().build(var_list)
        if hasattr(self, "_built") and self._built:
            return
        self._built = True
        self._emas = []
        for var in var_list:
            self._emas.append(
                self.add_variable_from_reference(
                    model_variable=var, variable_name="ema"
                )
            )

    def update_step(self, gradient, variable):
        """Update step given gradient and the associated model variable."""
        lr = tf.cast(self.learning_rate, variable.dtype)
        beta_1 = tf.constant(self.beta_1, shape=(1,))
        beta_2 = tf.constant(self.beta_2, shape=(1,))

        var_key = self._var_key(variable)
        ema = self._emas[self._index_dict[var_key]]

        if isinstance(gradient, tf.IndexedSlices):
            # Sparse gradients.
            lerp_fn = sparse_lerp
        else:
            # Dense gradients.
            lerp_fn = lerp

        update = lerp_fn(ema, gradient, 1 - beta_1)
        update = tf.sign(update)
        variable.assign_sub(update * lr)

        ema.assign(lerp_fn(ema, gradient, 1 - beta_2))

    def get_config(self):
        config = super().get_config()

        config.update(
            {
                "learning_rate": self._serialize_hyperparameter(self._learning_rate),
                "beta_1": self.beta_1,
                "beta_2": self.beta_2,
            }
        )
        return config


# class_weights = [0.4794289668401082, 2.489686000012794, 1.202780949162393, 17.671771098807746, 1.8562545727590145, 0.4794289668401082]
f3_class_weights = [
    0.00748884,
    0.04104853,
    0.019121,
    0.89586747,
    0.02898535,
    0.00748884,
]
f2_class_weights = [
    0.00569605,
    0.0317299,
    0.01508198,
    0.9191305,
    0.02266551,
    0.00569605,
]
f1_class_weights = [
    0.00619736,
    0.03338325,
    0.01614165,
    0.91373944,
    0.024341,
    0.00619736,
]

# combination loss out of dice and categorical crossentropy with class weights
loss = CombinedLoss(class_weights=f1_class_weights)
# loss = CombinedLoss()


# opt = tf.keras.optimizers.experimental.AdamW(1e-3, weight_decay=1e-7)
opt = Lion(learning_rate=5e-4)

metric = [
    dice_score,
    pixel_error,
    recall_m,
    precision_m,
    f1_m,
    tf.keras.metrics.categorical_accuracy,
    tasm.metrics.IOUScore(threshold=0.5),
    jaccard_index,
]
model = seg_hrnet(
    height=256,
    width=256,
    channel=3,
    classes=6,
    last_activation="softmax",
    dropout=True,
    mode="RES",
)
model.compile(loss=loss, optimizer=opt, metrics=metric)

In [None]:
import tensorflow as tf
import random
import numpy as np
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.optimizers.experimental import AdamW
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Dropout, Flatten
from keras.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    TensorBoard,
    ReduceLROnPlateau,
)

from keras.callbacks import LearningRateScheduler

# Define the callbacks
callbacks = [
    ModelCheckpoint("best_model.h5", save_best_only=True, monitor="val_loss"),
    EarlyStopping(monitor="val_loss", patience=4),
    TensorBoard(log_dir="logs"),
    ReduceLROnPlateau(monitor="val_loss", patience=2, factor=0.1),
]

In [None]:
STAGES = True

import csv
from keras.models import load_model


def save_eval_results(eval_results, filename):
    # Open the file in write mode and write the evaluation results to it
    with open(filename, "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(model.metrics_names)
        writer.writerow(eval_results)

# quick test operation to determine the current state of the model
def quick_test(model_path, test_gen, call="fold_3", random_batch=True, load=False):
    if not os.path.exists("preds"):
        os.mkdir("preds")
    if load:
        model_path = load_model(
            f"{model_path}",
            custom_objects={
                "Lion": Lion,
                "CombinedLoss": CombinedLoss(),
                "mish": Mish(mish),
                "dice_loss": dice_loss,
                "f1_m": f1_m,
                "precision_m": precision_m,
                "recall_m": recall_m,
                "pixel_error": pixel_error,
                "dice_score": dice_score,
                "iou_score": tasm.metrics.IOUScore(threshold=0.5),
                "focal_loss": tasm.losses.CategoricalFocalLoss(),
                "dice_loss": tasm.losses.DiceLoss(),
                "focal_loss_plus_dice_loss": tasm.losses.CategoricalFocalLoss()
                + tasm.losses.DiceLoss(),
                "jaccard_index": jaccard_index,
            },
        )
    if random_batch:
        rand = np.random.randint(0, len(test_gen))
        batch_x, batch_y = test_gen.__getitem__(rand)
    else:
        batch_x, batch_y = test_gen.__getitem__(0)
    preds = model_path.predict(batch_x)
    for i in range(0, len(batch_x)):
        fig, ax = plt.subplots(1, 5, figsize=(15, 5))
        ax[0].imshow(batch_x[i])
        ax[0].set_title("Image")
        ax[1].imshow(np.argmax(batch_y[i], axis=-1))
        ax[1].set_title("Mask")
        ax[2].imshow(np.argmax(preds[i], axis=-1))
        ax[2].set_title("Prediction")
        ax[3].imshow(batch_y[i][:, :, :3])
        ax[3].set_title("Mask (RGB)")
        ax[4].imshow(preds[i][:, :, :3])
        ax[4].set_title("Prediction (RGB)")
        for a in ax:
            a.axis("off")
        plt.savefig(f"./preds/preds_{str(i)}_{call}")
        plt.show()


def plot_history(history, call=None):
    fig, axs = plt.subplots(nrows=1, ncols=len(history.history), figsize=(25, 5))

    for i, metric in enumerate(history.history.keys()):
        axs[i].plot(history.history[metric])
        axs[i].set_title(metric)
        axs[i].set_xlabel("Epoch")
        axs[i].set_ylabel(metric)
    plt.savefig(f"{str(history)}_{call}.png")
    plt.show()


opt = Lion(learning_rate=5e-4)

model.compile(loss=loss, optimizer=opt, metrics=metric)
# original split as described in here: https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke
epochs = 5
# increase dropout to 0.5 and opt by .1 / .2
for i in range(epochs):
    t_2_v_1_t_3 = model.fit(tg_1, callbacks=callbacks, validation_data=(vg_2))
    model.save("model_F1_2.h5")
    quick_test(model, vg_3, call="CombinedLoss1")


eval_1 = model.evaluate(vg_3)
save_eval_results(eval_1, "t_1_v_2_t_3")
quick_test("model.h5", vg_3_shuffle, call="CombinedLoss1")
plot_history(t_2_v_1_t_3, call="1")
quick_test("best_model.h5", vg_3_shuffle, load=True)
# training loops
for i in range(epochs):
    model = seg_hrnet(
        height=256,
        width=256,
        channel=3,
        classes=6,
        last_activation="softmax",
        dropout=True,
        mode="RES",
    )
    model.compile(loss=loss, optimizer=opt, metrics=metric)
    t_2_v_1_t_3 = model.fit(tg_2, callbacks=callbacks, validation_data=(vg_1))
    model.save("model1.h5")
    quick_test(model, vg_3_shuffle, call="CombinedLoss2")

eval_2 = model.evaluate(vg_3)
save_eval_results(eval_2, "t_2_v_1_t_3")
quick_test("model1.h5", vg_3_shuffle)
plot_history(t_2_v_1_t_3, call="2")
quick_test("best_model.h5", vg_3_shuffle, load=True)

for i in range(epochs):
    model = seg_hrnet(
        height=256,
        width=256,
        channel=3,
        classes=6,
        last_activation="softmax",
        dropout=True,
        mode="RES",
    )
    model.compile(loss=loss, optimizer=opt, metrics=metric)
    t_3_v_2_t_1 = model.fit(tg_3, callbacks=callbacks, validation_data=(vg_2))
    model.save("model2.h5")
    quick_test(model, vg_3_shuffle, call="CombinedLoss3")

eval_3 = model.evaluate(vg_1)
save_eval_results(eval_3, "t_3_v_2_t_1")
quick_test("model2.h5", vg_1_shuffle)
plot_history(t_3_v_2_t_1, call="3")
quick_test("best_model.h5", vg_3_shuffle, load=True)

In [None]:
import shutil
# moving files to drive
MOVE = True

if MOVE:
    # models
    shutil.move("/content/best_model.h5", "/content/drive/MyDrive/inst_mod_best.h5")
    shutil.move("/content/model_F1_2.h5", "/content/drive/MyDrive/inst_mod.h5")
# shutil.move("/content/model1.h5", "/content/drive/MyDrive/fold2_model_combi_f.h5")
# shutil.move("/content/model2.h5", "/content/drive/MyDrive/fold3_model_combi_f.h5")
# eval scores
# shutil.move("/content/t_1_v_2_t_3", "/content/drive/MyDrive/t_1_v_2_t_3_f")
# shutil.move("/content/t_2_v_1_t_3", "/content/drive/MyDrive/t_2_v_1_t_3_f")
# shutil.move("/content/t_3_v_2_t_1", "/content/drive/MyDrive/t_3_v_2_t_1_f")

In [None]:
import shutil
# moving other collected data to drive
shutil.make_archive("preds.zip", "zip", "/content/preds")
shutil.move("preds.zip.zip", "/content/drive/MyDrive/")