In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import mutual_info_regression
from nilearn.connectome import ConnectivityMeasure
from nilearn.glm.first_level import spm_hrf, glover_hrf # NEW: For HRF functions
from scipy.signal import butter, filtfilt, deconvolve # NEW: deconvolve
from tqdm import tqdm
import os
import scipy.io as sio
from pathlib import Path
import psutil
import gc
import logging
import time
from typing import List, Tuple, Dict, Optional, Any
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed
from sklearn.linear_model import MultiTaskLassoCV

# --- 0. Global Configuration and Constants ---

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
logger = logging.getLogger(__name__)

# --- Neuroscientific Context & Recommendations Summary ---
# This script aims to generate multi-modal brain connectivity matrices from fMRI time series.
# These matrices (Static FC, Dynamic FC, Effective Connectivity, Non-linear FC)
# can serve as features for machine learning models to classify neurological conditions,
# such as Alzheimer's Disease vs. Normal Controls.
# The choice of connectivity measures and preprocessing steps is guided by their potential
# to capture different aspects of brain network alterations in such conditions.

### NEURO-ENHANCEMENT: Summary of evidence-based recommendations:
# | Aspect             | Recommended Practice                      | Rationale / References                                                                                                                               |
# |--------------------|-------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------|
# | Band-pass Filtering| Yes, 0.01–0.08 Hz before connectivity   | Eliminates scanner drift (<0.01 Hz) & physiological noise (>0.08 Hz) (pmc.ncbi.nlm.nih.gov). Focus on low frequencies improves AD vs HC discrimination. |
# | k in MI (k-NN)     | k = 5 (e.g., KSG estimator)               | k≈3–10 common; k=5 offers good bias-variance balance for N~140 timepoints, stable results (pmc.ncbi.nlm.nih.gov).                                      |
# | Max VAR Order (Lag)| Lag 1 preferred, especially for short TS  | For TR~2–3s, optimal order often 1 (journals.plos.org). Higher orders w/ short TS & many ROIs risk overfitting, even w/ Lasso.                       |
# | HRF Deconvolution  | Optional, for neuronal-level VAR          | Attempts to recover neural signal from BOLD. Can improve EC interpretability but is sensitive to noise & HRF model. Validate carefully.             |


# --- Configurable Parameters ---
TR_SECONDS = 3.0
LOW_CUT_HZ = 0.01
HIGH_CUT_HZ = 0.08
FILTER_ORDER = 2       # N for `butter`. Effective order is 2*N with filtfilt (i.e., 4th order effective).

N_ROIS_EXPECTED = 116
TARGET_LEN_TS = 140

N_NEIGHBORS_MI = 5

DFC_WIN_POINTS = 20    # TR=3s => 60s window
DFC_STEP = 10

# --- VAR Model Parameters ---
LASSO_VAR_MAX_LAG = 1  # MODIFIED: Reduced to 1 as per user request and common practice for fMRI.
                       ### NEURO-ENHANCEMENT: VAR(1) is often preferred for fMRI to model direct, immediate influences,
                       ### especially with TRs around 2-3s and limited timepoints, to avoid overfitting.
APPLY_HRF_DECONVOLUTION = False # NEW: Set to True to attempt HRF deconvolution before VAR.
                               ### NEURO-ENHANCEMENT: If True, VAR is applied to an estimate of neural activity.
                               ### This is an advanced step; results depend heavily on HRF model accuracy and noise.
                               ### Simple deconvolution can be unstable. Use with caution and validate.
HRF_MODEL = 'glover'   # Choose 'glover' or 'spm' for HRF model if deconvolution is applied.
                       ### Glover HRF includes a post-undershoot, SPM HRF is a sum of two gamma functions.

LASSO_MAX_ITER = 2500
LASSO_TOL = 1e-4

# --- Paths and Filenames ---
BASE_PATH = Path('/home/diego/Escritorio/desde_cero/AAL_116') # MODIFY AS NEEDED
SUBJECT_CSV_PATH = BASE_PATH / 'DataBaseSubjects.csv'
STATS_CSV_PATH = BASE_PATH / 'subjects_roi_timepoints_stats.csv'
ROI_SIGNALS_DIR_PATH = BASE_PATH / 'ROISignals'
ROI_FILENAME_TEMPLATE = 'ROISignals_{subject_id}.mat'

# Dynamic output directory name based on VAR lag and deconvolution status
deconv_str = "_deconv" if APPLY_HRF_DECONVOLUTION else ""
OUTPUT_CONNECTIVITY_DIR_NAME = f"AAL_{N_ROIS_EXPECTED}_fmri_tensor_{LASSO_VAR_MAX_LAG}lag{deconv_str}_NeuroEnhanced_v4"

# --- Derived Constants ---
POSSIBLE_ROI_KEYS = ["ROISignals", "signals", "roi_signals", "ROIsignals", "AAL_signals", "roi_ts"]

# Dynamically set VAR channel name
var_channel_suffix = "_Deconv_Influence" if APPLY_HRF_DECONVOLUTION else "_Influence"
CONNECTIVITY_CHANNEL_NAMES = ["Pearson_FisherZ", "MI_KNN", "dFC_AbsDiffMean", f"Lasso_VAR{var_channel_suffix}"]
N_CHANNELS = len(CONNECTIVITY_CHANNEL_NAMES)


