<a href="https://colab.research.google.com/github/RAvila-bioeng/M.R.AI/blob/main/alzheimer's.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kaggle



In [None]:

from google.colab import drive
drive.mount('/content/gdrive')

import os
os.environ['KAGGLE_CONFIG_DIR'] = "/content/gdrive/MyDrive/Kaggle_API"

# Descargar dataset
!kaggle datasets download -d ninadaithal/imagesoasis

# Descomprimir
!unzip -q imagesoasis.zip

import os
print("Contenido del directorio actual:", os.listdir("."))
print("Contenido de 'Data' si existe:", os.listdir("Data") if os.path.exists("Data") else "No existe 'Data'")


Mounted at /content/gdrive
Dataset URL: https://www.kaggle.com/datasets/ninadaithal/imagesoasis
License(s): apache-2.0
Downloading imagesoasis.zip to /content
 99% 1.22G/1.23G [00:05<00:00, 220MB/s]
100% 1.23G/1.23G [00:05<00:00, 224MB/s]
Contenido del directorio actual: ['.config', 'gdrive', 'Data', 'imagesoasis.zip', 'sample_data']
Contenido de 'Data' si existe: ['Moderate Dementia', 'Non Demented', 'Mild Dementia', 'Very mild Dementia']


In [None]:
import os
import shutil
import random
from collections import defaultdict

# Original dataset (after unzip)
original_dir = "./Data"
assert os.path.exists(original_dir), "Data folder not found. Check unzip step."

# Where weâ€™ll put the reduced version
subset_dir = "./Data_subset"

# Start CLEAN
if os.path.exists(subset_dir):
    shutil.rmtree(subset_dir)
os.makedirs(subset_dir, exist_ok=True)

valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
N_PATIENTS_NON_DEMENTED = 120   # ðŸ‘ˆ you can change this if needed

classes = [d for d in os.listdir(original_dir)
           if os.path.isdir(os.path.join(original_dir, d))]
print("Classes found:", classes)

for cls in classes:
    src_cls_dir = os.path.join(original_dir, cls)
    dst_cls_dir = os.path.join(subset_dir, cls)
    os.makedirs(dst_cls_dir, exist_ok=True)

    print(f"\nProcessing class: {cls}")

    if "non" not in cls.lower():   # any class that is NOT Non Demented
        imgs = [f for f in os.listdir(src_cls_dir)
                if f.lower().endswith(valid_exts)]
        print(f"  Copying ALL images ({len(imgs)})...")
        for fname in imgs:
            shutil.copy2(os.path.join(src_cls_dir, fname),
                         os.path.join(dst_cls_dir, fname))

    else:
        # Non Demented
        imgs = [f for f in os.listdir(src_cls_dir)
                if f.lower().endswith(valid_exts)]
        print(f"  Total images in {cls}: {len(imgs)}")
        print("  Example names:", imgs[:5])

        # Group by patient: OAS1_XXXX_MR... â†’ patient_id = "OAS1_XXXX"
        images_by_patient = defaultdict(list)
        for fname in imgs:
            parts = fname.split('_')
            patient_id = "_".join(parts[:2]) if len(parts) >= 2 else parts[0]
            images_by_patient[patient_id].append(fname)

        print(f"  Total patients detected: {len(images_by_patient)}")

        patients = list(images_by_patient.keys())
        n_keep = min(N_PATIENTS_NON_DEMENTED, len(patients))
        selected_patients = random.sample(patients, n_keep)

        print(f"  Patients weâ€™ll keep in {cls}: {n_keep}")

        count_imgs = 0
        for pid in selected_patients:
            for fname in images_by_patient[pid]:
                shutil.copy2(os.path.join(src_cls_dir, fname),
                             os.path.join(dst_cls_dir, fname))
                count_imgs += 1

        print(f"  Images copied in {cls} (subset): {count_imgs}")

print("\nâœ… Reduced dataset created at:", subset_dir)


Classes found: ['Moderate Dementia', 'Non Demented', 'Mild Dementia', 'Very mild Dementia']

Processing class: Moderate Dementia
  Copying ALL images (488)...

Processing class: Non Demented
  Total images in Non Demented: 67222
  Example names: ['OAS1_0097_MR1_mpr-3_127.jpg', 'OAS1_0136_MR1_mpr-4_144.jpg', 'OAS1_0317_MR1_mpr-2_155.jpg', 'OAS1_0212_MR1_mpr-2_106.jpg', 'OAS1_0162_MR1_mpr-1_128.jpg']
  Total patients detected: 266
  Patients weâ€™ll keep in Non Demented: 120
  Images copied in Non Demented (subset): 29585

Processing class: Mild Dementia
  Copying ALL images (5002)...

Processing class: Very mild Dementia
  Copying ALL images (13725)...

