In [1]:
import os
import pandas as pd
import torch
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from tqdm import tqdm
import logging
from sklearn.model_selection import StratifiedKFold, train_test_split

# --- Project configuration ---
project_dir = '/home/diego/Escritorio/santiago/1st_paper/116ROIs'
os.chdir(project_dir)

# Load CSV file
csv_path = os.path.join(project_dir, 'DataBaseSubjects.csv')
subjects_df = pd.read_csv(csv_path)

# Unificar clases en AD, CN, MCI
def map_3class_group(x):
    if x == 'AD':
        return 'AD'
    elif x == 'CN':
        return 'CN'
    else:
        # cualquier MCI, EMCI o LMCI se unifica a 'MCI'
        return 'MCI'

subjects_df['ThreeClassLabel'] = subjects_df['ResearchGroup'].apply(map_3class_group)

# Combinar sexo y nueva etiqueta en un único campo para la estratificación
subjects_df['Gender_ThreeClass'] = subjects_df['Sex'].astype(str) + '_' + subjects_df['ThreeClassLabel'].astype(str)

print("\nRecuento de clases tras unificar (AD, CN, MCI):")
print(subjects_df['ThreeClassLabel'].value_counts())

print("\nEjemplo de cómo quedaría la estratificación en 5 folds...")

# Ejemplo: 5 folds estratificados teniendo en cuenta también sexo
outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for outer_fold, (train_val_idx, test_idx) in enumerate(outer_cv.split(subjects_df, subjects_df['Gender_ThreeClass'])):
    train_val_df = subjects_df.iloc[train_val_idx]
    test_df = subjects_df.iloc[test_idx]

    # Dividimos train_val en train y val
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=0.2,
        stratify=train_val_df['Gender_ThreeClass'],
        random_state=outer_fold
    )

    print(f"\n=== Outer Fold {outer_fold + 1} ===")
    print("Train set distribution (ThreeClassLabel):\n", train_df['ThreeClassLabel'].value_counts())
    print("Val set distribution (ThreeClassLabel):\n", val_df['ThreeClassLabel'].value_counts())
    print("Test set distribution (ThreeClassLabel):\n", test_df['ThreeClassLabel'].value_counts())






Recuento de clases tras unificar (AD, CN, MCI):
ThreeClassLabel
MCI    168
AD      95
CN      89
Name: count, dtype: int64

Ejemplo de cómo quedaría la estratificación en 5 folds...

=== Outer Fold 1 ===
Train set distribution (ThreeClassLabel):
 ThreeClassLabel
MCI    107
AD      61
CN      56
Name: count, dtype: int64
Val set distribution (ThreeClassLabel):
 ThreeClassLabel
MCI    27
CN     15
AD     15
Name: count, dtype: int64
Test set distribution (ThreeClassLabel):
 ThreeClassLabel
MCI    34
AD     19
CN     18
Name: count, dtype: int64

=== Outer Fold 2 ===
Train set distribution (ThreeClassLabel):
 ThreeClassLabel
MCI    107
AD      61
CN      56
Name: count, dtype: int64
Val set distribution (ThreeClassLabel):
 ThreeClassLabel
MCI    27
AD     15
CN     15
Name: count, dtype: int64
Test set distribution (ThreeClassLabel):
 ThreeClassLabel
MCI    34
AD     19
CN     18
Name: count, dtype: int64

=== Outer Fold 3 ===
Train set distribution (ThreeClassLabel):
 ThreeClassLabel
MC