In [10]:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

In [11]:
# Define the self-attention layer
class SelfAttention(layers.Layer):
    def __init__(self):
        super(SelfAttention, self).__init__()
        self.gamma = tf.Variable(initial_value=tf.zeros((1,), dtype='float32'), trainable=True)

    def build(self, input_shape):
        _, height, width, channels = input_shape
        self.query_conv = layers.Conv2D(channels // 8, kernel_size=1, use_bias=False)
        self.key_conv = layers.Conv2D(channels // 8, kernel_size=1, use_bias=False)
        self.value_conv = layers.Conv2D(channels, kernel_size=1, use_bias=False)

    def call(self, inputs):
        # Compute the query, key, and value matrices
        query = self.query_conv(inputs)
        key = self.key_conv(inputs)
        value = self.value_conv(inputs)

        # Compute the dot-product attention
        attention_logits = tf.matmul(query, key, transpose_b=True)
        attention_weights = tf.nn.softmax(attention_logits)
        attention_output = tf.matmul(attention_weights, value)

        # Apply the gamma scaling factor
        attention_output = self.gamma * attention_output

        # Add the attention output to the original input
        output = inputs + attention_output
        return output




In [12]:
# Define the CNN with one self-attention layer
def cnn_sa():
    input_shape = (32, 32, 3)
    num_classes = 10

    model = tf.keras.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        SelfAttention(),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        SelfAttention(),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])

    return model

In [13]:
# loading cifar 10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0

# Convert the labels to one-hot encoding
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

model = cnn_sa()

# Compile and training the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=64, epochs=10, validation_data=(x_test, y_test))

# Evaluate the model on the test set
loss, accuracy = model.evaluate(x_test, y_test)
print(f'Test loss: {loss:.4f}, Test accuracy: {accuracy:.4f}')

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Test loss: 0.6655, Test accuracy: 0.7888
