In [None]:
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K


from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, concatenate,
                                     BatchNormalization, Activation, SpatialDropout2D, Dropout, LeakyReLU, Cropping2D)




def build_flexible_unet(input_shape=(256, 256, 3), num_classes=6, freeze_rgb_encoder=True):
    from tensorflow.keras.applications import ResNet50
    from tensorflow.keras.layers import (
        Input, Conv2D, UpSampling2D, concatenate, Resizing,
        BatchNormalization, Activation, SpatialDropout2D
    )
    from tensorflow.keras.models import Model

    inputs = Input(shape=input_shape, name="model_input")

    # --- Split RGB and Elevation ---
    if input_shape[-1] == 4:
        rgb = inputs[..., :3]
        elev = inputs[..., 3:]
    else:
        rgb = inputs
        elev = None

    # --- ResNet50 Backbone (RGB encoder) ---
    base_model = ResNet50(include_top=False, weights='imagenet', input_tensor=rgb, name="encoder")
    if freeze_rgb_encoder:
        for layer in base_model.layers:
            layer.trainable = False

    x1 = base_model.get_layer("conv1_relu").output        # 128x128
    x2 = base_model.get_layer("conv2_block3_out").output  # 64x64
    x3 = base_model.get_layer("conv3_block4_out").output  # 32x32
    x4 = base_model.get_layer("conv4_block6_out").output  # 16x16
    x5 = base_model.get_layer("conv5_block3_out").output  # 8x8

    # --- Elevation Path ---
    if elev is not None:
        def elev_block(elev_input, size, filters):
            x = Resizing(size, size)(elev_input)
            x = Conv2D(filters, 3, padding="same", activation=None)(x)
            x = BatchNormalization()(x)
            x = Activation("relu")(x)
            x = SpatialDropout2D(0.1)(x)
            return x

        # Initial processing
        e = Conv2D(32, 3, padding="same", activation=None)(elev)
        e = BatchNormalization()(e)
        e = Activation("relu")(e)
        e = SpatialDropout2D(0.1)(e)

        # Merge elevation at multiple encoder stages
        x1 = concatenate([x1, elev_block(e, 128, 64)])
        x2 = concatenate([x2, elev_block(e, 64, 128)])
        x3 = concatenate([x3, elev_block(e, 32, 256)])
        x4 = concatenate([x4, elev_block(e, 16, 256)])

    # --- Decoder Path ---
    def decoder_block(x, skip, filters, drop_rate=0.2):
        x = UpSampling2D()(x)
        x = concatenate([x, skip])
        x = Conv2D(filters, 3, padding="same", activation=None)(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(drop_rate)(x)
        return x

    d1 = decoder_block(x5, x4, 256)
    d2 = decoder_block(d1, x3, 256)
    d3 = decoder_block(d2, x2, 128, 0.3)
    d4 = decoder_block(d3, x1, 64, 0.3)

    d5 = UpSampling2D()(d4)
    d5 = Conv2D(32, 3, padding="same", activation=None)(d5)
    d5 = BatchNormalization()(d5)
    d5 = Activation("relu")(d5)

    outputs = Conv2D(num_classes, 1, activation="softmax")(d5)

    model = Model(inputs=inputs, outputs=outputs)
    return model, base_model





from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, BatchNormalization, Activation, SpatialDropout2D, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2

def enhanced_unet(input_shape=(256, 256, 3), num_classes=6, dropout=0.05):
    def conv_block(x, filters, dropout=dropout):
        x_skip = x
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout)(x)
        return x, x_skip

    def decoder_block(x, skip, filters, dropout=dropout):
        x = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding="same",
                            kernel_regularizer=l2(1e-4))(x)
        x = concatenate([x, skip])
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal",
                   kernel_regularizer=l2(1e-4))(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout)(x)
        return x

    inputs = Input(shape=input_shape)

    n_filters = 32

    # Encoder
    c1, s1 = conv_block(inputs, n_filters, dropout=0.0)
    p1 = MaxPooling2D((2, 2))(c1)

    c2, s2 = conv_block(p1, n_filters * 2, dropout=0.05)
    p2 = MaxPooling2D((2, 2))(c2)

    c3, s3 = conv_block(p2, n_filters * 4, dropout=0.25)
    p3 = MaxPooling2D((2, 2))(c3)

    c4, s4 = conv_block(p3, n_filters * 8, dropout=0.40)
    p4 = MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5, _ = conv_block(p4, n_filters * 16, dropout=0.55)  # Heavier dropout

    # Decoder
    u6 = decoder_block(c5, s4, n_filters * 8, dropout=0.45)
    u7 = decoder_block(u6, s3, n_filters * 4, dropout=0.35)
    u8 = decoder_block(u7, s2, n_filters * 2, dropout=0.15)
    u9 = decoder_block(u8, s1, n_filters, dropout=0.0)

    outputs = Conv2D(num_classes, (1, 1), activation="softmax", dtype="float32")(u9)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model




def build_multi_unet(input_shape=(256, 256, 3), num_classes=6):

  inputs = Input(shape=input_shape)
  source_input = inputs

  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(source_input)
  c1 = Dropout(0.2)(c1)
  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c1)
  p1 = MaxPooling2D((2,2))(c1)

  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p1)
  c2 = Dropout(0.2)(c2)
  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c2)
  p2 = MaxPooling2D((2,2))(c2)

  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p2)
  c3 = Dropout(0.2)(c3)
  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c3)
  p3 = MaxPooling2D((2,2))(c3)

  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p3)
  c4 = Dropout(0.2)(c4)
  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c4)
  p4 = MaxPooling2D((2,2))(c4)

  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
  c5 = Dropout(0.2)(c5)
  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)

  u6 = Conv2DTranspose(128, (2,2), strides=(2,2), padding="same")(c5)
  u6 = concatenate([u6, c4])
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u6)
  c6 = Dropout(0.2)(c6)
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c6)

  u7 = Conv2DTranspose(64, (2,2), strides=(2,2), padding="same")(c6)
  u7 = concatenate([u7, c3])
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c7)

  u8 = Conv2DTranspose(32, (2,2), strides=(2,2), padding="same")(c7)
  u8 = concatenate([u8, c2])
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u8)
  c8 = Dropout(0.2)(c8)
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c8)

  u9 = Conv2DTranspose(16, (2,2), strides=(2,2), padding="same")(c8)
  u9 = concatenate([u9, c1], axis=3)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u9)
  c9 = Dropout(0.2)(c9)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c9)

  outputs = Conv2D(num_classes, (1,1), activation="softmax")(c9)

  model = Model(inputs=[inputs], outputs=[outputs])
  return model
     

def build_unet(input_shape=(256, 256, 3), num_classes=6):
    print("🧪 build_unet called with input_shape =", input_shape)
    inputs = layers.Input(shape=input_shape)
    print("🧪 Input layer constructed with shape:", inputs.shape)

    # --- Contracting Path ---
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)

    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)

    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)

    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D((2,2))(c4)

    # --- Bottleneck ---
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(c5)

    # --- Expansive Path ---
    u6 = layers.UpSampling2D((2,2))(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c6)

    u7 = layers.UpSampling2D((2,2))(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c7)

    u8 = layers.UpSampling2D((2,2))(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c8)

    u9 = layers.UpSampling2D((2,2))(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c9)

    # --- Output Layer ---
    outputs = layers.Conv2D(num_classes, (1,1), activation='softmax')(c9)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    print("✅ U-Net model built successfully.")
    print("🧪 Final model.input_shape =", model.input_shape)

    return model




