In [1]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append("/content/drive/MyDrive/Colab Notebooks")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [32]:
import numpy as np
from skimage import color, filters
import tensorflow as tf
from tensorflow.keras.layers import Layer, Input, Conv2D, MaxPooling2D, Add, SpatialDropout2D, concatenate, BatchNormalization, UpSampling2D, Activation, ReLU

In [3]:
# Helper function to apply 2D thresholding to each image in a batch
def apply_threshold_to_batch(inputs_np, threshold_func):
    """Applies a 2D thresholding function to each image in a batch of single-channel images."""
    # inputs_np is expected to be a (batch, H, W, 1) numpy array
    # The thresholding functions from skimage.filters expect a 2D array.
    outputs = np.zeros_like(inputs_np, dtype=np.float32)
    for i in range(inputs_np.shape[0]):
        img_2d = inputs_np[i, ..., 0] # Take the single channel for the current image
        thresh = threshold_func(img_2d)
        outputs[i, ..., 0] = (img_2d > thresh).astype(np.float32)
    return outputs


class OtsuThresholdLayer(Layer):
    def __init__(self, **kwargs):
        super(OtsuThresholdLayer, self).__init__(**kwargs)

    def _otsu_np(self, inputs_np):
        return apply_threshold_to_batch(inputs_np, filters.threshold_otsu)

    def call(self, inputs):
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._otsu_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        # Correctly set shape to (batch, H, W, 1)
        outputs.set_shape(inputs_float.get_shape()[:-1].concatenate([1]))
        return outputs

    def get_config(self):
        return super(OtsuThresholdLayer, self).get_config()

class MinimumThresholdLayer(Layer):
    def __init__(self, **kwargs):
        super(MinimumThresholdLayer, self).__init__(**kwargs)

    def _minimum_np(self, inputs_np):
        return apply_threshold_to_batch(inputs_np, filters.threshold_minimum)

    def call(self, inputs):
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._minimum_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        outputs.set_shape(inputs_float.get_shape()[:-1].concatenate([1]))
        return outputs

    def get_config(self):
        return super(MinimumThresholdLayer, self).get_config()

class IsodataThresholdLayer(Layer):
    def __init__(self, **kwargs):
        super(IsodataThresholdLayer, self).__init__(**kwargs)

    def _isodata_np(self, inputs_np):
        return apply_threshold_to_batch(inputs_np, filters.threshold_isodata)

    def call(self, inputs):
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._isodata_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        outputs.set_shape(inputs_float.get_shape()[:-1].concatenate([1]))
        return outputs

    def get_config(self):
        return super(IsodataThresholdLayer, self).get_config()

class YenThresholdLayer(Layer):
    def __init__(self, **kwargs):
        super(YenThresholdLayer, self).__init__(**kwargs)

    def _yen_np(self, inputs_np):
        return apply_threshold_to_batch(inputs_np, filters.threshold_yen)

    def call(self, inputs):
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._yen_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        outputs.set_shape(inputs_float.get_shape()[:-1].concatenate([1]))
        return outputs

    def get_config(self):
        return super(YenThresholdLayer, self).get_config()

class LiThresholdLayer(Layer):
    def __init__(self, **kwargs):
        super(LiThresholdLayer, self).__init__(**kwargs)

    def _li_np(self, inputs_np):
        return apply_threshold_to_batch(inputs_np, filters.threshold_li)

    def call(self, inputs):
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._li_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        outputs.set_shape(inputs_float.get_shape()[:-1].concatenate([1]))
        return outputs

    def get_config(self):
        return super(LiThresholdLayer, self).get_config()

# Placeholder for predict and evaluate functions from predict_model.py
def predict(model, input_images, n_classes):
    predictions = model.predict(input_images)
    # Assuming predictions are (batch, H, W, n_classes) and need to be converted to class indices
    predicted_classes = np.argmax(predictions, axis=-1)
    return predicted_classes

def evaluate(model, test_inputs, test_labels, log_dir):
    loss, accuracy, mean_iou = model.evaluate(test_inputs, test_labels, verbose=1)
    print(f"Test Loss: {loss:.4f}")
    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test MeanIoU: {mean_iou:.4f}")
    return loss, accuracy, mean_iou

In [4]:
!pip install tensorflow keras opencv-python tensorflow_io



