In [None]:
from google.colab import drive
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator

## Load Data

In [None]:
#drive.mount('/content/drive')
#!unzip /content/drive/MyDrive/archive.zip -d /content/data

## Define Data Splitting Function

1. Split data 70/15/15 : Train/Validation/Test
2. Define robust per-image normalization
3. Create ImageDataGeneratory objects to efficiently load the data

In [None]:
# ----------------------------
# 1. Paths
# ----------------------------
SOURCE_DIR = r'/content/data/combined_images'
OUTPUT_DIR = r'/content/data'

CLASSES = [
    'MildDemented',
    'ModerateDemented',
    'NonDemented',
    'VeryMildDemented'
]

# ----------------------------
# 2. Dataset split (run once)
# ----------------------------
def split_dataset():
    all_images = []
    all_labels = []

    for class_name in CLASSES:
        class_dir = os.path.join(SOURCE_DIR, class_name)
        images = [os.path.join(class_dir, f) for f in os.listdir(class_dir)]
        all_images.extend(images)
        all_labels.extend([class_name] * len(images))

    # 70% train, 30% temp
    train_imgs, temp_imgs, train_labels, temp_labels = train_test_split(
        all_images,
        all_labels,
        test_size=0.3,
        stratify=all_labels,
        random_state=42
    )

    # 15% val, 15% test
    val_imgs, test_imgs, val_labels, test_labels = train_test_split(
        temp_imgs,
        temp_labels,
        test_size=0.5,
        stratify=temp_labels,
        random_state=42
    )

    for split_name, images, labels in [
        ('train', train_imgs, train_labels),
        ('validation', val_imgs, val_labels),
        ('test', test_imgs, test_labels)
    ]:
        for img_path, label in zip(images, labels):
            dest_dir = os.path.join(OUTPUT_DIR, split_name, label)
            os.makedirs(dest_dir, exist_ok=True)
            shutil.copy(img_path, dest_dir)

    print(f"Train: {len(train_imgs)}")
    print(f"Val:   {len(val_imgs)}")
    print(f"Test:  {len(test_imgs)}")

split_dataset()  