<a href="https://colab.research.google.com/github/MahiKhan5360/Segmentation-using-Capsule-layers-and-CNN/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#model

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, UpSampling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.metrics import MeanIoU


In [None]:
# Custom Capsule Layer Class
def create_capsule_layer():
    class CapsuleLayer(tf.keras.layers.Layer):
        def __init__(self, num_capsules, capsule_dim, num_routing=3, **kwargs):
            super(CapsuleLayer, self).__init__(**kwargs)
            self.num_capsules = num_capsules
            self.capsule_dim = capsule_dim
            self.num_routing = num_routing

        def build(self, input_shape):
            self.input_height, self.input_width, self.input_channels = input_shape[1:4]
            self.W = self.add_weight(
                shape=[self.input_channels, self.num_capsules * self.capsule_dim],
                initializer='glorot_uniform',
                trainable=True
            )
         def call(self, inputs):
            batch_size, height, width, channels = inputs.shape
            tf.debugging.assert_equal(height, self.input_height, message="Height mismatch")
            tf.debugging.assert_equal(width, self.input_width, message="Width mismatch")
            tf.debugging.assert_equal(channels, self.input_channels, message="Channels mismatch")

            inputs_reshaped = tf.nn.relu(inputs)  # Apply ReLU
            # shape (batch_size, height, width, channels)
            # Transform each spatial position to capsules
            u_hat = tf.tensordot(inputs_reshaped, self.W, axes=[[3], [0]])
            u_hat = tf.reshape(u_hat, (-1, height, width, self.num_capsules, self.capsule_dim))

            # Initialize routing logits
            b = tf.zeros(shape=(tf.shape(inputs)[0], height, width, self.num_capsules))

            # Dynamic routing
            for i in range(self.num_routing):
                c = tf.nn.softmax(b, axis=3)  # Softmax over num_capsules
                s = tf.reduce_sum(c[..., None] * u_hat, axis=3)  # Sum over num_capsules
                v = self.squash(s)  # Squash to get capsule vectors
                if i < self.num_routing - 1:
                    b += tf.reduce_sum(u_hat * v[..., None, :], axis=4)

            return v  # Shape: (batch_size, height, width, capsule_dim)
        def squash(self, s, axis=-1):
            s_squared_norm = tf.reduce_sum(tf.square(s), axis=axis, keepdims=True)
            scale = s_squared_norm / (1 + s_squared_norm)
            return scale * s / tf.sqrt(s_squared_norm + 1e-9)

    return CapsuleLayer

# Model Creation for Binary Image Segmentation
def create_capsule_segmentation_model(input_shape=(256, 256, 3), num_capsules=8, capsule_dim=16):
    CapsuleLayer = create_capsule_layer()

    # Input layer
    input_layer = Input(shape=input_shape)

    # Encoder: Convolutional block 1
    x = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same', kernel_regularizer=l2(0.0002))(input_layer)
    x = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Dropout(0.3)(x)

    # Encoder: Convolutional block 2
    x = Conv2D(filters=128, kernel_size=3, activation='relu', padding='same', kernel_regularizer=l2(0.0002))(x)
    x = Conv2D(filters=128, kernel_size=3, activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Dropout(0.3)(x)

    # Capsule layer
    x = CapsuleLayer(num_capsules=num_capsules, capsule_dim=capsule_dim)(x)

    # Decoder: Upsampling and convolutional layers
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(filters=128, kernel_size=3, activation='relu', padding='same', kernel_regularizer=l2(0.0002))(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same', kernel_regularizer=l2(0.0002))(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)

    # Output layer: Single channel with sigmoid for binary segmentation
    output_layer = Conv2D(filters=1, kernel_size=1, activation='sigmoid', padding='same')(x)

     # Create model
    model = Model(inputs=input_layer, outputs=output_layer)

    # Compile model with additional IoU metric
    model.compile(
        optimizer=Adam(learning_rate=0.0005),
        loss='binary_crossentropy',
        metrics=['accuracy', MeanIoU(num_classes=2)]  # Binary segmentation: background + lesion
    )

    return model

if __name__ == "__main__":


    # Create and compile model
    model = create_capsule_segmentation_model(input_shape=(256, 256, 3))
    model.summary()

      # Define callbacks
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
        ModelCheckpoint(filepath='/content/drive/MyDrive/ISIC2018/best_capsule_model.keras', save_best_only=True, monitor='val_loss'),
        ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-7)
    ]