In [1]:
!pip uninstall networkx -y
!pip install networkx==2.6.3

Found existing installation: networkx 2.6.3
Uninstalling networkx-2.6.3:
  Successfully uninstalled networkx-2.6.3
Collecting networkx==2.6.3
  Using cached networkx-2.6.3-py3-none-any.whl.metadata (5.0 kB)
Using cached networkx-2.6.3-py3-none-any.whl (1.9 MB)
Installing collected packages: networkx
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
scikit-image 0.20.0 requires networkx>=2.8, but you have networkx 2.6.3 which is incompatible.[0m[31m
[0mSuccessfully installed networkx-2.6.3


In [2]:
!pip install dyconnmap



In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Utilitario para Jupyter Notebook: Extracción de Características de Conectividad fMRI
Versión: v6.5.9_ParallelTuned (Adaptado de v6.5.8)

Cambios Principales:
- Optimizada la lógica de paralelización interna para MI y Granger (n_jobs_mi, n_jobs_granger)
  para un mejor uso de los cores disponibles en función de MAX_WORKERS.
- Mantenido el reemplazo de ElasticNet VAR por Causalidad de Granger.
- Mantenidas las optimizaciones previas.

Requisitos Clave:
- dyconnmap >= 1.0.4
- networkx == 2.6.3 (debido a dyconnmap)
- scikit-learn >= 1.0 
- nilearn >= 0.9 
- statsmodels (para Causalidad de Granger)
- numpy, pandas, scipy, tqdm, psutil, joblib
"""
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.feature_selection import mutual_info_regression
from nilearn.glm.first_level import spm_hrf, glover_hrf 
from scipy.signal import butter, filtfilt, deconvolve, windows
from scipy.interpolate import interp1d
from tqdm import tqdm
import os
import scipy.io as sio
from pathlib import Path
import psutil
import gc
import logging
import time
from typing import List, Tuple, Dict, Optional, Any
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed
from statsmodels.tsa.stattools import grangercausalitytests 
import networkx as nx 
import warnings 
from sklearn.exceptions import ConvergenceWarning 
from joblib import Parallel, delayed 

# --- Configuración del Logger ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
logger = logging.getLogger(__name__)

# --- Importación de OMST usando dyconnmap ---
OMST_PYTHON_LOADED = False
orthogonal_minimum_spanning_tree = None
PEARSON_OMST_CHANNEL_NAME_PRIMARY = "Pearson_OMST_GCE_Signed_Weighted" 
PEARSON_OMST_FALLBACK_NAME = "Pearson_Full_FisherZ_Signed" 
PEARSON_OMST_CHANNEL_NAME = PEARSON_OMST_FALLBACK_NAME 

try:
    from dyconnmap.graphs.threshold import threshold_omst_global_cost_efficiency
    orthogonal_minimum_spanning_tree = threshold_omst_global_cost_efficiency 
    logger.info("Successfully imported 'threshold_omst_global_cost_efficiency' from 'dyconnmap.graphs.threshold' and aliased as 'orthogonal_minimum_spanning_tree'.")
    OMST_PYTHON_LOADED = True
    PEARSON_OMST_CHANNEL_NAME = PEARSON_OMST_CHANNEL_NAME_PRIMARY 
except ImportError:
    logger.error("ERROR: Dyconnmap module or 'threshold_omst_global_cost_efficiency' not found. "
                 f"Channel '{PEARSON_OMST_FALLBACK_NAME}' will be used as fallback. "
                 "Please ensure dyconnmap is installed: pip install dyconnmap")
except Exception as e_import: 
    logger.error(f"ERROR during dyconnmap import: {e_import}. "
                 f"Channel '{PEARSON_OMST_FALLBACK_NAME}' will be used as fallback.")


# --- 0. Global Configuration and Constants ---

# --- Configurable Parameters ---
BASE_PATH_AAL3 = Path('/home/diego/Escritorio/AAL3') 
QC_OUTPUT_DIR = BASE_PATH_AAL3 / 'qc_outputs_doctoral_v3.2_aal3_shrinkage_flexible_thresh_fix'
SUBJECT_METADATA_CSV_PATH_QC = BASE_PATH_AAL3 / 'SubjctsDataAndTests_Schaefer2018_400Parcels_17Networks.csv' 
SUBJECT_METADATA_CSV_PATH = BASE_PATH_AAL3 / 'SubjectsData_Schaefer2018_400ROIs.csv' 
QC_REPORT_CSV_PATH = QC_OUTPUT_DIR / 'report_qc_final_with_discard_flags_v3.2.csv'
ROI_SIGNALS_DIR_PATH_AAL3 = BASE_PATH_AAL3 / 'ROISignals_AAL3_NiftiPreprocessedAllBatchesNorm'
ROI_FILENAME_TEMPLATE = 'ROISignals_{subject_id}.mat'
AAL3_META_PATH = BASE_PATH_AAL3 / 'ROI_MNI_V7_vol.txt' 

TR_SECONDS = 3.0 
LOW_CUT_HZ = 0.01
HIGH_CUT_HZ = 0.08
FILTER_ORDER = 2 
TAPER_ALPHA = 0.1 

RAW_DATA_EXPECTED_COLUMNS = 170 
AAL3_MISSING_INDICES_1BASED = [35, 36, 81, 82] 
EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL = RAW_DATA_EXPECTED_COLUMNS - len(AAL3_MISSING_INDICES_1BASED) 
SMALL_ROI_VOXEL_THRESHOLD = 100 

N_ROIS_EXPECTED = 131 
TARGET_LEN_TS = 140 

N_NEIGHBORS_MI = 5 
DFC_WIN_POINTS = 30 
DFC_STEP = 5      
APPLY_HRF_DECONVOLUTION = False 
HRF_MODEL = 'glover' 

# Parámetros para Causalidad de Granger
USE_GRANGER_CHANNEL = True
GRANGER_MAX_LAG = 1 

deconv_str = "_deconv" if APPLY_HRF_DECONVOLUTION else ""
granger_suffix = f"GrangerLag{GRANGER_MAX_LAG}" if USE_GRANGER_CHANNEL else "NoEffConn"
OUTPUT_CONNECTIVITY_DIR_NAME = f"AAL3_dynamicROIs_fmri_tensor_NeuroEnhanced_v6.5.9_{granger_suffix}" # Versión actualizada

POSSIBLE_ROI_KEYS = ["signals", "ROISignals", "roi_signals", "ROIsignals_AAL3", "AAL3_signals", "roi_ts"] 

USE_PEARSON_OMST_CHANNEL = True 
USE_PEARSON_FULL_SIGNED_CHANNEL = True 
USE_MI_CHANNEL_FOR_THESIS = True 
USE_DFC_ABS_DIFF_MEAN_CHANNEL = True 
USE_DFC_STDDEV_CHANNEL = True 

CONNECTIVITY_CHANNEL_NAMES: List[str] = [] 
N_CHANNELS = 0 

# --- Definición Global de MAX_WORKERS para ProcessPoolExecutor ---
try:
    TOTAL_CPU_CORES = multiprocessing.cpu_count()
    MAX_WORKERS = max(1, TOTAL_CPU_CORES // 2 if TOTAL_CPU_CORES > 2 else 1)
except NotImplementedError:
    logger.warning("multiprocessing.cpu_count() no está implementado en esta plataforma. Usando MAX_WORKERS = 1.")
    TOTAL_CPU_CORES = 1
    MAX_WORKERS = 1
logger.info(f"Global MAX_WORKERS for ProcessPoolExecutor set to: {MAX_WORKERS} (based on {TOTAL_CPU_CORES} total cores)")


# --- Global AAL3 ROI Processing Variables ---
VALID_AAL3_ROI_INFO_DF_166: Optional[pd.DataFrame] = None
AAL3_MISSING_INDICES_0BASED: Optional[List[int]] = None
INDICES_OF_SMALL_ROIS_TO_DROP_FROM_166: Optional[List[int]] = None
FINAL_N_ROIS_EXPECTED: Optional[int] = None 

def _initialize_aal3_roi_processing_info():
    global VALID_AAL3_ROI_INFO_DF_166, AAL3_MISSING_INDICES_0BASED, \
           INDICES_OF_SMALL_ROIS_TO_DROP_FROM_166, FINAL_N_ROIS_EXPECTED, \
           N_ROIS_EXPECTED, OUTPUT_CONNECTIVITY_DIR_NAME, CONNECTIVITY_CHANNEL_NAMES, N_CHANNELS, \
           PEARSON_OMST_CHANNEL_NAME, granger_suffix 

    logger.info("--- Initializing AAL3 ROI Processing Information ---")
    
    omst_suffix_for_dir = "OMST_GCE_Signed" if OMST_PYTHON_LOADED and orthogonal_minimum_spanning_tree is not None and USE_PEARSON_OMST_CHANNEL else "PearsonFullSigned"
    current_pearson_channel_to_use_as_base = PEARSON_OMST_CHANNEL_NAME_PRIMARY if OMST_PYTHON_LOADED and USE_PEARSON_OMST_CHANNEL else PEARSON_OMST_FALLBACK_NAME
    granger_suffix = f"GrangerLag{GRANGER_MAX_LAG}" if USE_GRANGER_CHANNEL else "NoEffConn" # Asegurar que se usa el global

    if not AAL3_META_PATH.exists():
        logger.error(f"AAL3 metadata file NOT found: {AAL3_META_PATH}. Cannot perform ROI reduction. "
                     f"Using placeholder N_ROIS_EXPECTED = {N_ROIS_EXPECTED}.")
        FINAL_N_ROIS_EXPECTED = N_ROIS_EXPECTED 
        OUTPUT_CONNECTIVITY_DIR_NAME = f"AAL3_{N_ROIS_EXPECTED}ROIs_fmri_tensor_{omst_suffix_for_dir}_{granger_suffix}{deconv_str}_NeuroEnhanced_v6.5.9_ERR"
    else:
        try:
            meta_aal3_df = pd.read_csv(AAL3_META_PATH, sep='\t')
            meta_aal3_df['color'] = pd.to_numeric(meta_aal3_df['color'], errors='coerce')
            meta_aal3_df.dropna(subset=['color'], inplace=True)
            meta_aal3_df['color'] = meta_aal3_df['color'].astype(int)
            
            if not all(col in meta_aal3_df.columns for col in ['nom_c', 'color', 'vol_vox']):
                raise ValueError("AAL3 metadata must contain 'nom_c', 'color', 'vol_vox'.")

            AAL3_MISSING_INDICES_0BASED = [idx - 1 for idx in AAL3_MISSING_INDICES_1BASED]
            VALID_AAL3_ROI_INFO_DF_166 = meta_aal3_df[~meta_aal3_df['color'].isin(AAL3_MISSING_INDICES_1BASED)].copy()
            VALID_AAL3_ROI_INFO_DF_166.sort_values(by='color', inplace=True)
            VALID_AAL3_ROI_INFO_DF_166.reset_index(drop=True, inplace=True)

            if len(VALID_AAL3_ROI_INFO_DF_166) != EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL:
                logger.warning(f"Expected {EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL} ROIs in AAL3 meta after filtering known missing, "
                               f"but found {len(VALID_AAL3_ROI_INFO_DF_166)}. Check AAL3_META_PATH content and AAL3_MISSING_INDICES_1BASED.")
            
            small_rois_mask_on_166 = VALID_AAL3_ROI_INFO_DF_166['vol_vox'] < SMALL_ROI_VOXEL_THRESHOLD
            INDICES_OF_SMALL_ROIS_TO_DROP_FROM_166 = VALID_AAL3_ROI_INFO_DF_166[small_rois_mask_on_166].index.tolist()
            
            FINAL_N_ROIS_EXPECTED = EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL - len(INDICES_OF_SMALL_ROIS_TO_DROP_FROM_166)
            N_ROIS_EXPECTED = FINAL_N_ROIS_EXPECTED 
            
            OUTPUT_CONNECTIVITY_DIR_NAME = f"AAL3_{N_ROIS_EXPECTED}ROIs_fmri_tensor_{omst_suffix_for_dir}_{granger_suffix}{deconv_str}_NeuroEnhanced_v6.5.9_ParallelTuned"

            logger.info(f"AAL3 ROI processing info initialized:")
            logger.info(f"  Indices of 4 AAL3 systemically missing ROIs (0-based, from 170): {AAL3_MISSING_INDICES_0BASED}")
            logger.info(f"  Number of ROIs in AAL3 meta after excluding systemically missing: {len(VALID_AAL3_ROI_INFO_DF_166)} (Expected: {EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL})")
            logger.info(f"  Indices of small ROIs to drop (from the {len(VALID_AAL3_ROI_INFO_DF_166)} set, 0-based): {INDICES_OF_SMALL_ROIS_TO_DROP_FROM_166}")
            logger.info(f"  Number of small ROIs to drop: {len(INDICES_OF_SMALL_ROIS_TO_DROP_FROM_166)}")
            logger.info(f"  FINAL_N_ROIS_EXPECTED for connectivity analysis: {FINAL_N_ROIS_EXPECTED} (This should be 131 if matching QC script)")
        except Exception as e:
            logger.error(f"Error initializing AAL3 ROI processing info: {e}", exc_info=True)
            FINAL_N_ROIS_EXPECTED = N_ROIS_EXPECTED 
            OUTPUT_CONNECTIVITY_DIR_NAME = f"AAL3_{N_ROIS_EXPECTED}ROIs_fmri_tensor_{omst_suffix_for_dir}_{granger_suffix}{deconv_str}_NeuroEnhanced_v6.5.9_ERROR_INIT"
            
    temp_channels = []
    if USE_PEARSON_OMST_CHANNEL:
        temp_channels.append(current_pearson_channel_to_use_as_base)
        if not (OMST_PYTHON_LOADED and orthogonal_minimum_spanning_tree is not None) and current_pearson_channel_to_use_as_base == PEARSON_OMST_CHANNEL_NAME_PRIMARY:
            logger.warning(f"OMST function from dyconnmap not loaded or is None, but primary OMST channel name was set. "
                           f"The channel '{PEARSON_OMST_CHANNEL_NAME_PRIMARY}' will effectively be '{PEARSON_OMST_FALLBACK_NAME}'.")
        elif not (OMST_PYTHON_LOADED and orthogonal_minimum_spanning_tree is not None):
            logger.info(f"OMST function from dyconnmap not loaded or is None. Using '{PEARSON_OMST_FALLBACK_NAME}' for the Pearson-based channel.")
    
    if USE_PEARSON_FULL_SIGNED_CHANNEL and current_pearson_channel_to_use_as_base != PEARSON_OMST_FALLBACK_NAME : 
        temp_channels.append(PEARSON_OMST_FALLBACK_NAME) 

    if USE_MI_CHANNEL_FOR_THESIS: temp_channels.append("MI_KNN_Symmetric")
    if USE_DFC_ABS_DIFF_MEAN_CHANNEL: temp_channels.append("dFC_AbsDiffMean")
    if USE_DFC_STDDEV_CHANNEL: temp_channels.append("dFC_StdDev") 
    
    if USE_GRANGER_CHANNEL: 
        granger_channel_name = f"Granger_F_lag{GRANGER_MAX_LAG}"
        # El sufijo _Influence o _Deconv_Influence no aplica directamente a Granger F-statistic
        # Si se quisiera mantener consistencia con el sufijo de deconvolución para el nombre del canal:
        # var_suffix = "_Deconv" if APPLY_HRF_DECONVOLUTION else "" 
        # granger_channel_name += var_suffix
        # Por ahora, se mantiene simple:
        temp_channels.append(granger_channel_name)
    
    CONNECTIVITY_CHANNEL_NAMES = list(dict.fromkeys(temp_channels)) 
    N_CHANNELS = len(CONNECTIVITY_CHANNEL_NAMES)
    return True


if not _initialize_aal3_roi_processing_info():
    logger.warning("ROI processing info could not be initialized properly. Pipeline may use placeholder values or fail for some operations.")

logger.info(f"Final N_ROIS_EXPECTED after initialization: {N_ROIS_EXPECTED}")
logger.info(f"Final OUTPUT_CONNECTIVITY_DIR_NAME: {OUTPUT_CONNECTIVITY_DIR_NAME}")
logger.info(f"Connectivity channels to be computed: {CONNECTIVITY_CHANNEL_NAMES}") 
logger.info(f"Total number of channels (for VAE): {N_CHANNELS}")

if str(SUBJECT_METADATA_CSV_PATH) != str(SUBJECT_METADATA_CSV_PATH_QC):
    logger.warning(f"El path del CSV de metadatos de sujetos en este script ({SUBJECT_METADATA_CSV_PATH.name}) "
                   f"difiere del usado en el script de QC ({SUBJECT_METADATA_CSV_PATH_QC.name}). "
                   "Asegúrate de que esto sea intencional y que los IDs de sujeto sean consistentes.")

# --- 1. Subject Metadata Loading and Merging ---
def load_metadata(
    subject_meta_csv_path: Path,
    qc_report_csv_path: Path
) -> Optional[pd.DataFrame]:
    logger.info("--- Starting Subject Metadata Loading and QC Integration ---")
    try:
        if not subject_meta_csv_path.exists():
            logger.critical(f"Subject metadata CSV file NOT found: {subject_meta_csv_path}")
            return None
        if not qc_report_csv_path.exists():
            logger.critical(f"QC report CSV file NOT found: {qc_report_csv_path}")
            return None

        subjects_db_df = pd.read_csv(subject_meta_csv_path)
        subjects_db_df['SubjectID'] = subjects_db_df['SubjectID'].astype(str).str.strip() 
        logger.info(f"Loaded main metadata from {subject_meta_csv_path}. Shape: {subjects_db_df.shape}")
        if 'SubjectID' not in subjects_db_df.columns:
            logger.critical("Column 'SubjectID' missing in main metadata CSV.")
            return None
        if 'ResearchGroup' not in subjects_db_df.columns:
            logger.warning("Column 'ResearchGroup' missing in main metadata CSV. May be needed for downstream VAE tasks.")

        qc_df = pd.read_csv(qc_report_csv_path)
        logger.info(f"Loaded QC report from {qc_report_csv_path}. Shape: {qc_df.shape}")

        if 'Subject' in qc_df.columns and 'SubjectID' not in qc_df.columns:
            logger.info("Found 'Subject' column in QC report, renaming to 'SubjectID'.")
            qc_df.rename(columns={'Subject': 'SubjectID'}, inplace=True)
        
        if 'SubjectID' in qc_df.columns:
            qc_df['SubjectID'] = qc_df['SubjectID'].astype(str).str.strip()
        else:
            logger.critical("Neither 'Subject' nor 'SubjectID' column found in QC report CSV.")
            return None
        
        essential_qc_cols = ['SubjectID', 'ToDiscard_Overall', 'TimePoints']
        if not all(col in qc_df.columns for col in essential_qc_cols):
            logger.critical(f"Essential columns ({essential_qc_cols}) missing in QC report CSV.")
            return None

        merged_df = pd.merge(subjects_db_df, qc_df, on='SubjectID', how='inner', suffixes=('_meta', '_qc'))
        
        if 'TimePoints_qc' in merged_df.columns: 
            merged_df['Timepoints_final_for_script'] = merged_df['TimePoints_qc']
        elif 'TimePoints' in merged_df.columns: 
             merged_df['Timepoints_final_for_script'] = merged_df['TimePoints']
        else: 
            logger.critical("Definitive 'TimePoints' column from QC report could not be identified after merge.")
            return None
        
        merged_df['Timepoints_final_for_script'] = pd.to_numeric(merged_df['Timepoints_final_for_script'], errors='coerce').fillna(0).astype(int)

        initial_subject_count = len(merged_df)
        subjects_passing_qc_df = merged_df[merged_df['ToDiscard_Overall'] == False].copy()
        num_discarded = initial_subject_count - len(subjects_passing_qc_df)
        
        logger.info(f"Total subjects after merge: {initial_subject_count}")
        logger.info(f"Subjects discarded based on QC ('ToDiscard_Overall' == True): {num_discarded}")
        logger.info(f"Subjects passing QC and to be processed: {len(subjects_passing_qc_df)}")

        if subjects_passing_qc_df.empty:
            logger.warning("No subjects passed QC. Check your QC criteria and report.")
            return None
            
        min_tp_after_qc = subjects_passing_qc_df['Timepoints_final_for_script'].min()
        max_tp_after_qc = subjects_passing_qc_df['Timepoints_final_for_script'].max()
        logger.info(f"Timepoints for subjects passing QC (from QC report): Min={min_tp_after_qc}, Max={max_tp_after_qc}.")
        logger.info(f"These will be homogenized to TARGET_LEN_TS = {TARGET_LEN_TS} for connectivity calculation.")

        final_cols_to_keep = ['SubjectID']
        subjects_passing_qc_df.rename(columns={'Timepoints_final_for_script': 'Timepoints'}, inplace=True)
        final_cols_to_keep.append('Timepoints')

        if 'ResearchGroup_meta' in subjects_passing_qc_df.columns: 
            subjects_passing_qc_df.rename(columns={'ResearchGroup_meta': 'ResearchGroup'}, inplace=True)
        elif 'ResearchGroup_qc' in subjects_passing_qc_df.columns and 'ResearchGroup' not in subjects_passing_qc_df.columns:
            subjects_passing_qc_df.rename(columns={'ResearchGroup_qc': 'ResearchGroup'}, inplace=True)
        
        if 'ResearchGroup' in subjects_passing_qc_df.columns:
             final_cols_to_keep.append('ResearchGroup')
        else:
            logger.warning("Creating placeholder 'ResearchGroup' column as it was not found. This is important for classification.")
            subjects_passing_qc_df['ResearchGroup'] = 'Unknown' 
            final_cols_to_keep.append('ResearchGroup')
        
        final_cols_to_keep = list(dict.fromkeys(final_cols_to_keep)) 
        return subjects_passing_qc_df[final_cols_to_keep]

    except FileNotFoundError as e:
        logger.critical(f"CRITICAL Error loading CSV files: {e}")
        return None
    except ValueError as e:
        logger.critical(f"Value error in metadata processing: {e}")
        return None
    except Exception as e:
        logger.critical(f"Unexpected error during metadata loading/QC integration: {e}", exc_info=True)
        return None

# --- 2. Time Series Loading and Preprocessing Functions ---
def _load_signals_from_mat(mat_path: Path, possible_keys: List[str]) -> Optional[np.ndarray]:
    try:
        data = sio.loadmat(mat_path)
    except Exception as e_load:
        logger.error(f"Could not load .mat file: {mat_path}. Error: {e_load}")
        return None
    
    for key in possible_keys:
        if key in data and isinstance(data[key], np.ndarray) and data[key].ndim >= 2:
            logger.debug(f"Found signals under key '{key}' in {mat_path.name}. Shape: {data[key].shape}")
            return data[key].astype(np.float64) 
            
    logger.warning(f"No valid signal keys {possible_keys} found in {mat_path.name}. Keys present: {list(data.keys())}")
    return None

def _orient_and_reduce_rois(
    raw_sigs: np.ndarray, 
    subject_id: str,
    initial_expected_cols: int, 
    aal3_missing_0based: Optional[List[int]], 
    small_rois_indices_from_166: Optional[List[int]], 
    final_expected_rois: Optional[int] 
) -> Optional[np.ndarray]:
    if raw_sigs.ndim != 2:
        logger.warning(f"S {subject_id}: Raw signal matrix has incorrect dimensions {raw_sigs.ndim} (expected 2). Skipping.")
        return None
    
    oriented_sigs = raw_sigs.copy()
    if oriented_sigs.shape[0] == initial_expected_cols and oriented_sigs.shape[1] != initial_expected_cols:
        logger.info(f"S {subject_id}: Transposing raw matrix from {oriented_sigs.shape} to ({oriented_sigs.shape[1]}, {oriented_sigs.shape[0]}) to match (TPs, ROIs_initial).")
        oriented_sigs = oriented_sigs.T
    elif oriented_sigs.shape[1] == initial_expected_cols and oriented_sigs.shape[0] != initial_expected_cols:
        logger.debug(f"S {subject_id}: Raw matrix already (TPs, ROIs_initial): {oriented_sigs.shape}.") 
    elif oriented_sigs.shape[0] == initial_expected_cols and oriented_sigs.shape[1] == initial_expected_cols:
         logger.warning(f"S {subject_id}: Raw signal matrix is square ({oriented_sigs.shape}) and matches initial_expected_cols. Assuming [Timepoints, ROIs_initial]. Careful if TPs also equals initial_expected_cols.")
    else: 
        logger.warning(f"S {subject_id}: Neither dimension of raw signal matrix ({oriented_sigs.shape}) matches initial_expected_cols ({initial_expected_cols}). Skipping.")
        return None

    if oriented_sigs.shape[1] != initial_expected_cols: 
        logger.warning(f"S {subject_id}: After orientation, raw ROI count ({oriented_sigs.shape[1]}) != initial_expected_cols ({initial_expected_cols}). Skipping.")
        return None
    
    if aal3_missing_0based is None:
        logger.warning(f"S {subject_id}: AAL3 missing ROI indices (0-based) not available. Skipping AAL3 known missing ROI removal. Using {oriented_sigs.shape[1]} ROIs for next step.")
        sigs_after_known_missing_removed = oriented_sigs 
    else:
        try:
            sigs_after_known_missing_removed = np.delete(oriented_sigs, aal3_missing_0based, axis=1)
            logger.info(f"S {subject_id}: Removed {len(aal3_missing_0based)} known missing AAL3 ROIs. Shape {oriented_sigs.shape} -> {sigs_after_known_missing_removed.shape}")
            if sigs_after_known_missing_removed.shape[1] != EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL:
                 logger.warning(f"S {subject_id}: After removing known missing ROIs, shape is {sigs_after_known_missing_removed.shape}, but expected (..., {EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL}).")
        except IndexError as e:
            logger.error(f"S {subject_id}: IndexError removing known missing AAL3 ROIs (indices: {aal3_missing_0based}) from matrix of shape {oriented_sigs.shape}. Error: {e}. Using original {oriented_sigs.shape[1]} ROIs for next step.")
            sigs_after_known_missing_removed = oriented_sigs 
            
    if small_rois_indices_from_166 is None:
        logger.warning(f"S {subject_id}: Small ROI indices (from 166-set) not available. Skipping small ROI removal. Using {sigs_after_known_missing_removed.shape[1]} ROIs.")
        sigs_final_rois = sigs_after_known_missing_removed
    elif sigs_after_known_missing_removed.shape[1] != EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL:
        logger.warning(f"S {subject_id}: Cannot remove small ROIs because the matrix (shape {sigs_after_known_missing_removed.shape}) does not have the expected {EXPECTED_ROIS_AFTER_AAL3_MISSING_REMOVAL} columns after first reduction step. Using current ROIs ({sigs_after_known_missing_removed.shape[1]}).")
        sigs_final_rois = sigs_after_known_missing_removed
    else:
        try:
            sigs_final_rois = np.delete(sigs_after_known_missing_removed, small_rois_indices_from_166, axis=1)
            logger.info(f"S {subject_id}: Removed {len(small_rois_indices_from_166)} small ROIs. Shape {sigs_after_known_missing_removed.shape} -> {sigs_final_rois.shape}")
        except IndexError as e:
            logger.error(f"S {subject_id}: IndexError removing small ROIs (indices: {small_rois_indices_from_166}) from matrix of shape {sigs_after_known_missing_removed.shape}. Error: {e}. Using {sigs_after_known_missing_removed.shape[1]} ROIs.")
            sigs_final_rois = sigs_after_known_missing_removed 

    if final_expected_rois is not None and sigs_final_rois.shape[1] != final_expected_rois:
        logger.warning(f"S {subject_id}: Final ROI count ({sigs_final_rois.shape[1]}) != FINAL_N_ROIS_EXPECTED ({final_expected_rois}). "
                       "This may indicate issues in AAL3 metadata or reduction logic. Proceeding with current matrix.")
    elif final_expected_rois is None:
        logger.warning(f"S {subject_id}: FINAL_N_ROIS_EXPECTED is None. Cannot validate final ROI count. Proceeding with {sigs_final_rois.shape[1]} ROIs.")
    else:
        logger.info(f"S {subject_id}: Final ROI count {sigs_final_rois.shape[1]} matches FINAL_N_ROIS_EXPECTED {final_expected_rois}.")
        
    return sigs_final_rois


def _bandpass_filter_signals(sigs: np.ndarray, lowcut: float, highcut: float, fs: float, order: int, subject_id: str, taper_alpha: float = 0.1) -> np.ndarray:
    nyquist_freq = 0.5 * fs
    low_norm = lowcut / nyquist_freq
    high_norm = highcut / nyquist_freq

    if not (0 < low_norm < 1 and 0 < high_norm < 1 and low_norm < high_norm):
        logger.error(f"S {subject_id}: Invalid critical frequencies for filter (low_norm={low_norm}, high_norm={high_norm}). Nyquist={nyquist_freq}. Skipping filtering.")
        return sigs
    try:
        b, a = butter(order, [low_norm, high_norm], btype='band', analog=False)
        filtered_sigs = np.zeros_like(sigs)
        padlen_required = 3 * (max(len(a), len(b))) 
        
        for i in range(sigs.shape[1]): 
            roi_signal = sigs[:, i].copy() 
            
            if len(roi_signal) > padlen_required: 
                try:
                    tukey_window = windows.tukey(len(roi_signal), alpha=taper_alpha)
                    roi_signal_tapered = roi_signal * tukey_window
                except Exception as e_taper:
                    logger.warning(f"S {subject_id}, ROI {i}: Error applying Tukey window: {e_taper}. Proceeding without taper.")
                    roi_signal_tapered = roi_signal 
            else:
                roi_signal_tapered = roi_signal 

            if np.all(np.isclose(roi_signal_tapered, roi_signal_tapered[0] if len(roi_signal_tapered)>0 else 0.0)): 
                logger.debug(f"S {subject_id}, ROI {i}: Signal is constant (possibly after taper). Skipping filter.")
                filtered_sigs[:, i] = roi_signal_tapered 
            elif len(roi_signal_tapered) <= padlen_required :
                logger.warning(f"S {subject_id}, ROI {i}: Signal too short ({len(roi_signal_tapered)} pts, need > {padlen_required}) for filtfilt. Skipping filter for this ROI.")
                filtered_sigs[:, i] = roi_signal_tapered
            else:
                filtered_sigs[:, i] = filtfilt(b, a, roi_signal_tapered)
        return filtered_sigs
    except Exception as e:
        logger.error(f"S {subject_id}: Error during bandpass filtering: {e}. Returning original signals.", exc_info=False)
        return sigs

def _hrf_deconvolution(sigs: np.ndarray, tr: float, hrf_model_type: str, subject_id: str) -> np.ndarray:
    logger.info(f"S {subject_id}: Attempting HRF deconvolution (Model: {hrf_model_type}, TR: {tr}s).")
    if hrf_model_type == 'glover': 
        hrf_kernel = glover_hrf(tr, oversampling=1) 
    elif hrf_model_type == 'spm': 
        hrf_kernel = spm_hrf(tr, oversampling=1)
    else: 
        logger.error(f"S {subject_id}: Unknown HRF model type '{hrf_model_type}'. Skipping deconvolution.")
        return sigs

    if len(hrf_kernel) == 0 or np.all(np.isclose(hrf_kernel, 0)):
        logger.error(f"S {subject_id}: HRF kernel is empty or all zeros for model '{hrf_model_type}'. Skipping deconvolution.")
        return sigs

    deconvolved_sigs = np.zeros_like(sigs)
    for i in range(sigs.shape[1]): 
        signal_roi = sigs[:, i]
        if len(signal_roi) < len(hrf_kernel): 
            logger.warning(f"S {subject_id}, ROI {i}: Signal length ({len(signal_roi)}) is shorter than HRF kernel length ({len(hrf_kernel)}). Skipping deconvolution for this ROI.")
            deconvolved_sigs[:, i] = signal_roi
            continue
        try:
            quotient, _ = deconvolve(signal_roi, hrf_kernel)
            if len(quotient) < sigs.shape[0]:
                deconvolved_sigs[:, i] = np.concatenate([quotient, np.zeros(sigs.shape[0] - len(quotient))])
            else: 
                deconvolved_sigs[:, i] = quotient[:sigs.shape[0]]
        except Exception as e_deconv:
            logger.error(f"S {subject_id}, ROI {i}: Deconvolution failed: {e_deconv}. Using original signal for this ROI.", exc_info=False)
            deconvolved_sigs[:, i] = signal_roi
            
    logger.info(f"S {subject_id}: HRF deconvolution attempt finished.")
    return deconvolved_sigs

def _preprocess_time_series(
    sigs: np.ndarray, 
    target_len_ts_val: int, 
    subject_id: str, 
    var_max_lag_val: int, 
    tr_seconds_val: float, low_cut_val: float, high_cut_val: float, filter_order_val: int,
    apply_hrf_deconv_val: bool, hrf_model_type_val: str,
    taper_alpha_val: float 
) -> Optional[np.ndarray]:
    original_length, current_n_rois = sigs.shape
    fs = 1.0 / tr_seconds_val 
    
    logger.info(f"S {subject_id}: Preprocessing. Input TPs: {original_length}, ROIs: {current_n_rois} (should be {FINAL_N_ROIS_EXPECTED}), TR: {tr_seconds_val}s. Target TPs for output: {target_len_ts_val}.")
    
    sigs_processed = _bandpass_filter_signals(sigs, low_cut_val, high_cut_val, fs, filter_order_val, subject_id, taper_alpha=taper_alpha_val)
    
    if apply_hrf_deconv_val:
        sigs_processed = _hrf_deconvolution(sigs_processed, tr_seconds_val, hrf_model_type_val, subject_id)
        if np.isnan(sigs_processed).any() or np.isinf(sigs_processed).any():
            logger.warning(f"S {subject_id}: NaNs/Infs detected after HRF deconvolution. Cleaning by replacing with 0.0.")
            sigs_processed = np.nan_to_num(sigs_processed, nan=0.0, posinf=0.0, neginf=0.0)
            
    min_len_for_var = var_max_lag_val + 10 # Used for Granger or VAR
    min_len_for_dfc = DFC_WIN_POINTS if DFC_WIN_POINTS > 0 else 5 
    min_overall_len = max(5, min_len_for_var, min_len_for_dfc) 
    if sigs_processed.shape[0] < min_overall_len:
        logger.warning(f"S {subject_id}: Timepoints after processing ({sigs_processed.shape[0]}) are less than minimum required ({min_overall_len}) for all connectivity measures. Skipping subject.")
        return None
        
    if np.isnan(sigs_processed).any():
        logger.warning(f"S {subject_id}: NaNs detected in signals before scaling. Filling with 0.0. This might affect results.")
        sigs_processed = np.nan_to_num(sigs_processed, nan=0.0) 
        
    try:
        scaler = StandardScaler() 
        sigs_normalized = scaler.fit_transform(sigs_processed)
        if np.isnan(sigs_normalized).any(): 
            logger.warning(f"S {subject_id}: NaNs detected after StandardScaler. Filling with 0.0. This is unusual.")
            sigs_normalized = np.nan_to_num(sigs_normalized, nan=0.0, posinf=0.0, neginf=0.0)
    except ValueError as e_scale: 
        logger.warning(f"S {subject_id}: StandardScaler failed (e.g. all-zero data after processing): {e_scale}. Attempting column-wise scaling or zeroing.")
        sigs_normalized = np.zeros_like(sigs_processed, dtype=np.float32)
        for i in range(sigs_processed.shape[1]):
            col_data = sigs_processed[:, i].reshape(-1,1)
            if np.std(col_data) > 1e-9: 
                try: 
                    sigs_normalized[:, i] = StandardScaler().fit_transform(col_data).flatten()
                except Exception as e_col_scale:
                    logger.error(f"S {subject_id}, ROI {i}: Column-wise scaling failed: {e_col_scale}. Setting to zero.")
                    sigs_normalized[:, i] = 0.0 
            else: 
                sigs_normalized[:, i] = 0.0 
        if np.isnan(sigs_normalized).any(): 
            sigs_normalized = np.nan_to_num(sigs_normalized, nan=0.0, posinf=0.0, neginf=0.0)

    current_length_norm, num_rois_norm = sigs_normalized.shape
    if current_length_norm != target_len_ts_val:
        logger.info(f"S {subject_id}: Homogenizing time series length from {current_length_norm} to {target_len_ts_val}.")
        if current_length_norm < target_len_ts_val:
            logger.debug(f"S {subject_id}: Interpolating from {current_length_norm} to {target_len_ts_val} points.")
            sigs_homogenized = np.zeros((target_len_ts_val, num_rois_norm), dtype=np.float32)
            if current_length_norm > 1: 
                x_old = np.linspace(0, 1, current_length_norm)
                x_new = np.linspace(0, 1, target_len_ts_val)
                for i in range(num_rois_norm):
                    f_interp = interp1d(x_old, sigs_normalized[:, i], kind='linear', fill_value="extrapolate")
                    sigs_homogenized[:, i] = f_interp(x_new)
            elif current_length_norm == 1: 
                 for i in range(num_rois_norm):
                    sigs_homogenized[:,i] = sigs_normalized[0,i] 

            if np.isnan(sigs_homogenized).any(): 
                logger.warning(f"S {subject_id}: NaNs found after interpolation/length adjustment. Filling with 0.0.")
                sigs_homogenized = np.nan_to_num(sigs_homogenized, nan=0.0)
        else: 
            sigs_homogenized = sigs_normalized[:target_len_ts_val, :]
    else:
        sigs_homogenized = sigs_normalized 
        
    return sigs_homogenized.astype(np.float32)

def load_and_preprocess_single_subject_series(
    subject_id: str, 
    target_len_ts_val: int,
    current_roi_signals_dir_path: Path, current_roi_filename_template: str,
    possible_roi_keys_list: List[str], 
    eff_conn_max_lag_val: int, # Parámetro de lag para conectividad efectiva (Granger o VAR)
    tr_seconds_val: float, low_cut_val: float, high_cut_val: float, filter_order_val: int,
    apply_hrf_deconv_val: bool, hrf_model_type_val: str,
    taper_alpha_val: float 
) -> Tuple[Optional[np.ndarray], str, bool]:
    mat_path = current_roi_signals_dir_path / current_roi_filename_template.format(subject_id=subject_id)
    if not mat_path.exists(): 
        return None, f"MAT file not found: {mat_path.name}", False
    
    try:
        loaded_sigs_raw_170 = _load_signals_from_mat(mat_path, possible_roi_keys_list)
        if loaded_sigs_raw_170 is None: 
            return None, f"No valid signal keys or load error in {mat_path.name}", False
        
        sigs_reduced_rois = _orient_and_reduce_rois(
            loaded_sigs_raw_170, subject_id, 
            RAW_DATA_EXPECTED_COLUMNS, 
            AAL3_MISSING_INDICES_0BASED, 
            INDICES_OF_SMALL_ROIS_TO_DROP_FROM_166, 
            FINAL_N_ROIS_EXPECTED 
        )
        del loaded_sigs_raw_170; gc.collect() 
        if sigs_reduced_rois is None: 
            return None, f"ROI orientation, reduction, or validation failed for S {subject_id}.", False
        
        if FINAL_N_ROIS_EXPECTED is not None and sigs_reduced_rois.shape[1] != FINAL_N_ROIS_EXPECTED:
            error_msg = (f"S {subject_id}: Post-reduction ROI count ({sigs_reduced_rois.shape[1]}) "
                         f"does not match FINAL_N_ROIS_EXPECTED ({FINAL_N_ROIS_EXPECTED}). This is unexpected. "
                         "Check AAL3 metadata and ROI reduction logic.")
            logger.error(error_msg)
            return None, error_msg, False
        elif FINAL_N_ROIS_EXPECTED is None:
             logger.warning(f"S {subject_id}: FINAL_N_ROIS_EXPECTED is None, cannot strictly validate ROI count. Proceeding with {sigs_reduced_rois.shape[1]} ROIs.")

        original_tp_count = sigs_reduced_rois.shape[0]
        
        sigs_processed = _preprocess_time_series(
            sigs_reduced_rois, target_len_ts_val,
            subject_id, eff_conn_max_lag_val, 
            tr_seconds_val, low_cut_val, high_cut_val, filter_order_val,
            apply_hrf_deconv_val, hrf_model_type_val,
            taper_alpha_val=taper_alpha_val 
        )
        del sigs_reduced_rois; gc.collect() 
        if sigs_processed is None: 
            return None, f"Preprocessing (filtering, scaling, or length adjustment) failed for S {subject_id}. Original TPs: {original_tp_count}", False
        
        final_shape_str = f"({sigs_processed.shape[0]}, {sigs_processed.shape[1]})"
        if FINAL_N_ROIS_EXPECTED is not None and sigs_processed.shape[1] != FINAL_N_ROIS_EXPECTED:
            error_msg = (f"S {subject_id}: Processed signal ROI count ({sigs_processed.shape[1]}) "
                         f"mismatches FINAL_N_ROIS_EXPECTED ({FINAL_N_ROIS_EXPECTED}).")
            logger.error(error_msg)
            return None, error_msg, False

        logger.info(f"S {subject_id}: Successfully loaded and preprocessed. Original TPs: {original_tp_count}, Final Shape for conn: {final_shape_str}")
        return sigs_processed, f"OK. Original TPs: {original_tp_count}, final shape for conn: {final_shape_str}", True
        
    except Exception as e:
        logger.error(f"Unhandled exception during load_and_preprocess for S {subject_id} ({mat_path.name}): {e}", exc_info=True)
        return None, f"Exception processing {mat_path.name}: {str(e)}", False

# --- 3. Connectivity Calculation Functions ---
def fisher_r_to_z(r_matrix: np.ndarray, eps: float = 1e-7) -> np.ndarray:
    r_clean = np.nan_to_num(r_matrix.astype(np.float32), nan=0.0) 
    r_clipped = np.clip(r_clean, -1.0 + eps, 1.0 - eps)
    z_matrix = np.arctanh(r_clipped)
    np.fill_diagonal(z_matrix, 0.0) 
    return z_matrix.astype(np.float32)

def calculate_pearson_full_fisher_z_signed(ts_subject: np.ndarray, sid: str) -> Optional[np.ndarray]: 
    if ts_subject.shape[0] < 2:
        logger.warning(f"Pearson_Full_FisherZ_Signed (S {sid}): Insufficient timepoints ({ts_subject.shape[0]} < 2).")
        return None
    try:
        corr_matrix = np.corrcoef(ts_subject, rowvar=False).astype(np.float32)
        if corr_matrix.ndim == 0: 
            logger.warning(f"Pearson_Full_FisherZ_Signed (S {sid}): Correlation resulted in a scalar. Input shape: {ts_subject.shape}.")
            num_rois = ts_subject.shape[1]
            return np.zeros((num_rois, num_rois), dtype=np.float32) if num_rois > 0 else None
        
        z_transformed_matrix = fisher_r_to_z(corr_matrix) 
        logger.info(f"Pearson_Full_FisherZ_Signed (S {sid}): Successfully calculated.")
        return z_transformed_matrix
    except Exception as e:
        logger.error(f"Error calculating Pearson_Full_FisherZ_Signed for S {sid}: {e}", exc_info=True)
        return None

def calculate_pearson_omst_signed_weighted(ts_subject: np.ndarray, sid: str) -> Optional[np.ndarray]: 
    if not OMST_PYTHON_LOADED or orthogonal_minimum_spanning_tree is None:
        logger.error(f"Pearson_OMST_GCE_Signed_Weighted (S {sid}): Dyconnmap OMST function not available. Cannot calculate.")
        return None 
    
    if ts_subject.shape[0] < 2: 
        logger.warning(f"Pearson_OMST_GCE_Signed_Weighted (S {sid}): Insufficient timepoints ({ts_subject.shape[0]} < 2).")
        return None
    
    try:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="divide by zero encountered in divide", category=RuntimeWarning)
            warnings.filterwarnings("ignore", message="invalid value encountered in divide", category=RuntimeWarning) 
            
            corr_matrix = np.corrcoef(ts_subject, rowvar=False).astype(np.float32)
            
            if corr_matrix.ndim == 0: 
                logger.warning(f"Pearson_OMST_GCE_Signed_Weighted (S {sid}): Correlation resulted in a scalar. Input shape: {ts_subject.shape}. Returning zero matrix.")
                num_rois = ts_subject.shape[1]
                return np.zeros((num_rois, num_rois), dtype=np.float32) if num_rois > 0 else None
                
            z_transformed_matrix = fisher_r_to_z(corr_matrix) 
            weights_for_omst_gce = np.abs(z_transformed_matrix) 
            np.fill_diagonal(weights_for_omst_gce, 0.0) 

            if np.all(np.isclose(weights_for_omst_gce, 0)):
                 logger.warning(f"Pearson_OMST_GCE_Signed_Weighted (S {sid}): All input weights for OMST GCE are zero. Returning zero matrix.")
                 return z_transformed_matrix.astype(np.float32) 
                 
            logger.info(f"S {sid}: Calling dyconnmap.threshold_omst_global_cost_efficiency with ABSOLUTE weights shape {weights_for_omst_gce.shape}")
            
            omst_outputs = orthogonal_minimum_spanning_tree(weights_for_omst_gce, n_msts=None) 
            
            if isinstance(omst_outputs, tuple) and len(omst_outputs) >= 2:
                omst_adjacency_matrix = np.asarray(omst_outputs[1]).astype(np.float32) 
                logger.debug(f"S {sid}: dyconnmap.threshold_omst_global_cost_efficiency returned multiple outputs. Using the second one (CIJtree) as omst_adjacency_matrix.")
            else:
                logger.error(f"S {sid}: dyconnmap.threshold_omst_global_cost_efficiency returned an unexpected type or insufficient outputs: {type(omst_outputs)}. Cannot extract OMST matrix.")
                return None

            if not isinstance(omst_adjacency_matrix, np.ndarray): 
                logger.error(f"S {sid}: Extracted omst_adjacency_matrix is not a numpy array (type: {type(omst_adjacency_matrix)}). Cannot proceed.")
                return None

            gce_weights = omst_adjacency_matrix 
            signs = np.sign(z_transformed_matrix)
            signs[signs == 0] = 1 
            signed_weighted_omst_matrix = np.multiply(gce_weights, signs)

            np.fill_diagonal(signed_weighted_omst_matrix, 0.0) 
            
            logger.info(f"Pearson_OMST_GCE_Signed_Weighted (S {sid}): Successfully calculated. Matrix density: {np.count_nonzero(signed_weighted_omst_matrix) / signed_weighted_omst_matrix.size:.4f}")
            return signed_weighted_omst_matrix.astype(np.float32)

    except AttributeError as ae:
        if 'from_numpy_matrix' in str(ae).lower() or 'from_numpy_array' in str(ae).lower(): 
            logger.error(f"Error calculating Pearson_OMST_GCE_Signed_Weighted (dyconnmap) for S {sid}: NetworkX version incompatibility. "
                         f"Dyconnmap (v1.0.4) may be using a deprecated NetworkX function. "
                         f"Your NetworkX version: {nx.__version__}. Consider using NetworkX 2.x (e.g., 'pip install networkx==2.8.8'). Original error: {ae}", exc_info=False) 
        else:
            logger.error(f"AttributeError calculating Pearson_OMST_GCE_Signed_Weighted (dyconnmap) for S {sid}: {ae}", exc_info=True)
        return None
    except Exception as e:
        logger.error(f"Error calculating Pearson_OMST_GCE_Signed_Weighted (dyconnmap) connectivity for S {sid}: {e}", exc_info=True)
        return None

def _calculate_mi_for_pair(X_i_reshaped, y_j, n_neighbors_val):
    """Helper function to calculate MI for a single pair, for parallelization."""
    try:
        mi_val = mutual_info_regression(X_i_reshaped, y_j, n_neighbors=n_neighbors_val, random_state=42, discrete_features=False)[0]
        return mi_val
    except Exception:
        return 0.0 

def calculate_mi_knn_connectivity(ts_subject: np.ndarray, n_neighbors_val: int, sid: str) -> Optional[np.ndarray]:
    n_tp, n_rois = ts_subject.shape
    if n_tp == 0: 
        logger.warning(f"MI_KNN (S {sid}): 0 Timepoints provided. Cannot calculate MI.")
        return None
    if n_tp <= n_neighbors_val: 
        logger.warning(f"MI_KNN (S {sid}): Timepoints ({n_tp}) <= n_neighbors ({n_neighbors_val}). Skipping MI calculation.")
        return None
        
    mi_matrix = np.zeros((n_rois, n_rois), dtype=np.float32)
    
    tasks = []
    for i in range(n_rois):
        for j in range(i + 1, n_rois):
            tasks.append((ts_subject[:, i].reshape(-1, 1), ts_subject[:, j], n_neighbors_val))
            tasks.append((ts_subject[:, j].reshape(-1, 1), ts_subject[:, i], n_neighbors_val)) 

    global MAX_WORKERS, TOTAL_CPU_CORES # Necesitamos acceso a las globales
    if MAX_WORKERS == 1:
        n_jobs_mi = max(1, TOTAL_CPU_CORES - 1 if TOTAL_CPU_CORES > 1 else 1)
    else:
        n_jobs_mi = max(1, TOTAL_CPU_CORES // MAX_WORKERS)
    logger.debug(f"MI_KNN (S {sid}): Using n_jobs={n_jobs_mi} for joblib.Parallel. Global MAX_WORKERS for subjects: {MAX_WORKERS}")

    try:
        results = Parallel(n_jobs=n_jobs_mi)(delayed(_calculate_mi_for_pair)(X, y, nn) for X, y, nn in tasks)
    except Exception as e_parallel:
        logger.error(f"MI_KNN (S {sid}): Error during parallel MI calculation: {e_parallel}. Falling back to serial.")
        results = [_calculate_mi_for_pair(X, y, nn) for X, y, nn in tasks] 

    task_idx = 0
    for i in range(n_rois):
        for j in range(i + 1, n_rois):
            mi_val_ij = results[task_idx]
            task_idx += 1
            mi_val_ji = results[task_idx]
            task_idx += 1
            mi_matrix[i, j] = mi_matrix[j, i] = (mi_val_ij + mi_val_ji) / 2.0
            if mi_val_ij == 0.0 and mi_val_ji == 0.0: 
                 logger.debug(f"MI_KNN (S {sid}): MI for pair ({i},{j}) resulted in 0.0 (possibly due to error or true zero MI).")

    logger.info(f"MI_KNN_Symmetric (S {sid}): Successfully calculated.")
    return mi_matrix

def calculate_custom_dfc_abs_diff_mean(ts_subject: np.ndarray, win_points_val: int, step_val: int, sid: str) -> Optional[np.ndarray]:
    n_tp, n_rois = ts_subject.shape
    if n_tp < win_points_val: 
        logger.warning(f"dFC_AbsDiffMean (S {sid}): Timepoints ({n_tp}) < window length ({win_points_val}). Skipping.")
        return None
        
    num_windows = (n_tp - win_points_val) // step_val + 1
    if num_windows < 2: 
        logger.warning(f"dFC_AbsDiffMean (S {sid}): Fewer than 2 windows ({num_windows}) can be formed. Skipping.")
        return None
        
    sum_abs_diff_matrix = np.zeros((n_rois, n_rois), dtype=np.float64) 
    n_diffs_calculated = 0
    prev_corr_matrix_abs: Optional[np.ndarray] = None
    
    for idx in range(num_windows):
        start_idx = idx * step_val
        end_idx = start_idx + win_points_val
        window_ts = ts_subject[start_idx:end_idx, :]
        
        if window_ts.shape[0] < 2: continue 
            
        try:
            corr_matrix_window = np.corrcoef(window_ts, rowvar=False)
            if corr_matrix_window.ndim < 2 or corr_matrix_window.shape != (n_rois, n_rois):
                logger.warning(f"dFC_AbsDiffMean (S {sid}), Window {idx}: corrcoef returned unexpected shape {corr_matrix_window.shape}. Using zeros for this window.")
                corr_matrix_window = np.full((n_rois, n_rois), 0.0, dtype=np.float32)
            else:
                corr_matrix_window = np.nan_to_num(corr_matrix_window.astype(np.float32), nan=0.0) 
            
            current_corr_matrix_abs = np.abs(corr_matrix_window)
            np.fill_diagonal(current_corr_matrix_abs, 0) 
            
            if prev_corr_matrix_abs is not None:
                sum_abs_diff_matrix += np.abs(current_corr_matrix_abs - prev_corr_matrix_abs)
                n_diffs_calculated += 1
            prev_corr_matrix_abs = current_corr_matrix_abs
        except Exception as e: 
            logger.error(f"dFC_AbsDiffMean (S {sid}), Window {idx}: Error calculating/processing correlation: {e}")
            
    if n_diffs_calculated == 0: 
        logger.warning(f"dFC_AbsDiffMean (S {sid}): No valid differences between windowed correlations were calculated. Returning None.")
        return None
        
    mean_abs_diff_matrix = (sum_abs_diff_matrix / n_diffs_calculated).astype(np.float32)
    np.fill_diagonal(mean_abs_diff_matrix, 0) 
    return mean_abs_diff_matrix

def calculate_dfc_std_dev(ts_subject: np.ndarray, win_points_val: int, step_val: int, sid: str) -> Optional[np.ndarray]: 
    n_tp, n_rois = ts_subject.shape
    if n_tp < win_points_val:
        logger.warning(f"dFC_StdDev (S {sid}): Timepoints ({n_tp}) < window length ({win_points_val}). Skipping.")
        return None
        
    num_windows = (n_tp - win_points_val) // step_val + 1
    if num_windows < 2: 
        logger.warning(f"dFC_StdDev (S {sid}): Fewer than 2 windows ({num_windows}) can be formed. StdDev would be trivial. Skipping.")
        return None
        
    window_corr_matrices_list = []
    
    for idx in range(num_windows):
        start_idx = idx * step_val
        end_idx = start_idx + win_points_val
        window_ts = ts_subject[start_idx:end_idx, :]
        
        if window_ts.shape[0] < 2: continue 
            
        try:
            corr_matrix_window = np.corrcoef(window_ts, rowvar=False)
            if corr_matrix_window.ndim < 2 or corr_matrix_window.shape != (n_rois, n_rois):
                logger.warning(f"dFC_StdDev (S {sid}), Window {idx}: corrcoef returned unexpected shape {corr_matrix_window.shape}. Skipping this window for StdDev.")
                continue
            else:
                corr_matrix_window = np.nan_to_num(corr_matrix_window.astype(np.float32), nan=0.0) 
            
            np.fill_diagonal(corr_matrix_window, 0) 
            window_corr_matrices_list.append(corr_matrix_window)
        except Exception as e: 
            logger.error(f"dFC_StdDev (S {sid}), Window {idx}: Error calculating/processing correlation: {e}")
            
    if len(window_corr_matrices_list) < 2: 
        logger.warning(f"dFC_StdDev (S {sid}): Fewer than 2 valid windowed correlation matrices were calculated. Cannot compute StdDev. Returning None.")
        return None
        
    stacked_corr_matrices = np.stack(window_corr_matrices_list, axis=0) 
    std_dev_matrix = np.std(stacked_corr_matrices, axis=0).astype(np.float32)
    np.fill_diagonal(std_dev_matrix, 0) 
    
    logger.info(f"dFC_StdDev (S {sid}): Successfully calculated from {len(window_corr_matrices_list)} windows.")
    return std_dev_matrix

def _granger_pair(ts1, ts2, maxlag, sid, i, j): 
    """F de Granger para ts1 → ts2 y viceversa con maxlag."""
    try:
        res_ij_data = np.column_stack([ts2, ts1])
        if np.any(np.std(res_ij_data, axis=0) < 1e-6): 
             logger.debug(f"S {sid}: GC pair ({i},{j}) - ts1->ts2: Datos con varianza casi nula. Saltando.")
             res_ij = 0.0
        else:
            res_ij = grangercausalitytests(res_ij_data, maxlag=maxlag, verbose=False)[maxlag][0]['ssr_ftest'][0]
        
        res_ji_data = np.column_stack([ts1, ts2])
        if np.any(np.std(res_ji_data, axis=0) < 1e-6):
            logger.debug(f"S {sid}: GC pair ({i},{j}) - ts2->ts1: Datos con varianza casi nula. Saltando.")
            res_ji = 0.0
        else:
            res_ji = grangercausalitytests(res_ji_data, maxlag=maxlag, verbose=False)[maxlag][0]['ssr_ftest'][0]
            
        return res_ij, res_ji
    except Exception as e:
        logger.debug(f"S {sid}: GC pair ({i},{j}) failed: {e}")
        return 0.0, 0.0
        
def calculate_granger_f_matrix(ts_subject: np.ndarray, maxlag: int, sid: str) -> Optional[np.ndarray]: 
    n_tp, n_rois = ts_subject.shape
    if n_tp <= maxlag + 2: 
        logger.warning(f"Granger (S {sid}): Too few TPs ({n_tp}) for lag {maxlag} and {n_rois} ROIs. Need > {maxlag + 2}.")
        return None
    
    gc_mat = np.zeros((n_rois, n_rois), dtype=np.float32)
    tasks = []
    for i in range(n_rois):
        for j in range(i + 1, n_rois): 
            tasks.append((ts_subject[:, i], ts_subject[:, j], maxlag, sid, i, j))
    
    global MAX_WORKERS, TOTAL_CPU_CORES
    if MAX_WORKERS == 1:
        n_jobs_granger = max(1, TOTAL_CPU_CORES - 1 if TOTAL_CPU_CORES > 1 else 1)
    else:
        n_jobs_granger = max(1, TOTAL_CPU_CORES // MAX_WORKERS)
    logger.debug(f"Granger (S {sid}): Using n_jobs={n_jobs_granger} for joblib.Parallel. Global MAX_WORKERS for subjects: {MAX_WORKERS}")

    try:
        results = Parallel(n_jobs=n_jobs_granger)(
            delayed(_granger_pair)(*args) for args in tasks
        )
    except Exception as e_parallel_granger:
        logger.error(f"Granger (S {sid}): Error during parallel Granger calculation: {e_parallel_granger}. Falling back to serial.")
        results = [_granger_pair(*args) for args in tasks]


    k = 0
    for i in range(n_rois):
        for j in range(i + 1, n_rois):
            f_ij, f_ji = results[k]; k += 1
            f_sym = (f_ij + f_ji) / 2.0 
            gc_mat[i, j] = gc_mat[j, i] = f_sym
            
    np.fill_diagonal(gc_mat, 0) 
    logger.info(f"Granger_F_lag{maxlag} (S {sid}): done.")
    return gc_mat


# --- 4. Function to Calculate All Connectivity Modalities ---
def calculate_all_connectivity_modalities_for_subject(
    subject_id: str, subject_ts_data: np.ndarray,
    n_neighbors_mi_param: int,
    dfc_win_points_param: int, dfc_step_param: int,
    granger_lag_param: int # Renombrado de var_lag_param
) -> Dict[str, Any]:
    matrices: Dict[str, Optional[np.ndarray]] = {name: None for name in CONNECTIVITY_CHANNEL_NAMES}
    errors_in_calculation: Dict[str, str] = {}
    timings: Dict[str, float] = {}

    for channel_name in CONNECTIVITY_CHANNEL_NAMES:
        logger.info(f"Calculating {channel_name} for S {subject_id} (TS shape: {subject_ts_data.shape})...")
        start_time_channel = time.time()
        matrix_result: Optional[np.ndarray] = None
        error_msg: Optional[str] = None

        try:
            if channel_name == PEARSON_OMST_CHANNEL_NAME_PRIMARY: 
                if OMST_PYTHON_LOADED and orthogonal_minimum_spanning_tree is not None:
                    matrix_result = calculate_pearson_omst_signed_weighted(subject_ts_data, subject_id) 
                    if matrix_result is None: 
                        error_msg = f"Primary OMST GCE (signed) calculation failed for S {subject_id}."
                        logger.error(error_msg) 
                else: 
                    error_msg = f"OMST function not loaded, cannot calculate '{PEARSON_OMST_CHANNEL_NAME_PRIMARY}' for S {subject_id}."
                    logger.error(error_msg)
                    if PEARSON_OMST_FALLBACK_NAME in CONNECTIVITY_CHANNEL_NAMES:
                        logger.info(f"Will attempt fallback '{PEARSON_OMST_FALLBACK_NAME}' later if it's in the channel list.")
                    else: 
                         matrices[channel_name] = None 
            
            elif channel_name == PEARSON_OMST_FALLBACK_NAME: 
                 matrix_result = calculate_pearson_full_fisher_z_signed(subject_ts_data, subject_id) 

            elif channel_name == "MI_KNN_Symmetric" and USE_MI_CHANNEL_FOR_THESIS: 
                matrix_result = calculate_mi_knn_connectivity(subject_ts_data, n_neighbors_mi_param, subject_id)
            
            elif channel_name == "dFC_AbsDiffMean" and USE_DFC_ABS_DIFF_MEAN_CHANNEL:
                matrix_result = calculate_custom_dfc_abs_diff_mean(subject_ts_data, dfc_win_points_param, dfc_step_param, subject_id)
            
            elif channel_name == "dFC_StdDev" and USE_DFC_STDDEV_CHANNEL: 
                matrix_result = calculate_dfc_std_dev(subject_ts_data, dfc_win_points_param, dfc_step_param, subject_id)
            
            elif channel_name.startswith("Granger_F_lag") and USE_GRANGER_CHANNEL: 
                current_lag = int(channel_name.split("lag")[-1].split("_")[0]) 
                matrix_result = calculate_granger_f_matrix(subject_ts_data, current_lag, subject_id)
            
            if matrix_result is None and error_msg is None and channel_name in CONNECTIVITY_CHANNEL_NAMES :
                if matrices.get(channel_name) is None : 
                    error_msg = f"'{channel_name}' was in CONNECTIVITY_CHANNEL_NAMES but not calculated or its function returned None without specific error."
                    logger.warning(error_msg)

        except Exception as e: 
            error_msg = str(e)
            logger.error(f"Unexpected error while attempting to calculate {channel_name} for S {subject_id}: {e}", exc_info=True)
        
        if not (channel_name == PEARSON_OMST_CHANNEL_NAME_PRIMARY and error_msg and PEARSON_OMST_FALLBACK_NAME in CONNECTIVITY_CHANNEL_NAMES):
            matrices[channel_name] = matrix_result
        
        if error_msg and channel_name not in errors_in_calculation : 
            if not (channel_name == PEARSON_OMST_CHANNEL_NAME_PRIMARY and PEARSON_OMST_FALLBACK_NAME in CONNECTIVITY_CHANNEL_NAMES):
                 errors_in_calculation[channel_name] = error_msg
        
        timings[f"{channel_name}_time_sec"] = time.time() - start_time_channel
        if matrix_result is not None:
            logger.info(f"{channel_name} for S {subject_id} calculated. Shape: {matrix_result.shape}. Took {timings[f'{channel_name}_time_sec']:.2f}s.")
        elif channel_name in CONNECTIVITY_CHANNEL_NAMES and not (channel_name == PEARSON_OMST_CHANNEL_NAME_PRIMARY and PEARSON_OMST_FALLBACK_NAME in CONNECTIVITY_CHANNEL_NAMES and error_msg): 
            logger.warning(f"{channel_name} for S {subject_id} failed or returned None. Took {timings[f'{channel_name}_time_sec']:.2f}s. Error: {errors_in_calculation.get(channel_name, 'N/A')}")
            
    num_modalities_expected = len(CONNECTIVITY_CHANNEL_NAMES)
    num_successful = sum(1 for mat_name in CONNECTIVITY_CHANNEL_NAMES if matrices.get(mat_name) is not None)
    
    if num_successful < num_modalities_expected:
        logger.warning(f"Connectivity for S {subject_id}: {num_successful}/{num_modalities_expected} selected modalities computed. Errors: {errors_in_calculation}")
    else:
        logger.info(f"Connectivity for S {subject_id}: All {num_successful}/{num_modalities_expected} selected modalities computed successfully.")

    return {"matrices": matrices, "errors_conn_calc": errors_in_calculation, "timings_conn_calc": timings}


# --- 5. Per-Subject Processing Pipeline (for Multiprocessing) ---
def process_single_subject_pipeline(subject_row_tuple: Tuple[int, pd.Series]) -> Dict[str, Any]:
    idx, subject_row = subject_row_tuple
    subject_id = str(subject_row['SubjectID']).strip()
    process = psutil.Process(os.getpid()) 
    ram_initial_mb = process.memory_info().rss / (1024**2)
    
    result: Dict[str, Any] = {
        "id": subject_id, 
        "status_preprocessing": "PENDING", "detail_preprocessing": "",
        "status_connectivity_calc": "NOT_ATTEMPTED", "errors_connectivity_calc": {},
        "timings_connectivity_calc_sec": {}, "path_saved_tensor": None,
        "status_overall": "PENDING",
        "ram_usage_mb_initial": ram_initial_mb, "ram_usage_mb_final": -1.0
    }
    series_data: Optional[np.ndarray] = None 

    try:
        # Usar GRANGER_MAX_LAG para el preprocesamiento si Granger está activo, sino un lag genérico para VAR (si se reintrodujera)
        eff_conn_lag_for_preprocess = GRANGER_MAX_LAG if USE_GRANGER_CHANNEL else 1 
        series_data, detail_msg_preproc, success_preproc = load_and_preprocess_single_subject_series(
            subject_id, 
            TARGET_LEN_TS,
            ROI_SIGNALS_DIR_PATH_AAL3, ROI_FILENAME_TEMPLATE, POSSIBLE_ROI_KEYS,
            eff_conn_lag_for_preprocess, 
            TR_SECONDS, LOW_CUT_HZ, HIGH_CUT_HZ, FILTER_ORDER, 
            APPLY_HRF_DECONVOLUTION, HRF_MODEL,
            taper_alpha_val=TAPER_ALPHA 
        )
        result["status_preprocessing"] = "SUCCESS" if success_preproc else "FAILED"
        result["detail_preprocessing"] = detail_msg_preproc
        
        if not success_preproc or series_data is None:
            result["status_overall"] = "PREPROCESSING_FAILED"
            return result 

        # El lag para el cálculo de conectividad efectiva será GRANGER_MAX_LAG
        connectivity_results = calculate_all_connectivity_modalities_for_subject(
            subject_id, series_data, N_NEIGHBORS_MI,
            DFC_WIN_POINTS, DFC_STEP, 
            GRANGER_MAX_LAG 
        )
        del series_data; series_data = None; gc.collect() 
        
        calculated_matrices_dict = connectivity_results["matrices"]
        result["errors_connectivity_calc"] = connectivity_results["errors_conn_calc"]
        result["timings_connectivity_calc_sec"] = connectivity_results.get("timings_conn_calc", {})

        all_modalities_valid_and_present = True
        final_matrices_to_stack_list = []
        
        current_expected_rois_for_matrices = FINAL_N_ROIS_EXPECTED if FINAL_N_ROIS_EXPECTED is not None else N_ROIS_EXPECTED 
        if current_expected_rois_for_matrices is None: 
            logger.critical(f"S {subject_id}: CRITICAL - current_expected_rois_for_matrices is None. Cannot validate matrix shapes.")
            result["status_overall"] = "FAILURE_CRITICAL_ROI_COUNT_UNSET"
            result["status_connectivity_calc"] = "FAILURE_CRITICAL_ROI_COUNT_UNSET"
            return result

        expected_matrix_shape = (current_expected_rois_for_matrices, current_expected_rois_for_matrices)

        for channel_name in CONNECTIVITY_CHANNEL_NAMES: 
            matrix = calculated_matrices_dict.get(channel_name)
            if matrix is None:
                all_modalities_valid_and_present = False
                err_msg = f"Modality '{channel_name}' result is None (check calculation logs for S {subject_id})." 
                logger.error(f"S {subject_id}: {err_msg}")
                if channel_name not in result["errors_connectivity_calc"]: 
                    result["errors_connectivity_calc"][channel_name] = err_msg
                break 
            elif matrix.shape != expected_matrix_shape:
                all_modalities_valid_and_present = False
                err_msg = f"Modality '{channel_name}' shape {matrix.shape} != expected {expected_matrix_shape}."
                logger.error(f"S {subject_id}: {err_msg}")
                result["errors_connectivity_calc"][channel_name] = result["errors_connectivity_calc"].get(channel_name, "") + " | " + err_msg
                break 
            else:
                final_matrices_to_stack_list.append(matrix)
        
        if all_modalities_valid_and_present and len(final_matrices_to_stack_list) == N_CHANNELS: 
            result["status_connectivity_calc"] = "SUCCESS_ALL_MODALITIES_VALID"
            try:
                subject_tensor = np.stack(final_matrices_to_stack_list, axis=0).astype(np.float32)
                del final_matrices_to_stack_list, calculated_matrices_dict; gc.collect() 
                
                output_dir_individual_tensors = BASE_PATH_AAL3 / OUTPUT_CONNECTIVITY_DIR_NAME / "individual_subject_tensors"
                output_dir_individual_tensors.mkdir(parents=True, exist_ok=True) 
                
                output_path = output_dir_individual_tensors / f"tensor_{N_CHANNELS}ch_{current_expected_rois_for_matrices}rois_{subject_id}.npz"
                np.savez_compressed(output_path, 
                                    tensor_data=subject_tensor, 
                                    subject_id=subject_id, 
                                    channel_names=np.array(CONNECTIVITY_CHANNEL_NAMES, dtype=str),
                                    rois_count = current_expected_rois_for_matrices,
                                    target_len_ts = TARGET_LEN_TS
                                    )
                result["path_saved_tensor"] = str(output_path)
                result["status_overall"] = "SUCCESS_ALL_PROCESSED_AND_SAVED"
                logger.info(f"S {subject_id}: Successfully processed. Tensor saved to {output_path.name}")
                del subject_tensor; gc.collect() 
            except Exception as e_save:
                logger.error(f"Error saving tensor for S {subject_id}: {e_save}", exc_info=True)
                result["errors_connectivity_calc"]["save_error"] = str(e_save)
                result["status_overall"] = "FAILURE_DURING_TENSOR_SAVING"
                result["status_connectivity_calc"] = "FAILURE_DURING_SAVING"
        else:
            result["status_overall"] = "FAILURE_IN_CONNECTIVITY_CALC_OR_VALIDATION"
            result["status_connectivity_calc"] = "FAILURE_MISSING_OR_INVALID_MODALITIES"
            if not all_modalities_valid_and_present: 
                logger.error(f"S {subject_id}: Not all connectivity modalities were valid or present. Tensor not saved. Errors: {result['errors_connectivity_calc']}")
    
    except Exception as e_pipeline: 
        logger.critical(f"CRITICAL UNHANDLED EXCEPTION for S {subject_id} in pipeline: {e_pipeline}", exc_info=True)
        result["status_overall"] = "CRITICAL_PIPELINE_EXCEPTION"
        result["detail_preprocessing"] = result.get("detail_preprocessing","") + " | Pipeline Exc: " + str(e_pipeline)
        result["errors_connectivity_calc"]["pipeline_exception"] = str(e_pipeline)
    
    finally: 
        if series_data is not None: del series_data; gc.collect() 
        result["ram_usage_mb_final"] = process.memory_info().rss / (1024**2)
    return result

# --- 6. Main Script Execution Flow ---
def main():
    # --- Diagnóstico de Versión de NetworkX ---
    try:
        import networkx as nx_runtime
        logger.info(f"RUNTIME NetworkX version being used: {nx_runtime.__version__}")
    except ImportError:
        logger.error("RUNTIME: NetworkX is not installed or importable.")

    # --- Fijar semilla para reproducibilidad ---
    np.random.seed(42)

    script_start_time = time.time()
    main_process_info = psutil.Process(os.getpid())
    logger.info(f"Main process RAM at start: {main_process_info.memory_info().rss / (1024**2):.2f} MB")
    logger.info(f"--- Starting fMRI Connectivity Pipeline (Version for Doctoral Thesis with Dyconnmap v6.5.8_GrangerEnhanced) ---") 
    
    logger.info(f"--- Final Expected ROIs for Connectivity Matrices: {N_ROIS_EXPECTED} (should be 131) ---")
    logger.info(f"--- Target Homogenized Time Series Length: {TARGET_LEN_TS} ---")
    logger.info(f"--- Output Directory Name: {OUTPUT_CONNECTIVITY_DIR_NAME} ---")
    logger.info(f"--- Selected Connectivity Channels for VAE: {CONNECTIVITY_CHANNEL_NAMES} ({N_CHANNELS} channels) ---")
    if USE_PEARSON_OMST_CHANNEL and not (OMST_PYTHON_LOADED and orthogonal_minimum_spanning_tree is not None):
        logger.warning(f"Note: OMST from dyconnmap could not be loaded. '{PEARSON_OMST_FALLBACK_NAME}' will be used instead of '{PEARSON_OMST_CHANNEL_NAME_PRIMARY}' if enabled.")


    if not BASE_PATH_AAL3.exists() or not ROI_SIGNALS_DIR_PATH_AAL3.exists():
        logger.critical(f"CRITICAL: Base AAL3 path ({BASE_PATH_AAL3}) or ROI signals directory ({ROI_SIGNALS_DIR_PATH_AAL3}) not found. Aborting.")
        return

    subject_metadata_df = load_metadata(SUBJECT_METADATA_CSV_PATH, QC_REPORT_CSV_PATH)
    if subject_metadata_df is None or subject_metadata_df.empty:
        logger.critical("Metadata loading failed or no subjects passed QC to process. Aborting.")
        return

    output_main_directory = BASE_PATH_AAL3 / OUTPUT_CONNECTIVITY_DIR_NAME
    output_individual_tensors_dir = output_main_directory / "individual_subject_tensors" 
    try:
        output_main_directory.mkdir(parents=True, exist_ok=True)
        output_individual_tensors_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Main output directory created/exists: {output_main_directory}")
    except OSError as e:
        logger.critical(f"Could not create output directories: {e}. Aborting."); return

    # MAX_WORKERS ahora se define globalmente
    logger.info(f"Total CPU cores available: {TOTAL_CPU_CORES}. Using MAX_WORKERS = {MAX_WORKERS} for ProcessPoolExecutor.")
    available_ram_gb = psutil.virtual_memory().available / (1024**3)
    logger.warning(f"Available system RAM at start of parallel processing: {available_ram_gb:.2f} GB. Monitor usage closely.")

    subject_rows_to_process = list(subject_metadata_df.iterrows()) 
    num_subjects_to_process = len(subject_rows_to_process)
    if num_subjects_to_process == 0: 
        logger.critical("No subjects to process after metadata loading and QC filtering. Aborting.")
        return
    logger.info(f"Starting parallel processing for {num_subjects_to_process} subjects.")
    
    all_subject_results_list = [] 
    with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_subject_id_map = {
            executor.submit(process_single_subject_pipeline, subject_tuple): str(subject_tuple[1]['SubjectID']).strip()
            for subject_tuple in subject_rows_to_process
        }
        for future in tqdm(as_completed(future_to_subject_id_map), total=num_subjects_to_process, desc="Processing Subjects"):
            subject_id_for_log = future_to_subject_id_map[future]
            try:
                subject_result = future.result() 
                all_subject_results_list.append(subject_result)
            except Exception as exc: 
                logger.critical(f"CRITICAL WORKER EXCEPTION for S {subject_id_for_log}: {exc}", exc_info=True)
                all_subject_results_list.append({ 
                    "id": subject_id_for_log, "status_overall": "CRITICAL_WORKER_EXCEPTION",
                    "detail_preprocessing": f"Worker process crashed: {str(exc)}", 
                    "errors_connectivity_calc": {"worker_exception": str(exc)}
                })

    processing_log_df = pd.DataFrame(all_subject_results_list)
    log_file_path = output_main_directory / f"pipeline_log_{output_main_directory.name}.csv"
    try: 
        processing_log_df.to_csv(log_file_path, index=False)
        logger.info(f"Detailed processing log saved to: {log_file_path}")
    except Exception as e_log_save: 
        logger.error(f"Failed to save detailed processing log: {e_log_save}")

    successful_subject_entries_list = [
        res for res in all_subject_results_list
        if res.get("status_overall") == "SUCCESS_ALL_PROCESSED_AND_SAVED" and \
           res.get("path_saved_tensor") and Path(res["path_saved_tensor"]).exists()
    ]
    num_successful_subjects_for_tensor = len(successful_subject_entries_list)
    
    logger.info(f"--- Overall Processing Summary ---")
    logger.info(f"Total subjects attempted: {num_subjects_to_process}")
    logger.info(f"Successfully processed and individual tensors saved: {num_successful_subjects_for_tensor}")
    if num_successful_subjects_for_tensor < num_subjects_to_process:
        num_failed = num_subjects_to_process - num_successful_subjects_for_tensor
        logger.warning(f"{num_failed} subjects failed at some stage. Check the detailed log: {log_file_path}")

    if num_successful_subjects_for_tensor > 0:
        logger.info(f"Attempting to assemble global tensor for {num_successful_subjects_for_tensor} successfully processed subjects.")
        global_conn_tensor_list = []
        final_subject_ids_in_tensor = []
        
        current_expected_rois_for_assembly = FINAL_N_ROIS_EXPECTED if FINAL_N_ROIS_EXPECTED is not None else N_ROIS_EXPECTED
        if current_expected_rois_for_assembly is None:
            logger.critical("Cannot assemble global tensor: FINAL_N_ROIS_EXPECTED is None.")
        else:
            logger.warning("Assembling global tensor using np.stack. This may be memory-intensive for large datasets. "
                           "Consider np.memmap, Zarr, or HDF5 for more scalable solutions.")
            try:
                for s_entry in tqdm(successful_subject_entries_list, desc="Assembling Global Tensor"):
                    s_id = s_entry["id"]
                    tensor_path_str = s_entry["path_saved_tensor"]
                    try:
                        with np.load(tensor_path_str) as loaded_npz:
                            s_tensor_data = loaded_npz['tensor_data']
                            if s_tensor_data.shape == (N_CHANNELS, current_expected_rois_for_assembly, current_expected_rois_for_assembly):
                                global_conn_tensor_list.append(s_tensor_data)
                                final_subject_ids_in_tensor.append(s_id)
                            else: 
                                logger.error(f"Tensor for S {s_id} has shape mismatch: {s_tensor_data.shape}. "
                                             f"Expected: ({N_CHANNELS}, {current_expected_rois_for_assembly}, {current_expected_rois_for_assembly}). Skipping this subject for global tensor.")
                        del s_tensor_data; gc.collect()
                    except Exception as e_load_ind_tensor: 
                        logger.error(f"Error loading individual tensor for S {s_id} from {tensor_path_str}: {e_load_ind_tensor}. Skipping for global tensor.")
                
                if global_conn_tensor_list:
                    global_conn_tensor = np.stack(global_conn_tensor_list, axis=0).astype(np.float32)
                    del global_conn_tensor_list; gc.collect() 
                    
                    current_lag_for_filename = GRANGER_MAX_LAG if USE_GRANGER_CHANNEL else 1 # Default lag if Granger not used
                    global_tensor_fname = f"GLOBAL_TENSOR_AAL3_{current_expected_rois_for_assembly}ROIs_{len(final_subject_ids_in_tensor)}subs_{N_CHANNELS}ch_{current_lag_for_filename}lag{deconv_str}.npz" 
                    global_tensor_path = output_main_directory / global_tensor_fname
                    
                    np.savez_compressed(global_tensor_path, 
                                        global_tensor_data=global_conn_tensor, 
                                        subject_ids=np.array(final_subject_ids_in_tensor, dtype=str), 
                                        channel_names=np.array(CONNECTIVITY_CHANNEL_NAMES, dtype=str),
                                        rois_count = current_expected_rois_for_assembly,
                                        target_len_ts = TARGET_LEN_TS,
                                        tr_seconds = TR_SECONDS,
                                        filter_low_hz = LOW_CUT_HZ,
                                        filter_high_hz = HIGH_CUT_HZ,
                                        hrf_deconvolution_applied = APPLY_HRF_DECONVOLUTION,
                                        hrf_model = HRF_MODEL if APPLY_HRF_DECONVOLUTION else "N/A"
                                        )
                    logger.info(f"Global tensor successfully assembled and saved: {global_tensor_path.name}")
                    logger.info(f"Global tensor shape: {global_conn_tensor.shape} (Subjects, Channels, ROIs, ROIs)")
                    del global_conn_tensor; gc.collect() 
                else: 
                    logger.warning("No valid individual tensors were loaded for global assembly. Global tensor not created.")
            except MemoryError: 
                logger.critical("MEMORY ERROR during global tensor assembly. The dataset might be too large to stack in RAM.")
            except Exception as e_global: 
                logger.critical(f"An unexpected error occurred during global tensor assembly: {e_global}", exc_info=True)

    total_time_min = (time.time() - script_start_time) / 60
    logger.info(f"--- fMRI Connectivity Pipeline Finished ---")
    logger.info(f"Total execution time: {total_time_min:.2f} minutes.")
    logger.info(f"Final main process RAM: {main_process_info.memory_info().rss / (1024**2):.2f} MB")
    logger.info(f"All outputs, logs, and tensors should be in: {output_main_directory}")
    logger.info("Reminder for Thesis: Document all parameters, QC steps, subject selection criteria, and the precise AAL3 ROI definitions used (131 ROIs). Also cite Dyconnmap if used.")

if __name__ == "__main__":
    multiprocessing.freeze_support() 
    main()


2025-06-01 20:25:25,120 - INFO - 3006399536.py:59 - Successfully imported 'threshold_omst_global_cost_efficiency' from 'dyconnmap.graphs.threshold' and aliased as 'orthogonal_minimum_spanning_tree'.
2025-06-01 20:25:25,121 - INFO - 3006399536.py:130 - Global MAX_WORKERS for ProcessPoolExecutor set to: 6 (based on 12 total cores)
2025-06-01 20:25:25,123 - INFO - 3006399536.py:145 - --- Initializing AAL3 ROI Processing Information ---
2025-06-01 20:25:25,128 - INFO - 3006399536.py:183 - AAL3 ROI processing info initialized:
2025-06-01 20:25:25,128 - INFO - 3006399536.py:184 -   Indices of 4 AAL3 systemically missing ROIs (0-based, from 170): [34, 35, 80, 81]
2025-06-01 20:25:25,129 - INFO - 3006399536.py:185 -   Number of ROIs in AAL3 meta after excluding systemically missing: 166 (Expected: 166)
2025-06-01 20:25:25,129 - INFO - 3006399536.py:186 -   Indices of small ROIs to drop (from the 166 set, 0-based): [108, 116, 117, 118, 119, 120, 121, 126, 127, 128, 129, 132, 133, 134, 135, 136,