In [1]:
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from pathlib import Path
import logging
import os
import gc
from typing import Tuple, Optional, Dict


# --- 0. Global Configuration and Constants ---
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(module)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# --- Script Configuration ---
# Paths
BASE_OUTPUT_DIR = Path('./preprocessed_connectomes_for_dl') # Main directory to save preprocessed data
TENSOR_DATA_PATH = Path('/home/diego/Escritorio/desde_cero/AAL_116/AAL_116_fmri_tensor_1lag_NeuroEnhanced_v4/GLOBAL_TENSOR_AAL116_352subs_4ch_1lag.npz') # Path to the global .npz tensor file
METADATA_CSV_PATH = Path('/home/diego/Escritorio/desde_cero/AAL_116/DataBaseSubjects.csv') # Path to the CSV with subject metadata

# Preprocessing Parameters
N_FOLDS = 5  # Number of folds for Stratified K-Fold Cross-Validation
RANDOM_STATE = 42  # Seed for reproducibility in shuffling and splitting
AGE_COLUMN = 'Age' # Column name for age in metadata CSV
SEX_COLUMN = 'Sex' # Column name for sex in metadata CSV (e.g., 'Male', 'Female')
DIAGNOSIS_COLUMN = 'ResearchGroup' # Column name for diagnosis in metadata CSV
SUBJECT_ID_COLUMN = 'SubjectID' # Column name for subject ID in metadata CSV

# Label Mapping and Grouping for MCI
# As per plan: CN:0, MCI_grouped:1, AD:2
LABEL_MAPPING = {
    'CN': 0,
    'MCI': 1, 'LMCI': 1, 'EMCI': 1, # Grouping all MCI variants
    'AD': 2
}
TARGET_CLASSES = ['CN', 'MCI_grouped', 'AD'] # For reference

# --- Helper Functions ---

