<a href="https://colab.research.google.com/github/RAvila-bioeng/M.R.AI/blob/main/alzheimer's.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kaggle



In [None]:

from google.colab import drive
drive.mount('/content/gdrive')

import os
os.environ['KAGGLE_CONFIG_DIR'] = "/content/gdrive/MyDrive/Kaggle_API"

# Descargar dataset
!kaggle datasets download -d ninadaithal/imagesoasis

# Descomprimir
!unzip -q imagesoasis.zip

import os
print("Contenido del directorio actual:", os.listdir("."))
print("Contenido de 'Data' si existe:", os.listdir("Data") if os.path.exists("Data") else "No existe 'Data'")


Mounted at /content/gdrive
Dataset URL: https://www.kaggle.com/datasets/ninadaithal/imagesoasis
License(s): apache-2.0
Downloading imagesoasis.zip to /content
 99% 1.22G/1.23G [00:08<00:00, 301MB/s]
100% 1.23G/1.23G [00:08<00:00, 152MB/s]
Contenido del directorio actual: ['.config', 'imagesoasis.zip', 'Data', 'Data_split_2cls', 'gdrive', 'sample_data']
Contenido de 'Data' si existe: ['Moderate Dementia', 'Non Demented', 'Mild Dementia', 'Very mild Dementia']


In [None]:
import os
import shutil
import random
from collections import defaultdict

original_dir = "./Data"
assert os.path.exists(original_dir), "No existe ./Data. Revisa la celda 1."

subset_dir = "./Data_subset"

# Empezar *limpio*
if os.path.exists(subset_dir):
    shutil.rmtree(subset_dir)
os.makedirs(subset_dir, exist_ok=True)

valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
N_PATIENTS_NON_DEMENTED = 120  # nº máximo de pacientes sanos

classes = [d for d in os.listdir(original_dir)
           if os.path.isdir(os.path.join(original_dir, d))]
print("Clases encontradas en Data:", classes)

for cls in classes:
    src_cls_dir = os.path.join(original_dir, cls)
    dst_cls_dir = os.path.join(subset_dir, cls)
    os.makedirs(dst_cls_dir, exist_ok=True)

    print(f"\nProcesando clase: {cls}")

    if "non" not in cls.lower():   # todas menos Non Demented
        imgs = [f for f in os.listdir(src_cls_dir)
                if f.lower().endswith(valid_exts)]
        print(f"  Copiando TODAS las imágenes ({len(imgs)})...")
        for fname in imgs:
            shutil.copy2(os.path.join(src_cls_dir, fname),
                         os.path.join(dst_cls_dir, fname))

    else:
        # Non Demented: reducir por nº de pacientes
        imgs = [f for f in os.listdir(src_cls_dir)
                if f.lower().endswith(valid_exts)]
        print(f"  Imágenes totales en {cls}: {len(imgs)}")
        print("  Ejemplos de nombres:", imgs[:5])

        images_by_patient = defaultdict(list)
        for fname in imgs:
            parts = fname.split('_')
            # OAS1_0097_MR1_mpr-3_127.jpg → paciente = "OAS1_0097"
            patient_id = "_".join(parts[:2]) if len(parts) >= 2 else parts[0]
            images_by_patient[patient_id].append(fname)

        print(f"  Pacientes totales detectados: {len(images_by_patient)}")

        patients = list(images_by_patient.keys())
        n_keep = min(N_PATIENTS_NON_DEMENTED, len(patients))
        selected_patients = random.sample(patients, n_keep)

        print(f"  Pacientes que vamos a conservar en {cls}: {n_keep}")

        count_imgs = 0
        for pid in selected_patients:
            for fname in images_by_patient[pid]:
                shutil.copy2(os.path.join(src_cls_dir, fname),
                             os.path.join(dst_cls_dir, fname))
                count_imgs += 1

        print(f"  Imágenes copiadas en {cls} (subset): {count_imgs}")

print("\n✅ Data_subset creado en:", subset_dir)

# Comprobar conteos
for cls in os.listdir(subset_dir):
    cls_path = os.path.join(subset_dir, cls)
    if os.path.isdir(cls_path):
        n = len([f for f in os.listdir(cls_path)
                 if f.lower().endswith(valid_exts)])
        print(f"{cls}: {n} imágenes en Data_subset")


Clases encontradas en Data: ['Moderate Dementia', 'Non Demented', 'Mild Dementia', 'Very mild Dementia']

Procesando clase: Moderate Dementia
  Copiando TODAS las imágenes (488)...

Procesando clase: Non Demented
  Imágenes totales en Non Demented: 67222
  Ejemplos de nombres: ['OAS1_0076_MR1_mpr-1_114.jpg', 'OAS1_0368_MR2_mpr-4_135.jpg', 'OAS1_0348_MR1_mpr-3_130.jpg', 'OAS1_0001_MR1_mpr-2_140.jpg', 'OAS1_0145_MR2_mpr-1_111.jpg']
  Pacientes totales detectados: 266
  Pacientes que vamos a conservar en Non Demented: 120
  Imágenes copiadas en Non Demented (subset): 30134

Procesando clase: Mild Dementia
  Copiando TODAS las imágenes (5002)...

Procesando clase: Very mild Dementia
  Copiando TODAS las imágenes (13725)...

✅ Data_subset creado en: ./Data_subset
Moderate Dementia: 488 imágenes en Data_subset
Non Demented: 30134 imágenes en Data_subset
Mild Dementia: 5002 imágenes en Data_subset
Very mild Dementia: 13725 imágenes en Data_subset


