In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from glob import glob
import tensorflow as tf
import tensorflow.keras as keras
import keras.backend as K
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, MaxPool2D, Add, Dropout, Concatenate, Conv2DTranspose, Dense, Reshape, Flatten, Softmax, Lambda, UpSampling2D, AveragePooling2D, Activation, BatchNormalization, GlobalAveragePooling2D, SeparableConv2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import DenseNet121
import pandas as pd
import tensorflow_probability as tfp

## Data Loader

In [None]:
def load_image(path, size, mask=False):
    image = Image.open(path)
    image = image.resize((size, size))

    if mask:
        image = image.convert('L')  # Convert to grayscale
    else:
        image = image.convert('RGB')  # Convert to RGB
    
    image = np.array(image)
    return image

def load_data(root_path, size):
    images = []
    masks = []

    image_folder = os.path.join(root_path, 'masks')
    mask_folder = os.path.join(root_path, 'images')

    for image_path in sorted(glob(os.path.join(image_folder, '*png'))):
        img_id = os.path.basename(image_path).split('.')[0]
        mask_path = os.path.join(mask_folder, f'{img_id}.png')

        img = load_image(image_path, size) / 255.0
        mask = load_image(mask_path, size, mask=True) / 255.0

        images.append(img)
        masks.append(mask)

    return np.array(images), np.array(masks)

In [None]:
size = 512   # image size: 512x512
root_path = '/kaggle/input/tnbc-seg/TNBC_NucleiSegmentation/TNBC_NucleiSegmentation'
X_train, y_train = load_data(root_path, size)

In [None]:
print(f"X shape: {X_train.shape}     |  y shape: {y_train.shape}")

# prepare data to modeling
# X = np.expand_dims(X, -1)
y_train = np.expand_dims(y_train, -1)

print(f"\nX shape: {X_train.shape}  |  y shape: {y_train.shape}")

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.1, random_state=35)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=35)

print('X_train shape:',X_train.shape)
print('y_train shape:',y_train.shape)
print('X_val shape:',X_val.shape)
print('X_test shape:',X_test.shape)
print('y_test shape:',y_test.shape)

#### Mask Image pair

In [None]:
image = X_test[0]
mask = y_test[0]

fig, axes = plt.subplots(1, 2, figsize=(5, 2))
axes[0].imshow(image, cmap='gray')
axes[0].axis('off')
axes[0].set_title('Image')

axes[1].imshow(mask*255, cmap='gray', vmin=0, vmax=1)
axes[1].axis('off')
axes[1].set_title('Mask')

plt.tight_layout()
plt.show()

## Metrics and Losses

#### Metrics

In [None]:
def dice_score(y_true, y_pred):
    smooth = K.epsilon()
    y_true_flat = K.flatten(K.cast(y_true, 'float32'))
    y_pred_flat = K.flatten(y_pred)
    intersection = K.sum(y_true_flat * y_pred_flat)
    score = (2. * intersection + smooth) / (K.sum(y_true_flat) + K.sum(y_pred_flat) + smooth)
    return score

def iou(y_true, y_pred):
    smooth = K.epsilon()
    y_true_flat = K.flatten(K.cast(y_true, 'float32'))
    y_pred_flat = K.flatten(y_pred)
    intersection = K.sum(y_true_flat * y_pred_flat)
    union = K.sum(y_true_flat) + K.sum(y_pred_flat) - intersection + smooth
    iou = (intersection + smooth) / union
    return iou

def recall(y_true, y_pred):
    smooth = K.epsilon()
    y_pred_pos = K.round(K.clip(y_pred, 0, 1))
    y_true_flat = K.flatten(K.cast(y_true, 'float32'))
    y_pred_flat = K.flatten(y_pred_pos)
    tp = K.sum(y_true_flat * y_pred_flat)
    fn = K.sum(y_true_flat * (1 - y_pred_flat))
    recall = (tp + smooth) / (tp + fn + smooth)
    return recall

def precision(y_true, y_pred):
    smooth = K.epsilon()
    y_pred_pos = K.round(K.clip(y_pred, 0, 1))
    y_true_flat = K.flatten(K.cast(y_true, 'float32'))
    y_pred_flat = K.flatten(y_pred_pos)
    tp = K.sum(y_true_flat * y_pred_flat)
    fp = K.sum((1 - y_true_flat) * y_pred_flat)
    precision = (tp + smooth) / (tp + fp + smooth)
    return precision

#### Losses

In [None]:
def dice_loss(y_true, y_pred):
    loss = 1 - dice_score(y_true, y_pred)
    return loss