def load_and_align_data(tensor_path: Path, metadata_path: Path, subject_id_col: str) -> Tuple[Optional[np.ndarray], Optional[pd.DataFrame], Optional[np.ndarray]]:
    """
    Loads the global connectivity tensor and subject metadata, aligning them by subject ID.

    Args:
        tensor_path (Path): Path to the .npz file containing the global tensor and subject IDs.
        metadata_path (Path): Path to the CSV file with subject metadata.
        subject_id_col (str): Column name for subject IDs in the metadata CSV.

    Returns:
        Tuple[Optional[np.ndarray], Optional[pd.DataFrame], Optional[np.ndarray]]:
            - Aligned global tensor (N_subjects, N_channels, N_ROIs, N_ROIs).
            - Aligned metadata DataFrame.
            - Array of aligned subject IDs.
        Returns (None, None, None) if loading or alignment fails.
    """
    logger.info(f"Loading global tensor from: {tensor_path}")
    if not tensor_path.exists():
        logger.error(f"Tensor file not found: {tensor_path}")
        return None, None, None
    try:
        tensor_data_npz = np.load(tensor_path)
        global_tensor = tensor_data_npz['global_tensor_data'] # Key used during saving in previous script
        tensor_subject_ids = tensor_data_npz['subject_ids']
        # channel_names = tensor_data_npz['channel_names'] # Available if needed later
        logger.info(f"Global tensor loaded. Shape: {global_tensor.shape}, Subjects in tensor: {len(tensor_subject_ids)}")
    except Exception as e:
        logger.error(f"Error loading tensor data from {tensor_path}: {e}", exc_info=True)
        return None, None, None

    logger.info(f"Loading metadata from: {metadata_path}")
    if not metadata_path.exists():
        logger.error(f"Metadata file not found: {metadata_path}")
        return None, None, None
    try:
        metadata_df = pd.read_csv(metadata_path)
        metadata_df[subject_id_col] = metadata_df[subject_id_col].astype(str).str.strip()
        logger.info(f"Metadata loaded. Shape: {metadata_df.shape}, Unique subjects in metadata: {metadata_df[subject_id_col].nunique()}")
    except Exception as e:
        logger.error(f"Error loading metadata from {metadata_path}: {e}", exc_info=True)
        return None, None, None

    # Align metadata with tensor data based on subject IDs
    logger.info("Aligning tensor data with metadata...")
    tensor_sids_df = pd.DataFrame({subject_id_col: tensor_subject_ids})
    # Ensure tensor_subject_ids are strings for merging, similar to metadata_df
    tensor_sids_df[subject_id_col] = tensor_sids_df[subject_id_col].astype(str).str.strip()

    # Merge to get metadata for subjects present in the tensor, in the tensor's order
    aligned_metadata_df = pd.merge(tensor_sids_df, metadata_df, on=subject_id_col, how='left')

    if len(aligned_metadata_df) != len(tensor_subject_ids):
        logger.warning(f"Mismatch in length after merging: {len(aligned_metadata_df)} vs {len(tensor_subject_ids)}. Some tensor subjects might not be in metadata.")
    
    # Check for subjects in tensor but not in metadata (NaNs in merged columns other than SubjectID)
    missing_metadata_count = aligned_metadata_df.drop(columns=[subject_id_col]).isnull().all(axis=1).sum()
    if missing_metadata_count > 0:
        logger.warning(f"{missing_metadata_count} subjects from the tensor file do not have corresponding entries in the metadata file.")
        # Option: Filter out subjects with missing metadata if critical columns are NaN
        # For now, we proceed but this should be handled based on requirements.

    # Final check: ensure the order of aligned_metadata_df matches tensor_subject_ids
    # The merge with tensor_sids_df (which is ordered like the tensor) should preserve this.
    if not all(aligned_metadata_df[subject_id_col].values == tensor_subject_ids):
        logger.error("CRITICAL: Subject ID order mismatch between tensor and aligned metadata after merge. This should not happen with a left merge on tensor IDs.")
        # This would be a critical issue, re-evaluate merging strategy if it occurs.
        # For safety, re-index metadata to match tensor order if necessary, though left merge should handle it.
        # temp_aligned_df = aligned_metadata_df.set_index(subject_id_col).reindex(tensor_subject_ids).reset_index()
        # if not temp_aligned_df[subject_id_col].equals(pd.Series(tensor_subject_ids)):
        #     logger.error("Re-indexing failed to align subject IDs.")
        #     return None, None, None
        # aligned_metadata_df = temp_aligned_df

    logger.info(f"Data alignment complete. Final aligned subjects: {len(aligned_metadata_df)}")
    return global_tensor, aligned_metadata_df, tensor_subject_ids


def encode_labels(metadata_df: pd.DataFrame, diagnosis_col: str, label_mapping: Dict[str, int]) -> Optional[np.ndarray]:
    """
    Encodes diagnostic labels based on the provided mapping, grouping MCI subtypes.

    Args:
        metadata_df (pd.DataFrame): DataFrame containing subject metadata.
        diagnosis_col (str): Name of the column with diagnostic labels.
        label_mapping (Dict[str, int]): Dictionary to map string labels to integers.

    Returns:
        Optional[np.ndarray]: NumPy array of encoded integer labels. Returns None if encoding fails.
    """
    logger.info("Encoding diagnostic labels...")
    if diagnosis_col not in metadata_df.columns:
        logger.error(f"Diagnosis column '{diagnosis_col}' not found in metadata.")
        return None

    # Ensure the diagnosis column is string type before mapping
    metadata_df[diagnosis_col] = metadata_df[diagnosis_col].astype(str).str.strip()
    
    # Apply the mapping
    encoded_labels = metadata_df[diagnosis_col].map(label_mapping)

    # Check for unmapped labels (NaNs after mapping)
    if encoded_labels.isnull().any():
        unmapped_values = metadata_df[encoded_labels.isnull()][diagnosis_col].unique()
        logger.warning(f"Unmapped diagnosis values found: {unmapped_values}. These subjects will have NaN labels.")
        # Decide how to handle: drop these subjects, assign a default label, or raise error.
        # For now, they will remain NaN and might be dropped later if using StratifiedKFold.
        # It's often better to ensure all labels are covered by the mapping or explicitly handled.
        # For StratifiedKFold, NaNs in labels will cause errors.
        # Let's filter out subjects with unmappable labels for robust splitting.
        valid_indices = ~encoded_labels.isnull()
        if not valid_indices.all():
            logger.warning(f"Dropping {np.sum(~valid_indices)} subjects due to unmappable diagnostic labels.")
            # This means the global_tensor and metadata_df need to be filtered accordingly *before* splitting.
            # This filtering should ideally happen in the main function after load_and_align_data.
            # For now, this function will return labels with NaNs, and the main logic must handle it.


    logger.info(f"Label encoding complete. Value counts:\n{encoded_labels.value_counts(dropna=False)}")
    return encoded_labels.to_numpy()


