In [7]:
import os
import yaml
import torch
import pandas as pd
import numpy as np
import pickle
import logging
from sklearn.model_selection import StratifiedKFold, train_test_split

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Montar o configurar Google Drive sólo si corresponde
# ...

# Paths del proyecto
project_dir = '/home/diego/Escritorio/santiago/1st_paper/116ROIs'  # o tu path local
csv_path = os.path.join(project_dir, 'DataBaseSubjects.csv')
tensor_data_dir = os.path.join(project_dir, 'TensorData')

# Cargar config
config_path = os.path.join(project_dir, 'config.yaml')
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

seed = config.get('seed', 42)
np.random.seed(seed)
torch.manual_seed(seed)

# Leer CSV
subjects_df = pd.read_csv(csv_path)

# ▶ (Opcional) unificar MCI, EMCI, LMCI en MCI, si lo requieres:
def map_3class_group(x):
    if x == 'AD': return 'AD'
    elif x == 'CN': return 'CN'
    else: return 'Others'
#subjects_df['ResearchGroup'] = subjects_df['ResearchGroup'].apply(map_3class_group)

# Crear etiqueta combinada (ResearchGroup, Sex) para estratificar
subjects_df['Group_Sex'] = subjects_df['ResearchGroup'].astype(str) + '_' + subjects_df['Sex'].astype(str)

# Construir diccionario con rutas .pt
# Ejemplo de cómo se construyen las rutas:

grouped_data = subjects_df.groupby(['ResearchGroup','Sex'])['SubjectID'].apply(list).to_dict()

tensor_groups = {}
for (group, sex), subject_ids in grouped_data.items():
    file_paths = []
    for sid in subject_ids:
        # Usar la convención: {ResearchGroup}_tensor_{SubjectID}.pt
        if not os.path.exists(tensor_data_dir):
            logging.warning(f"Missing tensor data dir: {tensor_data_dir}")
            continue

        if group not in ['AD', 'CN']:
            group = 'Others'
        fp = os.path.join(tensor_data_dir, f"{group}_tensor_{sid}.pt")
        file_paths.append(fp)
    tensor_groups[f"{group}_{sex}"] = file_paths


# ---------------------------------------------------------------------------------
# Funciones para cargar/preprocesar (tomadas de tu script actual):
# ---------------------------------------------------------------------------------

def zero_diagonals(tensor):
    idx = torch.arange(tensor.size(1))
    tensor[:, idx, idx] = 0
    return tensor

def load_tensor(fp):
    if not os.path.exists(fp):
        logging.warning(f"Missing tensor file: {fp}")
        return None

    try:
        data = torch.load(fp, weights_only=False)
        if isinstance(data, np.ndarray):
            data = torch.tensor(data)
        if not isinstance(data, torch.Tensor):
            logging.warning(f"Unexpected data format: {type(data)} in {fp}")
            return None

        # Aquí ejemplo: zero diagonals
        data = zero_diagonals(data)

        return data
    except Exception as e:
        logging.error(f"Error loading {fp}: {e}")
        return None

def load_tensors(file_paths, batch_size=16):
    """
    Carga tensores de una lista de rutas en batches.
    Devuelve un único tensor concatenado, shape: (N, C, H, W)
    """
    all_tensors = []
    for i in range(0, len(file_paths), batch_size):
        batch_paths = file_paths[i:i+batch_size]
        batch = [load_tensor(p) for p in batch_paths]
        batch = [t for t in batch if t is not None]
        if batch:
            all_tensors.extend(batch)

    if len(all_tensors) == 0:
        return None
    return torch.stack(all_tensors)

def compute_mean_std_per_channel(dataset):
    mean = dataset.mean(dim=(0,2,3))
    std = dataset.std(dim=(0,2,3))
    std[std==0] = 1e-9
    return mean, std

def normalize_dataset(dataset, mean, std):
    epsilon = 1e-9
    return (dataset - mean[None,:,None,None]) / (std[None,:,None,None] + epsilon)