In [5]:
import tensorflow as tf
from keras.models import Model
from keras.optimizers import Adam
from keras import metrics
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import cv2
from datetime import datetime
import numpy as np
import os
from matplotlib import pyplot as plt

In [6]:
class SelectChannelLayer(Layer):
    def __init__(self, channel_idx, **kwargs):
        super(SelectChannelLayer, self).__init__(**kwargs)
        self.channel_idx = channel_idx

    def call(self, inputs):
        return inputs[..., self.channel_idx]

    def get_config(self):
        config = super(SelectChannelLayer, self).get_config()
        config.update({"channel_idx": self.channel_idx})
        return config

In [7]:
class ExpandDimsLayer(Layer):
    def __init__(self, axis, **kwargs):
        super(ExpandDimsLayer, self).__init__(**kwargs)
        self.axis = axis

    def call(self, inputs):
        return tf.expand_dims(inputs, axis=self.axis)

    def get_config(self):
        config = super(ExpandDimsLayer, self).get_config()
        config.update({"axis": self.axis})
        return config

In [8]:
class RGBToHSVSkimageLayer(Layer):
    def __init__(self, **kwargs):
        super(RGBToHSVSkimageLayer, self).__init__(**kwargs)

    def _rgb_to_hsv_np(self, inputs_np):
        # skimage.color.rgb2hsv expects float values in range [0, 1]
        return color.rgb2hsv(inputs_np)

    def call(self, inputs):
        # Inputs are expected to be in range [0, 255] or normalized
        # Ensure float32 for skimage
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._rgb_to_hsv_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        outputs.set_shape(inputs_float.get_shape()[:-1] + [3]) # HSV has 3 channels
        return outputs

    def get_config(self):
        config = super(RGBToHSVSkimageLayer, self).get_config()
        return config

In [9]:
class RGBToCIELUVSkimageLayer(Layer):
    def __init__(self, **kwargs):
        super(RGBToCIELUVSkimageLayer, self).__init__(**kwargs)

    def _rgb_to_luv_np(self, inputs_np):
        # skimage.color.rgb2luv expects float values in range [0, 1]
        return color.rgb2luv(inputs_np)

    def call(self, inputs):
        # Inputs are expected to be in range [0, 255] or normalized
        # Ensure float32 for skimage
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._rgb_to_luv_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        outputs.set_shape(inputs_float.get_shape()[:-1] + [3]) # CIELUV has 3 channels
        return outputs

    def get_config(self):
        config = super(RGBToCIELUVSkimageLayer, self).get_config()
        return config

In [10]:
class RGBToCIELABSkimageLayer(Layer):
    def __init__(self, **kwargs):
        super(RGBToCIELABSkimageLayer, self).__init__(**kwargs)

    def _rgb_to_lab_np(self, inputs_np):
        # skimage.color.rgb2lab expects float values in range [0, 1]
        return color.rgb2lab(inputs_np)

    def call(self, inputs):
        # Inputs are expected to be in range [0, 255] or normalized
        # Ensure float32 for skimage
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._rgb_to_lab_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        outputs.set_shape(inputs_float.get_shape()[:-1] + [3]) # CIELAB has 3 channels
        return outputs

    def get_config(self):
        config = super(RGBToCIELABSkimageLayer, self).get_config()
        return config

In [11]:
class RGBToCMYKLayer(Layer):
    def __init__(self, **kwargs):
        super(RGBToCMYKLayer, self).__init__(**kwargs)

    def _rgb_to_cmyk_np(self, inputs_np):
        # Convert EagerTensor to numpy array if it's not already one
        if not isinstance(inputs_np, np.ndarray):
            inputs_np = inputs_np.numpy()

        # skimage.color.rgb2cmyk expects float values in range [0, 1]
        # Create float
        bgr = inputs_np.astype(float)/255.

        # Extract channels
        with np.errstate(invalid='ignore', divide='ignore'):
            K = 1 - np.max(bgr, axis=3)  # Max along the color channel axis (axis=3)
            C = (1-bgr[..., 2] - K)/(1-K)
            M = (1-bgr[..., 1] - K)/(1-K)
            Y = (1-bgr[..., 0] - K)/(1-K)

        # Convert the input BGR image to CMYK colorspace
        # Returns float values in range [0, 1] as expected by Tout=tf.float32
        CMYK = np.dstack((C, M, Y, K))
        return CMYK

    def call(self, inputs):
        # Inputs are expected to be in range [0, 255] or normalized
        # Ensure float32 for skimage
        inputs_float = tf.cast(inputs, tf.float32)
        outputs = tf.py_function(
            func=self._rgb_to_cmyk_np,
            inp=[inputs_float],
            Tout=tf.float32
        )
        outputs.set_shape(inputs_float.get_shape()[:-1] + [4]) # CMYK has 4 channels
        return outputs

    def get_config(self):
        config = super(RGBToCMYKLayer, self).get_config()
        return config

