<a href="https://colab.research.google.com/github/amimulhasan/Machine_learning_project/blob/main/vit_GRU_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------
# Hyperparameters
# -----------------------
input_shape = (128, 128, 3)
patch_size = 16
num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
projection_dim = 64
transformer_layers = 4
num_heads = 4
num_classes = 4

# -----------------------
# Patches layer
# -----------------------
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

# -----------------------
# Patch Encoder layer
# -----------------------
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patches):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patches) + self.position_embedding(positions)
        return encoded

# -----------------------
# Build the combined model
# -----------------------
def build_hybrid_model():
    inputs = layers.Input(shape=input_shape)

    # ---------------------
    # DCNN Branch
    # ---------------------
    x_cnn = layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputs)
    x_cnn = layers.MaxPooling2D((2,2))(x_cnn)
    x_cnn = layers.Conv2D(64, (3,3), activation='relu', padding='same')(x_cnn)
    x_cnn = layers.MaxPooling2D((2,2))(x_cnn)
    x_cnn = layers.Conv2D(128, (3,3), activation='relu', padding='same')(x_cnn)
    x_cnn = layers.MaxPooling2D((2,2))(x_cnn)
    x_cnn = layers.Flatten()(x_cnn)

    # ---------------------
    # ViT + GRU Branch
    # ---------------------
    x_vit = layers.Rescaling(1./255)(inputs)
    patches = Patches(patch_size)(x_vit)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.2
        )(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        encoded_patches = layers.Add()([x3, x2])

    x_vit = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    x_vit = layers.Dropout(0.2)(x_vit)
    x_vit = layers.Flatten()(x_vit)
    x_vit = layers.Reshape((-1, x_vit.shape[-1]))(x_vit)
    x_vit = layers.GRU(256)(x_vit)

    # ---------------------
    # Concatenate DCNN and ViT+GRU
    # ---------------------
    x = layers.concatenate([x_cnn, x_vit])
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    # Build model
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

# -----------------------
# 10-Fold Cross Validation
# -----------------------
X = np.load("/kaggle/input/brain-tumor-dataset/X.npy")  # Your data
y = np.load("/kaggle/input/brain-tumor-dataset/y.npy")  # Your labels

kfold = KFold(n_splits=10, shuffle=True, random_state=42)  # Now using 10-fold
fold = 1
all_metrics = []
accuracy_history = []
loss_history = []

for train_idx, val_idx in kfold.split(X, y):
    X_train_fold, X_val_fold = X[train_idx], X[val_idx]
    y_train_fold, y_val_fold = y[train_idx], y[val_idx]

    # Build the model
    model = build_hybrid_model()

    # Compile the model
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    # Train the model
    history = model.fit(X_train_fold, y_train_fold, epochs=5, batch_size=32, validation_data=(X_val_fold, y_val_fold))

    # Save training history for plotting
    accuracy_history.append(history.history['accuracy'])
    loss_history.append(history.history['loss'])

    # Predict on the validation set
    y_pred_fold = model.predict(X_val_fold)
    y_pred_fold = np.argmax(y_pred_fold, axis=1)  # Convert softmax output to class label

    # Calculate metrics
    metrics = classification_report(y_val_fold, y_pred_fold, output_dict=True)
    all_metrics.append(metrics)

    # Compute confusion matrix
    cm = confusion_matrix(y_val_fold, y_pred_fold)

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=np.arange(num_classes), yticklabels=np.arange(num_classes))
    plt.title(f"Confusion Matrix for Fold {fold}")
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

    # Print fold results
    print(f"Fold {fold} Classification Report:\n", classification_report(y_val_fold, y_pred_fold))
    fold += 1

# Compute average performance across folds
avg_accuracy = np.mean([metrics['accuracy'] for metrics in all_metrics])
avg_f1 = np.mean([metrics['weighted avg']['f1-score'] for metrics in all_metrics])

print(f"Average Accuracy across all folds: {avg_accuracy:.4f}")
print(f"Average F1-score across all folds: {avg_f1:.4f}")

# -----------------------
# Plot Accuracy and Loss Curves
# -----------------------

# Accuracy Curve
plt.figure(figsize=(8, 6))
for acc in accuracy_history:
    plt.plot(acc, label=f'Fold {accuracy_history.index(acc)+1}')
plt.title('Accuracy Curve for 5-Fold Cross-Validation')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# Loss Curve
plt.figure(figsize=(8, 6))
for loss in loss_history:
    plt.plot(loss, label=f'Fold {loss_history.index(loss)+1}')
plt.title('Loss Curve for 5-Fold Cross-Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