# --- 1. Subject Metadata Loading and Merging ---
def load_metadata(subject_csv_path: Path, stats_csv_path: Path) -> Optional[pd.DataFrame]:
    """Loads and merges subject metadata from CSV files."""
    logger.info("--- Starting Subject Metadata Loading ---")
    try:
        if not subject_csv_path.exists():
            logger.critical(f"Subject CSV file NOT found: {subject_csv_path}")
            return None
        if not stats_csv_path.exists():
            logger.critical(f"Statistics CSV file NOT found: {stats_csv_path}")
            return None

        stats_df = pd.read_csv(stats_csv_path)
        subjects_db_df = pd.read_csv(subject_csv_path)
        stats_df['SubjectID'] = stats_df['SubjectID'].astype(str).str.strip()
        subjects_db_df['SubjectID'] = subjects_db_df['SubjectID'].astype(str).str.strip()
        merged_df = pd.merge(stats_df, subjects_db_df, on='SubjectID', how='inner')

        if 'ResearchGroup' not in merged_df.columns:
            raise ValueError("Critical column 'ResearchGroup' is missing. This is essential for classification tasks (e.g., AD vs. NC).")

        if 'Timepoints' not in merged_df.columns:
            raise ValueError("Column 'Timepoints' is missing, needed for filtering or validation.")
        else:
            merged_df['Timepoints'] = pd.to_numeric(merged_df['Timepoints'], errors='coerce').fillna(0).astype(int)

        logger.info(f"Metadata DataFrame 'merged_df' loaded successfully. Shape: {merged_df.shape}")
        logger.info(f"Total unique subjects in merged_df: {merged_df['SubjectID'].nunique()}")
        return merged_df
    except FileNotFoundError as e:
        logger.critical(f"CRITICAL Error loading metadata CSV files: {e}")
        return None
    except ValueError as e:
        logger.critical(f"Value error in metadata: {e}")
        return None
    except Exception as e:
        logger.critical(f"Unexpected error during metadata loading: {e}", exc_info=True)
        return None

# --- 2. Time Series Loading and Preprocessing Functions ---
def _load_signals_from_mat(mat_path: Path, possible_keys: List[str]) -> Optional[np.ndarray]:
    """Loads signals from a .mat file, searching for possible keys."""
    try:
        data = sio.loadmat(mat_path)
    except Exception as e_load:
        logger.error(f"Could not load .mat file: {mat_path}. Error: {e_load}")
        return None
    for key in possible_keys:
        if key in data and isinstance(data[key], np.ndarray) and data[key].ndim >= 2:
            return data[key]
    logger.warning(f"No valid signal keys found in {mat_path}. Keys present: {list(data.keys())}")
    return None

def _orient_and_validate_signals(sigs: np.ndarray, n_rois_expected: int, subject_id: str) -> Optional[np.ndarray]:
    """Ensures signals are oriented as (Timepoints x ROIs) and validates dimensions."""
    if sigs.ndim != 2:
        logger.warning(f"Subject {subject_id}: Incorrect dimensions {sigs.ndim} != 2.")
        return None
    if sigs.shape[0] == n_rois_expected and sigs.shape[1] != n_rois_expected:
        sigs = sigs.T
    elif sigs.shape[1] == n_rois_expected and sigs.shape[0] != n_rois_expected:
        pass
    elif sigs.shape[0] == n_rois_expected and sigs.shape[1] == n_rois_expected:
         logger.warning(f"Subject {subject_id}: Signal matrix is square ({sigs.shape}). Assuming [Timepoints, ROIs]. Verify.")
         pass
    else:
        logger.warning(f"Subject {subject_id}: Neither dimension matches n_rois_expected ({n_rois_expected}). Shape: {sigs.shape}")
        return None
    if sigs.shape[1] != n_rois_expected:
        logger.warning(f"Subject {subject_id}: Final shape {sigs.shape} not [Timepoints, {n_rois_expected} ROIs].")
        return None
    return sigs

def _bandpass_filter_signals(sigs: np.ndarray, lowcut: float, highcut: float, fs: float, order: int, subject_id: str) -> np.ndarray:
    """Applies a Butterworth band-pass filter using filtfilt."""
    nyquist_freq = 0.5 * fs
    low = lowcut / nyquist_freq
    high = highcut / nyquist_freq
    if not (0 < low < 1 and 0 < high < 1 and low < high) : # Corrected condition
        logger.error(f"S {subject_id}: Invalid critical freqs for filter. LowN: {low:.4f}, HighN: {high:.4f}. Must be (0,1) & low<high. Skipping filtering.")
        return sigs
    try:
        b, a = butter(order, [low, high], btype='band', analog=False)
        filtered_sigs = np.zeros_like(sigs)
        for i in range(sigs.shape[1]):
            if np.all(np.isclose(sigs[:,i], sigs[0,i])):
                logger.warning(f"S {subject_id}, ROI {i}: Constant signal. Skipping filtering.")
                filtered_sigs[:, i] = sigs[:, i]
            elif len(sigs[:,i]) < (3 * (order + 1)) + 1 :
                logger.warning(f"S {subject_id}, ROI {i}: Signal too short ({len(sigs[:,i])}pts) for filtfilt order {order}. Skipping filtering.")
                filtered_sigs[:, i] = sigs[:, i]
            else:
                filtered_sigs[:, i] = filtfilt(b, a, sigs[:, i])
        return filtered_sigs
    except Exception as e:
        logger.error(f"S {subject_id}: Error during bandpass filtering: {e}. Returning original signals.", exc_info=True)
        return sigs

