In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
from tensorflow.keras.preprocessing import image_dataset_from_directory
DATASET_PATH = 'path_to_dataset'  # Update this with your dataset path
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
train_dataset = image_dataset_from_directory(DATASET_PATH, 
                                             validation_split=0.2, 
                                             subset="training", 
                                             seed=123, 
                                             image_size=IMG_SIZE, 
                                             batch_size=BATCH_SIZE)

test_dataset = image_dataset_from_directory(DATASET_PATH, 
                                           validation_split=0.2, 
                                           subset="validation", 
                                           seed=123, 
                                           image_size=IMG_SIZE, 
                                           batch_size=BATCH_SIZE)

NUM_CLASSES = 8  
EPOCHS = 60  
IMG_SIZE = 224
PATCH_SIZE = 16
D_MODEL = 768
NUM_HEADS = 12
NUM_LAYERS = 12
MLP_DIM = 3072
DROPOUT_RATE = 0.1
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 1e-4
class PatchEmbedding(layers.Layer):
    def __init__(self, img_size=IMG_SIZE, patch_size=PATCH_SIZE, d_model=D_MODEL):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = layers.Dense(d_model)
        self.positional_encoding = self.add_weight("pos_enc", shape=(1, self.num_patches, d_model))
    def call(self, x):
        patches = tf.image.extract_patches(x, [1, PATCH_SIZE, PATCH_SIZE, 1], [1, PATCH_SIZE, PATCH_SIZE, 1], [1, 1, 1, 1], 'VALID')
        patches = tf.reshape(patches, [-1, self.num_patches, PATCH_SIZE * PATCH_SIZE * 3])
        embedded_patches = self.projection(patches) + self.positional_encoding
        return embedded_patches
class TransformerBlock(layers.Layer):
    def __init__(self, d_model, num_heads, mlp_dim, dropout_rate):
        super().__init__()
        self.norm1 = layers.LayerNormalization()
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
        self.cross_attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.norm2 = layers.LayerNormalization()
        self.mlp = keras.Sequential([
            layers.Dense(mlp_dim, activation='gelu'),
            layers.Dropout(dropout_rate),
            layers.Dense(d_model),
        ])
    def call(self, x):
        attn_output = self.attn(x, x)
        cross_attn_output = self.cross_attn(attn_output, attn_output)
        x = self.norm1(x + self.dropout1(cross_attn_output))
        mlp_output = self.mlp(x)
        return self.norm2(x + mlp_output)
class VisionTransformer(keras.Model):
    def __init__(self, img_size=IMG_SIZE, patch_size=PATCH_SIZE, d_model=D_MODEL, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, mlp_dim=MLP_DIM, num_classes=NUM_CLASSES, dropout_rate=DROPOUT_RATE):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, d_model)
        self.transformer_blocks = [TransformerBlock(d_model, num_heads, mlp_dim, dropout_rate) for _ in range(num_layers)]
        self.class_token = self.add_weight("class_token", shape=(1, 1, d_model))
        self.norm = layers.LayerNormalization()
        self.head = layers.Dense(num_classes, activation='softmax')
    def call(self, x):
        batch_size = tf.shape(x)[0]
        x = self.patch_embedding(x)
        class_token = tf.tile(self.class_token, [batch_size, 1, 1])
        x = tf.concat([class_token, x], axis=1)
        for block in self.transformer_blocks:
            x = block(x)
        x = self.norm(x[:, 0])  # Use class token output
        return self.head(x)
vit_model = VisionTransformer()
vit_model.compile(optimizer=tfa.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY), 
                  loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
vit_model.fit(train_dataset, epochs=EPOCHS)
loss, accuracy = vit_model.evaluate(test_dataset)
print(f"Test Accuracy: {accuracy * 100:.2f}%")
y_true = []
y_pred = []
for images, labels in test_dataset:
    preds = vit_model.predict(images)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(preds, axis=1))
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=train_dataset.class_names))
conf_matrix = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=train_dataset.class_names, yticklabels=train_dataset.class_names)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()