def iou_loss(y_true, y_pred):
    loss = 1 - iou(y_true, y_pred)
    return loss
    
def focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25):
    epsilon = tf.keras.backend.epsilon()
    y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
    y_true = tf.cast(y_true, tf.float32)
    pt = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
    focal_weight = alpha * tf.pow(1 - pt, gamma)
    loss = tf.reduce_mean(-focal_weight * tf.math.log(pt))
    return loss

def bce_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred))
    return loss

def combined_loss(y_true, y_pred):
    loss = dice_loss(y_true, y_pred) + bce_loss(y_true, y_pred)
    return loss

## Model

In [None]:
'''texts = [
    "tumor epithelial tissue",
    "necrotic tissue",
    "lymphocytic tissue",
    "tumor associated stromal tissue",
    "coagulative necrosis",
    "liquefactive necrosis",
    "desmoplasia",
    "granular and non granular leukocytes",
    "perinuclear halo",
    "interstitial space",
    "neutrophils",
    "macrophages",
    "collagen",
    "fibronectin",
    "hyperplasia",
    "dysplasia"
]'''

text = pd.read_csv('/kaggle/input/tnbc-seg/text_labels.csv', header=None)
text = tf.convert_to_tensor(np.asarray(text), dtype=tf.float32)

text = Dense(32, activation='relu')(text)
text = Dense(32, activation='relu')(tf.transpose(text, perm=[1,0]))
text = tf.expand_dims(text, axis=0)
text = tf.expand_dims(text, axis=-1)

class DistributionModel(tf.keras.Model):
    def __init__(self, text):
        super(DistributionModel, self).__init__()
        self.mean_layer = tf.keras.layers.SeparableConv2D(filters=1, kernel_size=1, padding='same', activation='softplus')
        self.stddev_layer = tf.keras.layers.SeparableConv2D(filters=1, kernel_size=1, padding='same', activation='softplus')
        self.distribution_layer = tfp.layers.DistributionLambda(
            lambda t: tfp.distributions.Normal(loc=t[..., :1], scale=t[..., 1:])
        )
        self.concat = tf.keras.layers.Concatenate(axis=-1)
        self.conv_t1 = tf.keras.layers.Conv2DTranspose(256, 2, strides=2, padding="same", activation='relu')
        self.conv_t2 = tf.keras.layers.Conv2DTranspose(256, 2, strides=2, padding="same", activation='relu')
        self.conv_t3 = tf.keras.layers.Conv2DTranspose(256, 2, strides=2, padding="same", activation='relu')
        self.conv_t4 = tf.keras.layers.Conv2DTranspose(256, 2, strides=2, padding="same", activation='relu')
        self.text = text
        self.conv_tt1 = tf.keras.layers.Conv2DTranspose(1, 2, strides=2, padding="same", activation='relu')
        self.conv_tt2 = tf.keras.layers.Conv2DTranspose(1, 2, strides=2, padding="same", activation='relu')
        self.conv_tt3 = tf.keras.layers.Conv2DTranspose(1, 2, strides=2, padding="same", activation='relu')
        self.conv_tt4 = tf.keras.layers.Conv2DTranspose(1, 2, strides=2, padding="same", activation='relu')

    def distribution(self, x, text):
        mean = self.mean_layer(x)+tf.math.reduce_mean(text, axis=[-1,-2,-3])
        stddev = tf.math.sqrt(tf.math.square(self.stddev_layer(x))+tf.math.reduce_variance(text, axis=[-1,-2,-3]))

        # Concatenate mean and standard deviation
        parameters = self.concat([mean, stddev])

        # Generate distribution
        distribution = self.distribution_layer(parameters)
        return distribution

    def distribution_attn(self, x):
        x1 = self.conv_t1(x)
        x2 = self.conv_t2(x1)
        x3 = self.conv_t3(x2)
        x4 = self.conv_t4(x3)
        text = self.conv_tt1(self.text)
        dis1 = self.distribution(x1, text)
        text = self.conv_tt2(text)
        dis2 = self.distribution(x2, text)
        text = self.conv_tt3(text)
        dis3 = self.distribution(x3, text)
        text = self.conv_tt4(text)
        dis4 = self.distribution(x4, text)
        return dis1, dis2, dis3, dis4

    def call(self, inputs):
        return self.distribution_attn(inputs)

def conv_block(x, num_filters, kernel_size, padding="same", act=True):
    x = Conv2D(num_filters, kernel_size, padding=padding, use_bias=False)(x)
    x = BatchNormalization()(x)
    if act:
        x = Activation("relu")(x)
    return x

