# **CNN 2.5D**

In [None]:
# ==== Standard Libraries ====
import os, time, warnings

# ==== Scientific & Data Handling ====
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ==== PyTorch Core ====
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
from torchvision.models import resnet18

# ==== ML & Evaluation ====
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    balanced_accuracy_score,
    roc_auc_score
)
from sklearn.utils import resample
from statsmodels.stats.proportion import proportions_ztest

# ==== Utilities ====
from tqdm import tqdm

# ==== Reproducibility ====
SEED = 42

# ==== Device Handling (DirectML + fallback CPU) ====
try:
    import torch_directml
    DEVICE = torch_directml.device()
    print("Using DirectML device:", DEVICE)
except Exception as e:
    DEVICE = torch.device("cpu")
    print("DirectML no disponible. Usando CPU:", e)

# ==== Warnings ====
warnings.filterwarnings('ignore')

In [None]:
CSV_PATH = '../Data/Tabular.csv'
FOLDER_PATH = r"C:\Users\usuario\MRI\IMAGES_npy"

OUTPUT_DIR = "../Models_Output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# [] Atributos de las imégenes
shapes, means, stds, size = [], [], [], []

for f in os.listdir(FOLDER_PATH):
    if f.endswith(".npy"):
        img = np.load(os.path.join(FOLDER_PATH, f))
        shapes.append(img.shape)
        means.append(img.mean())
        stds.append(img.std())
        size.append(img.size)

# Contar shapes únicos
shapes_unicos = set(shapes)
size_unicos = set(size)
print("Shapes únicos encontrados:", shapes_unicos)
print("Total de imágenes:", len(shapes))
print(f"Media global promedio: {np.mean(means):.4f}")
print(f"Desviación global promedio: {np.mean(stds):.4f}")
print(f"Size únicos: {size_unicos}")

Shapes únicos encontrados: {(160, 192, 192)}
Total de imágenes: 220
Media global promedio: 0.0000
Desviación global promedio: 1.0000
Size únicos: {5898240}


In [None]:
# LOAD CSV
# ============================
df = pd.read_csv(CSV_PATH, dtype={'sujeto_id': str})
df = df.dropna(subset=['is_dementia'])

df['sujeto_id'] = df['sujeto_id'].astype(str).str.strip()
df['imagen_id'] = df['imagen_id'].astype(str).str.strip()

# FULL FILE PATH
def build_path(row):
    filename = f"{row['sujeto_id']}_{row['imagen_id']}.npy"
    return os.path.join(FOLDER_PATH, filename)

df['file_path'] = df.apply(build_path, axis=1)

df = df[['file_path', 'sujeto_id', 'is_dementia']].rename(
    columns={'is_dementia': 'label'}
)


In [None]:
# 1. Distribución básica de is_dementia

class_distribution = df['label'].value_counts().sort_index()
print("=== DISTRIBUCIÓN DE CLASES ===")
print(f"Clase 0 (No Dementia): {class_distribution.get(0, 0)} imágenes - {class_distribution.get(0, 0)/len(df)*100:.2f}%")
print(f"Clase 1 (Dementia): {class_distribution.get(1, 0)} imágenes - {class_distribution.get(1, 0)/len(df)*100:.2f}%")
class_distribution = df['label'].value_counts().sort_index()
N = len(df)
count_dementia = class_distribution.get(1, 0)
expected_proportion = 0.5 
z_stat, p_value = proportions_ztest(count=count_dementia, 
                                    nobs=N, 
                                    value=expected_proportion, 
                                    alternative='two-sided')
print(f"Estadístico Z: {z_stat:.4f}")
print(f"Valor p: {p_value:.4e}")
# Conclusión
alpha = 0.05
if p_value < alpha:
    print(f"\nConclusión: Se rechaza H0. La proporción de la Clase 1 ({count_dementia/N*100:.2f}%) es SIGNIFICATIVAMENTE diferente de {expected_proportion*100:.0f}%.")
else:
    print(f"\nConclusión: No se rechaza H0. No hay evidencia suficiente para decir que la proporción es diferente de {expected_proportion*100:.0f}%.")

