In [None]:
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, Conv2DTranspose, MaxPooling2D, Dropout, UpSampling2D, 
    BatchNormalization, LeakyReLU, concatenate, Cropping2D
    )

from tensorflow.keras import backend as K


# --- sanity ---
def test_models_sanity():
    print("✅ from models.ipynb")





def SegFormer_B0(input_shape, num_classes):
    input_layer = keras.layers.Input(shape=input_shape)
    x = MixVisionTransformer(
        img_size=input_shape[1],
        embed_dims=MODEL_CONFIGS["mit_b0"]["embed_dims"],
        depths=MODEL_CONFIGS["mit_b0"]["depths"],
    )(input_layer)
    x = SegFormerHead(
        num_classes=num_classes,
        decode_dim=MODEL_CONFIGS["mit_b0"]["decode_dim"],
    )(x)

    x = ResizeLayer(input_shape[0], input_shape[1])(x)
    x = ops.softmax(x)
    return keras.Model(inputs=input_layer, outputs=x)




def build_unet_aux(input_shape=(256, 256, 3), num_classes=6):
    inputs = Input(shape=input_shape)

    # --- Encoder ---
    c1 = Conv2D(16, (3, 3), padding="same")(inputs)
    c1 = BatchNormalization()(c1)
    c1 = LeakyReLU()(c1)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(16, (3, 3), padding="same")(c1)
    c1 = BatchNormalization()(c1)
    c1 = LeakyReLU()(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(32, (3, 3), padding="same")(p1)
    c2 = BatchNormalization()(c2)
    c2 = LeakyReLU()(c2)
    c2 = Dropout(0.2)(c2)
    c2 = Conv2D(32, (3, 3), padding="same")(c2)
    c2 = BatchNormalization()(c2)
    c2 = LeakyReLU()(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(64, (3, 3), padding="same")(p2)
    c3 = BatchNormalization()(c3)
    c3 = LeakyReLU()(c3)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(64, (3, 3), padding="same")(c3)
    c3 = BatchNormalization()(c3)
    c3 = LeakyReLU()(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = Conv2D(128, (3, 3), padding="same")(p3)
    c4 = BatchNormalization()(c4)
    c4 = LeakyReLU()(c4)
    c4 = Dropout(0.3)(c4)
    c4 = Conv2D(128, (3, 3), padding="same")(c4)
    c4 = BatchNormalization()(c4)
    c4 = LeakyReLU()(c4)
    p4 = MaxPooling2D((2, 2))(c4)

    c5 = Conv2D(256, (3, 3), padding="same")(p4)
    c5 = BatchNormalization()(c5)
    c5 = LeakyReLU()(c5)
    c5 = Dropout(0.4)(c5)
    c5 = Conv2D(256, (3, 3), padding="same")(c5)
    c5 = BatchNormalization()(c5)
    c5 = LeakyReLU()(c5)

    # --- Decoder ---
    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding="same")(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(128, (3, 3), padding="same")(u6)
    c6 = BatchNormalization()(c6)
    c6 = LeakyReLU()(c6)
    c6 = Dropout(0.3)(c6)
    c6 = Conv2D(128, (3, 3), padding="same")(c6)
    c6 = BatchNormalization()(c6)
    c6 = LeakyReLU()(c6)

    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding="same")(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(64, (3, 3), padding="same")(u7)
    c7 = BatchNormalization()(c7)
    c7 = LeakyReLU()(c7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv2D(64, (3, 3), padding="same")(c7)
    c7 = BatchNormalization()(c7)
    c7 = LeakyReLU()(c7)

    # Auxiliary output (upsample to final output size)
    aux_out = UpSampling2D(size=(4, 4), interpolation='bilinear')(c7)
    aux_out = Conv2D(num_classes, (1, 1), activation="softmax", name="aux_output")(aux_out)

    u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding="same")(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(32, (3, 3), padding="same")(u8)
    c8 = BatchNormalization()(c8)
    c8 = LeakyReLU()(c8)
    c8 = Dropout(0.2)(c8)
    c8 = Conv2D(32, (3, 3), padding="same")(c8)
    c8 = BatchNormalization()(c8)
    c8 = LeakyReLU()(c8)

    u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding="same")(c8)
    u9 = concatenate([u9, c1])
    c9 = Conv2D(16, (3, 3), padding="same")(u9)
    c9 = BatchNormalization()(c9)
    c9 = LeakyReLU()(c9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv2D(16, (3, 3), padding="same")(c9)
    c9 = BatchNormalization()(c9)
    c9 = LeakyReLU()(c9)

    outputs = Conv2D(num_classes, (1, 1), activation="softmax", name="main_output")(c9)

    model = Model(inputs=[inputs], outputs=[outputs, aux_out])
    return model



def build_multi_unet(input_shape=(256, 256, 3), num_classes=6):

  inputs = Input(shape=input_shape)
  source_input = inputs

  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(source_input)
  c1 = Dropout(0.2)(c1)
  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c1)
  p1 = MaxPooling2D((2,2))(c1)

  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p1)
  c2 = Dropout(0.2)(c2)
  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c2)
  p2 = MaxPooling2D((2,2))(c2)

  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p2)
  c3 = Dropout(0.2)(c3)
  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c3)
  p3 = MaxPooling2D((2,2))(c3)

  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p3)
  c4 = Dropout(0.2)(c4)
  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c4)
  p4 = MaxPooling2D((2,2))(c4)

  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
  c5 = Dropout(0.2)(c5)
  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)

  u6 = Conv2DTranspose(128, (2,2), strides=(2,2), padding="same")(c5)
  u6 = concatenate([u6, c4])
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u6)
  c6 = Dropout(0.2)(c6)
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c6)

  u7 = Conv2DTranspose(64, (2,2), strides=(2,2), padding="same")(c6)
  u7 = concatenate([u7, c3])
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c7)

  u8 = Conv2DTranspose(32, (2,2), strides=(2,2), padding="same")(c7)
  u8 = concatenate([u8, c2])
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u8)
  c8 = Dropout(0.2)(c8)
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c8)

  u9 = Conv2DTranspose(16, (2,2), strides=(2,2), padding="same")(c8)
  u9 = concatenate([u9, c1], axis=3)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u9)
  c9 = Dropout(0.2)(c9)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c9)

  outputs = Conv2D(num_classes, (1,1), activation="softmax")(c9)

  model = Model(inputs=[inputs], outputs=[outputs])
  return model
     

def build_unet(input_shape=(256, 256, 3), num_classes=6):
    print("🧪 build_unet called with input_shape =", input_shape)
    inputs = layers.Input(shape=input_shape)
    print("🧪 Input layer constructed with shape:", inputs.shape)

    # --- Contracting Path ---
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)

    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)

    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)

    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D((2,2))(c4)

    # --- Bottleneck ---
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(c5)

    # --- Expansive Path ---
    u6 = layers.UpSampling2D((2,2))(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c6)

    u7 = layers.UpSampling2D((2,2))(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c7)

    u8 = layers.UpSampling2D((2,2))(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c8)

    u9 = layers.UpSampling2D((2,2))(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c9)

    # --- Output Layer ---
    outputs = layers.Conv2D(num_classes, (1,1), activation='softmax')(c9)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    print("✅ U-Net model built successfully.")
    print("🧪 Final model.input_shape =", model.input_shape)

    return model