def multires_block(x, num_filters, alpha=1.67):
    W = num_filters * alpha

    x0 = x
    x1 = conv_block(x0, int(W*0.167), 3)
    x2 = conv_block(x1, int(W*0.333), 3)
    x3 = conv_block(x2, int(W*0.5), 3)
    xc = Concatenate()([x1, x2, x3])
    xc = BatchNormalization()(xc)

    nf = int(W*0.167) + int(W*0.333) + int(W*0.5)
    sc = conv_block(x0, nf, 1, act=False)

    x = Activation("relu")(xc + sc)
    x = BatchNormalization()(x)
    return x

def res_path(x, num_filters, length): 
    check = L.GlobalMaxPooling2D()(x)
    check = L.Dense(1, activation='sigmoid')(x)
    check = tf.math.reduce_mean(check, axis=0)
    
    x01 = x
    x11 = conv_block(x01, num_filters, 3, act=False)
    sc1 = conv_block(x01, num_filters, 1, act=False)
    x = Activation("relu")(x11 + sc1)
    x = BatchNormalization()(x)
    
    x02 = Concatenate()([x,x01])
    x12 = conv_block(x02, num_filters, 3, act=False)
    sc2 = conv_block(x02, num_filters, 1, act=False)
    x = Activation("relu")(x12 + sc2)
    x = BatchNormalization()(x)
    
    x03 = Concatenate()([x,x01,x02])
    x13 = conv_block(x03, num_filters, 3, act=False)
    sc3 = conv_block(x03, num_filters, 1, act=False)
    x = Activation("relu")(x13 + sc3)
    x = BatchNormalization()(x)
    
    x04 = Concatenate()([x,x01,x02,x03])
    x14 = conv_block(x04, num_filters, 3, act=False)
    sc4 = conv_block(x04, num_filters, 1, act=False)
    x = Activation("relu")(x14 + sc4)
    x = BatchNormalization()(x)
    return x*check

def encoder_block(x, num_filters, length):
    x = multires_block(x, num_filters)
    s = res_path(x, num_filters, length)
    p = MaxPooling2D((2, 2))(x)
    return s, p

def decoder_block(x, skip, num_filters):
    x = Conv2DTranspose(num_filters, 2, strides=2, padding="same")(x)
    x = Concatenate()([x, skip])
    x = multires_block(x, num_filters)
    return x

def build_multiresunet(shape, text):
    """ Input """
    inputs = Input(shape)

    """ Encoder """
    p0 = inputs
    s1, p1 = encoder_block(p0, 32, 4)
    s2, p2 = encoder_block(p1, 64, 4)
    s3, p3 = encoder_block(p2, 128, 4)
    s4, p4 = encoder_block(p3, 256, 4)

    """ Bridge """
    b1 = multires_block(p4, 512)
    dis1, dis2, dis3, dis4 = DistributionModel(text)(b1)

    """ Decoder """
    d1 = decoder_block(b1, s4, 256)
    d1 = d1*dis1
    d2 = decoder_block(d1, s3, 128)
    d2 = d2*dis2
    d3 = decoder_block(d2, s2, 64)
    d3 = d3*dis3
    d4 = decoder_block(d3, s1, 32)
    d4 = d4*dis4
    
    """ Output """
    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    """ Model """
    model = Model(inputs, outputs, name="MultiResUNET")

    return model
    
model = build_multiresunet((512, 512, 3), text)
optimizer = Adam(learning_rate=0.0001)
model.compile(loss=combined_loss, metrics=["accuracy", dice_score, recall, precision, iou], optimizer=optimizer)
model.summary()

## Training

In [None]:
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath='/kaggle/working/model.h5',
    monitor='val_dice_score',
    save_best_only=True,
    save_weights_only=True,
    mode='max',
    verbose=1
    )

history = model.fit(X_train, y_train,
                    epochs = 100,
                    batch_size = 2,
                    validation_data = (X_val,y_val),
                    verbose = 1,
                    callbacks=[model_checkpoint_callback],
                    shuffle = True)