In [12]:
dataset_path = os.path.join('/content/drive/MyDrive/Colab Notebooks', 'aml','segmentation', 'original')
labels_path = os.path.join('/content/drive/MyDrive/Colab Notebooks', 'aml','segmentation', 'GT')
stop_val = 40
i = 0
sub_img_width = 128
sub_img_height = 128
epochs = 500
bs=32
n_classes = 3
model_id= datetime.now().strftime("%Y%m%d_%H%M%S")
seed = 2

In [13]:
from utils_own_model import *
train_inputs, train_labels, val_inputs, val_labels, test_inputs, test_labels, file_names_train, file_names_val, file_names_test = prepare_data(csv_path='/content/drive/MyDrive/Colab Notebooks/aml/segmentation/images1.csv',input_path=dataset_path,label_path=labels_path,WIDTH=sub_img_width,HEIGHT=sub_img_height)

train loaded
val loaded
test loaded
Training on 694 images and labels
Validation on 180 images and labels
Test on 134 images and labels


In [14]:
def pooling(layer):
    return MaxPooling2D(pool_size=(2, 2))(layer)

In [15]:
def dropout(layer):
    return SpatialDropout2D(0.2)(layer)

In [16]:
def build_block(input_layer, filters, norm=True, k=(3, 3)):
    layer = Conv2D(filters, kernel_size=k, padding='same', use_bias=not norm, kernel_initializer='glorot_normal')(input_layer)
    if norm:
        layer = BatchNormalization()(layer)
    layer = Activation('relu')(layer)
    return layer

In [17]:
def residual_block(input_tensor, num_filters, kernel_size=(3, 3), stride=(1, 1)):
    # First convolution
    x = Conv2D(num_filters, kernel_size=kernel_size, strides=stride, padding='same')(input_tensor)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # Second convolution
    x = Conv2D(num_filters, kernel_size=kernel_size, padding='same')(x)
    x = BatchNormalization()(x)

    # Shortcut connection (identity mapping)
    if stride != (1, 1) or input_tensor.shape[-1] != num_filters:
        shortcut = Conv2D(num_filters, kernel_size=(1, 1), strides=stride, padding='same')(input_tensor)
        shortcut = BatchNormalization()(shortcut)
    else:
        shortcut = input_tensor

    # Add the shortcut to the main path
    x = Add()([x, shortcut])
    x = ReLU()(x)

    return x

In [25]:
def rgb_encoder_bloc(image_input, n_filters=64):

    conv_1 = build_block(image_input, n_filters) #(32, 32, 64)
    conv_2 = build_block(conv_1, n_filters) #(32, 32, 64)
    conv_3 = build_block(conv_2, n_filters) #(32, 32, 64)
    pool_1 = pooling(conv_3)#(16, 16, 64)
    drop_1 = dropout(pool_1)

    conv_4 = build_block(drop_1, n_filters * 2)#(16, 16, 64)
    conv_5 = build_block(conv_4, n_filters * 2)#(16, 16, 64)
    conv_6 = build_block(conv_5, n_filters * 2)#(16, 16, 64)
    pool_2 = pooling(conv_6)#(8, 8, 64)
    drop_2 = dropout(pool_2)

    conv_7 = build_block(drop_2, n_filters * 4)#(8, 8, 256)
    conv_8 = build_block(conv_7, n_filters * 4) #(8, 8, 256)
    conv_9 = build_block(conv_8, n_filters * 4) #(8, 8, 256)
    pool_3 = pooling(conv_9)#(4, 4, 256)
    drop_3 = dropout(pool_3)

    conv_10 = build_block(drop_3, n_filters * 8)#(4, 4, 512)
    conv_11 = build_block(conv_10, n_filters * 8) #(4, 4, 512)
    conv_12 = build_block(conv_11, n_filters * 8) #(4, 4, 512)
    pool_4 = pooling(conv_12)#(4, 4, 512)

    return  pool_4, conv_3, conv_6, conv_9, conv_12

