In [None]:
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K


from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, concatenate,
                                     BatchNormalization, Activation, SpatialDropout2D, Dropout, LeakyReLU, Cropping2D)



'''
def build_flexible_unet(input_shape=(256, 256, 3), num_classes=6, freeze_rgb_encoder=True):
    from tensorflow.keras.applications import ResNet50
    from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, MaxPooling2D, concatenate, Dropout
    from tensorflow.keras.models import Model

    # --- Input ---
    inputs = Input(shape=input_shape, name="model_input")

    # --- Split RGB and Elevation ---
    if input_shape[-1] == 4:
        rgb = inputs[..., :3]
        elev = inputs[..., 3:]
    else:
        rgb = inputs
        elev = None

    # --- ResNet50 Backbone ---
    base_model = ResNet50(include_top=False, weights='imagenet', input_tensor=rgb, name="encoder")
    if freeze_rgb_encoder:
        for layer in base_model.layers:
            layer.trainable = False

    # Encoder feature maps
    x1 = base_model.get_layer("conv1_relu").output       # 128x128
    x2 = base_model.get_layer("conv2_block3_out").output # 64x64
    x3 = base_model.get_layer("conv3_block4_out").output # 32x32
    x4 = base_model.get_layer("conv4_block6_out").output # 16x16
    x5 = base_model.get_layer("conv5_block3_out").output # 8x8

    # --- Elevation branch ---
    if elev is not None:
        e = Conv2D(16, 3, padding="same", activation="relu")(elev)     # 256x256
        e = MaxPooling2D()(e)                                          # 128x128
        e = Conv2D(16, 3, padding="same", activation="relu")(e)        # 128x128
        x1 = concatenate([x1, e])  # Match spatial shape with x1

    # --- Decoder ---
    d1 = UpSampling2D()(x5)
    d1 = concatenate([d1, x4])
    d1 = Conv2D(256, 3, padding="same", activation="relu")(d1)
    d1 = Dropout(0.2)(d1)

    d2 = UpSampling2D()(d1)
    d2 = concatenate([d2, x3])
    d2 = Conv2D(128, 3, padding="same", activation="relu")(d2)
    d2 = Dropout(0.2)(d2)

    d3 = UpSampling2D()(d2)
    d3 = concatenate([d3, x2])
    d3 = Conv2D(64, 3, padding="same", activation="relu")(d3)
    d3 = Dropout(0.2)(d3)

    d4 = UpSampling2D()(d3)
    d4 = concatenate([d4, x1])
    d4 = Conv2D(32, 3, padding="same", activation="relu")(d4)
    d4 = Dropout(0.2)(d4)

    d5 = UpSampling2D()(d4)
    d5 = Conv2D(32, 3, padding="same", activation="relu")(d5)

    outputs = Conv2D(num_classes, 1, activation="softmax")(d5)

    model = Model(inputs=inputs, outputs=outputs)
    return model, base_model



def build_flexible_unet(input_shape=(256, 256, 3), num_classes=6, freeze_rgb_encoder=True):
    from tensorflow.keras.applications import ResNet50
    from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, MaxPooling2D, concatenate, Dropout
    from tensorflow.keras.models import Model

    # --- Input ---
    inputs = Input(shape=input_shape, name="model_input")

    # --- Split RGB and Elevation ---
    if input_shape[-1] == 4:
        rgb = inputs[..., :3]
        elev = inputs[..., 3:]
    else:
        rgb = inputs
        elev = None

    # --- ResNet50 Backbone ---
    base_model = ResNet50(include_top=False, weights='imagenet', input_tensor=rgb, name="encoder")
    if freeze_rgb_encoder:
        for layer in base_model.layers:
            layer.trainable = False

    # Encoder feature maps
    x1 = base_model.get_layer("conv1_relu").output       # 128x128
    x2 = base_model.get_layer("conv2_block3_out").output # 64x64
    x3 = base_model.get_layer("conv3_block4_out").output # 32x32
    x4 = base_model.get_layer("conv4_block6_out").output # 16x16
    x5 = base_model.get_layer("conv5_block3_out").output # 8x8

    # --- Elevation branch ---
    if elev is not None:
        # Downsample elevation to match encoder stages
        e1 = Conv2D(16, 3, padding="same", activation="relu")(elev)   # 256x256
        e2 = MaxPooling2D()(e1)                                        # 128x128
        e2 = Conv2D(16, 3, padding="same", activation="relu")(e2)
        e3 = MaxPooling2D()(e2)                                        # 64x64
        e3 = Conv2D(16, 3, padding="same", activation="relu")(e3)
        e4 = MaxPooling2D()(e3)                                        # 32x32
        e4 = Conv2D(16, 3, padding="same", activation="relu")(e4)

        # Concatenate at multiple levels
        x1 = concatenate([x1, e2])
        x2 = concatenate([x2, e3])
        x3 = concatenate([x3, e4])

    # --- Decoder ---
    d1 = UpSampling2D()(x5)
    d1 = concatenate([d1, x4])
    d1 = Conv2D(256, 3, padding="same", activation="relu")(d1)
    d1 = Dropout(0.2)(d1)

    d2 = UpSampling2D()(d1)
    d2 = concatenate([d2, x3])
    d2 = Conv2D(128, 3, padding="same", activation="relu")(d2)
    d2 = Dropout(0.2)(d2)

    d3 = UpSampling2D()(d2)
    d3 = concatenate([d3, x2])
    d3 = Conv2D(64, 3, padding="same", activation="relu")(d3)
    d3 = Dropout(0.2)(d3)

    d4 = UpSampling2D()(d3)
    d4 = concatenate([d4, x1])
    d4 = Conv2D(32, 3, padding="same", activation="relu")(d4)
    d4 = Dropout(0.2)(d4)

    d5 = UpSampling2D()(d4)
    d5 = Conv2D(32, 3, padding="same", activation="relu")(d5)

    outputs = Conv2D(num_classes, 1, activation="softmax")(d5)

    model = Model(inputs=inputs, outputs=outputs)
    return model, base_model

'''



