In [None]:
from tensorflow.keras import backend as K
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Conv2DTranspose,
    BatchNormalization, Activation, SpatialDropout2D, concatenate
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from typing import Tuple




def enhanced_unet(input_shape=(256, 256, 3), num_classes=6) -> Model:
    """Builds an enhanced U-Net architecture with dropout and L2 regularisation.

    Args:
        input_shape (tuple): The input shape of the image, e.g., (256, 256, 3).
        num_classes (int): Number of output classes.

    Returns:
        tf.keras.Model: A compiled U-Net model.
    """

    def conv_block(x, filters, dropout_rate):
        """Applies two convolutional layers with batch norm, ReLU and dropout."""
        x_skip = x
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout_rate)(x)
        return x, x_skip

    def decoder_block(x, skip, filters, dropout_rate):
        """Upsamples and merges with skip connection, then applies conv block."""
        x = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding="same",
                            kernel_regularizer=l2(1e-4))(x)
        x = concatenate([x, skip])
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout_rate)(x)
        return x

    inputs = Input(shape=input_shape)
    n_filters = 64

    # --- Encoder ---
    c1, s1 = conv_block(inputs, n_filters, dropout_rate=0.05)
    p1 = MaxPooling2D((2, 2))(c1)

    c2, s2 = conv_block(p1, n_filters * 2, dropout_rate=0.1)
    p2 = MaxPooling2D((2, 2))(c2)

    c3, s3 = conv_block(p2, n_filters * 4, dropout_rate=0.2)
    p3 = MaxPooling2D((2, 2))(c3)

    c4, s4 = conv_block(p3, n_filters * 8, dropout_rate=0.35)
    p4 = MaxPooling2D((2, 2))(c4)

    # --- Bottleneck ---
    c5, _ = conv_block(p4, n_filters * 16, dropout_rate=0.5)  # Heavier dropout

    # --- Decoder ---
    u6 = decoder_block(c5, s4, n_filters * 8, dropout_rate=0.4)
    u7 = decoder_block(u6, s3, n_filters * 4, dropout_rate=0.3)
    u8 = decoder_block(u7, s2, n_filters * 2, dropout_rate=0.2)
    u9 = decoder_block(u8, s1, n_filters, dropout_rate=0.1)

    outputs = Conv2D(num_classes, (1, 1), activation="softmax", dtype="float32")(u9)

    return Model(inputs=inputs, outputs=outputs)