In [None]:
def Train_Val_Plot(loss, val_loss, dice_score, val_dice_score, iou, val_iou, recall, val_recall, precision, val_precision, accuracy, val_accuracy):
    fig, axs = plt.subplots(2, 3, figsize=(20, 10))
    fig.suptitle("MODEL'S METRICS VISUALIZATION")

    # Loss plot
    axs[0, 0].plot(range(1, len(loss) + 1), loss)
    axs[0, 0].plot(range(1, len(val_loss) + 1), val_loss)
    axs[0, 0].set_title('History of Loss')
    axs[0, 0].set_xlabel('Epochs')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].legend(['training', 'validation'])

    # Dice Coefficient plot
    axs[0, 1].plot(range(1, len(dice_score) + 1), dice_score)
    axs[0, 1].plot(range(1, len(val_dice_score) + 1), val_dice_score)
    axs[0, 1].set_title('History of Dice Coefficient')
    axs[0, 1].set_xlabel('Epochs')
    axs[0, 1].set_ylabel('Dice Coefficient')
    axs[0, 1].legend(['training', 'validation'])

    # Mean IOU plot
    axs[0, 2].plot(range(1, len(iou) + 1), iou)
    axs[0, 2].plot(range(1, len(val_iou) + 1), val_iou)
    axs[0, 2].set_title('History of IOU')
    axs[0, 2].set_xlabel('Epochs')
    axs[0, 2].set_ylabel('IOU')
    axs[0, 2].legend(['training', 'validation'])

    # Recall plot
    axs[1, 0].plot(range(1, len(recall) + 1), recall)
    axs[1, 0].plot(range(1, len(val_recall) + 1), val_recall)
    axs[1, 0].set_title('History of Recall')
    axs[1, 0].set_xlabel('Epochs')
    axs[1, 0].set_ylabel('Recall')
    axs[1, 0].legend(['training', 'validation'])

    # Precision plot
    axs[1, 1].plot(range(1, len(precision) + 1), precision)
    axs[1, 1].plot(range(1, len(val_precision) + 1), val_precision)
    axs[1, 1].set_title('History of Precision')
    axs[1, 1].set_xlabel('Epochs')
    axs[1, 1].set_ylabel('Precision')
    axs[1, 1].legend(['training', 'validation'])

    # Accuracy plot
    axs[1, 2].plot(range(1, len(accuracy) + 1), accuracy)
    axs[1, 2].plot(range(1, len(val_accuracy) + 1), val_accuracy)
    axs[1, 2].set_title('History of Accuracy')
    axs[1, 2].set_xlabel('Epochs')
    axs[1, 2].set_ylabel('Accuracy')
    axs[1, 2].legend(['training', 'validation'])

    plt.tight_layout()
    plt.show()

Train_Val_Plot(
    history.history['loss'], history.history['val_loss'],
    history.history['dice_score'], history.history['val_dice_score'],
    history.history['iou'], history.history['val_iou'],
    history.history['recall'], history.history['val_recall'],
    history.history['precision'], history.history['val_precision'],
    history.history['accuracy'], history.history['val_accuracy']
)

## Testing

In [None]:
model.load_weights("/kaggle/working/model.h5")
loss, accuracy, dice, recall, precision, iou = model.evaluate(X_test, y_test, batch_size = 4, verbose = 0)
print('loss:', loss)
print('accuracy:', accuracy)
print('dice:', dice)
print('recall:', recall)
print('precision:', precision)
print('iou:', iou)

In [None]:
modeller = Model(inputs=model.input, outputs=[model.get_layer(name="tf.__operators__.add_17").output,model.get_layer(name="tf.__operators__.add_20").output,model.layers[-3].output])

In [None]:
for z in range(0,15):
    # Load one image and corresponding mask from the test dataset
    test_image = X_test[z]  # Replace X_test with your actual test dataset
    test_mask = y_test[z]  # Replace y_test with your actual test masks

    # Reshape the image to match the input shape of the model
    test_image = np.reshape(test_image, (1,) + test_image.shape)

    # Predict the segmentation mask for the test image
    predicted_mask = model.predict(test_image)[0]
    feature_map1, feature_map2, feature_map3 = modeller.predict(test_image)

    # Convert the predicted mask values to binary (0 or 1)
    predicted_mask_binary = np.where(predicted_mask > 0.5, 1, 0) * 255

    # Create subplots
    fig, axes = plt.subplots(1, len(feature_map1) + 5, figsize=(20, 4))

    # Plot the test image
    axes[0].imshow(test_image[0], cmap='gray')
    axes[0].set_title('Test Image')
    axes[0].axis('off')

    # Plot the ground truth mask
    axes[1].imshow(test_mask, cmap='gray')
    axes[1].set_title('Ground Truth Mask')
    axes[1].axis('off')

    # Plot the predicted mask
    axes[2].imshow(predicted_mask_binary, cmap='gray')
    axes[2].set_title('Predicted Mask')
    axes[2].axis('off')

    # Plot the feature maps (Set 1)
    for i, fmap in enumerate(feature_map1):
        axes[i + 3].imshow(fmap[:, :, 0], cmap='jet')
        axes[i + 3].set_title(f'Encoder last layer')
        axes[i + 3].axis('off')

    # Plot the feature maps (Set 2)
    for i, fmap in enumerate(feature_map2):
        axes[i + 3 + len(feature_map1)].imshow(fmap[:, :, 0], cmap='jet')
        axes[i + 3 + len(feature_map1)].set_title(f'Bottleneck')
        axes[i + 3 + len(feature_map1)].axis('off')

    # Plot the feature maps (Set 3)
    for i, fmap in enumerate(feature_map3):
        axes[i + 3 + len(feature_map1) + len(feature_map2)].imshow(fmap[:, :, 0], cmap='jet')
        axes[i + 3 + len(feature_map1) + len(feature_map2)].set_title(f'Decoder last layer')
        axes[i + 3 + len(feature_map1) + len(feature_map2)].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