=== DISTRIBUCIÓN DE CLASES ===
Clase 0 (No Dementia): 158 imágenes - 71.82%
Clase 1 (Dementia): 62 imágenes - 28.18%
Estadístico Z: -7.1933
Valor p: 6.3244e-13

Conclusión: Se rechaza H0. La proporción de la Clase 1 (28.18%) es SIGNIFICATIVAMENTE diferente de 50%.


In [None]:
# [PASO 1] Group stratified split (returns DataFrames)
# ---------------------------

def group_stratified_split(records):
    """
    Divide los datos en train/val/test respetando:
        - Que todas las imágenes de un mismo sujeto estén en un mismo conjunto.
        - La distribución de clases (estratificación) por sujeto.

    Args:
        records (list of dict): Lista de registros (por ejemplo, resultado de df.to_dict("records")).
        seed (int): Semilla para reproducibilidad.

    Returns:
        train_df, val_df, test_df (pd.DataFrame): DataFrames de cada conjunto.
    """

    # Convertimos la lista de registros a DataFrame
    df_local = pd.DataFrame(records)

    # Obtenemos la etiqueta por sujeto: promedio de labels redondeado a entero
    subj_lab = df_local.groupby("sujeto_id")["label"].agg(lambda x: int(round(x.mean())))
    subjects = subj_lab.index.to_list()   # lista de IDs únicos
    y = subj_lab.values                   # etiquetas por sujeto

    # Dividir en train (70%) y resto (30%) respetando grupos
    gss = GroupShuffleSplit(n_splits=1, train_size=0.7, random_state=SEED)
    train_idx, rest_idx = next(gss.split(subjects, y, groups=subjects))

    train_subj = [subjects[i] for i in train_idx]
    rest_subj  = [subjects[i] for i in rest_idx]

    # Estratificación para val/test (50%/50% del resto)
    rest_labels = [subj_lab[s] for s in rest_subj]
    val_subj, test_subj = train_test_split(
        rest_subj,
        test_size=0.5,
        random_state=SEED,
        stratify=rest_labels
    )

    # Función para extraer filas de cada conjunto
    def pick(subjects_list):
        return df_local[df_local["sujeto_id"].isin(subjects_list)].reset_index(drop=True)

    # Retornamos DataFrames
    return pick(train_subj), pick(val_subj), pick(test_subj)


In [None]:
# [PASO 2] Construcción del DataSet MRI 2.5D
# ---------------------------

class MRI2p5DDataset(Dataset):
    def __init__(self, df, n_slices=32, target_size=(224,224), augment=False):
        """
        df: DataFrame con columnas 'file_path' y 'label'
        n_slices: número de cortes a tomar por volumen
        target_size: tamaño HxW de las imágenes
        augment: aplicar data augmentation
        """
        self.records = df.to_dict("records")
        self.n_slices = n_slices
        self.target_size = target_size
        self.augment = augment

        # Transforms
        self.augment_tf = T.Compose([
            T.RandomRotation(10),
            T.RandomResizedCrop(target_size, scale=(0.9,1.0))
        ])
        self.base_tf = T.Compose([
            T.Resize(target_size),
            T.ToTensor()
        ])
        
    def __len__(self):
        return len(self.records)
    
    def __getitem__(self, idx):
        rec = self.records[idx]
        vol = np.load(rec["file_path"])  # (D,H,W) 3D array, Cortes-Altura-Ancho

        # Elegir n_slices uniformemente
        D = vol.shape[0]
        idxs = np.linspace(0, D-1, self.n_slices).astype(int)
        slices = vol[idxs]  # (n_slices,H,W)

        # Convertir cada slice a tensor
        imgs = []
        for s in slices:
            # Normalizar a [0,255] y convertir a PIL
            s_img = ((s - s.min())/(s.max()-s.min()+1e-6) * 255).astype(np.uint8)
            img = Image.fromarray(s_img).convert("L")  # 1 canal
            if self.augment:
                img = self.augment_tf(img)
            img = self.base_tf(img)  # (1,H,W)
            imgs.append(img)
        
        # Apilar slices como canales -> (n_slices,H,W)
        input_tensor = torch.cat(imgs, dim=0)
        
        label = torch.tensor(rec["label"], dtype=torch.float32)
        return input_tensor, label

