In [1]:
# =======================
# Brain Tumor Training
# =======================

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
import pickle

# Paths
train_dir = "../data/brain_tumor_dataset/training"
test_dir = "../data/brain_tumor_dataset/testing"

# Image Data Generator (with augmentation)
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode="nearest"
)

test_datagen = ImageDataGenerator(rescale=1./255)

# Load dataset
train_data = train_datagen.flow_from_directory(
    train_dir,
    target_size=(150, 150),
    batch_size=32,
    class_mode="categorical"
)

test_data = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150, 150),
    batch_size=32,
    class_mode="categorical"
)

# CNN Model
model = Sequential([
    Conv2D(32, (3,3), activation="relu", input_shape=(150,150,3)),
    MaxPooling2D(pool_size=(2,2)),
    
    Conv2D(64, (3,3), activation="relu"),
    MaxPooling2D(pool_size=(2,2)),

    Conv2D(128, (3,3), activation="relu"),
    MaxPooling2D(pool_size=(2,2)),

    Flatten(),
    Dense(128, activation="relu"),
    Dropout(0.5),
    Dense(4, activation="softmax")  # 4 classes: glioma, meningioma, notumor, pituitary
])

# Compile
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

# Train-
history = model.fit(train_data, validation_data=test_data, epochs=1)

# Save Model
model.save("../model/brain_tumor_model.h5")

# Also save label classes for prediction
with open("../model/brain_tumor_classes.pkl", "wb") as f:
    pickle.dump(train_data.class_indices, f)


Found 5712 images belonging to 4 classes.
Found 1311 images belonging to 4 classes.


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


[1m179/179[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m130s[0m 714ms/step - accuracy: 0.5067 - loss: 1.0708 - val_accuracy: 0.5088 - val_loss: 1.2191