In [26]:
def nuc_encoder_bloc(image_input, n_filters=64):
    green = ExpandDimsLayer(axis=-1)(image_input[...,1])

    # mask1, mask2, mask3 = NucleusExtractionLayer()(image_input)
    cmyk = RGBToCMYKLayer()(image_input)
    magenta = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=1)(cmyk))

    lab = RGBToCIELABSkimageLayer()(image_input)
    luminance_lab = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=0)(lab))
    a = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=1)(lab))

    luv = RGBToCIELUVSkimageLayer()(image_input)
    luminance_luv = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=0)(luv))

    hsv = RGBToHSVSkimageLayer()(image_input)
    hue = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=0)(hsv))
    saturation = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=1)(hsv)) # Corrected from lab to hsv

    color_channels = concatenate([magenta, luminance_lab, a, hue, saturation, luminance_luv], axis=3)

    # Apply Otsu Thresholding to the v channel
    sat_otsu = OtsuThresholdLayer()(saturation)

    resid1 = residual_block(input_tensor=concatenate([color_channels, sat_otsu], axis=3), num_filters=n_filters)
    resid1 = pooling(resid1)#(16, 16, 64)

    conv_1 = build_block(color_channels, n_filters) #(32, 32, 64)
    conv_2 = build_block(conv_1, n_filters) #(32, 32, 64)
    conv_3 = build_block(conv_2, n_filters) #(32, 32, 64)
    pool_1 = pooling(conv_3)#(16, 16, 64)
    drop_1 = dropout(pool_1)
    output_tensor1 = Add()([resid1, drop_1])

    green_otsu = MinimumThresholdLayer()(green)
    green_otsu_pooled = pooling(green_otsu)

    resid2 = residual_block(input_tensor=concatenate([output_tensor1, green_otsu_pooled], axis=3), num_filters=n_filters*2)
    resid2 = pooling(resid2)#(16, 16, 64)

    conv_4 = build_block(output_tensor1, n_filters * 2)#(16, 16, 64)
    conv_5 = build_block(conv_4, n_filters * 2)#(16, 16, 64)
    conv_6 = build_block(conv_5, n_filters * 2)#(16, 16, 64)
    pool_2 = pooling(conv_6)#(8, 8, 64)
    drop_2 = dropout(pool_2)
    output_tensor2 = Add()([resid2, drop_2])

    magenta_isodata = IsodataThresholdLayer()(magenta)
    magenta_isodata_pooled = pooling(magenta_isodata)
    magenta_isodata_pooled = pooling(magenta_isodata_pooled)

    resid3 = residual_block(input_tensor=concatenate([output_tensor2, magenta_isodata_pooled], axis=3), num_filters=n_filters*4)
    resid3 = pooling(resid3)#(16, 16, 64)

    conv_7 = build_block(output_tensor2, n_filters * 4)#(8, 8, 256)
    conv_8 = build_block(conv_7, n_filters * 4) #(8, 8, 256)
    conv_9 = build_block(conv_8, n_filters * 4) #(8, 8, 256)
    pool_3 = pooling(conv_9)#(4, 4, 256)
    drop_3 = dropout(pool_3)
    output_tensor3 = Add()([resid3, drop_3])

    conv_10 = build_block(output_tensor3, n_filters * 8)#(4, 4, 512)
    conv_11 = build_block(conv_10, n_filters * 8) #(4, 4, 512)
    conv_12 = build_block(conv_11, n_filters * 8) #(4, 4, 512)
    pool_4 = pooling(conv_12)#(4, 4, 512)

    return  pool_4, conv_3, conv_6, conv_9, conv_12