def _hrf_deconvolution(sigs: np.ndarray, tr: float, hrf_model_type: str, subject_id: str) -> np.ndarray:
    """
    Attempts HRF deconvolution on BOLD signals.
    WARNING: Simple deconvolution can be unstable and amplify noise. Use with caution.
    """
    logger.info(f"S {subject_id}: Attempting HRF deconvolution with {hrf_model_type} model.")
    if hrf_model_type == 'glover':
        hrf = glover_hrf(tr, oversampling=1) # No oversampling needed for deconvolution with BOLD TR
    elif hrf_model_type == 'spm':
        hrf = spm_hrf(tr, oversampling=1)
    else:
        logger.error(f"S {subject_id}: Unknown HRF model type '{hrf_model_type}'. Skipping deconvolution.")
        return sigs
    
    # Normalize HRF to have unit integral for deconvolution stability if it's not already.
    # However, scipy.signal.deconvolve doesn't assume this for the divisor.
    # It's more about the length and characteristics of the HRF.
    # Shorter HRFs or HRFs with zeros can cause issues.
    hrf = hrf / np.sum(hrf) # Normalize to sum to 1, common for some deconv contexts.

    deconvolved_sigs = np.zeros_like(sigs)
    for i in range(sigs.shape[1]):
        signal_roi = sigs[:, i]
        if len(signal_roi) < len(hrf):
            logger.warning(f"S {subject_id}, ROI {i}: Signal shorter than HRF. Skipping deconvolution for this ROI.")
            deconvolved_sigs[:, i] = signal_roi
            continue
        try:
            # `deconvolve` returns (quotient, remainder). We want the quotient.
            # This is a naive deconvolution; can be unstable.
            q, r = deconvolve(signal_roi, hrf)
            # The deconvolved signal `q` will be shorter. We might need to pad it or handle this.
            # For now, let's take the first part matching original length if q is longer, or pad if shorter.
            if len(q) >= sigs.shape[0]:
                deconvolved_sigs[:, i] = q[:sigs.shape[0]]
            else:
                padding = np.zeros(sigs.shape[0] - len(q)) # Pad with zeros at the end
                deconvolved_sigs[:, i] = np.concatenate([q, padding])

        except Exception as e_deconv:
            logger.error(f"S {subject_id}, ROI {i}: Deconvolution failed: {e_deconv}. Using original signal for this ROI.", exc_info=False)
            deconvolved_sigs[:, i] = signal_roi # Fallback to original signal for this ROI
            
    logger.info(f"S {subject_id}: HRF deconvolution attempt finished.")
    return deconvolved_sigs


def _preprocess_time_series(
    sigs: np.ndarray, target_len_ts: int, n_rois_expected: int,
    subject_id: str, lasso_var_max_lag: int,
    tr_seconds: float, low_cut: float, high_cut: float, filter_order: int, # filter_order is N for butter
    apply_hrf_deconv: bool, hrf_model_type: str # Deconvolution params
) -> Optional[np.ndarray]:
    """
    Applies band-pass filtering, optional HRF deconvolution, normalization, and length homogenization.
    """
    original_length = sigs.shape[0]
    fs = 1.0 / tr_seconds
    logger.info(f"S {subject_id}: Preprocessing. Orig len: {original_length}, TR: {tr_seconds}s.")

    # 1. Band-pass filtering
    logger.info(f"S {subject_id}: Applying band-pass filter ({low_cut}-{high_cut} Hz, Butterworth N={filter_order}, effective order={2*filter_order}).")
    sigs_processed = _bandpass_filter_signals(sigs, low_cut, high_cut, fs, filter_order, subject_id)
    if sigs_processed is sigs:
        logger.warning(f"S {subject_id}: Filtering skipped/failed. Using original signals for next steps.")

    # 2. Optional HRF Deconvolution (NEW)
    if apply_hrf_deconv:
        sigs_processed = _hrf_deconvolution(sigs_processed, tr_seconds, hrf_model_type, subject_id)
        # Deconvolution might change signal properties significantly.
        # Re-check for NaNs or Infs if deconvolution is unstable.
        if np.isnan(sigs_processed).any() or np.isinf(sigs_processed).any():
            logger.warning(f"S {subject_id}: NaNs/Infs found after HRF deconvolution. Attempting to clean by setting to 0.")
            sigs_processed = np.nan_to_num(sigs_processed, nan=0.0, posinf=0.0, neginf=0.0)


    # Minimum length checks (after all length-altering steps if any)
    min_len_for_var = lasso_var_max_lag + 10
    min_len_for_dfc = DFC_WIN_POINTS if DFC_WIN_POINTS > 0 else 5
    min_overall_len = max(5, min_len_for_var, min_len_for_dfc)

    if sigs_processed.shape[0] < min_overall_len:
        logger.warning(f"S {subject_id}: Timepoints ({sigs_processed.shape[0]}) insufficient post-processing (min: {min_overall_len}). Skipping.")
        return None

    # 3. Handle NaNs (post-filtering/deconv, pre-scaling)
    if np.isnan(sigs_processed).any():
        logger.warning(f"S {subject_id}: NaNs detected pre-scaling. Filling with 0.0.")
        sigs_processed = np.nan_to_num(sigs_processed, nan=0.0)

    # 4. Standardization
    try:
        scaler = StandardScaler()
        sigs_normalized = scaler.fit_transform(sigs_processed)
    except ValueError as e_scale:
        logger.error(f"Error StandardScaler S {subject_id} (shape {sigs_processed.shape}): {e_scale}.")
        sigs_normalized = np.zeros_like(sigs_processed, dtype=np.float32)
        for i in range(sigs_processed.shape[1]):
            col = sigs_processed[:, i]
            if col.std() > 1e-9:
                try: sigs_normalized[:, i] = StandardScaler().fit_transform(col.reshape(-1, 1)).flatten()
                except Exception as e_cs: logger.error(f"Err scaling col {i} S {subject_id}: {e_cs}. Setting to 0."); sigs_normalized[:, i] = 0.0
            else: logger.warning(f"S {subject_id}: ROI {i} zero variance pre-scaling. Normalized to zeros."); sigs_normalized[:, i] = 0.0
    
    if np.isnan(sigs_normalized).any():
        logger.warning(f"S {subject_id}: NaNs AFTER StandardScaler. Filling with 0.0.")
        sigs_normalized = np.nan_to_num(sigs_normalized, nan=0.0, posinf=0.0, neginf=0.0)

    # 5. Length Homogenization
    current_length = sigs_normalized.shape[0]
    if current_length != target_len_ts:
        logger.info(f"S {subject_id}: Length {current_length} vs target {target_len_ts}. Padding/truncating.")
        if current_length < target_len_ts:
            padding = np.zeros((target_len_ts - current_length, n_rois_expected), dtype=np.float32)
            sigs_homogenized = np.vstack((sigs_normalized, padding))
        else:
            sigs_homogenized = sigs_normalized[:target_len_ts, :]
    else:
        sigs_homogenized = sigs_normalized

    return sigs_homogenized.astype(np.float32)


