<a href="https://colab.research.google.com/github/ZaryabRahman/Stability-and-Expression-The-dual-mechanism-of-normalization-in-deep-learning/blob/main/NIH_chestxray.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MobilenetV2- mobile net with spatial attention mechanism


In [None]:
from keras.applications.mobilenet import MobileNet
from keras.layers import Input, Conv2D, BatchNormalization, multiply, GlobalAveragePooling2D, Dropout, Dense
from keras.models import Model
import tensorflow as tf

def create_attention_and_classifier(input_features, num_labels):
    """
    Creates the spatial attention block and the final classifier.
    This block takes feature maps from a base model and produces two outputs.

    Args:
      input_features: The input tensor (features from the base model).
      num_labels: The number of output classes.

    Returns:
      A tuple containing (classification_output_layer, attention_map_layer)."""

    new_features = BatchNormalization()(input_features)

    attention_layer = Conv2D(64, kernel_size=(1, 1), padding='same', activation='elu')(new_features)
    attention_layer = Conv2D(32, kernel_size=(1, 1), padding='same', activation='elu')(attention_layer)
    attention_map = Conv2D(1,
                           kernel_size=(1, 1),
                           padding='valid',
                           activation='sigmoid',
                           name='attention_map')(attention_layer)

    mask_features = multiply([attention_map, new_features])
    gap_features = GlobalAveragePooling2D()(mask_features)

    x = Dropout(0.5)(gap_features)
    x = Dense(512, activation='elu')(x)
    x = Dropout(0.5)(x)
    classification_output = Dense(num_labels, activation='sigmoid', name='classification_output')(x)

    return classification_output, attention_map

input_tensor = Input(shape=t_x.shape[1:])

base_mobilenet_model = MobileNet(input_tensor=input_tensor,
                                 include_top=False,
                                 weights=None)

base_model_output = base_mobilenet_model.output

final_output, attention_map_output = create_attention_and_classifier(base_model_output, len(all_labels))

attention_mobilenet_model = Model(
    inputs=input_tensor,
    outputs=[final_output, attention_map_output]
)

attention_mobilenet_model.compile(
    optimizer='adam',
    loss={'classification_output': 'binary_crossentropy'},
    loss_weights={'classification_output': 1.0, 'attention_map': 0.0},
    metrics={'classification_output': ['binary_accuracy', 'mae']}
)

attention_mobilenet_model.summary()

# Callback Definition

In [None]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import optimizers

weight_path = "{}_mobilenet_attention_weights.best.hdf5".format('xray_class')

early = EarlyStopping(monitor="val_classification_output_loss",
                      mode="min",
                      patience=5)

checkpoint = ModelCheckpoint(filepath=weight_path,
                             monitor='val_classification_output_loss',
                             mode='min',
                             save_best_only=True,
                             save_weights_only=True)

callbacks_list = [early, checkpoint]

# Model Training

In [None]:
import matplotlib.pyplot as plt
from keras import optimizers

optimizers_list = [
    ('adam', optimizers.Adam())
]

plt.figure(figsize=(20,5))

for optimizer in optimizers_list:
    attention_mobilenet_model.compile(
        optimizer=optimizer[1],
        loss={'classification_output': 'binary_crossentropy'},
        loss_weights={'classification_output': 1.0, 'attention_map': 0.0},
        metrics={'classification_output': ['binary_accuracy', 'mae']}
    )

    history = attention_mobilenet_model.fit_generator(train_gen,
                                  steps_per_epoch=1000,
                                  validation_data=(test_X, test_Y),
                                  epochs=50,
                                  callbacks=callbacks_list)

    plt.plot(history.history['val_classification_output_loss'])

plt.legend([x[0] for x in optimizers_list], loc='upper right')
plt.title('model validation loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
plt.savefig('optimizer_selection_attention.png', bbox_inches='tight')

# Attention Mechanism Visualization

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt

attention_mobilenet_model.load_weights(weight_path)

sample_image = test_X[0]
image_for_prediction = np.expand_dims(sample_image, axis=0)

predictions, attention_map = attention_mobilenet_model.predict(image_for_prediction)
attention_heatmap = np.squeeze(attention_map)

resized_heatmap = cv2.resize(attention_heatmap, (sample_image.shape[1], sample_image.shape[0]))

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))

ax1.imshow(sample_image)
ax1.set_title('Original Image')
ax1.axis('off')

ax2.imshow(resized_heatmap, cmap='jet')
ax2.set_title('Attention Heatmap')
ax2.axis('off')

ax3.imshow(sample_image)
ax3.imshow(resized_heatmap, cmap='jet', alpha=0.5)
ax3.set_title('Image with Attention Overlay')
ax3.axis('off')

plt.show()