In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input

# -------------------- Metrics --------------------

def dice_coef(y_true, y_pred, smooth: float = 1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2.0 * intersection + smooth) / (
        tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth
    )


def iou_coef(y_true, y_pred, smooth: float = 1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)


# -------------------- Building blocks --------------------

def conv_bn_relu(x, filters: int, k: int = 3, name: str | None = None):
    x = layers.Conv2D(filters, k, padding="same", use_bias=False, name=None if name is None else name+"_conv")(x)
    x = layers.BatchNormalization(name=None if name is None else name+"_bn")(x)
    x = layers.ReLU(name=None if name is None else name+"_relu")(x)
    return x


def MSFA(x, out_filters: int, name: str | None = None):
    """Multi-Scale Feature Aggregation (1x1,3x3,5x5,7x7 -> concat).
    The sum of branch channels equals out_filters.
    """
    b = max(out_filters // 4, 1)
    p1 = conv_bn_relu(x, b, k=1, name=None if name is None else name+"_1x1")
    p3 = conv_bn_relu(x, b, k=3, name=None if name is None else name+"_3x3")
    p5 = conv_bn_relu(x, b, k=5, name=None if name is None else name+"_5x5")
    p7 = conv_bn_relu(x, b, k=7, name=None if name is None else name+"_7x7")
    x = layers.Concatenate(name=None if name is None else name+"_concat")([p1, p3, p5, p7])
    return x


def AttentionBlock(x, name: str | None = None):
    """Channel attention (GAP+GMP -> concat -> 1x1 conv -> sigmoid -> scale)."""
    c = x.shape[-1]
    gap = layers.GlobalAveragePooling2D(name=None if name is None else name+"_gap")(x)
    gmp = layers.GlobalMaxPooling2D(name=None if name is None else name+"_gmp")(x)
    gap = layers.Reshape((1, 1, c), name=None if name is None else name+"_gap_r")(gap)
    gmp = layers.Reshape((1, 1, c), name=None if name is None else name+"_gmp_r")(gmp)
    s = layers.Concatenate(axis=-1, name=None if name is None else name+"_cat")([gap, gmp])
    s = layers.Conv2D(c, 1, padding="same", name=None if name is None else name+"_conv1x1")(s)
    s = layers.Activation("sigmoid", name=None if name is None else name+"_sigmoid")(s)
    return layers.Multiply(name=None if name is None else name+"_scale")([x, s])


def DecoderBlock(x, skip, filters: int, name: str | None = None):
    # Upsample
    x = layers.Conv2DTranspose(filters, 2, strides=2, padding="same", use_bias=False,
                               name=None if name is None else name+"_up")(x)
    x = layers.BatchNormalization(name=None if name is None else name+"_up_bn")(x)
    x = layers.ReLU(name=None if name is None else name+"_up_relu")(x)

    # Attention on upsampled features
    x = AttentionBlock(x, name=None if name is None else name+"_att")

    # Process skip with MSFA and concatenate
    if skip is not None:
        skip = MSFA(skip, filters, name=None if name is None else name+"_msfa")
        x = layers.Concatenate(name=None if name is None else name+"_concat")([x, skip])

    # Refinement with residual
    y = conv_bn_relu(x, filters, 3, name=None if name is None else name+"_conv1")
    y = conv_bn_relu(y, filters, 3, name=None if name is None else name+"_conv2")
    res = layers.Conv2D(filters, 1, padding="same", name=None if name is None else name+"_proj")(x)
    out = layers.Add(name=None if name is None else name+"_add")([y, res])
    return out


# -------------------- Model builder --------------------

def build_msfa_attention_unet(
    input_shape: tuple[int, int, int] = (256, 256, 3),
    num_classes: int = 1,
    backbone_trainable: bool = False,
) -> Model:
    inputs = Input(shape=input_shape)

    backbone = tf.keras.applications.MobileNetV2(
        include_top=False, input_tensor=inputs, weights="imagenet"
    )
    backbone.trainable = backbone_trainable

    # Feature taps
    s1 = backbone.get_layer("block_1_expand_relu").output   # 64x64
    s2 = backbone.get_layer("block_3_expand_relu").output   # 32x32
    s3 = backbone.get_layer("block_6_expand_relu").output   # 16x16
    s4 = backbone.get_layer("block_13_expand_relu").output  # 8x8
    b  = backbone.get_layer("block_16_project").output      # 8x8 bottleneck

    # Bottleneck MSFA
    x = MSFA(b, 512, name="bottleneck_msfa")

    # Decoder (4 steps)
    x = DecoderBlock(x, s4, 256, name="dec4")  # 8 -> 16
    x = DecoderBlock(x, s3, 128, name="dec3")  # 16 -> 32
    x = DecoderBlock(x, s2, 64,  name="dec2")  # 32 -> 64
    x = DecoderBlock(x, s1, 32,  name="dec1")  # 64 -> 128

    # Final upsample to 256x256
    x = layers.Conv2DTranspose(32, 2, strides=2, padding="same", name="final_up")(x)

    # Output head
    activation = "sigmoid" if num_classes == 1 else "softmax"
    outputs = layers.Conv2D(num_classes, 1, activation=activation, name="mask")(x)

    return Model(inputs, outputs, name="MSFA_Attention_MobileNetV2")


# -------------------- Example usage --------------------
if __name__ == "__main__":
    model = build_msfa_attention_unet(input_shape=(256, 256, 3), num_classes=1, backbone_trainable=False)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss="binary_crossentropy",
        metrics=[dice_coef, iou_coef, "accuracy"],
    )
    model.summary()
