In [22]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import class_weight

In [4]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ninadaithal/imagesoasis")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/imagesoasis


In [33]:
# Constants
DATA_DIR = path+"/Data"
IMG_HEIGHT = 248  # Original image dimensions from paper's model
IMG_WIDTH = 496
BATCH_SIZE = 32
EPOCHS = 10
SEED = 42


In [6]:
# Load and prepare data
file_paths = []
labels = []
for class_name in os.listdir(DATA_DIR):
    class_dir = os.path.join(DATA_DIR, class_name)
    if os.path.isdir(class_dir):
        images = [os.path.join(class_dir, img) for img in os.listdir(class_dir)]
        file_paths.extend(images)
        labels.extend([class_name]*len(images))

df = pd.DataFrame({'file_path': file_paths, 'label': labels})

In [7]:
df

Unnamed: 0,file_path,label
0,/kaggle/input/imagesoasis/Data/Non Demented/OA...,Non Demented
1,/kaggle/input/imagesoasis/Data/Non Demented/OA...,Non Demented
2,/kaggle/input/imagesoasis/Data/Non Demented/OA...,Non Demented
3,/kaggle/input/imagesoasis/Data/Non Demented/OA...,Non Demented
4,/kaggle/input/imagesoasis/Data/Non Demented/OA...,Non Demented
...,...,...
86432,/kaggle/input/imagesoasis/Data/Mild Dementia/O...,Mild Dementia
86433,/kaggle/input/imagesoasis/Data/Mild Dementia/O...,Mild Dementia
86434,/kaggle/input/imagesoasis/Data/Mild Dementia/O...,Mild Dementia
86435,/kaggle/input/imagesoasis/Data/Mild Dementia/O...,Mild Dementia


In [19]:
train_df, test_df = train_test_split(df, test_size=0.2,
                                   stratify=df['label'], random_state=SEED)
train_df, val_df = train_test_split(train_df, test_size=0.2,
                                  stratify=train_df['label'], random_state=SEED)

In [9]:
# Calculate class weights
class_weights = class_weight.compute_class_weight(
    'balanced',
    classes=np.unique(train_df['label']),
    y=train_df['label']
)
class_weights = dict(enumerate(class_weights))

In [20]:
# Data generators with paper-specified augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

test_datagen = ImageDataGenerator(rescale=1./255)

In [27]:
model = models.Sequential([
    layers.Conv2D(32, (3,3), activation='relu',
                 input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
    layers.MaxPooling2D(2,2),

    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),

    layers.Conv2D(128, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),

    layers.Conv2D(256, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),

    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(4, activation='softmax')
])

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [28]:
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [30]:
train_generator = train_datagen.flow_from_dataframe(
    train_df,
    x_col='file_path',
    y_col='label',
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

val_generator = test_datagen.flow_from_dataframe(
    val_df,
    x_col='file_path',
    y_col='label',
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

Found 55319 validated image filenames belonging to 4 classes.
Found 13830 validated image filenames belonging to 4 classes.


In [None]:
# Training
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=EPOCHS,
    class_weight=class_weights
)

Epoch 1/10
[1m1729/1729[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1659s[0m 960ms/step - accuracy: 0.0781 - loss: 1.4237 - val_accuracy: 0.0578 - val_loss: 1.3463
Epoch 2/10
[1m1729/1729[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1657s[0m 958ms/step - accuracy: 0.2018 - loss: 1.3612 - val_accuracy: 0.0056 - val_loss: 1.3956
Epoch 3/10
[1m1729/1729[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1659s[0m 959ms/step - accuracy: 0.0485 - loss: 1.4099 - val_accuracy: 0.7777 - val_loss: 1.3685
Epoch 4/10
[1m1729/1729[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1657s[0m 958ms/step - accuracy: 0.4659 - loss: 1.3728 - val_accuracy: 0.0578 - val_loss: 1.3956
Epoch 5/10
[1m1729/1729[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1639s[0m 948ms/step - accuracy: 0.0955 - loss: 1.3908 - val_accuracy: 0.0056 - val_loss: 1.3922
Epoch 6/10
[1m 835/1729[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m13:16[0m 891ms/step - accuracy: 0.0825 - loss: 1.4017

In [None]:
test_generator = test_datagen.flow_from_dataframe(
    test_df,
    x_col='file_path',
    y_col='label',
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

In [None]:
# Metrics
test_loss, test_acc = model.evaluate(test_generator)
print(f"\nTest Accuracy: {test_acc*100:.2f}%")
print(f"Test Loss: {test_loss:.4f}")

In [None]:
# Predictions
y_pred = model.predict(test_generator).argmax(axis=1)
y_true = test_generator.classes


In [None]:
# Classification report
print("\nClassification Report:")
print(classification_report(y_true, y_pred,
                           target_names=test_generator.class_indices.keys()))

In [None]:
# Confusion matrix
plt.figure(figsize=(12,8))
sns.heatmap(confusion_matrix(y_true, y_pred),
           annot=True, fmt='d',
           cmap='Blues',
           xticklabels=test_generator.class_indices.keys(),
           yticklabels=test_generator.class_indices.keys())
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()