<a href="https://colab.research.google.com/github/TharinsaMudalige/Neuron-Brain_Tumor_Detection_Classification_with_XAI/blob/Detection-Classification-VIT/Tumor_Classification_Hybrid_Model_(ViT_and_CNN).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Importing Libraries

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras import layers, models

Load the preprocessed dataset

In [None]:
# Define the dataset directory
base_dir = "/content/drive/MyDrive/Colab Notebooks/Preprocessed_Dataset_classification"

EDA

In [None]:
# Exploratory Data Analysis (EDA)
def perform_eda(images, labels, class_names):
    """Perform EDA on the dataset."""
    import pandas as pd

    # 1. Dataset Overview
    print("\n--- Dataset Overview ---")
    print(f"Total Images: {len(images)}")
    print(f"Total Classes: {len(class_names)}")
    print(f"Class Distribution: {np.bincount(labels)}")
    print("\nClass Names with Distribution:")
    class_distribution = pd.DataFrame({"Class": class_names, "Count": np.bincount(labels)})
    print(class_distribution)

    # 2. Visualize Class Distribution
    plt.figure(figsize=(10, 6))
    sns.barplot(x=class_distribution['Class'], y=class_distribution['Count'], palette="viridis")
    plt.title('Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Number of Samples')
    plt.xticks(rotation=45)
    plt.show()

    # 3. Display Sample Images from Each Class
    print("\n--- Displaying Sample Images from Each Class ---")
    plt.figure(figsize=(24, 24))
    for i, class_name in enumerate(class_names):
        idx = labels.tolist().index(i)  # Get the first image index for this class
        plt.subplot(1, len(class_names), i + 1)
        plt.imshow(images[idx].squeeze(), cmap='gray')
        plt.title(class_name)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

    # 4. Image Shape and Value Distribution
    print("\n--- Image Shape and Value Distribution ---")
    print(f"Image Shape: {images[0].shape}")
    print(f"Pixel Intensity Range: Min = {np.min(images)}, Max = {np.max(images)}")
    plt.figure(figsize=(12, 5))
    plt.hist(images.flatten(), bins=50, color='blue', alpha=0.7)
    plt.title('Pixel Intensity Distribution')
    plt.xlabel('Pixel Intensity')
    plt.ylabel('Frequency')
    plt.show()

    # 5. Check for Class Imbalance
    print("\n--- Checking Class Imbalance ---")
    class_imbalance = np.bincount(labels)
    max_samples = max(class_imbalance)
    imbalance_ratio = max_samples / class_imbalance
    print(f"Imbalance Ratio: {imbalance_ratio}")
    if any(imbalance_ratio > 1.5):
        print("Warning: Dataset has class imbalance.")


Function to Load the dataset

In [None]:
def load_dataset(base_dir):
    """Load the dataset, preprocess images, convert grayscale to RGB, and return data and labels."""
    images, labels = [], []
    class_names = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])
    print(f"Class Names: {class_names}")

    for label, class_name in enumerate(class_names):
        class_dir = os.path.join(base_dir, class_name)
        for file in os.listdir(class_dir):
            if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                file_path = os.path.join(class_dir, file)
                try:
                    # Load the image in grayscale
                    image = tf.io.read_file(file_path)
                    image = tf.image.decode_image(image, channels=1)  # Load as grayscale
                    image = tf.image.resize(image, [224, 224])  # Resize to 224x224
                    image = tf.image.grayscale_to_rgb(image)  # Convert grayscale to RGB
                    image = tf.cast(image, tf.float32) / 255.0  # Normalize to [0, 1]

                    images.append(image.numpy())
                    labels.append(label)
                except Exception as e:
                    print(f"[ERROR] Could not process {file_path}: {e}")

    images = np.array(images, dtype=np.float32)
    labels = np.array(labels, dtype=np.int32)

    # Display a sample image to confirm RGB conversion
    if len(images) > 0:
        plt.imshow(images[0])  # Matplotlib assumes (H, W, 3) as RGB
        plt.title(f"Sample Image (Class: {class_names[labels[0]]})")
        plt.axis("off")
        plt.show()

    return images, labels, class_names


Split the dataset

In [None]:
# Load and split the dataset
images, labels, class_names = load_dataset(base_dir) # Assign class_names
train_images, test_images, train_labels, test_labels = train_test_split(
    images, labels, test_size=0.3, random_state=42, stratify=labels
)

In [None]:
# Further split training data into training (80%) and validation (20%)
train_images, val_images, train_labels, val_labels = train_test_split(
    train_images, train_labels, test_size=0.2, random_state=42, stratify=train_labels
)

In [None]:
# Print dataset shapes
print(f"Training Images Shape: {train_images.shape}")
print(f"Training Labels Shape: {train_labels.shape}")
print(f"Validation Images Shape: {val_images.shape}")
print(f"Validation Labels Shape: {val_labels.shape}")
print(f"Test Images Shape: {test_images.shape}")
print(f"Test Labels Shape: {test_labels.shape}")

Visualize EDA

In [None]:
# Perform EDA on the dataset
perform_eda(images, labels, class_names)

Create the Hybrid Model

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

def create_hybrid_vit_cnn_model(input_shape, num_classes, patch_size=32, embed_dim=64, num_heads=4, transformer_layers=4):
    """Build a Hybrid CNN-ViT Model."""

    inputs = layers.Input(shape=input_shape)

    # CNN Feature Extractor
    x = layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D(pool_size=(2,2))(x)
    x = layers.Conv2D(64, (3,3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D(pool_size=(2,2))(x)
    x = layers.Conv2D(128, (3,3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D(pool_size=(2,2))(x)

    # Patch Embedding
    x = layers.Conv2D(embed_dim, kernel_size=1, activation='relu')(x)  # Reduce depth
    h, w = x.shape[1], x.shape[2]  # Get spatial dimensions
    num_patches = (h * w)  # Compute correct number of patches
    patches = layers.Reshape((num_patches, embed_dim))(x)  # Reshape CNN output

    # Learnable Positional Embedding (Fix shape mismatch)
    positional_embedding = tf.Variable(tf.random.normal([1, num_patches, embed_dim]), trainable=True)
    x = patches + tf.broadcast_to(positional_embedding, tf.shape(patches))  # Ensure correct shape

    # Transformer Encoder Layers
    for _ in range(transformer_layers):
        x_norm1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads)(x_norm1, x_norm1)
        x = layers.Add()([x, attention_output])  # Residual connection

        x_norm2 = layers.LayerNormalization(epsilon=1e-6)(x)
        ff_output = layers.Dense(embed_dim * 2, activation='relu')(x_norm2)
        ff_output = layers.Dense(embed_dim)(ff_output)
        x = layers.Add()([x, ff_output])  # Residual connection

    # Classification Head
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.GlobalAveragePooling1D()(x)  # Properly pool across patches
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    # Create Model
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

In [None]:
# Example Usage
input_shape = (224, 224, 3)  # Adjust according to dataset
num_classes = 4  # Modify as needed

hybrid_model = create_hybrid_vit_cnn_model(input_shape, num_classes)
hybrid_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
hybrid_model.summary()