<a href="https://colab.research.google.com/github/MahiKhan5360/Skin-Lesion-Segmentation-using-Capsule-Layer-and-CNN/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, UpSampling2D, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.metrics import MeanIoU
import gc

In [None]:
# Custom Capsule Layer Class
def create_capsule_layer():
    class CapsuleLayer(tf.keras.layers.Layer):
        def __init__(self, num_capsules=4, capsule_dim=8, num_routing=2, **kwargs):
            super(CapsuleLayer, self).__init__(**kwargs)
            self.num_capsules = num_capsules
            self.capsule_dim = capsule_dim
            self.num_routing = num_routing

        def build(self, input_shape):
            self.input_channels = input_shape[-1]

            self.W = self.add_weight(
                shape=[self.input_channels, self.num_capsules * self.capsule_dim],
                initializer='glorot_uniform',
                trainable=True
            )
            super().build(input_shape)

        def compute_output_shape(self, input_shape):

            return (input_shape[0], input_shape[1], input_shape[2], self.capsule_dim)

        def call(self, inputs):
            batch_size = tf.shape(inputs)[0]
            height, width = inputs.shape[1], inputs.shape[2]

            inputs_activated = tf.nn.relu(inputs)


            u_hat = tf.matmul(
                tf.reshape(inputs_activated, [-1, self.input_channels]),
                self.W
            )
            u_hat = tf.reshape(u_hat, [batch_size, height, width, self.num_capsules, self.capsule_dim])


            b = tf.zeros([batch_size, height, width, self.num_capsules])

            for i in range(self.num_routing):
                # Softmax over capsules dimension
                c = tf.nn.softmax(b, axis=-1)  # Shape: [batch, height, width, num_capsules]

                # Expand c to match u_hat dimensions for element-wise multiplication
                c_expanded = tf.expand_dims(c, axis=-1)  # Shape: [batch, height, width, num_capsules, 1]


                s = tf.reduce_sum(c_expanded * u_hat, axis=3)  # Sum over num_capsules
                # s shape: [batch, height, width, capsule_dim]

                # Apply squashing
                v = self.squash(s)

                # Update routing coefficients if not last iteration
                if i < self.num_routing - 1:
                    # v shape: [batch, height, width, capsule_dim]
                    # u_hat shape: [batch, height, width, num_capsules, capsule_dim]
                    v_expanded = tf.expand_dims(v, axis=3)  # [batch, height, width, 1, capsule_dim]

                    # Calculate agreement: dot product between v and each u_hat
                    agreement = tf.reduce_sum(u_hat * v_expanded, axis=-1)  # [batch, height, width, num_capsules]
                    b = b + agreement

            return v

        def squash(self, s, axis=-1, epsilon=1e-9):
            s_squared_norm = tf.reduce_sum(tf.square(s), axis=axis, keepdims=True)
            scale = s_squared_norm / (1 + s_squared_norm + epsilon)
            return scale * s / tf.sqrt(s_squared_norm + epsilon)

    return CapsuleLayer

# Memory-efficient Model Creation
def create_efficient_capsule_segmentation_model(input_shape=(256, 256, 3), num_capsules=4, capsule_dim=8):
    CapsuleLayer = create_capsule_layer()

    # Input layer
    input_layer = Input(shape=input_shape)


    # Encoder:
    # Block 1
    x = Conv2D(filters=32, kernel_size=3, activation='relu', padding='same', kernel_regularizer=l2(0.0001))(input_layer)
    x = BatchNormalization()(x)
    x = Conv2D(filters=32, kernel_size=3, activation='relu', padding='same')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Dropout(0.2)(x)  # Reduced dropout

    # Block 2
    x = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same', kernel_regularizer=l2(0.0001))(x)
    x = BatchNormalization()(x)
    x = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Dropout(0.2)(x)

    # Block 3 - Additional downsampling to reduce spatial dimension
    x = Conv2D(filters=128, kernel_size=3, activation='relu', padding='same', kernel_regularizer=l2(0.0001))(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Dropout(0.2)(x)

    # Capsule layer with reduced parameters
    x = CapsuleLayer(num_capsules=num_capsules, capsule_dim=capsule_dim)(x)
