# Dementia MRI Classification (Professional Project Version)

این نسخه‌ی بازنویسی‌شده شامل این بهبودهاست:
- ساختار حرفه‌ای و پارامتری برای مسیر داده‌ها
- ساخت دیتاست با `image_dataset_from_directory`
- نرمال‌سازی و Data Augmentation
- مدل CNN با `L2` و `BatchNormalization`
- استفاده از Callbackها (`EarlyStopping`, `ReduceLROnPlateau`, `ModelCheckpoint`)
- گزارش کامل: نمودارها، Confusion Matrix و Classification Report
- تابع پیش‌بینی روی تصاویر جدید


In [None]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers
from sklearn.metrics import confusion_matrix, classification_report

SEED = 42
tf.keras.utils.set_random_seed(SEED)

print('TensorFlow version:', tf.__version__)


## 1) Configuration


In [None]:
# مسیر اصلی دیتاست را اینجا تنظیم کنید.
# ساختار مورد انتظار:
# DATA_DIR/
#   Non Demented/
#   Mild Dementia/
#   Moderate Dementia/
#   Very mild Dementia/

DATA_DIR = Path('data/train')
IMG_SIZE = (128, 128)
BATCH_SIZE = 16
EPOCHS = 40
VAL_SPLIT = 0.2
TEST_SPLIT = 0.1
AUTOTUNE = tf.data.AUTOTUNE

if not DATA_DIR.exists():
    raise FileNotFoundError(
        f'Dataset path not found: {DATA_DIR.resolve()}\n'
        'Please create the folder or change DATA_DIR.'
    )


## 2) Load datasets


In [None]:
full_ds = tf.keras.utils.image_dataset_from_directory(
    DATA_DIR,
    labels='inferred',
    label_mode='categorical',
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=SEED
)

class_names = full_ds.class_names
num_classes = len(class_names)
print('Classes:', class_names)

full_batches = tf.data.experimental.cardinality(full_ds).numpy()
test_batches = max(1, int(full_batches * TEST_SPLIT))
val_batches = max(1, int(full_batches * VAL_SPLIT))
train_batches = full_batches - test_batches - val_batches

if train_batches < 1:
    raise ValueError('Dataset is too small for current split configuration.')

train_ds = full_ds.take(train_batches)
rest_ds = full_ds.skip(train_batches)
val_ds = rest_ds.take(val_batches)
test_ds = rest_ds.skip(val_batches)

print(f'Train batches: {tf.data.experimental.cardinality(train_ds).numpy()}')
print(f'Val batches: {tf.data.experimental.cardinality(val_ds).numpy()}')
print(f'Test batches: {tf.data.experimental.cardinality(test_ds).numpy()}')


## 3) Preprocessing + augmentation


In [None]:
data_augmentation = keras.Sequential([
    layers.RandomFlip('horizontal'),
    layers.RandomRotation(0.08),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1)
], name='data_augmentation')

normalization = layers.Rescaling(1.0 / 255)

def prepare(ds, training=False):
    ds = ds.map(lambda x, y: (normalization(x), y), num_parallel_calls=AUTOTUNE)
    if training:
        ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds.prefetch(AUTOTUNE)

train_ds = prepare(train_ds, training=True)
val_ds = prepare(val_ds, training=False)
test_ds = prepare(test_ds, training=False)


## 4) Build CNN model


In [None]:
def build_model(input_shape=(128, 128, 3), num_classes=4):
    l2 = regularizers.L2(1e-4)
    model = keras.Sequential([
        layers.Input(shape=input_shape),

        layers.Conv2D(32, 3, padding='same', activation='relu', kernel_regularizer=l2),
        layers.BatchNormalization(),
        layers.Conv2D(32, 3, padding='same', activation='relu', kernel_regularizer=l2),
        layers.MaxPooling2D(),
        layers.Dropout(0.25),

        layers.Conv2D(64, 3, padding='same', activation='relu', kernel_regularizer=l2),
        layers.BatchNormalization(),
        layers.Conv2D(64, 3, padding='same', activation='relu', kernel_regularizer=l2),
        layers.MaxPooling2D(),
        layers.Dropout(0.25),

        layers.Conv2D(128, 3, padding='same', activation='relu', kernel_regularizer=l2),
        layers.BatchNormalization(),
        layers.MaxPooling2D(),
        layers.Dropout(0.30),

        layers.GlobalAveragePooling2D(),
        layers.Dense(128, activation='relu', kernel_regularizer=l2),
        layers.Dropout(0.40),
        layers.Dense(num_classes, activation='softmax')
    ], name='dementia_cnn')

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

model = build_model(input_shape=(*IMG_SIZE, 3), num_classes=num_classes)
model.summary()


## 5) Train with callbacks


In [None]:
os.makedirs('artifacts', exist_ok=True)
checkpoint_path = 'artifacts/best_dementia_cnn.keras'

callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-6),
    keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_loss', save_best_only=True)
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)


## 6) Learning curves


In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))

ax[0].plot(history.history['loss'], label='Train Loss')
ax[0].plot(history.history['val_loss'], label='Val Loss')
ax[0].set_title('Loss Curve')
ax[0].set_xlabel('Epoch')
ax[0].legend()

ax[1].plot(history.history['accuracy'], label='Train Acc')
ax[1].plot(history.history['val_accuracy'], label='Val Acc')
ax[1].set_title('Accuracy Curve')
ax[1].set_xlabel('Epoch')
ax[1].legend()

plt.tight_layout()
plt.show()


## 7) Evaluation (test set)


In [None]:
test_loss, test_acc = model.evaluate(test_ds, verbose=0)
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_acc:.4f}')

y_true, y_pred = [], []
for x_batch, y_batch in test_ds:
    probs = model.predict(x_batch, verbose=0)
    y_true.extend(np.argmax(y_batch.numpy(), axis=1))
    y_pred.extend(np.argmax(probs, axis=1))

y_true = np.array(y_true)
y_pred = np.array(y_pred)

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Test Set)')
plt.show()

print(classification_report(y_true, y_pred, target_names=class_names, digits=4))


## 8) Inference on new images


In [None]:
def predict_image(image_path, model, class_names, img_size=(128, 128)):
    img = keras.utils.load_img(image_path, target_size=img_size)
    x = keras.utils.img_to_array(img)
    x = x / 255.0
    x = np.expand_dims(x, axis=0)

    probs = model.predict(x, verbose=0)[0]
    idx = int(np.argmax(probs))
    confidence = float(probs[idx]) * 100

    plt.figure(figsize=(4, 4))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f'{class_names[idx]} ({confidence:.2f}%)')
    plt.show()

    return class_names[idx], confidence

# مثال:
# label, conf = predict_image('data/new_cases/sample_1.jpg', model, class_names, IMG_SIZE)
# print(label, conf)


## 9) Save final model


In [None]:
final_model_path = 'artifacts/dementia_cnn_final.keras'
model.save(final_model_path)
print('Model saved to:', final_model_path)
print('Best checkpoint:', checkpoint_path)
