<a href="https://colab.research.google.com/github/TharinsaMudalige/Neuron-Brain_Tumor_Detection_Classification_with_XAI/blob/Detection-Classification-VIT/Brain_Tumour_Classification_Using_VIT_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install keras



In [5]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


In [6]:
def load_dataset(base_dir):
    """Load images and labels from the dataset."""
    images = []
    labels = []
    class_names = sorted(os.listdir(base_dir))  # Get class names from folder names

    print(f"Class Names: {class_names}")  # Debugging class names

    for label, class_name in enumerate(class_names):  # Iterate over each class
        class_dir = os.path.join(base_dir, class_name)  # Path to the class folder
        print(f"Processing class: {class_name}, Label: {label}, Path: {class_dir}")

        if not os.path.isdir(class_dir):
            print(f"Skipping {class_dir}, not a directory.")
            continue

        for file in os.listdir(class_dir):  # Iterate over each image file in the folder
            file_path = os.path.join(class_dir, file)  # Full path to the image file
            print(f"Processing file: {file_path}")  # Debugging file path

            try:
                # Load image as grayscale and resize to 224x224
                image = tf.keras.preprocessing.image.load_img(
                    file_path, color_mode='grayscale', target_size=(224, 224)
                )
                # Convert the image to a numpy array and normalize pixel values to [0, 1]
                image = tf.keras.preprocessing.image.img_to_array(image) / 255.0
                images.append(image)  # Add the image to the list
                labels.append(label)  # Add the label to the list
            except Exception as e:
                print(f"Error loading image: {file_path}, Error: {e}")

    print(f"Loaded {len(images)} images.")  # Debugging total images loaded
    return np.array(images), np.array(labels), class_names


In [7]:
# Define dataset directory
base_dir = "/content/drive/MyDrive/DSGP_BrainTumorDetection/Preprocessed_Dataset_classes_morepreprocess_techniques"  # Path to the dataset folder


In [9]:
def load_dataset(base_dir):
    """Load images and labels from the dataset."""
    images = []
    labels = []
    class_names = sorted(os.listdir(base_dir))  # Get class names from folder names

    print(f"Class Names: {class_names}")  # Debugging class names

    for label, class_name in enumerate(class_names):  # Iterate over each class
        class_dir = os.path.join(base_dir, class_name)  # Path to the class folder
        print(f"Processing class: {class_name}, Label: {label}, Path: {class_dir}")

        if not os.path.isdir(class_dir):
            print(f"Skipping {class_dir}, not a directory.")
            continue

        for file in os.listdir(class_dir):  # Iterate over each image file in the folder
            file_path = os.path.join(class_dir, file)  # Full path to the image file
            print(f"Processing file: {file_path}")  # Debugging file path

            try:
                # Load image as grayscale and resize to 224x224
                image = tf.keras.preprocessing.image.load_img(
                    file_path, color_mode='grayscale', target_size=(224, 224)
                )
                # Convert the image to a numpy array and normalize pixel values to [0, 1]
                image = tf.keras.preprocessing.image.img_to_array(image) / 255.0
                images.append(image)  # Add the image to the list
                labels.append(label)  # Add the label to the list
            except Exception as e:
                print(f"Error loading image: {file_path}, Error: {e}")

    print(f"Loaded {len(images)} images.")  # Debugging total images loaded

    # Check image shapes
    for i, img in enumerate(images):
        if img.shape != (224, 224, 1):
            print(f"Image {i} has shape {img.shape}, expected (224, 224, 1)")
            return None, None, None

    # Check memory usage
    import psutil
    memory = psutil.virtual_memory()
    print(f"Available Memory Before Conversion: {memory.available / (1024**3):.2f} GB")

    # Convert to numpy arrays in chunks if necessary
    images_np = np.zeros((len(images), 224, 224, 1), dtype=np.float32)
    for i, img in enumerate(images):
        images_np[i] = img

    print(f"Available Memory After Conversion: {psutil.virtual_memory().available / (1024**3):.2f} GB")
    return images_np, np.array(labels), class_names


In [None]:
# Load the dataset
images, labels, class_names = load_dataset(base_dir)  # Load images, labels, and class names


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Processing file: /content/drive/MyDrive/DSGP_BrainTumorDetection/Preprocessed_Dataset_classes_morepreprocess_techniques/meningioma/image_30004.png
Processing file: /content/drive/MyDrive/DSGP_BrainTumorDetection/Preprocessed_Dataset_classes_morepreprocess_techniques/meningioma/image_30005.png
Processing file: /content/drive/MyDrive/DSGP_BrainTumorDetection/Preprocessed_Dataset_classes_morepreprocess_techniques/meningioma/image_30006.png
Processing file: /content/drive/MyDrive/DSGP_BrainTumorDetection/Preprocessed_Dataset_classes_morepreprocess_techniques/meningioma/image_30007.png
Processing file: /content/drive/MyDrive/DSGP_BrainTumorDetection/Preprocessed_Dataset_classes_morepreprocess_techniques/meningioma/image_30008.png
Processing file: /content/drive/MyDrive/DSGP_BrainTumorDetection/Preprocessed_Dataset_classes_morepreprocess_techniques/meningioma/image_30009.png
Processing file: /content/drive/MyDrive/DSGP_BrainTum

