In [None]:
def channel_attention_1(input_feature, ratio=8):
    channel = input_feature.shape[-1]
    
    shared_layer_one = layers.Dense(channel // ratio,
                                    kernel_initializer='he_normal',
                                    use_bias=True,
                                    bias_initializer='zeros')
    
    shared_layer_two = layers.Dense(channel,
                                    kernel_initializer='he_normal',
                                    use_bias=True,
                                    bias_initializer='zeros')
    
    # Reshape to fit Conv layers
    
    x = layers.GlobalAveragePooling2D()(input_feature)
    
    x = shared_layer_one(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = shared_layer_two(x)
    x = layers.Activation('relu')(x)
    
    return layers.Multiply()([input_feature, x])

In [1]:
def spatial_attention(input_feature):
    avg_pool = layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(input_feature)
    max_pool = layers.Lambda(lambda x: tf.reduce_max(x, axis=-1, keepdims=True))(input_feature)
    
    concat = layers.Concatenate(axis=-1)([avg_pool, max_pool])
    attention = layers.Conv2D(1, (7, 7), strides=1, padding='same', activation='sigmoid', 
                              kernel_initializer='he_normal', use_bias=False)(concat)
    
    return layers.Multiply()([input_feature, attention])

def cbam_block(input_feature, ratio=8):
    # Channel attention
    channel_att = channel_attention_1(input_feature, ratio)
    
    # Spatial attention
    spatial_att = spatial_attention(channel_att)
    
    return spatial_att

In [2]:
base_model_cat = Xception(weights='imagenet', include_top=False, input_shape=input_shape)
base_model_cat.trainable = False

# Second pre-trained model (MobileNetV2) for the binary output
base_model_bin = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
base_model_bin.trainable = False

# Input layer
input_layer = tf.keras.Input(shape=input_shape)

# Categorical branch
x_cat = base_model_cat(input_layer, training=False)
x_cat  = spatial_attention(x_cat)
x_cat = layers.GlobalAveragePooling2D()(x_cat)

x_cat = layers.Dropout(0.4)(x_cat)
x_cat = layers.Dense(64, kernel_initializer='he_normal')(x_cat)
x_cat = layers.Activation('relu')(x_cat)
x_cat = layers.Dropout(0.4)(x_cat)
x_cat = layers.Dense(32, kernel_initializer='he_normal')(x_cat)
x_cat = layers.Activation('relu')(x_cat)
x_cat = layers.Dropout(0.4)(x_cat)
class_output = layers.Dense(40, activation='softmax', name='categorical_output')(x_cat)

# Binary branch
x_bin = base_model_bin(input_layer, training=False)
x_bin = layers.GlobalAveragePooling2D()(x_bin)

x_bin = layers.Dropout(0.4)(x_bin)
x_bin = layers.Dense(16, activation='relu', kernel_initializer='he_normal')(x_bin)
x_bin = layers.Dropout(0.4)(x_bin)
binary_output = layers.Dense(1, activation='sigmoid', name='binary_output')(x_bin)

# Final model
test_model = models.Model(inputs=input_layer, outputs=[class_output, binary_output])

# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
test_model.compile(optimizer=optimizer,
                   loss={'categorical_output': 'sparse_categorical_crossentropy',
                         'binary_output': 'binary_crossentropy'},
                   metrics={'categorical_output': 'accuracy',
                            'binary_output': tf.keras.metrics.Precision(name='precision')})

# Model summary
test_model.summary()

NameError: name 'Xception' is not defined

In [None]:
base_model_cat = Xception(weights='imagenet', include_top=False, input_shape=input_shape)
base_model_cat.trainable = False

# Second pre-trained model (MobileNetV2) for the binary output
base_model_bin = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
base_model_bin.trainable = False

# Input layer
input_layer = tf.keras.Input(shape=input_shape)

# Categorical branch
x_cat = base_model_cat(input_layer, training=False)
x_cat = layers.GlobalAveragePooling2D()(x_cat)
x_cat = layers.Dropout(0.4)(x_cat)
x_cat = layers.Dense(64, kernel_initializer='he_normal')(x_cat)
x_cat = layers.Activation('relu')(x_cat)
x_cat = layers.Dropout(0.4)(x_cat)
x_cat = layers.Dense(32, kernel_initializer='he_normal')(x_cat)
x_cat = layers.Activation('relu')(x_cat)
x_cat = layers.Dropout(0.4)(x_cat)
class_output = layers.Dense(40, activation='softmax', name='categorical_output')(x_cat)

# Binary branch
x_bin = base_model_bin(input_layer, training=False)
x_bin = layers.GlobalAveragePooling2D()(x_bin)
x_bin = layers.Dropout(0.4)(x_bin)
x_bin = layers.Dense(16, activation='relu', kernel_initializer='he_normal')(x_bin)
x_bin = layers.Dropout(0.4)(x_bin)
binary_output = layers.Dense(1, activation='sigmoid', name='binary_output')(x_bin)

# Final model
test_model = models.Model(inputs=input_layer, outputs=[class_output, binary_output])

# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
test_model.compile(optimizer=optimizer,
                   loss={'categorical_output': 'sparse_categorical_crossentropy',
                         'binary_output': 'binary_crossentropy'},
                   metrics={'categorical_output': 'accuracy',
                            'binary_output': tf.keras.metrics.Precision(name='precision')})

# Model summary
test_model.summary()

In [None]:
base_model = Xception(weights='imagenet', include_top=False, input_shape=input_shape)

base_model.trainable = False

input_layer = tf.keras.Input(shape=input_shape)

x = base_model(input_layer, training=False)

x = channel_attention_1(x)
x = layers.GlobalAveragePooling2D()(x) 
x = layers.Dropout(0.4)(x)

class_branch = layers.Dropout(0.4)(x)
class_branch = layers.Dense(64, activation='relu', kernel_initializer='he_normal', kernel_regularizer=regularizers.l2(0.01))(class_branch)
class_branch = layers.Dropout(0.4)(class_branch)
class_branch = layers.Dense(32, activation='relu', kernel_initializer='he_normal', kernel_regularizer=regularizers.l2(0.01))(class_branch)
class_branch = layers.Dropout(0.4)(class_branch)
class_branch = layers.Dense(40, activation='softmax', name='categorical_output')(class_branch)

person_branch = layers.Dropout(0.4)(x)
person_branch = layers.Dense(64, activation='relu', kernel_initializer='he_normal', kernel_regularizer=regularizers.l2(0.01))(person_branch)
person_branch = layers.Dropout(0.4)(person_branch)
person_branch_1 = layers.Dense(32, activation='relu', kernel_initializer='he_normal', kernel_regularizer=regularizers.l2(0.01))(person_branch)
person_branch = layers.Dropout(0.4)(person_branch_1)
person_branch = layers.Dense(1, activation='sigmoid', name='binary_output')(person_branch)

aux_branch = layers.Dropout(0.4)(person_branch_1)
aux_branch = layers.Dense(10, activation='softmax', name='aux_output')(aux_branch)

aug_model = models.Model(inputs=input_layer, outputs=[class_branch, person_branch, aux_branch])

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

loss_weights = {
    'categorical_output': 1.0,  # default weight for multi-class output
    'binary_output': 2.0,       # higher weight for binary output to emphasize it
    'aux_output': 1.0           # weight for additional output (if needed)
}

aug_model.compile(optimizer=optimizer,
              loss={'categorical_output': 'sparse_categorical_crossentropy',
                    'binary_output': 'binary_crossentropy',
                    'aux_output': 'sparse_categorical_crossentropy'},
              metrics={'categorical_output': 'accuracy',
                       'binary_output': tf.keras.metrics.Precision(name='precision'),
                       'aux_output': 'accuracy'},)