In [None]:
# [PASO 3] Modelo CNN 2.5D
# ---------------------------

class CNN2p5D(nn.Module):
    """
    CNN simple para volúmenes 2.5D: toma n_slices como canales.
    Más ligera que ResNet, menos propensa a overfitting.
    """
    def __init__(self, n_slices=32):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(n_slices, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),         # 112x112

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),         # 56x56

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),         # 28x28

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),         # 14x14
        )

        # Clasificador final
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 14 * 14, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x.squeeze(1)


In [None]:
# [PASO 4] Entrenamiento por epoca y evaluacion
# ---------------------------

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()                    # Coloca el modelo en modo entrenamiento
    total_loss = 0
    total = 0
    pbar = tqdm(loader, desc="Train", leave=False)
    
    for X, y in pbar:
        X, y = X.to(device), y.to(device)  # Enviar batch a GPU/CPU
        optimizer.zero_grad()              # Reiniciar gradientes

        logits = model(X)                  # Forward
        loss = criterion(logits, y)        # Calcular pérdida
        loss.backward()                     # Backpropagation
        optimizer.step()                    # Actualizar pesos

        total_loss += loss.item() * y.size(0)  # Acumular pérdida ponderada por batch
        total += y.size(0)
        pbar.set_postfix({"batch_loss": loss.item()})  # Mostrar pérdida por batch

    return total_loss / total  # Pérdida promedio por sample


def evaluate(model, loader, device):
    model.eval()                       # Coloca el modelo en modo evaluación
    all_preds, all_labels = [], []

    with torch.no_grad():               # No computa gradientes
        for X, y in tqdm(loader, desc="Eval", leave=False):
            X, y = X.to(device), y.to(device)
            logits = model(X)
            probs = torch.sigmoid(logits)  # Convertir logit a probabilidad
            preds = (probs > 0.5).float()  # Umbral 0.5 para clasif binaria

            all_preds.append(probs.cpu())
            all_labels.append(y.cpu())

    # Concatenar todos los batches
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Métricas
    acc = (all_preds.round() == all_labels).float().mean().item()   # Accuracy
    bal_acc = balanced_accuracy_score(all_labels, all_preds.round())  # Balanced Accuracy
    try:
        auc = roc_auc_score(all_labels, all_preds)                   # AUC ROC
    except ValueError:
        auc = float('nan')
    
    return acc, bal_acc, auc

In [None]:
# [PASO 5] DataLoaders
# ---------------------------
def prepare_dataloaders(df, batch_size, n_slices=32, target_size=(224,224)):
    records = df.to_dict("records")
    train_df, val_df, test_df = group_stratified_split(records)

    # Mostrar distribución original
    print("=== Distribución original en Train ===")
    print(train_df['label'].value_counts())

    # Oversampling de clase minoritaria
    df_min = train_df[train_df['label'] == 1]
    df_maj = train_df[train_df['label'] == 0]

    df_min_upsampled = resample(
        df_min,
        replace=True,
        n_samples=len(df_maj),
        random_state=42
    )

    train_df_bal = pd.concat([df_maj, df_min_upsampled]).sample(frac=1, random_state=42)

    # Mostrar distribución después de oversampling
    print("=== Distribución después de oversampling en Train ===")
    print(train_df_bal['label'].value_counts())

    # Crear datasets
    train_dataset = MRI2p5DDataset(train_df_bal,
                                   n_slices=n_slices, target_size=target_size, augment=True)
    val_dataset   = MRI2p5DDataset(val_df, n_slices=n_slices, target_size=target_size, augment=False)
    test_dataset  = MRI2p5DDataset(test_df, n_slices=n_slices, target_size=target_size, augment=False)

    # Crear DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size)

    # Mostrar tamaño de cada loader
    print(f"\nTrain Loader: {len(train_loader.dataset)} ejemplos, {len(train_loader)} batches")
    print(f"Validation Loader: {len(val_loader.dataset)} ejemplos, {len(val_loader)} batches")
    print(f"Test Loader: {len(test_loader.dataset)} ejemplos, {len(test_loader)} batches")

    return train_loader, val_loader, test_loader

