In [None]:
from tensorflow.keras import backend as K
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Conv2DTranspose,
    BatchNormalization, Activation, SpatialDropout2D, concatenate
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from typing import Tuple




def build_flexible_unet(input_shape=(256, 256, 3), num_classes=6, freeze_rgb_encoder=True):
    from tensorflow.keras.applications import ResNet50

    inputs = Input(shape=input_shape, name="model_input")

    # --- Split RGB and Elevation ---
    if input_shape[-1] == 4:
        rgb = inputs[..., :3]
        elev = inputs[..., 3:]
    else:
        rgb = inputs
        elev = None

    # --- ResNet50 Backbone (RGB encoder) ---
    base_model = ResNet50(include_top=False, weights='imagenet', input_tensor=rgb, name="encoder")
    if freeze_rgb_encoder:
        for layer in base_model.layers:
            layer.trainable = False

    x1 = base_model.get_layer("conv1_relu").output        # 128x128
    x2 = base_model.get_layer("conv2_block3_out").output  # 64x64
    x3 = base_model.get_layer("conv3_block4_out").output  # 32x32
    x4 = base_model.get_layer("conv4_block6_out").output  # 16x16
    x5 = base_model.get_layer("conv5_block3_out").output  # 8x8

    # --- Elevation Path ---
    if elev is not None:
        def elev_block(elev_input, size, filters):
            x = Resizing(size, size)(elev_input)
            x = Conv2D(filters, 3, padding="same", activation=None)(x)
            x = BatchNormalization()(x)
            x = Activation("relu")(x)
            x = SpatialDropout2D(0.1)(x)
            return x

        # Initial processing
        e = Conv2D(32, 3, padding="same", activation=None)(elev)
        e = BatchNormalization()(e)
        e = Activation("relu")(e)
        e = SpatialDropout2D(0.1)(e)

        # Merge elevation at multiple encoder stages
        x1 = concatenate([x1, elev_block(e, 128, 64)])
        x2 = concatenate([x2, elev_block(e, 64, 128)])
        x3 = concatenate([x3, elev_block(e, 32, 256)])
        x4 = concatenate([x4, elev_block(e, 16, 256)])

    # --- Decoder Path ---
    def decoder_block(x, skip, filters, drop_rate=0.2):
        x = UpSampling2D()(x)
        x = concatenate([x, skip])
        x = Conv2D(filters, 3, padding="same", activation=None)(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(drop_rate)(x)
        return x

    d1 = decoder_block(x5, x4, 256)
    d2 = decoder_block(d1, x3, 256)
    d3 = decoder_block(d2, x2, 128, 0.3)
    d4 = decoder_block(d3, x1, 64, 0.3)

    d5 = UpSampling2D()(d4)
    d5 = Conv2D(32, 3, padding="same", activation=None)(d5)
    d5 = BatchNormalization()(d5)
    d5 = Activation("relu")(d5)

    outputs = Conv2D(num_classes, 1, activation="softmax")(d5)

    model = Model(inputs=inputs, outputs=outputs)
    return model, base_model






def enhanced_unet(input_shape=(256, 256, 3), num_classes=6, dropout=0.05):
    def conv_block(x, filters, dropout=dropout):
        x_skip = x
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout)(x)
        return x, x_skip

    def decoder_block(x, skip, filters, dropout=dropout):
        x = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding="same",
                            kernel_regularizer=l2(1e-4))(x)
        x = concatenate([x, skip])
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout)(x)
        return x

    inputs = Input(shape=input_shape)

    n_filters = 64

    # Encoder
    c1, s1 = conv_block(inputs, n_filters, dropout=0.0)
    p1 = MaxPooling2D((2, 2))(c1)

    c2, s2 = conv_block(p1, n_filters * 2, dropout=0.00)
    p2 = MaxPooling2D((2, 2))(c2)

    c3, s3 = conv_block(p2, n_filters * 4, dropout=0.25)
    p3 = MaxPooling2D((2, 2))(c3)

    c4, s4 = conv_block(p3, n_filters * 8, dropout=0.4)
    p4 = MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5, _ = conv_block(p4, n_filters * 16, dropout=0.5)  # Heavier dropout

    # Decoder
    u6 = decoder_block(c5, s4, n_filters * 8, dropout=0.4)
    u7 = decoder_block(u6, s3, n_filters * 4, dropout=0.3)
    u8 = decoder_block(u7, s2, n_filters * 2, dropout=0.05)
    u9 = decoder_block(u8, s1, n_filters, dropout=0.0)

    outputs = Conv2D(num_classes, (1, 1), activation="softmax", dtype="float32")(u9)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model


