In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt

In [2]:
# Custom Self-Attention Layer
class SelfAttention(layers.Layer):
    def __init__(self, channels, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        self.channels = channels
        self.attention_dense = layers.Dense(channels, activation='sigmoid')

    def call(self, inputs):
        attention_weights = self.attention_dense(inputs)
        return inputs * attention_weights

In [3]:
class F1Score(tf.keras.metrics.Metric):
    def __init__(self, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision = tf.keras.metrics.Precision()
        self.recall = tf.keras.metrics.Recall()

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Update the precision and recall for each batch
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)

    def result(self):
        # Calculate F1 score as the harmonic mean of precision and recall
        precision = self.precision.result()
        recall = self.recall.result()
        return 2 * (precision * recall) / (precision + recall + tf.keras.backend.epsilon())

    def reset_states(self):
        self.precision.reset_states()
        self.recall.reset_states()

In [4]:
# Data preparation (same as in your code)
train_datagen = ImageDataGenerator(rescale=1./255)
training_set1 = train_datagen.flow_from_directory(
    "Monkeypox/archive (60)/Augmented Images/Augmented Images/FOLDS_AUG/fold4_AUG/Train/",
    target_size=(64, 64),
    batch_size=32,
    class_mode='categorical',  # Use 'sparse' if your labels are integers (not one-hot encoded)
)

Found 7336 images belonging to 6 classes.


In [6]:
test_datagen = ImageDataGenerator(rescale=1./255)
test_set1 = test_datagen.flow_from_directory(
    "Monkeypox/archive (60)/Augmented Images/Augmented Images/FOLDS_AUG/fold4_AUG/Train/",
    target_size=(64, 64),
    batch_size=32,
    class_mode='categorical',
)

Found 7336 images belonging to 6 classes.


In [7]:
# CNN model with Self-Attention and Bidirectional GRU
cnn = Sequential()

In [8]:
cnn.add(layers.Conv2D(filters=32, padding="same", kernel_size=3, activation='relu', strides=2, input_shape=[64, 64, 3]))
cnn.add(SelfAttention(channels=32))
cnn.add(layers.MaxPool2D(pool_size=2, strides=2))

cnn.add(layers.Conv2D(filters=32, padding='same', kernel_size=3, activation='relu'))
cnn.add(layers.MaxPool2D(pool_size=2, strides=2))

# Ensure the correct reshaping before GRU
cnn.add(layers.Reshape((-1, 32)))  # Reshape for GRU (batch_size, time_steps, features)

# Bidirectional GRU layer
cnn.add(layers.Bidirectional(layers.GRU(64, return_sequences=True)))

cnn.add(layers.Flatten())
cnn.add(layers.Dense(units=128, activation='relu'))
cnn.add(layers.Dense(6, activation='softmax'))

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)





In [9]:
# Compile the model with Precision, Recall, and custom F1 score
cnn.compile(
    optimizer='adam', 
    loss='categorical_crossentropy', 
    metrics=['accuracy', 'Precision', 'Recall', F1Score()]
)

In [10]:
# Model summary
cnn.summary()

In [11]:
# Train the model and capture the history
history = cnn.fit(
    training_set1,
    validation_data=test_set1,
    epochs=100
)

  self._warn_if_super_not_called()


Epoch 1/100
[1m230/230[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m193s[0m 717ms/step - Precision: 0.4440 - Recall: 0.0590 - accuracy: 0.3815 - f1_score: 0.1032 - loss: 1.5984 - val_Precision: 0.5781 - val_Recall: 0.3291 - val_accuracy: 0.4911 - val_f1_score: 0.4194 - val_loss: 1.3639
Epoch 2/100
[1m230/230[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m172s[0m 747ms/step - Precision: 0.6667 - Recall: 0.2995 - accuracy: 0.5207 - f1_score: 0.4127 - loss: 1.2747 - val_Precision: 0.6634 - val_Recall: 0.4740 - val_accuracy: 0.5748 - val_f1_score: 0.5529 - val_loss: 1.1149
Epoch 3/100
[1m230/230[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m160s[0m 693ms/step - Precision: 0.7320 - Recall: 0.4543 - accuracy: 0.6140 - f1_score: 0.5604 - loss: 1.0345 - val_Precision: 0.7644 - val_Recall: 0.6352 - val_accuracy: 0.6927 - val_f1_score: 0.6939 - val_loss: 0.8440
Epoch 4/100
[1m230/230[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m156s[0m 680ms/step - Precision: 0.8180 - Recall: 0.6254

In [13]:
# Save Loss plot
plt.figure()
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.title("Loss")
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.savefig('loss_plot_bigru.png')  # Save the plot as a PNG file
plt.close()

# Save Accuracy plot
plt.figure()
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title("Accuracy")
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.savefig('accuracy_plot_bigru.png')  # Save the plot as a PNG file
plt.close()

# Save Precision plot (Correct key for precision)
plt.figure()
plt.plot(history.history['Precision'], label='Train Precision')  # Corrected for precision
plt.plot(history.history['val_Precision'], label='Validation Precision')  # Corrected for validation precision
plt.legend()
plt.title("Precision")
plt.xlabel('Epochs')
plt.ylabel('Precision')
plt.savefig('precision_plot_bigru.png')  # Save the plot as a PNG file
plt.close()

# Save Recall plot (Correct key for recall)
plt.figure()
plt.plot(history.history['Recall'], label='Train Recall')  # Corrected for recall
plt.plot(history.history['val_Recall'], label='Validation Recall')  # Corrected for validation recall
plt.legend()
plt.title("Recall")
plt.xlabel('Epochs')
plt.ylabel('Recall')
plt.savefig('recall_plot_bigru.png')  # Save the plot as a PNG file
plt.close()

# Save F1 Score plot (Correct key for F1 score)
plt.figure()
plt.plot(history.history['f1_score'], label='Train F1 Score')  # Corrected for F1 score
plt.plot(history.history['val_f1_score'], label='Validation F1 Score')  # Corrected for validation F1 score
plt.legend()
plt.title("F1 Score")
plt.xlabel('Epochs')
plt.ylabel('F1 Score')
plt.savefig('f1_score_plot_bigru.png')  # Save the plot as a PNG file
plt.close()