def build_flexible_unet(input_shape=(256, 256, 4), num_classes=6, freeze_rgb_encoder=True):
    from tensorflow.keras.applications import ResNet50
    from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, MaxPooling2D, concatenate, Dropout
    from tensorflow.keras.models import Model

    # --- Input ---
    inputs = Input(shape=input_shape, name="model_input")

    # --- Split RGB and Elevation ---
    rgb = inputs[..., :3]
    elev = inputs[..., 3:]

    # --- ResNet50 Backbone ---
    base_model = ResNet50(include_top=False, weights='imagenet', input_tensor=rgb, name="encoder")
    if freeze_rgb_encoder:
        for layer in base_model.layers:
            layer.trainable = False

    # Encoder feature maps
    x1 = base_model.get_layer("conv1_relu").output       # 128x128
    x2 = base_model.get_layer("conv2_block3_out").output # 64x64
    x3 = base_model.get_layer("conv3_block4_out").output # 32x32
    x4 = base_model.get_layer("conv4_block6_out").output # 16x16
    x5 = base_model.get_layer("conv5_block3_out").output # 8x8

    # --- Elevation branch ---
    e1 = Conv2D(64, 3, padding="same", activation="relu")(elev)    # 256x256
    e2 = MaxPooling2D()(e1)                                         # 128x128
    e2 = Conv2D(128, 3, padding="same", activation="relu")(e2)
    e3 = MaxPooling2D()(e2)                                         # 64x64
    e3 = Conv2D(256, 3, padding="same", activation="relu")(e3)
    e4 = MaxPooling2D()(e3)                                         # 32x32
    e4 = Conv2D(512, 3, padding="same", activation="relu")(e4)
    e5 = MaxPooling2D()(e4)                                         # 16x16
    e5 = Conv2D(1024, 3, padding="same", activation="relu")(e5)

    # --- Concatenate elevation with encoder ---
    x1 = Conv2D(64, 3, padding="same", activation="relu")(concatenate([x1, e1]))
    x2 = Conv2D(128, 3, padding="same", activation="relu")(concatenate([x2, e2]))
    x3 = Conv2D(256, 3, padding="same", activation="relu")(concatenate([x3, e3]))
    x4 = Conv2D(512, 3, padding="same", activation="relu")(concatenate([x4, e4]))
    x5 = Conv2D(1024, 3, padding="same", activation="relu")(concatenate([x5, e5]))

    # --- Decoder ---
    d1 = UpSampling2D()(x5)
    d1 = concatenate([d1, x4])
    d1 = Conv2D(512, 3, padding="same", activation="relu")(d1)
    d1 = Dropout(0.2)(d1)

    d2 = UpSampling2D()(d1)
    d2 = concatenate([d2, x3])
    d2 = Conv2D(256, 3, padding="same", activation="relu")(d2)
    d2 = Dropout(0.2)(d2)

    d3 = UpSampling2D()(d2)
    d3 = concatenate([d3, x2])
    d3 = Conv2D(128, 3, padding="same", activation="relu")(d3)
    d3 = Dropout(0.2)(d3)

    d4 = UpSampling2D()(d3)
    d4 = concatenate([d4, x1])
    d4 = Conv2D(64, 3, padding="same", activation="relu")(d4)
    d4 = Dropout(0.2)(d4)

    d5 = UpSampling2D()(d4)
    d5 = Conv2D(32, 3, padding="same", activation="relu")(d5)

    outputs = Conv2D(num_classes, 1, activation="softmax")(d5)

    model = Model(inputs=inputs, outputs=outputs)
    return model, base_model