In [None]:
import os
import shutil
import random
from collections import defaultdict

# Directorio de entrada (ya balanceado por pacientes sanos)
subset_dir = "./Data_subset"

# Nuevo directorio de salida para el split binario
split_root = "./Data_split_2cls"
os.makedirs(split_root, exist_ok=True)

# Mapa de clases original -> clase lógica binaria
CLASS_MAP = {
    "Non Demented": "Non Demented",
    "Very mild Dementia": "Demented",
    "Mild Dementia": "Demented",
    "Moderate Dementia": "Demented",
}

logical_classes = sorted(set(CLASS_MAP.values()))
print("Logical classes:", logical_classes)

# Crear estructura de carpetas: train/val/test + clases
for split in ["train", "val", "test"]:
    for cls in logical_classes:
        os.makedirs(os.path.join(split_root, split, cls), exist_ok=True)

# Agrupar imágenes por (clase lógica, paciente)
class_patient_images = {cls: defaultdict(list) for cls in logical_classes}

for orig_class in os.listdir(subset_dir):
    orig_class_path = os.path.join(subset_dir, orig_class)
    if not os.path.isdir(orig_class_path):
        continue

    if orig_class not in CLASS_MAP:
        print(f"Saltando clase desconocida: {orig_class}")
        continue

    logical_class = CLASS_MAP[orig_class]

    for fname in os.listdir(orig_class_path):
        fpath = os.path.join(orig_class_path, fname)
        if not os.path.isfile(fpath):
            continue

        # Filtrar solo imágenes
        if not fname.lower().endswith((".jpg", ".jpeg", ".png")):
            continue

        # Extraer patient_id del nombre de archivo (adaptado al patrón OASIS)
        parts = fname.split("_")
        if len(parts) >= 2 and parts[0].startswith("OAS1"):
            patient_id = parts[0] + "_" + parts[1]
        else:
            # Fallback por si acaso
            patient_id = parts[0]

        class_patient_images[logical_class][patient_id].append(fpath)

# Función para hacer split por pacientes
def split_patients(patients, train_ratio=0.7, val_ratio=0.15, seed=42):
    random.seed(seed)
    patients = list(patients)
    random.shuffle(patients)

    n = len(patients)
    if n == 0:
        return [], [], []

    n_train = max(1, int(train_ratio * n))
    n_val = max(1, int(val_ratio * n))
    n_test = n - n_train - n_val

    # Ajustes por si n_test queda a 0
    if n_test <= 0:
        n_test = 1
        if n_val > 1:
            n_val -= 1
        else:
            n_train = max(1, n_train - 1)

    train_patients = patients[:n_train]
    val_patients = patients[n_train:n_train + n_val]
    test_patients = patients[n_train + n_val:]

    return train_patients, val_patients, test_patients

# Hacer el split y copiar archivos
for logical_class in logical_classes:
    patients_dict = class_patient_images[logical_class]
    patients = list(patients_dict.keys())
    print(f"\nClase '{logical_class}': {len(patients)} pacientes")

    train_p, val_p, test_p = split_patients(patients)

    print(f"  Train: {len(train_p)} pacientes")
    print(f"  Val:   {len(val_p)} pacientes")
    print(f"  Test:  {len(test_p)} pacientes")

    for split_name, split_pat_list in [("train", train_p),
                                       ("val", val_p),
                                       ("test", test_p)]:
        for pid in split_pat_list:
            for src_path in patients_dict[pid]:
                dst_path = os.path.join(
                    split_root,
                    split_name,
                    logical_class,
                    os.path.basename(src_path),
                )
                shutil.copy2(src_path, dst_path)

print("\nCopiado terminado. Comprobando número de imágenes por split y clase...\n")

# Comprobación final de conteos
for split in ["train", "val", "test"]:
    print(f"Split: {split}")
    for cls in logical_classes:
        cls_dir = os.path.join(split_root, split, cls)
        if not os.path.isdir(cls_dir):
            print(f"  {cls}: 0 imágenes (directorio no encontrado)")
            continue
        n_images = sum(
            1 for f in os.listdir(cls_dir)
            if f.lower().endswith((".jpg", ".jpeg", ".png"))
        )
        print(f"  {cls}: {n_images} imágenes")
    print()


Logical classes: ['Demented', 'Non Demented']

Clase 'Demented': 81 pacientes
  Train: 56 pacientes
  Val:   12 pacientes
  Test:  13 pacientes

Clase 'Non Demented': 120 pacientes
  Train: 84 pacientes
  Val:   18 pacientes
  Test:  18 pacientes

Copiado terminado. Comprobando número de imágenes por split y clase...

Split: train
  Demented: 13298 imágenes
  Non Demented: 21289 imágenes

Split: val
  Demented: 2867 imágenes
  Non Demented: 4331 imágenes

Split: test
  Demented: 3050 imágenes
  Non Demented: 4514 imágenes



In [None]:
for split in ["train", "val", "test"]:
    print(f"\n=== {split.upper()} ===")
    split_dir = os.path.join(split_root, split)
    for cls in os.listdir(split_dir):
        cls_path = os.path.join(split_dir, cls)
        if os.path.isdir(cls_path):
            n = len([f for f in os.listdir(cls_path)
                     if f.lower().endswith(valid_exts)])
            print(f"  {cls}: {n} imágenes")



=== TRAIN ===
  Non Demented: 21289 imágenes
  Demented: 13298 imágenes

=== VAL ===
  Non Demented: 4331 imágenes
  Demented: 2867 imágenes

=== TEST ===
  Non Demented: 4514 imágenes
  Demented: 3050 imágenes