In [None]:
# [PASO 6] Modulo de entrenamiento
# ---------------------------
def train_model(
        model, train_loader, val_loader, device, 
        epochs, lr, weight_decay, early_stopping_patience):
    """
    Entrena el modelo y retorna listas de pérdidas y métricas.
    """
    print(f"\n⏺️ Entrenando modelo: \nÉpocas {epochs} | LR-WD {lr}-{ weight_decay} | Patience {early_stopping_patience}")
    pos = 62
    neg = 158
    pos_weight = torch.tensor(neg / pos)  # 158/62 ~ 2.55

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    train_losses, val_bal_accs, val_aucs = [], [], []
    best_bal_acc = 0
    patience_counter = 0

    print("▶️ Start train")
    time_all = time.time()

    for ep in range(epochs):
        time_ep = time.time()
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        time_train = time.time() - time_ep
        acc, bal_acc, auc = evaluate(model, val_loader, device)
        

        train_losses.append(train_loss)
        val_bal_accs.append(bal_acc)
        val_aucs.append(auc)

        print(f"Epoch {ep+1}/{epochs} | TrainLoss={train_loss:.4f} || Validation:  Acc={acc:.4f} | Balance={bal_acc:.4f} | AUC={auc:.4f} || Time: {time_train:.1f}")

        # Scheduler
        scheduler.step(train_loss)

        # Early stopping + guardado del mejor modelo
        if bal_acc > best_bal_acc:
            best_bal_acc = bal_acc
            patience_counter = 0
            model_path = f"../Models_Output/model_{epochs}_{ep}_{early_stopping_patience}.pth"
            torch.save(model.state_dict(), model_path)
            print(">> Mejor modelo guardado")
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print(">> Early stopping activado")
                break
    
    print(f"Tiempo total de entrenamiento: {(time.time() - time_all)/60:.2f} minutos")
    return train_losses, val_bal_accs, val_aucs, model_path


In [None]:
# [PASO 7] Modulo de evaluación final
# ---------------------------
def evaluate_final(model, test_loader, device, model_path):
    """
    Carga el mejor modelo guardado y devuelve métricas sobre el test set,
    incluyendo matriz de confusión y reporte de precision/recall/F1.
    """
    model.load_state_dict(torch.load(model_path))
    model.eval()

    all_preds, all_labels = [], []

    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    acc = (all_preds == all_labels).float().mean().item()
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    try:
        auc = roc_auc_score(all_labels, all_preds)
    except ValueError:
        auc = float('nan')

    print(f"Test Acc={acc:.4f} | Test Balanced Acc={bal_acc:.4f} | Test AUC={auc:.4f}")

    # Matriz de confusión
    cm = confusion_matrix(all_labels, all_preds)
    print("Matriz de confusión:")
    print(cm)

    # Reporte de Precision / Recall / F1
    report = classification_report(all_labels, all_preds, target_names=["No Dementia", "Dementia"])
    print("Reporte de clasificación:")
    print(report)

    return acc, bal_acc, auc, cm, report


In [None]:
# ---------------------------
# MAIN
# ---------------------------
def main(df, device, EPOCHS, LR, WEIGHT, EARLY, BATCH):
    train_loader, val_loader, test_loader = prepare_dataloaders(df, BATCH)

    model = CNN2p5D(n_slices=32).to(device)
    
    train_losses, val_bal_accs, val_aucs, model_path = train_model(
        model, train_loader, val_loader, device, 
        EPOCHS, LR, WEIGHT, EARLY)
    # Curvas
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='Train Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)

    plt.subplot(1,2,2)
    plt.plot(val_bal_accs, label='Val Balanced Acc')
    plt.plot(val_aucs, label='Val AUC')
    plt.xlabel('Epoch'); plt.ylabel('Score'); plt.legend(); plt.grid(True)
    plt.show()
    
    # Evaluación final
    evaluate_final(model, test_loader, device, model_path)