âœ… Reduced dataset created at: ./Data_subset


In [None]:
for cls in os.listdir(subset_dir):
    cls_path = os.path.join(subset_dir, cls)
    if os.path.isdir(cls_path):
        n = len([f for f in os.listdir(cls_path)
                 if f.lower().endswith(valid_exts)])
        print(f"{cls}: {n} images in subset")


Moderate Dementia: 488 images in subset
Non Demented: 29585 images in subset
Mild Dementia: 5002 images in subset
Very mild Dementia: 13725 images in subset


In [None]:
import os
import shutil
import random
from collections import defaultdict

subset_dir = "./Data_subset"
split_root = "./Data_split"

# Clean split dir
if os.path.exists(split_root):
    shutil.rmtree(split_root)
os.makedirs(split_root, exist_ok=True)

splits = ["train", "val", "test"]
for s in splits:
    os.makedirs(os.path.join(split_root, s), exist_ok=True)

valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')

classes = [d for d in os.listdir(subset_dir)
           if os.path.isdir(os.path.join(subset_dir, d))]
print("Classes in subset:", classes)

# Ratios for split
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

for cls in classes:
    cls_dir = os.path.join(subset_dir, cls)
    print(f"\nSplitting class: {cls}")

    # Group images by patient
    images_by_patient = defaultdict(list)
    for fname in os.listdir(cls_dir):
        if not fname.lower().endswith(valid_exts):
            continue
        parts = fname.split('_')
        patient_id = "_".join(parts[:2]) if len(parts) >= 2 else parts[0]
        images_by_patient[patient_id].append(fname)

    patients = list(images_by_patient.keys())
    random.shuffle(patients)
    n_total = len(patients)

    # Compute # of patients per split (ensure at least 1 in test if possible)
    n_train = max(1, int(TRAIN_RATIO * n_total))
    n_val = max(1, int(VAL_RATIO * n_total))
    n_test = n_total - n_train - n_val
    if n_test <= 0:
        n_test = 1
        n_train = max(1, n_train - 1)

    train_patients = patients[:n_train]
    val_patients = patients[n_train:n_train + n_val]
    test_patients = patients[n_train + n_val:]

    print(f"  Patients: total={n_total}, train={len(train_patients)}, val={len(val_patients)}, test={len(test_patients)}")

    # Create class dirs inside each split
    for s in splits:
        os.makedirs(os.path.join(split_root, s, cls), exist_ok=True)

    # Copy images to corresponding split
    def copy_group(pat_list, split_name):
        dst_base = os.path.join(split_root, split_name, cls)
        count = 0
        for pid in pat_list:
            for fname in images_by_patient[pid]:
                src = os.path.join(cls_dir, fname)
                dst = os.path.join(dst_base, fname)
                shutil.copy2(src, dst)
                count += 1
        print(f"    {split_name}: {count} images")

    copy_group(train_patients, "train")
    copy_group(val_patients, "val")
    copy_group(test_patients, "test")

print("\nâœ… Patient-wise train/val/test split created in:", split_root)


Classes in subset: ['Moderate Dementia', 'Non Demented', 'Mild Dementia', 'Very mild Dementia']

Splitting class: Moderate Dementia
  Patients: total=2, train=1, val=1, test=0
    train: 244 images
    val: 244 images
    test: 0 images

Splitting class: Non Demented
  Patients: total=120, train=84, val=18, test=18
    train: 21045 images
    val: 4148 images
    test: 4392 images

Splitting class: Mild Dementia
  Patients: total=21, train=14, val=3, test=4
    train: 3416 images
    val: 671 images
    test: 915 images

Splitting class: Very mild Dementia
  Patients: total=58, train=40, val=8, test=10
    train: 9394 images
    val: 1952 images
    test: 2379 images

âœ… Patient-wise train/val/test split created in: ./Data_split


In [None]:
for split in ["train", "val", "test"]:
    print(f"\n=== {split.upper()} ===")
    split_dir = os.path.join(split_root, split)
    for cls in os.listdir(split_dir):
        cls_path = os.path.join(split_dir, cls)
        if os.path.isdir(cls_path):
            n = len([f for f in os.listdir(cls_path)
                     if f.lower().endswith(valid_exts)])
            print(f"  {cls}: {n} images")



=== TRAIN ===
  Moderate Dementia: 244 images
  Non Demented: 21045 images
  Mild Dementia: 3416 images
  Very mild Dementia: 9394 images

=== VAL ===
  Moderate Dementia: 244 images
  Non Demented: 4148 images
  Mild Dementia: 671 images
  Very mild Dementia: 1952 images

=== TEST ===
  Moderate Dementia: 0 images
  Non Demented: 4392 images
  Mild Dementia: 915 images
  Very mild Dementia: 2379 images