In [27]:
def leukocyte_encoder_bloc(image_input, n_filters=64):
    green = ExpandDimsLayer(axis=-1)(image_input[...,1])

    # mask1, mask2, mask3 = NucleusExtractionLayer()(image_input)
    cmyk = RGBToCMYKLayer()(image_input)
    cyan = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=0)(cmyk))
    yellow = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=2)(cmyk))

    lab = RGBToCIELABSkimageLayer()(image_input)
    # luminance_lab = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=0)(lab))
    b = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=2)(lab))

    luv = RGBToCIELUVSkimageLayer()(image_input)
    v = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=2)(luv))

    hsv = RGBToHSVSkimageLayer()(image_input)
    hue = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=0)(hsv))
    saturation = ExpandDimsLayer(axis=-1)(SelectChannelLayer(channel_idx=1)(hsv)) # Corrected from lab to hsv

    color_channels = concatenate([cyan, yellow, b, hue, saturation, v], axis=3)

    # Apply Otsu Thresholding to the v channel
    yellow_li = LiThresholdLayer()(yellow)

    resid1 = residual_block(input_tensor=concatenate([color_channels, yellow_li], axis=3), num_filters=n_filters)
    resid1 = pooling(resid1)#(16, 16, 64)

    conv_1 = build_block(color_channels, n_filters) #(32, 32, 64)
    conv_2 = build_block(conv_1, n_filters) #(32, 32, 64)
    conv_3 = build_block(conv_2, n_filters) #(32, 32, 64)
    pool_1 = pooling(conv_3)#(16, 16, 64)
    drop_1 = dropout(pool_1)
    output_tensor1 = Add()([resid1, drop_1])

    v_otsu = OtsuThresholdLayer()(v)
    v_otsu_pooled = pooling(v_otsu)

    resid2 = residual_block(input_tensor=concatenate([output_tensor1, v_otsu_pooled], axis=3), num_filters=n_filters*2)
    resid2 = pooling(resid2)#(16, 16, 64)

    conv_4 = build_block(output_tensor1, n_filters * 2)#(16, 16, 64)
    conv_5 = build_block(conv_4, n_filters * 2)#(16, 16, 64)
    conv_6 = build_block(conv_5, n_filters * 2)#(16, 16, 64)
    pool_2 = pooling(conv_6)#(8, 8, 64)
    drop_2 = dropout(pool_2)
    output_tensor2 = Add()([resid2, drop_2])

    v_yen = YenThresholdLayer()(v)
    v_yen_pooled = pooling(v_yen)
    v_yen_pooled = pooling(v_yen_pooled)

    resid3 = residual_block(input_tensor=concatenate([output_tensor2, v_yen_pooled], axis=3), num_filters=n_filters*4)
    resid3 = pooling(resid3)#(16, 16, 64)

    conv_7 = build_block(output_tensor2, n_filters * 4)#(8, 8, 256)
    conv_8 = build_block(conv_7, n_filters * 4) #(8, 8, 256)
    conv_9 = build_block(conv_8, n_filters * 4) #(8, 8, 256)
    pool_3 = pooling(conv_9)#(4, 4, 256)
    drop_3 = dropout(pool_3)
    output_tensor3 = Add()([resid3, drop_3])

    conv_10 = build_block(output_tensor3, n_filters * 8)#(4, 4, 512)
    conv_11 = build_block(conv_10, n_filters * 8) #(4, 4, 512)
    conv_12 = build_block(conv_11, n_filters * 8) #(4, 4, 512)
    pool_4 = pooling(conv_12)#(4, 4, 512)

    return  pool_4, conv_3, conv_6, conv_9, conv_12

In [28]:
def decoder(n_classes=4, n_filters=64, conv_2=None, conv_4=None, conv_6=None, conv_8=None, conv_10=None):
    upsp_1 = UpSampling2D(size=(2, 2))(conv_10) #(-1, 8, 8, 64)
    upsp_1 = concatenate([upsp_1, conv_8]) #(-1, 8, 8, 192)
    conv_11 = build_block(upsp_1, n_filters * 8) #(-1, 8, 8, 64)
    conv_12 = build_block(conv_11, n_filters * 8)
    conv_13 = build_block(conv_12, n_filters * 8)
    drop_7 = dropout(conv_13)

    upsp_3 = UpSampling2D(size=(2, 2))(drop_7) #(-1, 16, 16, 64)
    upsp_3 = concatenate([upsp_3, conv_6]) #(-1, 16, 16, 192)
    conv_14 = build_block(upsp_3, n_filters * 4) #(-1, 16, 16, 64)
    conv_15 = build_block(conv_14, n_filters * 4)
    conv_16 = build_block(conv_15, n_filters * 4)
    drop_8 = dropout(conv_16)

    upsp_4 = UpSampling2D(size=(2, 2))(drop_8) #(-1, 32, 32, 64)
    upsp_4 = concatenate([upsp_4, conv_4])#(-1, 32, 32, 92)
    conv_17 = build_block(upsp_4, n_filters * 2)#(-1, 32, 32, 32)
    conv_18 = build_block(conv_17, n_filters * 2)
    conv_19 = build_block(conv_18, n_filters * 2)
    drop_9 = dropout(conv_19)

    upsp_4 = UpSampling2D(size=(2, 2))(conv_18) #(-1, 32, 32, 64)
    upsp_4 = concatenate([upsp_4, conv_2])#(-1, 32, 32, 92)
    conv_20 = build_block(upsp_4, n_filters)#(-1, 32, 32, 32)
    conv_21 = build_block(conv_20, n_filters)
    conv_22 = build_block(conv_21, n_filters)
    drop_10 = dropout(conv_22)

    output = Conv2D(n_classes, (1, 1), kernel_initializer='glorot_normal', activation='softmax')(drop_10)
    return output