# ---------------------------------------------------------------------------------
# Enfoque Cross-Validation:
# - Outer folds (p. ej., 5) para dividir en train+val vs. test
# - Dentro de cada train+val, hacemos un 80/20 para train vs. val
# ---------------------------------------------------------------------------------

# 1) Construir una lista total de (filepath, label) para luego estratificar
X = []
y = []
for gsex, paths in tensor_groups.items():
    # gsex = "MCI_F", "AD_M", etc.
    # O bien, podemos separar con un split real si lo deseas
    for fp in paths:
        X.append(fp)
        # la "etiqueta" para estratificar puede ser gsex (o sólo la parte de 'ResearchGroup')
        y.append(gsex)

X = np.array(X)
y = np.array(y)

# 2) StratifiedKFold
outer_folds = 5
outer_cv = StratifiedKFold(n_splits=outer_folds, shuffle=True, random_state=seed)

fold_idx = 1
for train_val_idx, test_idx in outer_cv.split(X, y):
    logging.info(f"\n======= FOLD {fold_idx}/{outer_folds} =======")
    fold_idx += 1

    # Rutas para train+val, test
    X_train_val, X_test = X[train_val_idx], X[test_idx]
    y_train_val, y_test = y[train_val_idx], y[test_idx]

    # 3) Sep. interna train/val (80/20)
    # Notar que aquí NO estamos estratificando por group_sex otra vez; si quieres, puedes hacerlo igual
    train_idx, val_idx = train_test_split(
        np.arange(len(X_train_val)),
        test_size=0.2,
        random_state=seed,
        stratify=y_train_val  # si deseas estratificar
    )
    X_train, X_val = X_train_val[train_idx], X_train_val[val_idx]
    y_train, y_val = y_train_val[train_idx], y_train_val[val_idx]

    # 4) Cargar tensores
    train_data = load_tensors(list(X_train))
    val_data = load_tensors(list(X_val))
    test_data = load_tensors(list(X_test))

    if train_data is None or val_data is None or test_data is None:
        logging.warning("Algún split no pudo cargar tensores. Se omite este fold.")
        continue

    # 5) Calcular stats en train
    train_mean, train_std = compute_mean_std_per_channel(train_data)
    # Normalizar
    train_data = normalize_dataset(train_data, train_mean, train_std)
    val_data = normalize_dataset(val_data, train_mean, train_std)
    test_data = normalize_dataset(test_data, train_mean, train_std)

    # 6) Guardar o procesar, según necesites
    # Por ejemplo, guardamos en una carpeta fold_k
    fold_dir = os.path.join(project_dir, f'fold_{fold_idx-1}')
    os.makedirs(fold_dir, exist_ok=True)

    torch.save(train_data, os.path.join(fold_dir, 'train_data.pt'))
    torch.save(val_data,   os.path.join(fold_dir, 'val_data.pt'))
    torch.save(test_data,  os.path.join(fold_dir, 'test_data.pt'))

    logging.info(f"Fold {fold_idx-1} guardado en {fold_dir}")

logging.info("Cross-Validation finalizado.")


2025-03-27 18:11:50,814 - INFO - 
2025-03-27 18:11:52,278 - INFO - Fold 1 guardado en /home/diego/Escritorio/santiago/1st_paper/116ROIs/fold_1
2025-03-27 18:11:52,278 - INFO - 
2025-03-27 18:11:53,406 - INFO - Fold 2 guardado en /home/diego/Escritorio/santiago/1st_paper/116ROIs/fold_2
2025-03-27 18:11:53,406 - INFO - 
2025-03-27 18:11:54,340 - INFO - Fold 3 guardado en /home/diego/Escritorio/santiago/1st_paper/116ROIs/fold_3
2025-03-27 18:11:54,341 - INFO - 
2025-03-27 18:11:55,264 - INFO - Fold 4 guardado en /home/diego/Escritorio/santiago/1st_paper/116ROIs/fold_4
2025-03-27 18:11:55,265 - INFO - 
2025-03-27 18:11:56,212 - INFO - Fold 5 guardado en /home/diego/Escritorio/santiago/1st_paper/116ROIs/fold_5
2025-03-27 18:11:56,213 - INFO - Cross-Validation finalizado.