In [3]:
# Split the dataset equally for each class
train_images, test_images, train_labels, test_labels = [], [], [], []  # Initialize empty lists for train/test splits

for label in range(len(class_names)):  # Iterate over each class label
    class_indices = np.where(labels == label)[0]  # Get indices of all images belonging to the current class
    class_images = images[class_indices]  # Extract images for the current class
    class_labels = labels[class_indices]  # Extract labels for the current class
    # Split the class data into training and testing sets (80% train, 20% test)
    train_x, test_x, train_y, test_y = train_test_split(class_images, class_labels, test_size=0.3, random_state=42)
    train_images.extend(train_x)  # Add training images to the train list
    test_images.extend(test_x)  # Add testing images to the test list
    train_labels.extend(train_y)  # Add training labels to the train list
    test_labels.extend(test_y)  # Add testing labels to the test list

# Convert lists to numpy arrays
train_images = np.array(train_images)
test_images = np.array(test_images)
train_labels = np.array(train_labels)
test_labels = np.array(test_labels)


NameError: name 'class_names' is not defined

In [None]:
# Step 2: Create the Vision Transformer (ViT) model
def create_vit_model(input_shape, num_classes):
    """Create a Vision Transformer model."""
    inputs = layers.Input(shape=input_shape)  # Input layer with the specified shape

    # Patch embedding
    patch_size = 16  # Size of each image patch
    num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)  # Calculate total number of patches
    patches = layers.Conv2D(filters=64, kernel_size=patch_size, strides=patch_size, padding='valid')(inputs)  # Extract patches
    patches = layers.Reshape((num_patches, -1))(patches)  # Reshape patches into 2D array for transformer

    # Positional embedding
    positional_embedding = layers.Embedding(input_dim=num_patches, output_dim=64)(tf.range(num_patches))  # Add positional info
    x = patches + positional_embedding  # Add positional embeddings to the patch embeddings

    # Transformer encoder layers
    for _ in range(8):  # Add 8 transformer encoder layers
        # Multi-head self-attention
        attention_output = layers.MultiHeadAttention(num_heads=4, key_dim=64)(x, x)  # Apply attention mechanism
        x = layers.Add()([x, attention_output])  # Add residual connection
        x = layers.LayerNormalization()(x)  # Normalize the output

        # Feed-forward network
        ff_output = layers.Dense(128, activation='relu')(x)  # First dense layer with ReLU activation
        ff_output = layers.Dense(64)(ff_output)  # Second dense layer without activation
        x = layers.Add()([x, ff_output])  # Add residual connection
        x = layers.LayerNormalization()(x)  # Normalize the output

    # Classification head
    x = layers.GlobalAveragePooling1D()(x)  # Global average pooling to reduce dimensions
    outputs = layers.Dense(num_classes, activation='softmax')(x)  # Output layer with softmax activation

    model = models.Model(inputs, outputs)  # Create the model
    return model  # Return the model


In [None]:
# Create the ViT model
input_shape = (224, 224, 1)  # Input shape for grayscale images
num_classes = len(class_names)  # Number of classes (tumor types + no tumor)
model = create_vit_model(input_shape, num_classes)  # Build the model


In [None]:
# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])  # Compile with Adam optimizer

# Step 3: Train the model
history = model.fit(train_images, train_labels, validation_data=(test_images, test_labels), epochs=20, batch_size=32)  # Train the model

# Step 4: Evaluate the model
predictions = model.predict(test_images)  # Predict on the test set
predicted_labels = np.argmax(predictions, axis=1)  # Convert probabilities to class labels


In [None]:
# Classification report
print("Classification Report:")
print(classification_report(test_labels, predicted_labels, target_names=class_names))  # Print detailed metrics

# Confusion matrix
cm = confusion_matrix(test_labels, predicted_labels)  # Compute confusion matrix
plt.figure(figsize=(12, 10))  # Set figure size
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)  # Plot heatmap
plt.xlabel('Predicted Labels')  # Label for x-axis
plt.ylabel('True Labels')  # Label for y-axis
plt.title('Confusion Matrix')  # Title of the plot
plt.show()  # Display the plot

# Step 5: Plot training and validation metrics
def plot_metrics(history):
    """Plot accuracy and loss curves."""
    plt.figure(figsize=(12, 5))  # Set figure size

    # Accuracy
    plt.subplot(1, 2, 1)  # Create a subplot for accuracy
    plt.plot(history.history['accuracy'], label='Training Accuracy')  # Plot training accuracy
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')  # Plot validation accuracy
    plt.title('Accuracy Over Epochs')  # Title for accuracy plot
    plt.xlabel('Epochs')  # Label for x-axis
    plt.ylabel('Accuracy')  # Label for y-axis
    plt.legend()  # Add legend

    # Loss
    plt.subplot(1, 2, 2)  # Create a subplot for loss
    plt.plot(history.history['loss'], label='Training Loss')  # Plot training loss
    plt.plot(history.history['val_loss'], label='Validation Loss')  # Plot validation loss
    plt.title('Loss Over Epochs')  # Title for loss plot
    plt.xlabel('Epochs')  # Label for x-axis
    plt.ylabel('Loss')  # Label for y-axis
    plt.legend()  # Add legend

    plt.show()  # Display the plots

plot_metrics(history)  # Call the function to plot metrics