def enhanced_unet(input_shape=(256, 256, 4), num_classes=6, dropout_rate=0.2):
    def conv_block(x, filters):
        x_skip = x
        x = Conv2D(filters, (3,3), padding="same", kernel_initializer="he_normal")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3,3), padding="same", kernel_initializer="he_normal")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout_rate)(x)
        return x, x_skip

    def decoder_block(x, skip, filters):
        x = Conv2DTranspose(filters, (2,2), strides=(2,2), padding="same")(x)
        x = concatenate([x, skip])
        x = Conv2D(filters, (3,3), padding="same", kernel_initializer="he_normal")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = Conv2D(filters, (3,3), padding="same", kernel_initializer="he_normal")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
        x = SpatialDropout2D(dropout_rate)(x)
        return x

    inputs = Input(shape=input_shape)

    # Encoder
    c1, s1 = conv_block(inputs, 32)
    p1 = MaxPooling2D((2,2))(c1)

    c2, s2 = conv_block(p1, 64)
    p2 = MaxPooling2D((2,2))(c2)

    c3, s3 = conv_block(p2, 128)
    p3 = MaxPooling2D((2,2))(c3)

    c4, s4 = conv_block(p3, 256)
    p4 = MaxPooling2D((2,2))(c4)

    # Bottleneck
    c5, _ = conv_block(p4, 512)

    # Decoder
    u6 = decoder_block(c5, s4, 256)
    u7 = decoder_block(u6, s3, 128)
    u8 = decoder_block(u7, s2, 64)
    u9 = decoder_block(u8, s1, 32)

    outputs = Conv2D(num_classes, (1,1), activation="softmax")(u9)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model