def enhanced_unet(input_shape: Tuple[int, int, int] = (256, 256, 3), num_classes: int = 6, n_filters: int = 32) -> Model:
    """Builds an enhanced U-Net architecture for semantic segmentation.

    This U-Net implementation includes several enhancements over a basic U-Net:
    -   Uses 'he_normal' kernel initializer for better weight initialization.
    -   Applies L2 regularization to convolutional kernels to prevent overfitting.
    -   Includes Batch Normalization after convolutions and before activation for
        improved training stability and speed.
    -   Uses ReLU activation functions throughout the encoder and decoder.
    -   Incorporates Spatial Dropout 2D layers, with increasing dropout rates
        in deeper encoder layers and decreasing rates in shallower decoder layers,
        to further regularize the model and prevent overfitting.
    -   The final output layer uses a 1x1 convolution with softmax activation
        to produce class probabilities for each pixel.

    Args:
    input_shape: A `tuple` of three integers `(height, width, channels)`
        specifying the input shape of the image tiles. Defaults to (256, 256, 3)
        for RGB images.
    num_classes: An `int` specifying the number of output classes for
        segmentation. Defaults to 6.
    n_filters: An `int` specifying the base number of filters for the first
        convolutional block. This value will be multiplied by powers of 2 in the
        encoder and decoder paths to increase the number of filters in deeper layers.

    Returns:
    A `tf.keras.Model` instance representing the enhanced U-Net.
    """

    def _conv_block(x: tf.Tensor, filters: int, dropout_rate: float) -> Tuple[tf.Tensor, tf.Tensor]:
        """Applies two convolutional layers with batch normalization, ReLU, and spatial dropout.

        This block is a building block for both the encoder and decoder paths. It
        applies two 3x3 convolutions, each followed by Batch Normalization and ReLU
        activation. Spatial Dropout 2D is applied at the end of the block.

        Args:
            x: The input `tf.Tensor` to the convolutional block.
            filters: An `int` specifying the number of convolutional filters (output
            channels) for the layers in this block.
            dropout_rate: A `float` specifying the dropout rate for `SpatialDropout2D`.

        Returns:
            A `tuple` containing two `tf.Tensor` objects:
            - The output of the convolutional block.
            - The original input `x_skip` to this block, used for skip connections.
        """
        x_skip = x # Store input for potential skip connection in the decoder.

        # First convolutional layer
        x = Conv2D(
            filters,
            (3, 3),
            padding="same",
            kernel_initializer="he_normal", # He initializer for ReLU activation.
            kernel_regularizer=l2(1e-4), # L2 regularization to prevent overfitting.
        )(x)
        x = BatchNormalization()(x) # Normalize activations to improve training speed and stability.
        x = Activation("relu")(x) # ReLU activation for non-linearity.

        # Second convolutional layer
        x = Conv2D(
            filters,
            (3, 3),
            padding="same",
            kernel_initializer="he_normal",
            kernel_regularizer=l2(1e-4),
        )(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)

        # Spatial Dropout 2D: drops entire 2D feature maps, effective for convolutional layers.
        x = SpatialDropout2D(dropout_rate)(x)
        return x, x_skip


    def _decoder_block(x: tf.Tensor, skip_connection: tf.Tensor, filters: int, dropout_rate: float) -> tf.Tensor:
        """Performs upsampling, concatenates with a skip connection, and applies a conv block.

        This block forms the decoding path of the U-Net. It first upsamples the input
        using `Conv2DTranspose`, then concatenates it with the corresponding skip
        connection from the encoder, and finally applies a convolutional block
        (`_conv_block`) with the specified filters and dropout.

        Args:
            x: The input `tf.Tensor` from the previous decoder stage (low-resolution feature map).
            skip_connection: The `tf.Tensor` from the corresponding encoder stage (high-resolution
            feature map) used for concatenation.
            filters: An `int` specifying the number of convolutional filters for the layers
            in the internal conv block.
            dropout_rate: A `float` specifying the dropout rate for `SpatialDropout2D`
            within the internal conv block.

        Returns:
            A `tf.Tensor` representing the output of the decoder block.
        """
        # Upsampling using Conv2DTranspose (transposed convolution).
        x = Conv2DTranspose(
            filters,
            (2, 2), # Kernel size for upsampling.
            strides=(2, 2), # Stride of 2 for upsampling.
            padding="same",
            kernel_regularizer=l2(1e-4),
        )(x)

        # Concatenate with the skip connection from the encoder path.
        x = concatenate([x, skip_connection]) # Combines feature maps across resolutions.

        # Apply two convolutional layers with batch norm, ReLU, and spatial dropout.
        # Note: This is essentially a _conv_block without returning an x_skip.
        x = Conv2D(
            filters,
            (3, 3),
            padding="same",
            kernel_initializer="he_normal",
            kernel_regularizer=l2(1e-4),
        )(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)

        x = Conv2D(
            filters,
            (3, 3),
            padding="same",
            kernel_initializer="he_normal",
            kernel_regularizer=l2(1e-4),
        )(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)

        x = SpatialDropout2D(dropout_rate)(x)
        return x


    # Define input layer for the model.
    inputs = Input(shape=input_shape)

    # Base number of filters for the first encoder block.
    n_filters = 32

    # --- Encoder Path ---
    # Downsampling blocks. Each block consists of _conv_block followed by MaxPooling.
    # Dropout rates are typically lower at shallower layers and increase deeper into the network.
    c1, s1 = _conv_block(inputs, n_filters, dropout_rate=0.05) # c1 is output, s1 is skip connection.
    p1 = MaxPooling2D((2, 2))(c1) # Halves spatial dimensions.

    c2, s2 = _conv_block(p1, n_filters * 2, dropout_rate=0.1)
    p2 = MaxPooling2D((2, 2))(c2)

    c3, s3 = _conv_block(p2, n_filters * 4, dropout_rate=0.2)
    p3 = MaxPooling2D((2, 2))(c3)

    c4, s4 = _conv_block(p3, n_filters * 8, dropout_rate=0.35)
    p4 = MaxPooling2D((2, 2))(c4)

    # --- Bottleneck Layer ---
    # The deepest part of the network, typically with the highest filter count and dropout.
    c5, _ = _conv_block(p4, n_filters * 16, dropout_rate=0.5)

    # --- Decoder Path ---
    # Upsampling blocks. Each block consists of `_decoder_block`.
    # Dropout rates typically decrease as the resolution increases.
    u6 = _decoder_block(c5, s4, n_filters * 8, dropout_rate=0.4)
    u7 = _decoder_block(u6, s3, n_filters * 4, dropout_rate=0.3)
    u8 = _decoder_block(u7, s2, n_filters * 2, dropout_rate=0.2)
    u9 = _decoder_block(u8, s1, n_filters, dropout_rate=0.1)

    # Output layer: 1x1 convolution to map feature channels to `num_classes`,
    # followed by softmax activation for probability distribution over classes.
    # `dtype="float32"` explicitly sets the output dtype, important for mixed precision.
    outputs = Conv2D(num_classes, (1, 1), activation="softmax", dtype="float32")(u9)

    # Create the Keras Model.
    model = Model(inputs=[inputs], outputs=[outputs])
    return model

