In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Concatenate, BatchNormalization, Activation
from tensorflow.keras.models import Model


def build_vgg16_unet(
    input_shape=(256, 256, 6),
    num_classes=1,
    freeze_encoder=True,
    use_imagenet_weights=True
):
    """
    VGG16-U-Net for segmentation.

    If input_shape[-1] != 3, a 1x1 Conv2D "band adapter" maps C->3 before feeding VGG16.
    This lets you use 6-band inputs while still using VGG16 (optionally pretrained).

    Returns: tf.keras.Model
    """

    inputs = Input(input_shape, name="input_6band")

    # Map 6-band -> 3-band for VGG16
    if input_shape[-1] != 3:
        x = Conv2D(3, (1, 1), padding="same", name="band_adapter")(inputs)
    else:
        x = inputs

    weights = "imagenet" if use_imagenet_weights else None
    vgg = tf.keras.applications.VGG16(
        include_top=False,
        weights=weights,
        input_tensor=x
    )

    if freeze_encoder:
        for layer in vgg.layers:
            layer.trainable = False

    # Skip connections
    s1 = vgg.get_layer("block1_conv2").output  # 256x256x64
    s2 = vgg.get_layer("block2_conv2").output  # 128x128x128
    s3 = vgg.get_layer("block3_conv3").output  # 64x64x256
    s4 = vgg.get_layer("block4_conv3").output  # 32x32x512

    # Bridge
    b = vgg.get_layer("block5_conv3").output   # 16x16? (depends on pooling) for 256 it is 16x16 before pool5, then pool gives 8x8
    # In VGG16 include_top=False, outputs are after block5_pool (8x8x512).
    # block5_conv3 is 16x16x512. We'll use block5_conv3 as bridge.

    def dec_block(x, skip, filters, name):
        x = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding="same", name=f"{name}_up")(x)
        x = Concatenate(name=f"{name}_concat")([x, skip])
        x = Conv2D(filters, (3, 3), padding="same", name=f"{name}_conv1")(x)
        x = BatchNormalization(name=f"{name}_bn1")(x)
        x = Activation("relu", name=f"{name}_relu1")(x)
        x = Conv2D(filters, (3, 3), padding="same", name=f"{name}_conv2")(x)
        x = BatchNormalization(name=f"{name}_bn2")(x)
        x = Activation("relu", name=f"{name}_relu2")(x)
        return x

    # Decoder (reverse)
    d1 = dec_block(b,  s4, 512, "dec4")   # 32x32
    d2 = dec_block(d1, s3, 256, "dec3")   # 64x64
    d3 = dec_block(d2, s2, 128, "dec2")   # 128x128
    d4 = dec_block(d3, s1, 64,  "dec1")   # 256x256

    # Name this layer so Grad-CAM can target it reliably
    cam_feat = Conv2D(64, (3, 3), padding="same", name="decoder_cam_conv")(d4)
    cam_feat = Activation("relu", name="decoder_cam_relu")(cam_feat)

    if num_classes == 1:
        outputs = Conv2D(1, (1, 1), activation="sigmoid", name="mask")(cam_feat)
    else:
        outputs = Conv2D(num_classes, (1, 1), activation="softmax", name="mask")(cam_feat)

    model = Model(inputs=inputs, outputs=outputs, name="VGG16_UNet_6band")
    return model
