In [4]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator



# Define constants
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
NUM_CLASSES = 2
EPOCHS = 30


TRAIN_PATH = r'D:\Mushroom dataset\Project database\cnn\resized\train'
TEST_PATH = r'D:\Mushroom dataset\Project database\cnn\resized\test'
VALIDATION_PATH = r'D:\Mushroom dataset\Project database\cnn\resized\validation'


# Load and preprocess data
train_data_gen = ImageDataGenerator(rescale=1./255)
test_data_gen = ImageDataGenerator(rescale=1./255)
validation_data_gen = ImageDataGenerator(rescale=1./255)

train_data = train_data_gen.flow_from_directory(
    TRAIN_PATH,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary'
)

test_data = test_data_gen.flow_from_directory(
    TEST_PATH,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary'
)

validation_data = validation_data_gen.flow_from_directory(
    VALIDATION_PATH,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary'
)

# Define Vision Transformer model
def create_vit_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    # Patch embedding layer
    patch_embedding = layers.Conv2D(64, kernel_size=3, strides=3, activation='relu')(inputs)
    patch_embedding = layers.Conv2D(64, kernel_size=3, strides=3, activation='relu')(patch_embedding)
    patch_embedding = layers.Conv2D(64, kernel_size=3, strides=3, activation='relu')(patch_embedding)

    # Transformer encoder
    transformer_block_1 = TransformerBlock(embed_dim=64, num_heads=2, mlp_dim=128, dropout=0.1)(patch_embedding)
    transformer_block_2 = TransformerBlock(embed_dim=64, num_heads=2, mlp_dim=128, dropout=0.1)(transformer_block_1)
    transformer_block_3 = TransformerBlock(embed_dim=64, num_heads=2, mlp_dim=128, dropout=0.1)(transformer_block_2)

    # Classification head
    global_average_pooling = layers.GlobalAveragePooling2D()(transformer_block_3)
    outputs = layers.Dense(num_classes, activation='softmax')(global_average_pooling)

    model = Model(inputs, outputs)
    return model

# Define TransformerBlock layer
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.mlp1 = layers.Dense(mlp_dim, activation='relu')
        self.mlp2 = layers.Dense(embed_dim)
        self.layer_norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)

    def call(self, inputs, training=True):
        # Self-attention
        attn_output = self.attention(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layer_norm1(inputs + attn_output)

        # MLP
        mlp_output = self.mlp1(out1)
        mlp_output = self.mlp2(mlp_output)
        mlp_output = self.dropout2(mlp_output, training=training)
        out2 = self.layer_norm2(out1 + mlp_output)

        return out2

# Create and compile the model
model = create_vit_model(input_shape=(128, 128, 3), num_classes=NUM_CLASSES)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(
    train_data,
    epochs=EPOCHS,
    validation_data=validation_data
)

# Evaluate the model
loss, accuracy = model.evaluate(test_data)
print("Test Loss:", loss)
print("Test Accuracy:", accuracy)


Found 2208 images belonging to 2 classes.
Found 440 images belonging to 2 classes.
Found 298 images belonging to 2 classes.
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Test Loss: 0.1594657450914383
Test Accuracy: 0.949999988079071