modeller = Model(inputs=model.input, outputs=[model.get_layer(name="tf.math.multiply").output,model.get_layer(name="tf.math.multiply_1").output,model.get_layer(name="tf.math.multiply_2").output, model.get_layer(name="tf.math.multiply_3").output])

In [None]:
for z in range(0,15):
    # Load one image and corresponding mask from the test dataset
    test_image = X_test[z]  # Replace X_test with your actual test dataset
    test_mask = y_test[z]  # Replace y_test with your actual test masks

    # Reshape the image to match the input shape of the model
    test_image = np.reshape(test_image, (1,) + test_image.shape)

    # Predict the segmentation mask for the test image
    predicted_mask = model.predict(test_image)[0]
    feature_map1, feature_map2, feature_map3, feature_map4 = modeller.predict(test_image)
    feature_map1 = tf.math.reduce_mean(feature_map1, axis=-1, keepdims=True)
    feature_map2 = tf.math.reduce_mean(feature_map2, axis=-1, keepdims=True)
    feature_map3 = tf.math.reduce_mean(feature_map3, axis=-1, keepdims=True)
    feature_map4 = tf.math.reduce_mean(feature_map4, axis=-1, keepdims=True)

    # Convert the predicted mask values to binary (0 or 1)
    predicted_mask_binary = np.where(predicted_mask > 0.5, 1, 0) * 255

    # Create subplots
    fig, axes = plt.subplots(1, len(feature_map1) + 6, figsize=(20, 4))

    # Plot the test image
    axes[0].imshow(test_image[0], cmap='gray')
    axes[0].set_title('Test Image')
    axes[0].axis('off')

    # Plot the ground truth mask
    axes[1].imshow(test_mask, cmap='gray')
    axes[1].set_title('Ground Truth Mask')
    axes[1].axis('off')

    # Plot the predicted mask
    axes[2].imshow(predicted_mask_binary, cmap='gray')
    axes[2].set_title('Predicted Mask')
    axes[2].axis('off')

    # Plot the feature maps (Set 1)
    for i, fmap in enumerate(feature_map1):
        axes[i + 3].imshow(fmap[:, :, 0], cmap='jet')
        axes[i + 3].set_title(f'Decoder 1')
        axes[i + 3].axis('off')

    # Plot the feature maps (Set 2)
    for i, fmap in enumerate(feature_map2):
        axes[i + 3 + len(feature_map1)].imshow(fmap[:, :, 0], cmap='jet')
        axes[i + 3 + len(feature_map1)].set_title(f'Decoder 2')
        axes[i + 3 + len(feature_map1)].axis('off')
        
    # Plot the feature maps (Set 3)
    for i, fmap in enumerate(feature_map3):
        axes[i + 3 + len(feature_map1) + len(feature_map2)].imshow(fmap[:, :, 0], cmap='jet')
        axes[i + 3 + len(feature_map1) + len(feature_map2)].set_title(f'Decoder 3')
        axes[i + 3 + len(feature_map1) + len(feature_map2)].axis('off')

    # Plot the feature maps (Set 4)
    for i, fmap in enumerate(feature_map4):
        axes[i + 3 + len(feature_map1) + len(feature_map2) + len(feature_map3)].imshow(fmap[:, :, 0], cmap='jet')
        axes[i + 3 + len(feature_map1) + len(feature_map2) + len(feature_map3)].set_title(f'Decoder 4')
        axes[i + 3 + len(feature_map1) + len(feature_map2) + len(feature_map3)].axis('off')

    plt.tight_layout()
    plt.show()