def load_and_preprocess_single_subject_series(
    subject_id: str, n_rois_expected: int, target_len_ts: int,
    current_roi_signals_dir_path: Path, current_roi_filename_template: str,
    possible_roi_keys: List[str], lasso_var_max_lag: int,
    tr_seconds: float, low_cut: float, high_cut: float, filter_order: int,
    apply_hrf_deconv: bool, hrf_model_type: str # Deconvolution params added
) -> Tuple[Optional[np.ndarray], str, bool]:
    """Loads, orients, filters, optionally deconvolves, and preprocesses time series."""
    mat_path = current_roi_signals_dir_path / current_roi_filename_template.format(subject_id=subject_id)
    if not mat_path.exists(): return None, f"MAT file not found: {mat_path}", False
    try:
        loaded_sigs_raw = _load_signals_from_mat(mat_path, possible_roi_keys)
        if loaded_sigs_raw is None: return None, f"No valid signal keys/load error in {mat_path}", False
        loaded_sigs_raw = np.asarray(loaded_sigs_raw, dtype=np.float64)
        sigs_oriented = _orient_and_validate_signals(loaded_sigs_raw, n_rois_expected, subject_id)
        del loaded_sigs_raw; gc.collect()
        if sigs_oriented is None: return None, f"Incorrect shape/validation failed S {subject_id}.", False
        original_len = sigs_oriented.shape[0]
        sigs_processed = _preprocess_time_series(
            sigs_oriented, target_len_ts, n_rois_expected, subject_id, lasso_var_max_lag,
            tr_seconds, low_cut, high_cut, filter_order, apply_hrf_deconv, hrf_model_type
        )
        del sigs_oriented; gc.collect()
        if sigs_processed is None: return None, f"Preprocessing failed S {subject_id}. Orig len: {original_len}", False
        return sigs_processed, f"OK. Orig len: {original_len}, final len: {sigs_processed.shape[0]}", True
    except Exception as e:
        logger.error(f"Unhandled exception load_and_preprocess S {subject_id} ({mat_path}): {e}", exc_info=True)
        return None, f"Exception processing {mat_path}: {str(e)}", False

# --- 3. Connectivity Calculation Functions ---
# (Fisher, MI, dFC, Pearson remain unchanged from previous version)
# (Lasso-VAR also remains largely the same, but operates on potentially deconvolved data)