def build_unet_aux(input_shape=(256, 256, 3), num_classes=6):
    inputs = Input(shape=input_shape)

    # --- Encoder ---
    c1 = Conv2D(16, (3, 3), padding="same")(inputs)
    c1 = BatchNormalization()(c1)
    c1 = LeakyReLU()(c1)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(16, (3, 3), padding="same")(c1)
    c1 = BatchNormalization()(c1)
    c1 = LeakyReLU()(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(32, (3, 3), padding="same")(p1)
    c2 = BatchNormalization()(c2)
    c2 = LeakyReLU()(c2)
    c2 = Dropout(0.2)(c2)
    c2 = Conv2D(32, (3, 3), padding="same")(c2)
    c2 = BatchNormalization()(c2)
    c2 = LeakyReLU()(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(64, (3, 3), padding="same")(p2)
    c3 = BatchNormalization()(c3)
    c3 = LeakyReLU()(c3)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(64, (3, 3), padding="same")(c3)
    c3 = BatchNormalization()(c3)
    c3 = LeakyReLU()(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = Conv2D(128, (3, 3), padding="same")(p3)
    c4 = BatchNormalization()(c4)
    c4 = LeakyReLU()(c4)
    c4 = Dropout(0.3)(c4)
    c4 = Conv2D(128, (3, 3), padding="same")(c4)
    c4 = BatchNormalization()(c4)
    c4 = LeakyReLU()(c4)
    p4 = MaxPooling2D((2, 2))(c4)

    c5 = Conv2D(256, (3, 3), padding="same")(p4)
    c5 = BatchNormalization()(c5)
    c5 = LeakyReLU()(c5)
    c5 = Dropout(0.4)(c5)
    c5 = Conv2D(256, (3, 3), padding="same")(c5)
    c5 = BatchNormalization()(c5)
    c5 = LeakyReLU()(c5)

    # --- Decoder ---
    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding="same")(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(128, (3, 3), padding="same")(u6)
    c6 = BatchNormalization()(c6)
    c6 = LeakyReLU()(c6)
    c6 = Dropout(0.3)(c6)
    c6 = Conv2D(128, (3, 3), padding="same")(c6)
    c6 = BatchNormalization()(c6)
    c6 = LeakyReLU()(c6)

    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding="same")(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(64, (3, 3), padding="same")(u7)
    c7 = BatchNormalization()(c7)
    c7 = LeakyReLU()(c7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv2D(64, (3, 3), padding="same")(c7)
    c7 = BatchNormalization()(c7)
    c7 = LeakyReLU()(c7)

    # Auxiliary output (upsample to final output size)
    aux_out = UpSampling2D(size=(4, 4), interpolation='bilinear')(c7)
    aux_out = Conv2D(num_classes, (1, 1), activation="softmax", name="aux_output")(aux_out)

    u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding="same")(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(32, (3, 3), padding="same")(u8)
    c8 = BatchNormalization()(c8)
    c8 = LeakyReLU()(c8)
    c8 = Dropout(0.2)(c8)
    c8 = Conv2D(32, (3, 3), padding="same")(c8)
    c8 = BatchNormalization()(c8)
    c8 = LeakyReLU()(c8)

    u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding="same")(c8)
    u9 = concatenate([u9, c1])
    c9 = Conv2D(16, (3, 3), padding="same")(u9)
    c9 = BatchNormalization()(c9)
    c9 = LeakyReLU()(c9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv2D(16, (3, 3), padding="same")(c9)
    c9 = BatchNormalization()(c9)
    c9 = LeakyReLU()(c9)

    outputs = Conv2D(num_classes, (1, 1), activation="softmax", name="main_output")(c9)

    model = Model(inputs=[inputs], outputs=[outputs, aux_out])
    return model



def build_multi_unet(input_shape=(256, 256, 3), num_classes=6):

  inputs = Input(shape=input_shape)
  source_input = inputs

  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(source_input)
  c1 = Dropout(0.2)(c1)
  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c1)
  p1 = MaxPooling2D((2,2))(c1)

  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p1)
  c2 = Dropout(0.2)(c2)
  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c2)
  p2 = MaxPooling2D((2,2))(c2)

  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p2)
  c3 = Dropout(0.2)(c3)
  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c3)
  p3 = MaxPooling2D((2,2))(c3)

  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p3)
  c4 = Dropout(0.2)(c4)
  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c4)
  p4 = MaxPooling2D((2,2))(c4)

  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
  c5 = Dropout(0.2)(c5)
  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)

  u6 = Conv2DTranspose(128, (2,2), strides=(2,2), padding="same")(c5)
  u6 = concatenate([u6, c4])
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u6)
  c6 = Dropout(0.2)(c6)
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c6)

  u7 = Conv2DTranspose(64, (2,2), strides=(2,2), padding="same")(c6)
  u7 = concatenate([u7, c3])
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c7)

  u8 = Conv2DTranspose(32, (2,2), strides=(2,2), padding="same")(c7)
  u8 = concatenate([u8, c2])
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u8)
  c8 = Dropout(0.2)(c8)
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c8)

  u9 = Conv2DTranspose(16, (2,2), strides=(2,2), padding="same")(c8)
  u9 = concatenate([u9, c1], axis=3)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u9)
  c9 = Dropout(0.2)(c9)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c9)

  outputs = Conv2D(num_classes, (1,1), activation="softmax")(c9)

  model = Model(inputs=[inputs], outputs=[outputs])
  return model
     

def build_unet(input_shape=(256, 256, 3), num_classes=6):
    print("🧪 build_unet called with input_shape =", input_shape)
    inputs = layers.Input(shape=input_shape)
    print("🧪 Input layer constructed with shape:", inputs.shape)

    # --- Contracting Path ---
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)

    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)

    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)

    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D((2,2))(c4)

    # --- Bottleneck ---
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(c5)

    # --- Expansive Path ---
    u6 = layers.UpSampling2D((2,2))(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c6)

    u7 = layers.UpSampling2D((2,2))(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c7)

    u8 = layers.UpSampling2D((2,2))(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c8)

    u9 = layers.UpSampling2D((2,2))(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c9)

    # --- Output Layer ---
    outputs = layers.Conv2D(num_classes, (1,1), activation='softmax')(c9)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    print("✅ U-Net model built successfully.")
    print("🧪 Final model.input_shape =", model.input_shape)

    return model