from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Conv2DTranspose,
    BatchNormalization, Activation, SpatialDropout2D, concatenate
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2


def enhanced_unet(input_shape=(256, 256, 3), num_classes=6) -> Model:
    """Builds an enhanced U-Net architecture with dropout and L2 regularisation.

    Args:
        input_shape (tuple): The input shape of the image, e.g., (256, 256, 3).
        num_classes (int): Number of output classes.

    Returns:
        tf.keras.Model: A compiled U-Net model.
    """

    def conv_block(x, filters, dropout_rate):
        """Applies two convolutional layers with batch norm, ReLU and dropout."""
        x_skip = x
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout_rate)(x)
        return x, x_skip

    def decoder_block(x, skip, filters, dropout_rate):
        """Upsamples and merges with skip connection, then applies conv block."""
        x = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding="same",
                            kernel_regularizer=l2(1e-4))(x)
        x = concatenate([x, skip])
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout_rate)(x)
        return x

    inputs = Input(shape=input_shape)
    n_filters = 64

    # --- Encoder ---
    c1, s1 = conv_block(inputs, n_filters, dropout_rate=0.05)
    p1 = MaxPooling2D((2, 2))(c1)

    c2, s2 = conv_block(p1, n_filters * 2, dropout_rate=0.1)
    p2 = MaxPooling2D((2, 2))(c2)

    c3, s3 = conv_block(p2, n_filters * 4, dropout_rate=0.2)
    p3 = MaxPooling2D((2, 2))(c3)

    c4, s4 = conv_block(p3, n_filters * 8, dropout_rate=0.35)
    p4 = MaxPooling2D((2, 2))(c4)

    # --- Bottleneck ---
    c5, _ = conv_block(p4, n_filters * 16, dropout_rate=0.5)  # Heavier dropout

    # --- Decoder ---
    u6 = decoder_block(c5, s4, n_filters * 8, dropout_rate=0.4)
    u7 = decoder_block(u6, s3, n_filters * 4, dropout_rate=0.3)
    u8 = decoder_block(u7, s2, n_filters * 2, dropout_rate=0.2)
    u9 = decoder_block(u8, s1, n_filters, dropout_rate=0.1)

    outputs = Conv2D(num_classes, (1, 1), activation="softmax", dtype="float32")(u9)

    return Model(inputs=inputs, outputs=outputs)






