In [None]:
from tensorflow.keras import layers, models 
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, Dropout, concatenate, Input
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2DTranspose, Cropping2D


def test_models_sanity():
    print("‚úÖ from models.ipynb")


def multi_unet_model(input_shape, num_classes):
    inputs = Input(input_shape)

    # Downsampling Path
    c1 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = Dropout(0.1)(c2)
    c2 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = Dropout(0.2)(c4)
    c4 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = MaxPooling2D((2, 2))(c4)

    c5 = Conv2D(1024, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = Dropout(0.3)(c5)
    c5 = Conv2D(1024, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

    # Upsampling Path
    u6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding="same")(c5)
    # Calculate the cropping amount needed to match c4's height and width
    # Note: This assumes u6 is equal to or larger than c4
    # If c4 is larger, the cropping should be applied to c4 or the upsampling adjusted
    # Based on the error (36 vs 37), u6 is likely smaller, so adjust u6 size instead of cropping c4
    # Let's adjust the Conv2DTranspose padding or stride if needed, or use UpSampling2D for simpler scaling.
    # A simpler approach is often UpSampling2D followed by Conv2D, which explicitly doubles size.
    # Let's switch to UpSampling2D + Conv2D for potentially more predictable sizing.

    # Replace Conv2DTranspose with UpSampling2D + Conv2D
    u6 = UpSampling2D((2, 2))(c5)
    u6 = Conv2D(512, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')(u6)


    # *** CROPPING ADDED HERE ***
    # Get the target shape (height and width) from c4
    target_h = tf.shape(c4)[1] # Height
    target_w = tf.shape(c4)[2] # Width

    # Get the current shape (height and width) of u6
    current_h = tf.shape(u6)[1]
    current_w = tf.shape(u6)[2]

    # Calculate the amount to crop from top, bottom, left, right
    # This handles cases where one dimension is slightly off.
    # It crops symmetrically if the difference is even, or slightly unevenly if odd.
    crop_h = current_h - target_h
    crop_w = current_w - target_w

    # Ensure crop values are non-negative before calculating padding
    crop_h = tf.maximum(0, crop_h)
    crop_w = tf.maximum(0, crop_w)

    cropping_amount = ((crop_h // 2, crop_h - crop_h // 2),
                       (crop_w // 2, crop_w - crop_w // 2))

    # Only apply cropping if necessary
    if crop_h > 0 or crop_w > 0:
         u6 = Cropping2D(cropping=cropping_amount)(u6)
         print(f"‚ÑπÔ∏è Cropping u6 by {cropping_amount} to match c4 shape for concatenation.")


    # *** END CROPPING ***


    u6 = concatenate([u6, c4]) # Now u6 and c4 should have matching spatial dimensions
    c6 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = Dropout(0.2)(c6)
    c6 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

    u7 = UpSampling2D((2, 2))(c6)
    u7 = Conv2D(256, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')(u7)

    # *** CROPPING ADDED HERE ***
    target_h = tf.shape(c3)[1]
    target_w = tf.shape(c3)[2]
    current_h = tf.shape(u7)[1]
    current_w = tf.shape(u7)[2]

    crop_h = tf.maximum(0, current_h - target_h)
    crop_w = tf.maximum(0, current_w - target_w)
    cropping_amount = ((crop_h // 2, crop_h - crop_h // 2),
                       (crop_w // 2, crop_w - crop_w // 2))
    if crop_h > 0 or crop_w > 0:
        u7 = Cropping2D(cropping=cropping_amount)(u7)
        print(f"‚ÑπÔ∏è Cropping u7 by {cropping_amount} to match c3 shape for concatenation.")
    # *** END CROPPING ***

    u7 = concatenate([u7, c3])
    c7 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

    u8 = UpSampling2D((2, 2))(c7)
    u8 = Conv2D(128, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')(u8)

    # *** CROPPING ADDED HERE ***
    target_h = tf.shape(c2)[1]
    target_w = tf.shape(c2)[2]
    current_h = tf.shape(u8)[1]
    current_w = tf.shape(u8)[2]

    crop_h = tf.maximum(0, current_h - target_h)
    crop_w = tf.maximum(0, current_w - target_w)
    cropping_amount = ((crop_h // 2, crop_h - crop_h // 2),
                       (crop_w // 2, crop_w - crop_w // 2))
    if crop_h > 0 or crop_w > 0:
        u8 = Cropping2D(cropping=cropping_amount)(u8)
        print(f"‚ÑπÔ∏è Cropping u8 by {cropping_amount} to match c2 shape for concatenation.")
    # *** END CROPPING ***

    u8 = concatenate([u8, c2])
    c8 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = Dropout(0.1)(c8)
    c8 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    u9 = UpSampling2D((2, 2))(c8)
    u9 = Conv2D(64, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')(u9)

    # *** CROPPING ADDED HERE ***
    target_h = tf.shape(c1)[1]
    target_w = tf.shape(c1)[2]
    current_h = tf.shape(u9)[1]
    current_w = tf.shape(u9)[2]

    crop_h = tf.maximum(0, current_h - target_h)
    crop_w = tf.maximum(0, current_w - target_w)
    cropping_amount = ((crop_h // 2, crop_h - crop_h // 2),
                       (crop_w // 2, crop_w - crop_w // 2))
    if crop_h > 0 or crop_w > 0:
        u9 = Cropping2D(cropping=cropping_amount)(u9)
        print(f"‚ÑπÔ∏è Cropping u9 by {cropping_amount} to match c1 shape for concatenation.")
    # *** END CROPPING ***


    u9 = concatenate([u9, c1])
    c9 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

    outputs = Conv2D(num_classes, (1, 1), activation='softmax')(c9)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

def multi_unet_model_old(input_shape=(512, 512, 4), num_classes=6):

  inputs = Input(shape=input_shape)
  source_input = inputs

  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(source_input)
  c1 = Dropout(0.2)(c1)
  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c1)
  p1 = MaxPooling2D((2,2))(c1)

  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p1)
  c2 = Dropout(0.2)(c2)
  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c2)
  p2 = MaxPooling2D((2,2))(c2)

  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p2)
  c3 = Dropout(0.2)(c3)
  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c3)
  p3 = MaxPooling2D((2,2))(c3)

  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p3)
  c4 = Dropout(0.2)(c4)
  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c4)
  p4 = MaxPooling2D((2,2))(c4)

  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
  c5 = Dropout(0.2)(c5)
  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)

  u6 = Conv2DTranspose(128, (2,2), strides=(2,2), padding="same")(c5)
  u6 = concatenate([u6, c4])
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u6)
  c6 = Dropout(0.2)(c6)
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c6)

  u7 = Conv2DTranspose(64, (2,2), strides=(2,2), padding="same")(c6)
  u7 = concatenate([u7, c3])
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c7)

  u8 = Conv2DTranspose(32, (2,2), strides=(2,2), padding="same")(c7)
  u8 = concatenate([u8, c2])
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u8)
  c8 = Dropout(0.2)(c8)
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c8)

  u9 = Conv2DTranspose(16, (2,2), strides=(2,2), padding="same")(c8)
  u9 = concatenate([u9, c1], axis=3)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u9)
  c9 = Dropout(0.2)(c9)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c9)

  outputs = Conv2D(num_classes, (1,1), activation="softmax")(c9)

  model = Model(inputs=[inputs], outputs=[outputs])
  return model
     

def build_unet(input_shape=(512, 512, 4), num_classes=6):
    print("üß™ build_unet called with input_shape =", input_shape)
    inputs = layers.Input(shape=input_shape)
    print("üß™ Input layer constructed with shape:", inputs.shape)

    # --- Contracting Path ---
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)

    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)

    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)

    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D((2,2))(c4)

    # --- Bottleneck ---
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(c5)

    # --- Expansive Path ---
    u6 = layers.UpSampling2D((2,2))(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c6)

    u7 = layers.UpSampling2D((2,2))(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c7)

    u8 = layers.UpSampling2D((2,2))(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c8)

    u9 = layers.UpSampling2D((2,2))(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c9)

    # --- Output Layer ---
    outputs = layers.Conv2D(num_classes, (1,1), activation='softmax')(c9)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    print("‚úÖ U-Net model built successfully.")
    print("üß™ Final model.input_shape =", model.input_shape)

    return model

def build_segformer(input_shape=(512, 512, 4), num_classes=6):
    # Placeholder - You will replace with actual SegFormer loading (Huggingface transformers)
    raise NotImplementedError("SegFormer model building not yet implemented.")