def normalize_connectivity_tensors_per_fold(
    X_train: np.ndarray, X_val: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Normalizes connectivity tensors (Z-score) per channel.
    Statistics (mean, std) are computed ONLY from X_train and applied to both X_train and X_val.

    Args:
        X_train (np.ndarray): Training data tensors (N_train, N_channels, N_ROIs, N_ROIs).
        X_val (np.ndarray): Validation data tensors (N_val, N_channels, N_ROIs, N_ROIs).

    Returns:
        Tuple[np.ndarray, np.ndarray]: Normalized X_train_norm, X_val_norm.
    """
    logger.info("Normalizing tensors per fold (Z-score per channel, stats from training set)...")
    N_channels = X_train.shape[1]
    X_train_norm = np.zeros_like(X_train, dtype=np.float32)
    X_val_norm = np.zeros_like(X_val, dtype=np.float32)
    epsilon = 1e-8 # To prevent division by zero if std is too small

    for ch_idx in range(N_channels):
        logger.debug(f"Normalizing channel {ch_idx}...")
        # Calculate mean and std from the training data for the current channel
        train_channel_data = X_train[:, ch_idx, :, :]
        mean_ch = np.mean(train_channel_data)
        std_ch = np.std(train_channel_data)
        logger.debug(f"Channel {ch_idx} - Train Mean: {mean_ch:.4f}, Train Std: {std_ch:.4f}")

        if std_ch < epsilon:
            logger.warning(f"Channel {ch_idx} has std deviation close to zero ({std_ch:.2e}) in training data. Normalization might be unstable or result in NaNs/Infs. Using epsilon for division.")
            std_ch_safe = epsilon # Avoid division by zero
        else:
            std_ch_safe = std_ch

        # Normalize training data for the current channel
        X_train_norm[:, ch_idx, :, :] = (train_channel_data - mean_ch) / std_ch_safe
        
        # Normalize validation data for the current channel using training statistics
        val_channel_data = X_val[:, ch_idx, :, :]
        X_val_norm[:, ch_idx, :, :] = (val_channel_data - mean_ch) / std_ch_safe
        
        # Verify normalization on training data (should be approx mean=0, std=1)
        logger.debug(f"Channel {ch_idx} - Normalized Train Mean: {np.mean(X_train_norm[:, ch_idx, :, :]):.4f}, Std: {np.std(X_train_norm[:, ch_idx, :, :]):.4f}")
        # Validation data will not necessarily have mean=0, std=1 after this process
        logger.debug(f"Channel {ch_idx} - Normalized Val Mean: {np.mean(X_val_norm[:, ch_idx, :, :]):.4f}, Std: {np.std(X_val_norm[:, ch_idx, :, :]):.4f}")


    logger.info("Tensor normalization complete for this fold.")
    return X_train_norm, X_val_norm

# --- Main Preprocessing Script ---
def main_preprocess():
    """
    Orchestrates the entire preprocessing pipeline:
    1. Load data and metadata.
    2. Encode labels.
    3. Perform stratified K-fold splitting.
    4. For each fold:
        a. Normalize tensors.
        b. Save preprocessed data for PyTorch.
    """
    logger.info("--- Starting Connectivity Tensor Preprocessing for Deep Learning ---")

    # Create base output directory if it doesn't exist
    BASE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    logger.info(f"Output directory: {BASE_OUTPUT_DIR}")

    # 1. Load and Align Data
    global_tensor, metadata_df, aligned_subject_ids = load_and_align_data(TENSOR_DATA_PATH, METADATA_CSV_PATH, SUBJECT_ID_COLUMN)
    if global_tensor is None or metadata_df is None:
        logger.critical("Failed to load or align data. Aborting.")
        return

    # 2. Encode Labels
    raw_labels = encode_labels(metadata_df, DIAGNOSIS_COLUMN, LABEL_MAPPING)
    if raw_labels is None:
        logger.critical("Failed to encode labels. Aborting.")
        return

    # Filter out subjects with NaN labels (unmapped diagnoses) from tensor and metadata
    valid_label_indices = ~np.isnan(raw_labels)
    if not np.all(valid_label_indices):
        num_dropped = np.sum(~valid_label_indices)
        logger.warning(f"Dropping {num_dropped} subjects due to missing/unmappable labels before splitting.")
        global_tensor = global_tensor[valid_label_indices]
        metadata_df = metadata_df[valid_label_indices].reset_index(drop=True)
        encoded_labels = raw_labels[valid_label_indices].astype(int) # Now safe to convert to int
        final_subject_ids = aligned_subject_ids[valid_label_indices]
        logger.info(f"Data filtered. New tensor shape: {global_tensor.shape}, New metadata shape: {metadata_df.shape}")
    else:
        encoded_labels = raw_labels.astype(int)
        final_subject_ids = aligned_subject_ids

    if len(np.unique(encoded_labels)) < 2:
        logger.critical(f"Not enough classes after label encoding ({len(np.unique(encoded_labels))} found). Need at least 2 for stratified splitting. Aborting.")
        return
    
    # Extract Age and Sex for potential secondary stratification or analysis
    # Ensure these columns exist and handle missing values if necessary
    ages = metadata_df[AGE_COLUMN].values if AGE_COLUMN in metadata_df else None
    sexes = metadata_df[SEX_COLUMN].values if SEX_COLUMN in metadata_df else None
    
    if ages is None: logger.warning(f"Age column '{AGE_COLUMN}' not found in metadata. Age-based stratification/analysis will not be possible.")
    if sexes is None: logger.warning(f"Sex column '{SEX_COLUMN}' not found in metadata. Sex-based stratification/analysis will not be possible.")


    # 3. Stratified K-Fold Splitting
    logger.info(f"Performing Stratified {N_FOLDS}-Fold Cross-Validation. Random state: {RANDOM_STATE}")
    skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_STATE)

    # The `split` method yields (train_indices, val_indices)
    for fold_idx, (train_indices, val_indices) in enumerate(skf.split(global_tensor, encoded_labels)):
        logger.info(f"--- Processing Fold {fold_idx + 1}/{N_FOLDS} ---")
        
        X_train, X_val = global_tensor[train_indices], global_tensor[val_indices]
        y_train, y_val = encoded_labels[train_indices], encoded_labels[val_indices]
        
        train_sids = final_subject_ids[train_indices]
        val_sids = final_subject_ids[val_indices]

        logger.info(f"Train set: {len(X_train)} samples, Val set: {len(X_val)} samples")
        logger.info(f"Train labels distribution: {np.bincount(y_train)}")
        logger.info(f"Val labels distribution: {np.bincount(y_val)}")

        # Log distribution of Age and Sex in this fold (if available)
        if ages is not None:
            logger.info(f"Fold {fold_idx+1} Train Age Stats: Mean={np.mean(ages[train_indices]):.2f}, Std={np.std(ages[train_indices]):.2f}")
            logger.info(f"Fold {fold_idx+1} Val Age Stats:   Mean={np.mean(ages[val_indices]):.2f}, Std={np.std(ages[val_indices]):.2f}")
        if sexes is not None:
            train_sex_counts = pd.Series(sexes[train_indices]).value_counts().to_dict()
            val_sex_counts = pd.Series(sexes[val_indices]).value_counts().to_dict()
            logger.info(f"Fold {fold_idx+1} Train Sex Dist: {train_sex_counts}")
            logger.info(f"Fold {fold_idx+1} Val Sex Dist:   {val_sex_counts}")
        # NOTE: For more rigorous multi-label stratification (diagnosis, age, sex),
        # one might need iterative stratification techniques or create composite labels.
        # This script primarily stratifies by diagnosis. The logs above help verify secondary balance.

        # 4. Normalization (per fold, stats from train set)
        X_train_norm, X_val_norm = normalize_connectivity_tensors_per_fold(X_train, X_val)

        # Convert to PyTorch Tensors
        X_train_tensor = torch.from_numpy(X_train_norm).float()
        y_train_tensor = torch.from_numpy(y_train).long() # long for CrossEntropyLoss
        X_val_tensor = torch.from_numpy(X_val_norm).float()
        y_val_tensor = torch.from_numpy(y_val).long()

        # 5. Save Preprocessed Data for PyTorch
        fold_data = {
            'X_train': X_train_tensor,
            'y_train': y_train_tensor,
            'train_subject_ids': train_sids,
            'X_val': X_val_tensor,
            'y_val': y_val_tensor,
            'val_subject_ids': val_sids
        }
        
        fold_output_dir = BASE_OUTPUT_DIR / f"fold_{fold_idx}"
        fold_output_dir.mkdir(parents=True, exist_ok=True)
        
        output_file_path = fold_output_dir / f"fold_{fold_idx}_preprocessed_data.pt"
        try:
            torch.save(fold_data, output_file_path)
            logger.info(f"Saved preprocessed data for fold {fold_idx + 1} to: {output_file_path}")
        except Exception as e:
            logger.error(f"Error saving data for fold {fold_idx + 1}: {e}", exc_info=True)

        # Clean up to save memory before next fold
        del X_train, X_val, y_train, y_val, X_train_norm, X_val_norm
        del X_train_tensor, y_train_tensor, X_val_tensor, y_val_tensor, fold_data
        gc.collect()

    logger.info("--- All folds processed and data saved. ---")
    logger.info(f"Preprocessed data for each fold saved in subdirectories under: {BASE_OUTPUT_DIR}")

if __name__ == '__main__':
    # Ensure PyTorch is available
    try:
        import torch
        logger.info(f"PyTorch version: {torch.__version__}")
    except ImportError:
        logger.critical("PyTorch is not installed. Please install PyTorch to run this script.")
        exit()
        
    main_preprocess()

2025-05-25 21:02:34 - INFO - 3328762466 - PyTorch version: 2.4.0
2025-05-25 21:02:34 - INFO - 3328762466 - --- Starting Connectivity Tensor Preprocessing for Deep Learning ---
2025-05-25 21:02:34 - INFO - 3328762466 - Output directory: preprocessed_connectomes_for_dl
2025-05-25 21:02:34 - INFO - 3328762466 - Loading global tensor from: /home/diego/Escritorio/desde_cero/AAL_116/AAL_116_fmri_tensor_1lag_NeuroEnhanced_v4/GLOBAL_TENSOR_AAL116_352subs_4ch_1lag.npz
2025-05-25 21:02:34 - INFO - 3328762466 - Global tensor loaded. Shape: (352, 4, 116, 116), Subjects in tensor: 352
2025-05-25 21:02:34 - INFO - 3328762466 - Loading metadata from: /home/diego/Escritorio/desde_cero/AAL_116/DataBaseSubjects.csv
2025-05-25 21:02:34 - INFO - 3328762466 - Metadata loaded. Shape: (352, 24), Unique subjects in metadata: 352
2025-05-25 21:02:34 - INFO - 3328762466 - Aligning tensor data with metadata...
2025-05-25 21:02:34 - INFO - 3328762466 - Data alignment complete. Final aligned subjects: 352
2025-05-