In [12]:
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input,
    Dense,
    Flatten,
    Multiply,
    UpSampling2D,
    Activation,
    Conv2D,
)
import tensorflow as tf

# Load a pre-trained segmentation model (e.g., DenseNet121)
segmentation_model = tf.keras.applications.DenseNet121(
    include_top=False,
    weights="imagenet",
    input_shape=(224, 224, 3),
    pooling=None,  # No global pooling, maintain spatial dimensions
)

# Input layer
input_layer = Input(shape=(224, 224, 3))

# Segmentation step
segmentation_output = segmentation_model(input_layer)

# Apply sigmoid activation using a Keras Activation layer
segmentation_mask = Activation("sigmoid")(segmentation_output)

# Ensure segmentation mask has 3 channels to match input
segmentation_mask = Conv2D(3, (1, 1), activation="sigmoid")(segmentation_mask)

# Resize segmentation output to match input dimensions (224x224)
segmentation_mask = UpSampling2D(
    size=(224 // segmentation_mask.shape[1], 224 // segmentation_mask.shape[2])
)(segmentation_mask)

# Apply segmentation mask to input image using Keras Multiply layer
segmented_input = Multiply()([input_layer, segmentation_mask])

# Classification network (VGG19 as base model)
base_model = VGG19(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
x = base_model(segmented_input)
x = Flatten(name="flatten")(x)
x = Dense(4096, activation="relu", name="fc1")(x)
x = Dense(4096, activation="relu", name="fc2")(x)
x = Dense(36, activation="softmax", name="predictions")(x)

# Final model
model = Model(inputs=input_layer, outputs=x)

# Freeze VGG19 base model layers
for layer in base_model.layers:
    layer.trainable = False

model.summary()