def fisher_r_to_z(r_matrix: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    r_clean = np.nan_to_num(r_matrix.astype(np.float32), nan=0.0, posinf=1.0-eps, neginf=-1.0+eps)
    r_clipped = np.clip(r_clean, -1.0 + eps, 1.0 - eps)
    z_matrix = np.arctanh(r_clipped)
    np.fill_diagonal(z_matrix, 0.0)
    return z_matrix.astype(np.float32)

def calculate_mi_knn_connectivity(ts_subject: np.ndarray, n_neighbors: int, sid: str) -> Optional[np.ndarray]:
    n_tp, n_rois = ts_subject.shape
    if n_tp == 0: logger.warning(f"MI_KNN (S {sid}): 0 timepoints."); return None
    if n_tp <= n_neighbors: logger.warning(f"MI_KNN (S {sid}): Insufficient TP ({n_tp}) for N_N={n_neighbors}. Needs > N_N. Skipping."); return None
    mi_matrix = np.zeros((n_rois, n_rois), dtype=np.float32)
    for i in range(n_rois):
        for j in range(i + 1, n_rois):
            try:
                X_i = ts_subject[:, i].reshape(-1, 1); y_j = ts_subject[:, j]
                mi_val = mutual_info_regression(X_i, y_j, n_neighbors=n_neighbors, random_state=42, discrete_features=False)[0]
                mi_matrix[i, j] = mi_val; mi_matrix[j, i] = mi_val
            except Exception as e: logger.error(f"MI_KNN err S {sid}, pair ({i},{j}): {e}", exc_info=False); mi_matrix[i,j]=0.0; mi_matrix[j,i]=0.0
    return mi_matrix

def calculate_custom_dfc_abs_diff_mean(ts_subject: np.ndarray, win_points: int, step: int, sid: str) -> Optional[np.ndarray]:
    n_tp, n_rois = ts_subject.shape
    if n_tp < win_points: logger.warning(f"dFC (S {sid}): TP ({n_tp}) < win_len ({win_points})."); return None
    num_windows = (n_tp - win_points) // step + 1
    if num_windows < 2: logger.warning(f"dFC (S {sid}): Insufficient windows ({num_windows})."); return None
    sum_abs_diff_matrix = np.zeros((n_rois, n_rois), dtype=np.float64)
    n_diffs_calculated = 0; prev_corr_matrix_abs = None
    for idx in range(num_windows):
        start_idx = idx * step; end_idx = start_idx + win_points
        window_ts = ts_subject[start_idx:end_idx, :]
        if window_ts.shape[0] < 2: continue
        try:
            corr_matrix_window = np.corrcoef(window_ts, rowvar=False)
            if corr_matrix_window.ndim<2 or corr_matrix_window.shape!=(n_rois,n_rois): logger.warning(f"dFC (S {sid}) win {idx}: Invalid corr shape {corr_matrix_window.shape}. Zeros."); corr_matrix_window=np.full((n_rois,n_rois),0.0,dtype=np.float32)
            else: corr_matrix_window = np.nan_to_num(corr_matrix_window.astype(np.float32), nan=0.0)
            current_corr_matrix_abs = np.abs(corr_matrix_window); np.fill_diagonal(current_corr_matrix_abs, 0)
            if prev_corr_matrix_abs is not None: sum_abs_diff_matrix += np.abs(current_corr_matrix_abs - prev_corr_matrix_abs); n_diffs_calculated +=1
            prev_corr_matrix_abs = current_corr_matrix_abs
        except Exception as e: logger.error(f"dFC (S {sid}) win {idx} err: {e}")
    if n_diffs_calculated == 0: logger.warning(f"dFC (S {sid}): No valid diffs. None."); return None
    mean_abs_diff_matrix = (sum_abs_diff_matrix / n_diffs_calculated).astype(np.float32); np.fill_diagonal(mean_abs_diff_matrix,0)
    return mean_abs_diff_matrix

def calculate_lasso_var_connectivity_matrix(ts: np.ndarray, max_lag: int, sid: str) -> Optional[np.ndarray]:
    n_tp, n_rois = ts.shape; p = max_lag
    if n_tp <= p: logger.warning(f"LassoVAR (S {sid}): Insufficient TP ({n_tp}) for lag={p}."); return None
    if n_tp < p*n_rois/5: logger.warning(f"LassoVAR (S {sid}): TP ({n_tp}) very low for ROIs*lag ({n_rois*p}). Unstable model likely.")
    elif n_tp < p*n_rois: logger.info(f"LassoVAR (S {sid}): TP ({n_tp}) low for ROIs*lag ({n_rois*p}). Relies on Lasso.")
    Y = ts[p:,:].astype(np.float32); X_list = [ts[p-lag_idx:n_tp-lag_idx,:] for lag_idx in range(1,p+1)]; X = np.hstack(X_list).astype(np.float32)
    scaler_X = StandardScaler(); scaler_Y = StandardScaler()
    try: X_scaled = scaler_X.fit_transform(X); Y_scaled = scaler_Y.fit_transform(Y)
    except ValueError as e: logger.error(f"LassoVAR (S {sid}): Scaling failed - {e}. Check zero-var ROIs."); return None
    try:
        model = MultiTaskLassoCV(cv=5,n_jobs=1,max_iter=LASSO_MAX_ITER,tol=LASSO_TOL,random_state=42,selection='random',verbose=False)
        model.fit(X_scaled, Y_scaled)
    except Exception as e: logger.error(f"LassoVAR (S {sid}): Fit failed – {e}", exc_info=False); return None
    var_coeffs = model.coef_; influence_matrix = np.zeros((n_rois,n_rois),dtype=np.float32)
    for target_idx in range(n_rois):
        for source_idx in range(n_rois):
            coeffs = [var_coeffs[target_idx, source_idx+(lag_k*n_rois)] for lag_k in range(p)]
            influence_matrix[target_idx,source_idx] = np.sum(np.abs(coeffs))
    np.fill_diagonal(influence_matrix,0.0)
    return influence_matrix

def calculate_pearson_connectivity(ts_subject: np.ndarray, sid: str) -> Optional[np.ndarray]:
    if ts_subject.shape[0] < 2: logger.warning(f"Pearson (S {sid}): <2 TPs ({ts_subject.shape[0]})."); return None
    try:
        conn_measure = ConnectivityMeasure(kind='correlation',vectorize=False,discard_diagonal=False)
        corr_matrix = conn_measure.fit_transform([ts_subject])[0]
        return fisher_r_to_z(corr_matrix)
    except Exception as e: logger.error(f"Err Pearson conn S {sid}: {e}", exc_info=True); return None

# --- 4. Function to Calculate All Connectivity Modalities ---
def calculate_all_connectivity_modalities_for_subject(
    subject_id: str, subject_ts_data: np.ndarray, n_neighbors_mi_param: int,
    dfc_win_points_param: int, dfc_step_param: int,
    lasso_lag_param: int, # Added apply_hrf_deconv and hrf_model_type implicitly through CONNECTIVITY_CHANNEL_NAMES
    connectivity_channel_names_list: List[str] # Pass the dynamic list
) -> Dict[str, Any]:
    """Calculates all specified connectivity modalities for a single subject."""
    matrices = {name: None for name in connectivity_channel_names_list} # Use passed list
    errors_in_calculation = {}
    timings = {}

    # Pearson
    try:
        matrices[connectivity_channel_names_list[0]] = calculate_pearson_connectivity(subject_ts_data, subject_id)
        if matrices[connectivity_channel_names_list[0]] is None: errors_in_calculation[connectivity_channel_names_list[0]] = "Calc returned None."
    except Exception as e: errors_in_calculation[connectivity_channel_names_list[0]] = str(e); logger.error(f"Err Pearson {subject_id}: {e}")

    # MI_KNN
    try:
        matrices[connectivity_channel_names_list[1]] = calculate_mi_knn_connectivity(subject_ts_data, n_neighbors_mi_param, subject_id)
        if matrices[connectivity_channel_names_list[1]] is None: errors_in_calculation[connectivity_channel_names_list[1]] = "Calc returned None."
    except Exception as e: errors_in_calculation[connectivity_channel_names_list[1]] = str(e); logger.error(f"Err MI_KNN {subject_id}: {e}")

    # dFC
    try:
        matrices[connectivity_channel_names_list[2]] = calculate_custom_dfc_abs_diff_mean(subject_ts_data, dfc_win_points_param, dfc_step_param, subject_id)
        if matrices[connectivity_channel_names_list[2]] is None: errors_in_calculation[connectivity_channel_names_list[2]] = "Calc returned None."
    except Exception as e: errors_in_calculation[connectivity_channel_names_list[2]] = str(e); logger.error(f"Err dFC {subject_id}: {e}")

    # Lasso-VAR (name is dynamic)
    var_channel_name = connectivity_channel_names_list[3] # The VAR channel
    lasso_start_time = time.time()
    try:
        matrices[var_channel_name] = calculate_lasso_var_connectivity_matrix(subject_ts_data, lasso_lag_param, subject_id)
        if matrices[var_channel_name] is None: errors_in_calculation[var_channel_name] = "Calc returned None."
    except Exception as e: errors_in_calculation[var_channel_name] = str(e); logger.error(f"Err {var_channel_name} {subject_id}: {e}", exc_info=True)
    timings["lasso_var_time_sec"] = time.time() - lasso_start_time
    if matrices[var_channel_name] is not None: logger.info(f"{var_channel_name} {subject_id} took {timings['lasso_var_time_sec']:.2f}s.")
    else: logger.warning(f"{var_channel_name} {subject_id} failed. Took {timings['lasso_var_time_sec']:.2f}s.")

    num_successful = sum(1 for mat in matrices.values() if mat is not None)
    if num_successful < N_CHANNELS: logger.warning(f"Conn S {subject_id}: {num_successful}/{N_CHANNELS} computed. Err: {errors_in_calculation}")

    return {"matrices": matrices, "errors_conn_calc": errors_in_calculation, "timings_conn_calc": timings}


# --- 5. Per-Subject Processing Pipeline (for Multiprocessing) ---
def process_single_subject_pipeline(subject_row_tuple: Tuple[int, pd.Series]) -> Dict[str, Any]:
    idx, subject_row = subject_row_tuple
    subject_id = str(subject_row['SubjectID']).strip()
    process = psutil.Process(os.getpid()); ram_initial_mb = process.memory_info().rss/(1024**2)
    result = {"id":subject_id, "status_preprocessing":"PENDING", "detail_preprocessing":"",
              "status_connectivity_calc":"NOT_ATTEMPTED", "errors_connectivity_calc":{},
              "timings_connectivity_calc_sec":{}, "path_saved_tensor":None,
              "status_overall":"PENDING", "ram_usage_mb_initial":ram_initial_mb, "ram_usage_mb_final":-1.0}
    series_data=None; calculated_matrices=None; final_matrices_to_stack=None
    try:
        series_data, detail_msg_preproc, success_preproc = load_and_preprocess_single_subject_series(
            subject_id, N_ROIS_EXPECTED, TARGET_LEN_TS, ROI_SIGNALS_DIR_PATH,
            ROI_FILENAME_TEMPLATE, POSSIBLE_ROI_KEYS, LASSO_VAR_MAX_LAG,
            TR_SECONDS, LOW_CUT_HZ, HIGH_CUT_HZ, FILTER_ORDER,
            APPLY_HRF_DECONVOLUTION, HRF_MODEL # Pass deconv params
        )
        result["status_preprocessing"] = "SUCCESS" if success_preproc else "FAILED"
        result["detail_preprocessing"] = detail_msg_preproc
        if not success_preproc or series_data is None: result["status_overall"]="PREPROCESSING_FAILED"; return result

        # Use the globally defined CONNECTIVITY_CHANNEL_NAMES which is now dynamic
        connectivity_results = calculate_all_connectivity_modalities_for_subject(
            subject_id, series_data, N_NEIGHBORS_MI,
            DFC_WIN_POINTS, DFC_STEP, LASSO_VAR_MAX_LAG,
            CONNECTIVITY_CHANNEL_NAMES # Pass the dynamic list
        )
        del series_data; series_data=None; gc.collect()
        calculated_matrices = connectivity_results["matrices"]
        result["errors_connectivity_calc"]=connectivity_results["errors_conn_calc"]
        result["timings_connectivity_calc_sec"]=connectivity_results.get("timings_conn_calc",{})

        all_modalities_valid = True; final_matrices_to_stack = []
        for channel_name in CONNECTIVITY_CHANNEL_NAMES: # Iterate using dynamic list
            matrix = calculated_matrices.get(channel_name)
            if matrix is None: all_modalities_valid=False; err_msg=f"Modality '{channel_name}' not computed."; logger.error(f"S {subject_id}: {err_msg}"); result["errors_connectivity_calc"][channel_name]=result["errors_connectivity_calc"].get(channel_name,"")+" | "+err_msg; break
            elif matrix.shape!=(N_ROIS_EXPECTED,N_ROIS_EXPECTED): all_modalities_valid=False; err_msg=f"Modality '{channel_name}' incorrect shape {matrix.shape}."; logger.error(f"S {subject_id}: {err_msg}"); result["errors_connectivity_calc"][channel_name]=result["errors_connectivity_calc"].get(channel_name,"")+" | "+err_msg; break
            else: final_matrices_to_stack.append(matrix)

        if all_modalities_valid and len(final_matrices_to_stack)==N_CHANNELS:
            result["status_connectivity_calc"]="SUCCESS_ALL_MODALITIES_VALID"
            try:
                subject_tensor = np.stack(final_matrices_to_stack,axis=0).astype(np.float32)
                del final_matrices_to_stack, calculated_matrices; final_matrices_to_stack,calculated_matrices=None,None; gc.collect()
                output_dir_ind_tensors = BASE_PATH/OUTPUT_CONNECTIVITY_DIR_NAME/"individual_subject_tensors"
                output_path = output_dir_ind_tensors/f"tensor_{N_CHANNELS}ch_{subject_id}.npz"
                np.savez_compressed(output_path, tensor_data=subject_tensor, subject_id=subject_id, channel_names=np.array(CONNECTIVITY_CHANNEL_NAMES,dtype=str))
                result["path_saved_tensor"]=str(output_path); result["status_overall"]="SUCCESS_ALL_PROCESSED_AND_SAVED"
                del subject_tensor; subject_tensor=None; gc.collect()
            except Exception as e: logger.error(f"Err saving tensor S {subject_id}: {e}",exc_info=True); result["errors_connectivity_calc"]["save_error"]=str(e); result["status_overall"]="FAILURE_DURING_TENSOR_SAVING"; result["status_connectivity_calc"]="FAILURE_DURING_SAVING"
        else:
            result["status_overall"]="FAILURE_IN_CONNECTIVITY_CALC_OR_VALIDATION"; result["status_connectivity_calc"]="FAILURE_MISSING_OR_INVALID_MODALITIES"
            if not all_modalities_valid: logger.error(f"S {subject_id}: Not all modalities valid. Tensor not saved. Err: {result['errors_connectivity_calc']}")
    finally:
        if series_data is not None: del series_data; gc.collect()
        if calculated_matrices is not None: del calculated_matrices; gc.collect()
        if final_matrices_to_stack is not None: del final_matrices_to_stack; gc.collect()
        result["ram_usage_mb_final"] = process.memory_info().rss/(1024**2)
    return result

# --- 6. Main Script Execution Flow ---
def main():
    script_start_time = time.time()
    main_process_info = psutil.Process(os.getpid())
    logger.info(f"Main RAM start: {main_process_info.memory_info().rss / (1024**2):.2f} MB")
    logger.info(f"--- Starting fMRI Pipeline: {OUTPUT_CONNECTIVITY_DIR_NAME} ---") # Uses dynamic name

    if not BASE_PATH.exists() or not ROI_SIGNALS_DIR_PATH.exists(): logger.critical(f"CRITICAL: Base or ROI signals dir not found. Aborting."); return
    subject_metadata_df = load_metadata(SUBJECT_CSV_PATH, STATS_CSV_PATH)
    if subject_metadata_df is None or subject_metadata_df.empty: logger.critical("Metadata failed or no subjects. Aborting."); return

    output_main_directory = BASE_PATH / OUTPUT_CONNECTIVITY_DIR_NAME # Uses dynamic name
    output_individual_tensors_dir = output_main_directory / "individual_subject_tensors"
    try:
        output_main_directory.mkdir(parents=True, exist_ok=True)
        output_individual_tensors_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Output dir: {output_main_directory}")
    except OSError as e: logger.critical(f"Could not create output dirs: {e}. Aborting."); return

    total_cpu_cores = multiprocessing.cpu_count(); MAX_WORKERS = 3
    logger.info(f"CPUs: {total_cpu_cores}. Max workers: {MAX_WORKERS}.")
    available_ram_gb = psutil.virtual_memory().available/(1024**3); logger.warning(f"Available RAM: {available_ram_gb:.2f} GB. Monitor usage.")

    subject_rows_to_process = list(subject_metadata_df.iterrows()); num_subjects_to_process = len(subject_rows_to_process)
    if num_subjects_to_process == 0: logger.critical("No subjects to process. Aborting."); return
    logger.info(f"Starting parallel processing for {num_subjects_to_process} subjects.")
    all_subject_results = []

    with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_s_id = {executor.submit(process_single_subject_pipeline,s_tuple):str(s_tuple[1]['SubjectID']).strip() for s_tuple in subject_rows_to_process}
        for future in tqdm(as_completed(future_to_s_id), total=num_subjects_to_process, desc="Processing Subjects"):
            s_id_log = future_to_s_id[future]
            try: all_subject_results.append(future.result())
            except Exception as exc: logger.critical(f"CRITICAL WORKER EXCEPTION S {s_id_log}: {exc}",exc_info=True); all_subject_results.append({"id":s_id_log, "status_overall":"CRITICAL_WORKER_EXCEPTION", "detail_preprocessing":str(exc), "errors_connectivity_calc":{"worker_exception":str(exc)}})

    processing_log_df = pd.DataFrame(all_subject_results)
    log_file_path = output_main_directory/f"pipeline_log_{output_main_directory.name}.csv" # Use dynamic name
    try: processing_log_df.to_csv(log_file_path, index=False); logger.info(f"Processing log: {log_file_path}")
    except Exception as e: logger.error(f"Failed to save log: {e}")

    successful_subject_entries = [res for res in all_subject_results if res.get("status_overall")=="SUCCESS_ALL_PROCESSED_AND_SAVED" and res.get("path_saved_tensor") and Path(res["path_saved_tensor"]).exists()]
    num_successful_subjects = len(successful_subject_entries)
    logger.info(f"--- Summary ---"); logger.info(f"Total processed: {num_subjects_to_process}"); logger.info(f"Successful & saved: {num_successful_subjects}")
    if num_successful_subjects < num_subjects_to_process: logger.warning(f"{num_subjects_to_process-num_successful_subjects} subjects failed. Check log.")

    if num_successful_subjects > 0:
        logger.info(f"Assembling global tensor for {num_successful_subjects} subjects.")
        global_conn_tensor_list = []; final_s_ids_in_tensor = []
        try:
            for s_entry in tqdm(successful_subject_entries, desc="Assembling Global Tensor"):
                s_id = s_entry["id"]; tensor_path = Path(s_entry["path_saved_tensor"])
                try:
                    loaded_npz = np.load(tensor_path); s_tensor_data = loaded_npz['tensor_data']
                    expected_shape = (N_CHANNELS, N_ROIS_EXPECTED, N_ROIS_EXPECTED)
                    if s_tensor_data.shape == expected_shape: global_conn_tensor_list.append(s_tensor_data); final_s_ids_in_tensor.append(s_id)
                    else: logger.error(f"Tensor S {s_id} at {tensor_path} incorrect shape: {s_tensor_data.shape}. Skipping.")
                    del loaded_npz, s_tensor_data; gc.collect()
                except Exception as e: logger.error(f"Err loading tensor S {s_id} from {tensor_path}: {e}. Skipping.")
            if global_conn_tensor_list:
                global_conn_tensor = np.stack(global_conn_tensor_list,axis=0).astype(np.float32); del global_conn_tensor_list; gc.collect()
                global_tensor_fname = f"GLOBAL_TENSOR_AAL{N_ROIS_EXPECTED}_{len(final_s_ids_in_tensor)}subs_{N_CHANNELS}ch_{LASSO_VAR_MAX_LAG}lag{deconv_str}.npz" # Use deconv_str
                global_tensor_path = output_main_directory/global_tensor_fname
                np.savez_compressed(global_tensor_path, global_tensor_data=global_conn_tensor, subject_ids=np.array(final_s_ids_in_tensor,dtype=str), channel_names=np.array(CONNECTIVITY_CHANNEL_NAMES,dtype=str)) # Use dynamic CONNECTIVITY_CHANNEL_NAMES
                logger.info(f"Global tensor saved: {global_tensor_path}. Shape: {global_conn_tensor.shape}")
            else: logger.warning("No valid individual tensors for global assembly.")
        except MemoryError: logger.critical(f"MEMORY ERROR global tensor assembly. Individual tensors in: {output_individual_tensors_dir}.")
        except Exception as e: logger.critical(f"Unexpected error global assembly: {e}", exc_info=True)

    total_time_min = (time.time()-script_start_time)/60
    logger.info(f"--- Pipeline Finished ---"); logger.info(f"Total time: {total_time_min:.2f} minutes.")
    logger.info(f"Final main RAM: {main_process_info.memory_info().rss/(1024**2):.2f} MB")
    logger.info(f"Outputs in: {output_main_directory}") # Uses dynamic name

if __name__ == "__main__":
    multiprocessing.freeze_support()
    main()

2025-05-25 00:37:11,900 - INFO - 316480251.py:529 - Main RAM start: 239.30 MB
2025-05-25 00:37:11,900 - INFO - 316480251.py:530 - --- Starting fMRI Pipeline: AAL_116_fmri_tensor_1lag_NeuroEnhanced_v4 ---
2025-05-25 00:37:11,901 - INFO - 316480251.py:94 - --- Starting Subject Metadata Loading ---
2025-05-25 00:37:11,908 - INFO - 316480251.py:117 - Metadata DataFrame 'merged_df' loaded successfully. Shape: (352, 27)
2025-05-25 00:37:11,908 - INFO - 316480251.py:118 - Total unique subjects in merged_df: 352
2025-05-25 00:37:11,909 - INFO - 316480251.py:541 - Output dir: /home/diego/Escritorio/desde_cero/AAL_116/AAL_116_fmri_tensor_1lag_NeuroEnhanced_v4
2025-05-25 00:37:11,909 - INFO - 316480251.py:545 - CPUs: 12. Max workers: 3.
2025-05-25 00:37:11,920 - INFO - 316480251.py:550 - Starting parallel processing for 352 subjects.
Processing Subjects:   0%|          | 0/352 [00:00<?, ?it/s]2025-05-25 00:37:12,043 - INFO - 316480251.py:247 - S 002_S_0295: Preprocessing. Orig len: 140, TR: 3.0s.