<a href="https://colab.research.google.com/github/HazenDeveloper/Attn-CNN-Model/blob/main/Atten-CNN-v-00.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive


In [None]:
#attCNN.py
import tensorflow

from tensorflow.python.keras.layers import Input, Conv2D, Dense, Flatten, Activation, Dropout, MaxPool2D, Multiply, Add
from tensorflow.python.keras.models import Model
from keras.preprocessing.image import ImageDataGenerator


In [None]:
train_data_dir = './Dataset/train'
test_data_dir = './Dataset/test'
valid_data_dir = './Dataset/val'

# Set the number of target classes
num_classes = 3

# Set the input image dimensions
input_shape = (32, 32, 3)

# Set the batch size and number of training steps per epoch
batch_size = 64
# train_steps_per_epoch = 100


In [None]:
# Preprocess the training data with data augmentation
train_datagen = ImageDataGenerator(rescale=1./255,
                                   # rescale=1./255,      # Normalize pixel values to [0, 1]
                                   rotation_range=15,   # Randomly rotate images by 10 degrees
                                   # width_shift_range=0.1,   # Randomly shift images horizontally by 10% of the width
                                   # height_shift_range=0.1,  # Randomly shift images vertically by 10% of the height
                                   # shear_range=0.1,     # Apply shear transformation with a shear angle of 10 degrees
                                   # zoom_range=0.1,      # Apply random zoom between 0.9x and 1.1x
                                   # horizontal_flip=True,    # Randomly flip images horizontally
                                   # vertical_flip=False      # Do not flip images vertically
                                   )

# Preprocess the validation and testing data (only rescale pixel values)
valid_test_datagen = ImageDataGenerator(rescale=1./255)
# Load and augment the training data
train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=input_shape[:2],
    batch_size=batch_size,
    class_mode='categorical'   # Use categorical mode for multi-class classification
)

# Load the validation data
valid_generator = valid_test_datagen.flow_from_directory(
    valid_data_dir,
    target_size=input_shape[:2],
    batch_size=batch_size,
    class_mode='categorical'
)

# Load the testing data
test_generator = valid_test_datagen.flow_from_directory(
    test_data_dir,
    target_size=input_shape[:2],
    batch_size=batch_size,
    class_mode='categorical'
)


In [None]:
# Attention Mechanism
def attention_block(input_tensor, input_channels):
    q = Conv2D(input_channels, (1, 1), padding='same')(input_tensor)
    k = Conv2D(input_channels, (1, 1), padding='same')(input_tensor)
    v = Conv2D(input_channels, (1, 1), padding='same')(input_tensor)

    qk = Multiply()([q, k])
    qk = Activation('softmax')(qk)
    attention = Multiply()([v, qk])

    output_tensor = Add()([input_tensor, attention])
    return output_tensor
#%%
# CNN Model with Attention
def attention_cnn(input_shape, num_classes):
    inputs = Input(shape=input_shape)

    x = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    #x = attention_block(x, 32)
    x = MaxPool2D()(x)
    x = Dropout(0.25)(x)

    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    #x = attention_block(x, 64)
    x = MaxPool2D()(x)
    x = Dropout(0.25)(x)

    x = Flatten()(x)

    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs=inputs, outputs=outputs)
    return model


In [None]:
num_classes = 3
#%%
model = attention_cnn(input_shape, num_classes)
#%%
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(
    train_generator,
    epochs=10,
    validation_data=valid_generator
)

# Evaluate the model on the testing data
test_loss, test_accuracy = model.evaluate(test_generator)
print(f'Test loss: {test_loss}')
print(f'Test accuracy: {test_accuracy}')