def enhanced_unet(input_shape: Tuple[int, int, int] = (256, 256, 3), num_classes: int = 6) -> Model:
    """Builds an enhanced U-Net architecture for semantic segmentation.

    This U-Net implementation includes several enhancements over a basic U-Net:
    -   Uses 'he_normal' kernel initializer for better weight initialization.
    -   Applies L2 regularization to convolutional kernels to prevent overfitting.
    -   Includes Batch Normalization after convolutions and before activation for
        improved training stability and speed.
    -   Uses ReLU activation functions throughout the encoder and decoder.
    -   Incorporates Spatial Dropout 2D layers, with increasing dropout rates
        in deeper encoder layers and decreasing rates in shallower decoder layers,
        to further regularize the model and prevent overfitting.
    -   The final output layer uses a 1x1 convolution with softmax activation
        to produce class probabilities for each pixel.

    Args:
    input_shape: A `tuple` of three integers `(height, width, channels)`
        specifying the input shape of the image tiles. Defaults to (256, 256, 3)
        for RGB images.
    num_classes: An `int` specifying the number of output classes for
        segmentation. Defaults to 6.

    Returns:
    A `tf.keras.Model` instance representing the enhanced U-Net.
    """

    def _conv_block(x: tf.Tensor, filters: int, dropout_rate: float) -> Tuple[tf.Tensor, tf.Tensor]:
        """Applies two convolutional layers with batch normalization, ReLU, and spatial dropout.

        This block is a building block for both the encoder and decoder paths. It
        applies two 3x3 convolutions, each followed by Batch Normalization and ReLU
        activation. Spatial Dropout 2D is applied at the end of the block.

        Args:
            x: The input `tf.Tensor` to the convolutional block.
            filters: An `int` specifying the number of convolutional filters (output
            channels) for the layers in this block.
            dropout_rate: A `float` specifying the dropout rate for `SpatialDropout2D`.

        Returns:
            A `tuple` containing two `tf.Tensor` objects:
            - The output of the convolutional block.
            - The original input `x_skip` to this block, used for skip connections.
        """
        x_skip = x # Store input for potential skip connection in the decoder.

        # First convolutional layer
        x = Conv2D(
            filters,
            (3, 3),
            padding="same",
            kernel_initializer="he_normal", # He initializer for ReLU activation.
            kernel_regularizer=l2(1e-4), # L2 regularization to prevent overfitting.
        )(x)
        x = BatchNormalization()(x) # Normalize activations to improve training speed and stability.
        x = Activation("relu")(x) # ReLU activation for non-linearity.

        # Second convolutional layer
        x = Conv2D(
            filters,
            (3, 3),
            padding="same",
            kernel_initializer="he_normal",
            kernel_regularizer=l2(1e-4),
        )(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)

        # Spatial Dropout 2D: drops entire 2D feature maps, effective for convolutional layers.
        x = SpatialDropout2D(dropout_rate)(x)
        return x, x_skip


    def _decoder_block(x: tf.Tensor, skip_connection: tf.Tensor, filters: int, dropout_rate: float) -> tf.Tensor:
        """Performs upsampling, concatenates with a skip connection, and applies a conv block.

        This block forms the decoding path of the U-Net. It first upsamples the input
        using `Conv2DTranspose`, then concatenates it with the corresponding skip
        connection from the encoder, and finally applies a convolutional block
        (`_conv_block`) with the specified filters and dropout.

        Args:
            x: The input `tf.Tensor` from the previous decoder stage (low-resolution feature map).
            skip_connection: The `tf.Tensor` from the corresponding encoder stage (high-resolution
            feature map) used for concatenation.
            filters: An `int` specifying the number of convolutional filters for the layers
            in the internal conv block.
            dropout_rate: A `float` specifying the dropout rate for `SpatialDropout2D`
            within the internal conv block.

        Returns:
            A `tf.Tensor` representing the output of the decoder block.
        """
        # Upsampling using Conv2DTranspose (transposed convolution).
        x = Conv2DTranspose(
            filters,
            (2, 2), # Kernel size for upsampling.
            strides=(2, 2), # Stride of 2 for upsampling.
            padding="same",
            kernel_regularizer=l2(1e-4),
        )(x)

        # Concatenate with the skip connection from the encoder path.
        x = concatenate([x, skip_connection]) # Combines feature maps across resolutions.

        # Apply two convolutional layers with batch norm, ReLU, and spatial dropout.
        # Note: This is essentially a _conv_block without returning an x_skip.
        x = Conv2D(
            filters,
            (3, 3),
            padding="same",
            kernel_initializer="he_normal",
            kernel_regularizer=l2(1e-4),
        )(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)

        x = Conv2D(
            filters,
            (3, 3),
            padding="same",
            kernel_initializer="he_normal",
            kernel_regularizer=l2(1e-4),
        )(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)

        x = SpatialDropout2D(dropout_rate)(x)
        return x


    # Define input layer for the model.
    inputs = Input(shape=input_shape)

    # Base number of filters for the first encoder block.
    n_filters = 64

    # --- Encoder Path ---
    # Downsampling blocks. Each block consists of _conv_block followed by MaxPooling.
    # Dropout rates are typically lower at shallower layers and increase deeper into the network.
    c1, s1 = _conv_block(inputs, n_filters, dropout_rate=0.05) # c1 is output, s1 is skip connection.
    p1 = MaxPooling2D((2, 2))(c1) # Halves spatial dimensions.

    c2, s2 = _conv_block(p1, n_filters * 2, dropout_rate=0.1)
    p2 = MaxPooling2D((2, 2))(c2)

    c3, s3 = _conv_block(p2, n_filters * 4, dropout_rate=0.2)
    p3 = MaxPooling2D((2, 2))(c3)

    c4, s4 = _conv_block(p3, n_filters * 8, dropout_rate=0.35)
    p4 = MaxPooling2D((2, 2))(c4)

    # --- Bottleneck Layer ---
    # The deepest part of the network, typically with the highest filter count and dropout.
    c5, _ = _conv_block(p4, n_filters * 16, dropout_rate=0.5)

    # --- Decoder Path ---
    # Upsampling blocks. Each block consists of `_decoder_block`.
    # Dropout rates typically decrease as the resolution increases.
    u6 = _decoder_block(c5, s4, n_filters * 8, dropout_rate=0.4)
    u7 = _decoder_block(u6, s3, n_filters * 4, dropout_rate=0.3)
    u8 = _decoder_block(u7, s2, n_filters * 2, dropout_rate=0.2)
    u9 = _decoder_block(u8, s1, n_filters, dropout_rate=0.1)

    # Output layer: 1x1 convolution to map feature channels to `num_classes`,
    # followed by softmax activation for probability distribution over classes.
    # `dtype="float32"` explicitly sets the output dtype, important for mixed precision.
    outputs = Conv2D(num_classes, (1, 1), activation="softmax", dtype="float32")(u9)

    # Create the Keras Model.
    model = Model(inputs=[inputs], outputs=[outputs])
    return model