In [33]:
def build_unet(input_shape=(128, 128, 3), n_filters=64, n_classes=4):
    inputs = Input(input_shape)
    module_1, conv_2_1, conv_4_1, conv_6_1, conv_8_1 = rgb_encoder_bloc(image_input=inputs, n_filters=n_filters)
    module_2, conv_2_2, conv_4_2, conv_6_2, conv_8_2 = nuc_encoder_bloc(image_input=inputs, n_filters=n_filters)
    module_3, conv_2_3, conv_4_3, conv_6_3, conv_8_3 = leukocyte_encoder_bloc(image_input=inputs, n_filters=n_filters)
    encoder = concatenate([module_1, module_2, module_3])
    conv_9 = build_block(encoder, n_filters * 16)#(4, 4, 64)
    conv_10 = build_block(conv_9, n_filters * 16) #(4, 4, 64)
    output = decoder(conv_2=concatenate([conv_2_1, conv_2_2, conv_2_3]),
                     conv_4=concatenate([conv_4_1, conv_4_2, conv_4_3]),
                     conv_6=concatenate([conv_6_1, conv_6_2, conv_6_3]),
                     conv_8=concatenate([conv_8_1, conv_8_2, conv_8_3]),
                     conv_10=conv_10, n_classes=n_classes, n_filters=n_filters)
    model = Model(inputs=inputs, outputs=output)
    learning_rate = 1e-4
    optimizer = Adam(learning_rate=learning_rate)

    model.compile(
        loss="categorical_crossentropy",
        optimizer=optimizer,
        metrics=['accuracy', metrics.MeanIoU(num_classes=4)]
    )
    model.summary()

    return model

In [34]:
stopper = EarlyStopping(monitor = 'val_accuracy',
                            min_delta=0,
                            patience=stop_val,
                            verbose=1,
                            mode='max')


model_filename = 'tf_model_cyto_'+model_id+'.h5'

input_shape = (sub_img_width, sub_img_height, 3)

result_path = os.path.join('/content/drive/MyDrive/Colab Notebooks', 'results', 'Unet 4_1', model_id)
if os.path.exists(result_path) == False:
    os.makedirs(result_path)

In [35]:
final_model = build_unet(input_shape=input_shape, n_classes=n_classes)
final_model.summary()

In [36]:
checkpoint = ModelCheckpoint(os.path.join(result_path, model_filename),
                            monitor='val_accuracy',
                            verbose=1,
                            save_best_only= True,
                            mode='max')
datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True, width_shift_range=0.1, zoom_range=[0.5,1.0])
datagen.fit(train_inputs, augment=True, seed=seed)

generator = datagen.flow(train_inputs, train_labels, batch_size=bs)
validation_data = ImageDataGenerator().flow(x=val_inputs, y=val_labels, batch_size=1)
log_dir = os.path.join(".",result_path, 'Graph', 'Adam')
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True)

callback_list = [checkpoint, tensorboard_callback, stopper]

/usr/local/lib/python3.12/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
  self._warn_if_super_not_called()

In [38]:
history = final_model.fit(
  x=generator,
  steps_per_epoch= int(np.ceil(train_labels.shape[0]/bs)),
  epochs=epochs,
  validation_data=validation_data,
  validation_steps= int(np.ceil(val_labels.shape[0]/bs)),
  verbose=1,
  callbacks = callback_list
)
# list all data in history
print(history.history.keys())

  self._warn_if_super_not_called()


Epoch 1/500


InvalidArgumentError: Graph execution error:

Detected at node functional_1/concatenate_4_1/concat defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/usr/local/lib/python3.12/dist-packages/colab_kernel_launcher.py", line 37, in <module>

  File "/usr/local/lib/python3.12/dist-packages/traitlets/config/application.py", line 992, in launch_instance

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelapp.py", line 712, in start

  File "/usr/local/lib/python3.12/dist-packages/tornado/platform/asyncio.py", line 211, in start

  File "/usr/lib/python3.12/asyncio/base_events.py", line 645, in run_forever

  File "/usr/lib/python3.12/asyncio/base_events.py", line 1999, in _run_once

  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 499, in process_one

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 730, in execute_request

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/ipkernel.py", line 383, in do_execute

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/zmqshell.py", line 528, in run_cell

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "/tmp/ipython-input-656844309.py", line 1, in <cell line: 0>

  File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/trainer.py", line 377, in fit

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/trainer.py", line 220, in function

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/trainer.py", line 133, in multi_step_on_iterator

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/trainer.py", line 114, in one_step_on_data

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/trainer.py", line 58, in train_step

  File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 936, in __call__

  File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.12/dist-packages/keras/src/ops/operation.py", line 58, in __call__

  File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler

  File "/usr/local/lib/python3.12/dist-packages/keras/src/models/functional.py", line 183, in call

  File "/usr/local/lib/python3.12/dist-packages/keras/src/ops/function.py", line 177, in _run_through_graph

  File "/usr/local/lib/python3.12/dist-packages/keras/src/models/functional.py", line 648, in call

  File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py", line 936, in __call__

  File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.12/dist-packages/keras/src/ops/operation.py", line 58, in __call__

  File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler

  File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/merging/base_merge.py", line 225, in call

  File "/usr/local/lib/python3.12/dist-packages/keras/src/layers/merging/concatenate.py", line 102, in _merge_function

  File "/usr/local/lib/python3.12/dist-packages/keras/src/ops/numpy.py", line 1846, in concatenate

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/numpy.py", line 1119, in concatenate

ConcatOp : Expected concatenating dimensions in the range [-3, 3), but got 3
	 [[{{node functional_1/concatenate_4_1/concat}}]] [Op:__inference_multi_step_on_iterator_49940]

In [None]:
# summarize history for accuracy
plt.ioff()
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig(os.path.join(result_path, 'accuracy.jpg'))
plt.close()
# summarize history for loss
plt.ioff()
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig(os.path.join(result_path, 'loss.jpg'))
plt.close()
final_model.save(os.path.join(result_path, model_filename))

In [None]:
print('evaluate model')
evaluate(final_model, test_inputs, test_labels, log_dir=log_dir)

In [None]:
# prediction = predict(final_model, val_inputs, n_classes)
# plot_prediction(prediction, test_path, (sub_img_width, sub_img_height), result_path, dataset_path, labels_path)

In [None]:
size = (128, 128)
path=['BAS_0001.tiff']
path2=['BAS_0001.jpg']
for i in range(len(path)):
    file_name = path[i]
    img = cv2.imread(os.path.join(dataset_path, file_name[:3], file_name))
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, size, interpolation = cv2.INTER_LINEAR)
    prediction = predict(final_model, np.expand_dims(img, axis=0), n_classes) # Explicitly add batch dimension
    y = cv2.imread(os.path.join(labels_path, file_name[:3], path2[i]))
    y = cv2.cvtColor(y,cv2.COLOR_BGR2RGB)
    y = cv2.resize(y, size, interpolation = cv2.INTER_LINEAR)
    plt.ioff()
    plt.figure(figsize=(16, 4), dpi = 100)
    plt.subplot(1, 3, 1)
    plt.imshow(img, cmap='gray')
    plt.xlabel('original', fontsize=10)
    plt.yticks([])
    plt.xticks([])
    plt.subplot(1, 3, 2)
    plt.imshow(y, cmap='gray')
    plt.xlabel('Ground Truth', fontsize=10)
    plt.yticks([])
    plt.xticks([])
    plt.subplot(1, 3, 3)
    plt.imshow(prediction[0], cmap='gray') # Access the first (and only) image in the batch
    plt.xlabel('Segmentation', fontsize=10)
    plt.yticks([])
    plt.xticks([])
    plt.show()
    # plt.savefig(os.path.join(result_dir, path[i]))
    # plt.close()