In [9]:
import os
import shutil
import random
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input as eff_preprocess
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns


# Paths to source folders
Brain_Tumor = r'C:/Users/sunda/Downloads/braincancer/Brain_Cancer raw MRI data/Brain_Cancer/brain_tumor'
Brain_Menin = r'C:/Users/sunda/Downloads/braincancer/Brain_Cancer raw MRI data/Brain_Cancer/brain_menin'
Brain_Glioma = r'C:/Users/sunda/Downloads/braincancer/Brain_Cancer raw MRI data/Brain_Cancer/brain_glioma'

# Merged destination
merged_folder = r'C:/Users/sunda/Downloads/merged_folder'
split_output = r'C:/Users/sunda/Downloads/dataset_split'
resize_dim = (224, 224)

# Class source map
class_sources = {
    'brain_tumor': Brain_Tumor,
    'brain_menin': Brain_Menin,
    'brain_glioma': Brain_Glioma
}

# Check if merge needed (not just parent folder but class subfolders)
need_merge = False
for class_name in class_sources:
    class_folder = os.path.join(merged_folder, class_name)
    if not os.path.exists(class_folder) or len(os.listdir(class_folder)) < 1000:
        need_merge = True
        break

if need_merge:
    print(" Performing merge of class folders...")
    os.makedirs(merged_folder, exist_ok=True)
    
    for class_name, src_folder in class_sources.items():
        dest_folder = os.path.join(merged_folder, class_name)
        os.makedirs(dest_folder, exist_ok=True)

        images = [f for f in os.listdir(src_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:1000]
        for img in images:
            shutil.copy2(os.path.join(src_folder, img), os.path.join(dest_folder, img))

    print(" Merging complete.")
else:
    print(" Merged folders already exist and contain enough images. Skipping merge.")

# Check if dataset split folders already contain files
def is_split_done(split_path):
    for cls in class_sources:
        folder = os.path.join(split_path, cls)
        if not os.path.exists(folder):
            return False
        if len(os.listdir(folder)) > 0:
            return True
    return False

if all(is_split_done(os.path.join(split_output, split)) for split in ['train', 'val', 'test']):
    print(" Split folders already populated. Skipping split.")
else:
    print("\n Splitting and resizing images...")

    split_counts = {'train': 700, 'val': 150, 'test': 150}

    for split in ['train', 'val', 'test']:
        for cls in class_sources:
            os.makedirs(os.path.join(split_output, split, cls), exist_ok=True)

    def resize_and_save(src, dst):
        try:
            with Image.open(src) as img:
                img = img.resize(resize_dim)
                img.save(dst)
        except Exception as e:
            print(f" Error resizing {src}: {e}")

    for class_name in class_sources:
        src_folder = os.path.join(merged_folder, class_name)
        images = [f for f in os.listdir(src_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        images = images[:1000]
        random.shuffle(images)

        train_imgs = images[:700]
        val_imgs = images[700:850]
        test_imgs = images[850:1000]

        for split, split_images in zip(['train', 'val', 'test'], [train_imgs, val_imgs, test_imgs]):
            for img_name in split_images:
                src_path = os.path.join(src_folder, img_name)
                dst_path = os.path.join(split_output, split, class_name, img_name)
                resize_and_save(src_path, dst_path)

    print("\n Split complete: 700 train, 150 val, 150 test per class.")

   
 # Path config
data_dir = r'C:/Users/sunda/Downloads/dataset_split'
img_size = (224, 224)
batch_size = 32
epochs = 20

def get_datagens(preprocess_func):
    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_func,
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True
    )
    val_datagen = ImageDataGenerator(preprocessing_function=preprocess_func)
    test_datagen = ImageDataGenerator(preprocessing_function=preprocess_func)

    train_gen = train_datagen.flow_from_directory(
        os.path.join(data_dir, 'train'),
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical'
    )
    val_gen = val_datagen.flow_from_directory(
        os.path.join(data_dir, 'val'),
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical'
    )
    test_gen = test_datagen.flow_from_directory(
        os.path.join(data_dir, 'test'),
        target_size=img_size,
        batch_size=1,
        class_mode='categorical',
        shuffle=False
    )
    return train_gen, val_gen, test_gen

# MODEL BUILD + FINE-TUNE
def build_model(base_model_class, name, num_classes):
    base_model = base_model_class(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    base_model.trainable = False

    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    predictions = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=predictions, name=name)
    model.compile(optimizer=Adam(learning_rate=1e-4),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model, base_model

def fine_tune(model, base_model, layers_to_unfreeze=30, lr=1e-5):
    for layer in base_model.layers[-layers_to_unfreeze:]:
        layer.trainable = True
    model.compile(optimizer=Adam(learning_rate=lr),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model

def plot_confusion_matrix(y_true, y_pred, class_names, model_name):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix - {model_name}")
    plt.show()

# MODELS TO TRY
models_to_try = {
    "ResNet50": (ResNet50, resnet_preprocess),
    "VGG16": (VGG16, vgg_preprocess),
    "EfficientNetB0": (EfficientNetB0, eff_preprocess)
}

results, histories = {}, {}

# TRAINING LOOP

for name, (model_class, preprocess_func) in models_to_try.items():
    print(f"\n Training {name}...")

    train_gen, val_gen, test_gen = get_datagens(preprocess_func)

    model, base_model = build_model(model_class, name, train_gen.num_classes)

    checkpoint_path = f"{name}_best_model.h5"
    checkpoint = ModelCheckpoint(checkpoint_path, save_best_only=True, monitor='val_accuracy', mode='max')
    earlystop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

    # Stage 1: frozen training
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=10,
        callbacks=[checkpoint, earlystop]
    )

    # Stage 2: fine-tuning
    print(f"\n Fine-tuning {name}...")
    model = fine_tune(model, base_model, layers_to_unfreeze=30, lr=1e-5)
    fine_tune_history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=10,
        callbacks=[checkpoint, earlystop]
    )

    # Merge histories
    for k in fine_tune_history.history:
        history.history[k] += fine_tune_history.history[k]

    # Evaluate
    print(f"\n Evaluating {name} on test set...")
    model.load_weights(checkpoint_path)
    loss, acc = model.evaluate(test_gen)
    y_pred = model.predict(test_gen)
    y_pred_classes = y_pred.argmax(axis=1)

    print(classification_report(test_gen.classes, y_pred_classes, target_names=list(test_gen.class_indices.keys())))
   # Plot confusion matrix immediately
    class_names = list(test_gen.class_indices.keys())
    plot_confusion_matrix(test_gen.classes, y_pred_classes, class_names, name)

    # Plot training curves immediately
    plt.figure(figsize=(12, 4))
    plt.suptitle(f"{name} Training Curves", fontsize=16)

    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train')
    plt.plot(history.history['val_accuracy'], label='Val')
    plt.title('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Val')
    plt.title('Loss')
    plt.legend()

    plt.show()

    # Save results
    results[name] = acc
    histories[name] = history

# FINAL SUMMARY

print("\nModel Accuracy on Test Set:")
for model_name, accuracy in results.items():
    print(f"{model_name}: {accuracy:.4f}")

best_model_name = max(results, key=results.get)
print(f"\n Best Model: {best_model_name}")

 Merged folders already exist and contain enough images. Skipping merge.
 Split folders already populated. Skipping split.

🚀 Training ResNet50...
Found 2100 images belonging to 3 classes.
Found 450 images belonging to 3 classes.
Found 450 images belonging to 3 classes.


  self._warn_if_super_not_called()


Epoch 1/10
[1m16/66[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m1:38[0m 2s/step - accuracy: 0.4912 - loss: 1.0604

KeyboardInterrupt: 