def build_multi_unet(input_shape=(256, 256, 3), num_classes=6):

  inputs = Input(shape=input_shape)
  source_input = inputs

  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(source_input)
  c1 = Dropout(0.2)(c1)
  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c1)
  p1 = MaxPooling2D((2,2))(c1)

  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p1)
  c2 = Dropout(0.2)(c2)
  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c2)
  p2 = MaxPooling2D((2,2))(c2)

  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p2)
  c3 = Dropout(0.2)(c3)
  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c3)
  p3 = MaxPooling2D((2,2))(c3)

  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p3)
  c4 = Dropout(0.2)(c4)
  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c4)
  p4 = MaxPooling2D((2,2))(c4)

  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
  c5 = Dropout(0.2)(c5)
  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)

  u6 = Conv2DTranspose(128, (2,2), strides=(2,2), padding="same")(c5)
  u6 = concatenate([u6, c4])
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u6)
  c6 = Dropout(0.2)(c6)
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c6)

  u7 = Conv2DTranspose(64, (2,2), strides=(2,2), padding="same")(c6)
  u7 = concatenate([u7, c3])
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c7)

  u8 = Conv2DTranspose(32, (2,2), strides=(2,2), padding="same")(c7)
  u8 = concatenate([u8, c2])
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u8)
  c8 = Dropout(0.2)(c8)
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c8)

  u9 = Conv2DTranspose(16, (2,2), strides=(2,2), padding="same")(c8)
  u9 = concatenate([u9, c1], axis=3)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u9)
  c9 = Dropout(0.2)(c9)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c9)

  outputs = Conv2D(num_classes, (1,1), activation="softmax")(c9)

  model = Model(inputs=[inputs], outputs=[outputs])
  return model
     

def build_unet(input_shape=(256, 256, 3), num_classes=6):
    print("🧪 build_unet called with input_shape =", input_shape)
    inputs = layers.Input(shape=input_shape)
    print("🧪 Input layer constructed with shape:", inputs.shape)

    # --- Contracting Path ---
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)

    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)

    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)

    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D((2,2))(c4)

    # --- Bottleneck ---
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(c5)

    # --- Expansive Path ---
    u6 = layers.UpSampling2D((2,2))(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c6)

    u7 = layers.UpSampling2D((2,2))(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c7)

    u8 = layers.UpSampling2D((2,2))(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c8)

    u9 = layers.UpSampling2D((2,2))(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c9)

    # --- Output Layer ---
    outputs = layers.Conv2D(num_classes, (1,1), activation='softmax')(c9)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    print("✅ U-Net model built successfully.")
    print("🧪 Final model.input_shape =", model.input_shape)

    return model




