<a href="https://colab.research.google.com/github/amimulhasan/ML_project/blob/main/vit_gru_4classes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau


In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download masoudnickparvar/brain-tumor-mri-dataset

Dataset URL: https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset
License(s): CC0-1.0
Downloading brain-tumor-mri-dataset.zip to /content
 56% 83.0M/149M [00:00<00:00, 858MB/s]
100% 149M/149M [00:00<00:00, 742MB/s] 


In [None]:
import zipfile

zip_path = '/content/brain-tumor-mri-dataset.zip'
extract_to = 'brain_tumor_data'

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

print("Unzipping completed!")

Unzipping completed!


In [None]:
data_dir = "/content/brain_tumor_data/Training"  # update this path

img_size = 128  # smaller for ViT
batch_size = 32

# Data generators
train_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    rotation_range=20,
    zoom_range=0.15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.15,
    horizontal_flip=True,
    fill_mode="nearest"
)

valid_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

train_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

valid_generator = valid_datagen.flow_from_directory(
    data_dir,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)


Found 4571 images belonging to 4 classes.
Found 1141 images belonging to 4 classes.


In [None]:
def create_vit_gru_model(input_shape=(128,128,3), num_classes=4, patch_size=8, projection_dim=64, transformer_layers=4, num_heads=4, gru_units=64):
    inputs = layers.Input(shape=input_shape)

    # Create patches
    patches = layers.Conv2D(filters=projection_dim, kernel_size=patch_size, strides=patch_size, padding='valid')(inputs)
    seq_len = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
    x = layers.Reshape((seq_len, projection_dim))(patches)

    # Transformer blocks
    for _ in range(transformer_layers):
        # Layer norm + multi-head attention
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(x1, x1)
        x2 = layers.Add()([x, attention_output])

        # Feed-forward network
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        ff_output = layers.Dense(projection_dim*2, activation='relu')(x3)
        ff_output = layers.Dense(projection_dim)(ff_output)
        x = layers.Add()([x2, ff_output])

    # GRU layer on sequence of patch embeddings
    x = layers.GRU(gru_units, return_sequences=False)(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


In [None]:
model = create_vit_gru_model()
model.summary()


In [None]:
model.compile(
    optimizer=keras.optimizers.Adam(1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)


In [None]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3),
    ModelCheckpoint("best_vit_gru_model.h5", save_best_only=True, monitor='val_loss')
]

history = model.fit(
    train_generator,
    epochs=30,
    validation_data=valid_generator,
    callbacks=callbacks
)


  self._warn_if_super_not_called()


Epoch 1/30
[1m 44/143[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m3:16[0m 2s/step - accuracy: 0.2540 - loss: 1.5526

KeyboardInterrupt: 

In [None]:
def plot_history(hist):
    acc = hist.history['accuracy']
    val_acc = hist.history['val_accuracy']
    loss = hist.history['loss']
    val_loss = hist.history['val_loss']
    epochs = range(len(acc))

    plt.figure(figsize=(14,5))
    plt.subplot(1,2,1)
    plt.plot(epochs, acc, 'b', label='Training acc')
    plt.plot(epochs, val_acc, 'r', label='Validation acc')
    plt.legend()
    plt.title('Accuracy')

    plt.subplot(1,2,2)
    plt.plot(epochs, loss, 'b', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.legend()
    plt.title('Loss')
    plt.show()

plot_history(history)


In [None]:
loss, acc = model.evaluate(valid_generator)
print(f"Validation Loss: {loss:.4f}")
print(f"Validation Accuracy: {acc:.4f}")
