In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import mutual_info_regression
from nilearn.connectome import ConnectivityMeasure
from nilearn.glm.first_level import spm_hrf, glover_hrf
from scipy.signal import butter, filtfilt, deconvolve
from tqdm import tqdm
import os
import scipy.io as sio
from pathlib import Path
import psutil
import gc
import logging
import time
from typing import List, Tuple, Dict, Optional, Any
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed
from sklearn.linear_model import MultiTaskLassoCV

# Import for OMST - ensure 'topological_filtering_networks' is installed
try:
    from omst.omst import orthogonal_minimum_spanning_tree
except ImportError:
    print("ERROR: The 'omst' package (topological_filtering_networks) is not installed. Please install it: pip install git+https://github.com/stdimitr/topological_filtering_networks.git")
    orthogonal_minimum_spanning_tree = None # Placeholder if not installed


# --- 0. Global Configuration and Constants ---

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
logger = logging.getLogger(__name__)

# --- Neuroscientific Context & Recommendations Summary (Updated for Luppi et al. 2024) ---
# This script aims to generate multi-modal brain connectivity matrices from fMRI time series
# using the AAL3 atlas (166 ROIs, further refined by QC).
# These matrices can serve as features for machine learning models to classify neurological conditions.
#
# CRITICAL FOR THESIS: This script should be run AFTER your QC pipeline.
# The `load_metadata` function is designed to read the output of your QC script
# to ensure only quality-controlled subjects are processed.
#
# KEY UPDATE (Luppi et al., 2024, Nature Communications):
# This version incorporates recommendations for constructing static functional connectivity (sFC)
# matrices that demonstrated high consistency and reliability in a systematic evaluation
# of 768 pipelines. The core recommended sFC pipeline is:
# 1. Pearson Correlation
# 2. Fisher R-to-Z transformation
# 3. Orthogonal Minimum Spanning Tree (OMST) filtering (on absolute Fisher-Z values)
# 4. Using weighted (not binary) matrices.
# This "Pearson_OMST_Weighted" channel is now the primary sFC modality.

### NEURO-ENHANCEMENT: Summary of evidence-based recommendations:
# | Aspect                      | Recommended Practice                                      | Rationale / References                                                                                                                                  |
# |-----------------------------|-----------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|
# | Data Input                  | Use data that has passed rigorous QC                    | Ensures reliability and validity of connectivity findings. (Your QC script output is key).                                                              |
# | Atlas Consideration         | AAL3 (e.g., 131 ROIs after your QC filtering)           | Document the final ROI set used after all filtering (AAL3 base, missing ROIs, small volume ROIs).                                                       |
# | **Static FC (Primary)** | **Pearson -> Fisher-Z -> abs -> OMST -> Weighted** | **One of 9 optimal pipelines from Luppi et al. (2024) for consistent functional connectomics.** |
# | Band-pass Filtering         | Yes, 0.01–0.08 Hz before connectivity                   | Eliminates scanner drift (<0.01 Hz) & physiological noise (>0.08 Hz). (PMID: 17266107)                                                                   |
# | k in MI (k-NN) (Optional)   | k = 5 (e.g., KSG estimator)                             | k≈3–10 common; k=5 offers good bias-variance balance for N~171+ TPs. (PMID: 15042233)                                                                   |
# | Max VAR Order (Lag)         | Lag 1 preferred, especially for typical fMRI TS         | For TR~2–3s, optimal order often 1. Higher orders w/ short TS & many ROIs risk overfitting. (PMID: 21833123)                                           |
# | HRF Deconvolution           | Optional, for neuronal-level VAR. Use with caution.     | Attempts to recover neural signal from BOLD. Can improve EC interpretability but is sensitive to noise & HRF model. Validate carefully. (PMID: 17329248)|
# | Time Series Length          | Consistent length for analysis (e.g., `TARGET_LEN_TS`)    | Ensure `TARGET_LEN_TS` is appropriate for subjects passing QC.                                                                                          |


# --- Configurable Parameters ---
BASE_PATH_AAL3 = Path('/home/diego/Escritorio/AAL3')
QC_OUTPUT_DIR = BASE_PATH_AAL3 / 'qc_outputs_doctoral_v3.2_aal3_shrinkage_flexible_thresh_fix' # Points to v3.2 QC output
SUBJECT_METADATA_CSV_PATH = BASE_PATH_AAL3 / 'SubjctsDataAndTests_Schaefer2018_400Parcels_17Networks.csv'
QC_REPORT_CSV_PATH = QC_OUTPUT_DIR / 'report_qc_final_with_discard_flags_v3.2.csv'
ROI_SIGNALS_DIR_PATH_AAL3 = BASE_PATH_AAL3 / 'ROISignals_AAL3_NiftiPreprocessedAllBatchesNorm'
ROI_FILENAME_TEMPLATE = 'ROISignals_{subject_id}.mat'

TR_SECONDS = 3.0
LOW_CUT_HZ = 0.01
HIGH_CUT_HZ = 0.08
FILTER_ORDER = 2

N_ROIS_EXPECTED = 131 # Based on your QC script v3.2 output (ROIs_for_MV = 131)
TARGET_LEN_TS = 140   # Based on your QC script v3.2 (TP_THRESHOLD_VALUE = 140)

N_NEIGHBORS_MI = 5
DFC_WIN_POINTS = 20
DFC_STEP = 10
LASSO_VAR_MAX_LAG = 1
APPLY_HRF_DECONVOLUTION = False
HRF_MODEL = 'glover'
LASSO_MAX_ITER = 2500
LASSO_TOL = 1e-4

deconv_str = "_deconv" if APPLY_HRF_DECONVOLUTION else ""
OUTPUT_CONNECTIVITY_DIR_NAME = f"AAL3_{N_ROIS_EXPECTED}_fmri_tensor_LuppiOMST_{LASSO_VAR_MAX_LAG}lag{deconv_str}_NeuroEnhanced_v6.1" # Version increment

POSSIBLE_ROI_KEYS = ["signals", "ROISignals", "roi_signals", "ROIsignals_AAL3", "AAL3_signals", "roi_ts"]

# --- Define which connectivity channels to compute ---
PEARSON_OMST_CHANNEL_NAME = "Pearson_OMST_Weighted" # Luppi et al. (2024) recommended

# Flags to control optional channels, as per user instruction for thesis "core"
USE_MI_CHANNEL_FOR_THESIS = False
USE_DFC_CHANNEL_FOR_THESIS = True
USE_LASSO_VAR_FOR_THESIS = True

# Dynamically build the list of channel names to be processed
CONNECTIVITY_CHANNEL_NAMES = [PEARSON_OMST_CHANNEL_NAME] # Start with the primary channel
if USE_MI_CHANNEL_FOR_THESIS:
    CONNECTIVITY_CHANNEL_NAMES.append("MI_KNN")
if USE_DFC_CHANNEL_FOR_THESIS:
    CONNECTIVITY_CHANNEL_NAMES.append("dFC_AbsDiffMean")
if USE_LASSO_VAR_FOR_THESIS:
    var_channel_suffix_dynamic = "_Deconv_Influence" if APPLY_HRF_DECONVOLUTION else "_Influence"
    CONNECTIVITY_CHANNEL_NAMES.append(f"Lasso_VAR{var_channel_suffix_dynamic}")

N_CHANNELS = len(CONNECTIVITY_CHANNEL_NAMES)
logger.info(f"Connectivity channels to be computed: {CONNECTIVITY_CHANNEL_NAMES}")
logger.info(f"Total number of channels: {N_CHANNELS}")


# --- 1. Subject Metadata Loading and Merging ---
def load_metadata(
    subject_meta_csv_path: Path,
    qc_report_csv_path: Path
) -> Optional[pd.DataFrame]:
    """Loads subject metadata, merges with QC results, and filters subjects."""
    logger.info("--- Starting Subject Metadata Loading and QC Integration ---")
    try:
        if not subject_meta_csv_path.exists():
            logger.critical(f"Subject metadata CSV file NOT found: {subject_meta_csv_path}")
            return None
        if not qc_report_csv_path.exists():
            logger.critical(f"QC report CSV file NOT found: {qc_report_csv_path}")
            return None

        subjects_db_df = pd.read_csv(subject_meta_csv_path)
        subjects_db_df['SubjectID'] = subjects_db_df['SubjectID'].astype(str).str.strip()
        logger.info(f"Loaded main metadata from {subject_meta_csv_path}. Shape: {subjects_db_df.shape}")
        if 'SubjectID' not in subjects_db_df.columns:
            logger.critical("Column 'SubjectID' missing in main metadata CSV.")
            return None
        if 'ResearchGroup' not in subjects_db_df.columns:
            logger.warning("Column 'ResearchGroup' missing in main metadata CSV. May be needed for downstream VAE tasks.")

        qc_df = pd.read_csv(qc_report_csv_path)
        qc_df['SubjectID'] = qc_df['SubjectID'].astype(str).str.strip()
        logger.info(f"Loaded QC report from {qc_report_csv_path}. Shape: {qc_df.shape}")
        # Ensure 'TimePoints' from QC script (not 'Timepoints_final' from previous version of this script)
        if not all(col in qc_df.columns for col in ['SubjectID', 'ToDiscard_Overall', 'TimePoints']):
            logger.critical("Essential columns ('SubjectID', 'ToDiscard_Overall', 'TimePoints') missing in QC report CSV.")
            return None

        merged_df = pd.merge(subjects_db_df, qc_df, on='SubjectID', how='inner', suffixes=('_meta', '_qc'))
        
        # Determine the definitive 'Timepoints' column, prioritizing from QC
        if 'TimePoints_qc' in merged_df.columns: # If QC report had 'TimePoints' and it got suffixed
            merged_df['Timepoints_final_for_script'] = merged_df['TimePoints_qc']
        elif 'TimePoints' in merged_df.columns: # If QC report had 'TimePoints' and it was not suffixed (or was the only one)
             merged_df['Timepoints_final_for_script'] = merged_df['TimePoints']
        else:
            logger.critical("Definitive 'TimePoints' column from QC report could not be identified after merge.")
            return None
        
        merged_df['Timepoints_final_for_script'] = pd.to_numeric(merged_df['Timepoints_final_for_script'], errors='coerce').fillna(0).astype(int)

        initial_subject_count = len(merged_df)
        subjects_passing_qc_df = merged_df[merged_df['ToDiscard_Overall'] == False].copy()
        num_discarded = initial_subject_count - len(subjects_passing_qc_df)
        logger.info(f"Total subjects after merge: {initial_subject_count}")
        logger.info(f"Subjects discarded based on QC ('ToDiscard_Overall' == True): {num_discarded}")
        logger.info(f"Subjects passing QC and to be processed: {len(subjects_passing_qc_df)}")

        if subjects_passing_qc_df.empty:
            logger.warning("No subjects passed QC. Check your QC criteria and report.")
            return None
            
        min_tp_after_qc = subjects_passing_qc_df['Timepoints_final_for_script'].min()
        max_tp_after_qc = subjects_passing_qc_df['Timepoints_final_for_script'].max()
        logger.info(f"Timepoints for subjects passing QC (from QC report): Min={min_tp_after_qc}, Max={max_tp_after_qc}.")
        logger.info(f"These will be homogenized to TARGET_LEN_TS = {TARGET_LEN_TS} for connectivity calculation.")

        final_cols_to_keep = ['SubjectID', 'Timepoints_final_for_script']
        if 'ResearchGroup' in subjects_passing_qc_df.columns:
            final_cols_to_keep.append('ResearchGroup')
        else:
            logger.warning("Creating placeholder 'ResearchGroup' column as it was not found.")
            subjects_passing_qc_df['ResearchGroup'] = 'Unknown' # Ensure it exists
            final_cols_to_keep.append('ResearchGroup')
        
        # Rename 'Timepoints_final_for_script' to 'Timepoints' for consistency downstream in this script
        subjects_passing_qc_df.rename(columns={'Timepoints_final_for_script': 'Timepoints'}, inplace=True)
        
        return subjects_passing_qc_df[final_cols_to_keep]

    except FileNotFoundError as e:
        logger.critical(f"CRITICAL Error loading CSV files: {e}")
        return None
    except ValueError as e:
        logger.critical(f"Value error in metadata processing: {e}")
        return None
    except Exception as e:
        logger.critical(f"Unexpected error during metadata loading/QC integration: {e}", exc_info=True)
        return None

# --- 2. Time Series Loading and Preprocessing Functions ---
def _load_signals_from_mat(mat_path: Path, possible_keys: List[str]) -> Optional[np.ndarray]:
    try:
        data = sio.loadmat(mat_path)
    except Exception as e_load:
        logger.error(f"Could not load .mat file: {mat_path}. Error: {e_load}")
        return None
    for key in possible_keys:
        if key in data and isinstance(data[key], np.ndarray) and data[key].ndim >= 2:
            logger.debug(f"Found signals under key '{key}' in {mat_path}.")
            return data[key]
    logger.warning(f"No valid signal keys {possible_keys} found in {mat_path}. Keys present: {list(data.keys())}")
    return None

def _orient_and_validate_signals(sigs: np.ndarray, n_rois_expected_val: int, subject_id: str) -> Optional[np.ndarray]:
    if sigs.ndim != 2:
        logger.warning(f"S {subject_id}: Signal matrix has incorrect dimensions {sigs.ndim} (expected 2).")
        return None
    if sigs.shape[0] == n_rois_expected_val and sigs.shape[1] != n_rois_expected_val:
        logger.info(f"S {subject_id}: Transposing matrix from {sigs.shape} to ({sigs.shape[1]}, {sigs.shape[0]}).")
        sigs = sigs.T
    elif sigs.shape[1] == n_rois_expected_val and sigs.shape[0] != n_rois_expected_val:
        pass 
    elif sigs.shape[0] == n_rois_expected_val and sigs.shape[1] == n_rois_expected_val:
         logger.warning(f"S {subject_id}: Signal matrix is square ({sigs.shape}) and matches N_ROIS_EXPECTED. Assuming [Timepoints, ROIs]. Verify assumption if TPs can also be {n_rois_expected_val}.")
    else: 
        logger.warning(f"S {subject_id}: Neither dimension of signal matrix ({sigs.shape}) matches N_ROIS_EXPECTED ({n_rois_expected_val}).")
        return None
    if sigs.shape[1] != n_rois_expected_val: 
        logger.warning(f"S {subject_id}: After orientation, ROI count ({sigs.shape[1]}) != N_ROIS_EXPECTED ({n_rois_expected_val}).")
        return None
    return sigs

def _bandpass_filter_signals(sigs: np.ndarray, lowcut: float, highcut: float, fs: float, order: int, subject_id: str) -> np.ndarray:
    nyquist_freq = 0.5 * fs
    low_norm = lowcut / nyquist_freq
    high_norm = highcut / nyquist_freq
    if not (0 < low_norm < 1 and 0 < high_norm < 1 and low_norm < high_norm):
        logger.error(f"S {subject_id}: Invalid critical frequencies for filter. Skipping filtering.")
        return sigs
    try:
        b, a = butter(order, [low_norm, high_norm], btype='band', analog=False)
        filtered_sigs = np.zeros_like(sigs)
        padlen_required = 3 * max(len(a), len(b))
        for i in range(sigs.shape[1]):
            roi_signal = sigs[:, i]
            if np.all(np.isclose(roi_signal, roi_signal[0] if len(roi_signal)>0 else 0)):
                logger.debug(f"S {subject_id}, ROI {i}: Signal constant. Skipping filter.")
                filtered_sigs[:, i] = roi_signal
            elif len(roi_signal) <= padlen_required :
                logger.warning(f"S {subject_id}, ROI {i}: Signal too short ({len(roi_signal)} pts) for filtfilt. Skipping filter.")
                filtered_sigs[:, i] = roi_signal
            else:
                filtered_sigs[:, i] = filtfilt(b, a, roi_signal)
        return filtered_sigs
    except Exception as e:
        logger.error(f"S {subject_id}: Error during bandpass filtering: {e}. Returning original.", exc_info=False)
        return sigs

def _hrf_deconvolution(sigs: np.ndarray, tr: float, hrf_model_type: str, subject_id: str) -> np.ndarray:
    logger.info(f"S {subject_id}: Attempting HRF deconvolution ({hrf_model_type} model, TR={tr}s).")
    if hrf_model_type == 'glover': hrf_kernel = glover_hrf(tr, oversampling=1)
    elif hrf_model_type == 'spm': hrf_kernel = spm_hrf(tr, oversampling=1)
    else: logger.error(f"S {subject_id}: Unknown HRF model '{hrf_model_type}'. Skipping deconv."); return sigs
    if len(hrf_kernel) == 0 or np.all(np.isclose(hrf_kernel, 0)):
        logger.error(f"S {subject_id}: HRF kernel empty/zero for {hrf_model_type}. Skipping deconv."); return sigs
    deconvolved_sigs = np.zeros_like(sigs)
    for i in range(sigs.shape[1]):
        signal_roi = sigs[:, i]
        if len(signal_roi) < len(hrf_kernel):
            logger.warning(f"S {subject_id}, ROI {i}: Signal shorter than HRF. Skipping deconv.")
            deconvolved_sigs[:, i] = signal_roi; continue
        try:
            quotient, _ = deconvolve(signal_roi, hrf_kernel)
            if len(quotient) < sigs.shape[0]:
                deconvolved_sigs[:, i] = np.concatenate([quotient, np.zeros(sigs.shape[0] - len(quotient))])
            else: deconvolved_sigs[:, i] = quotient[:sigs.shape[0]]
        except Exception as e_deconv:
            logger.error(f"S {subject_id}, ROI {i}: Deconvolution failed: {e_deconv}. Using original.", exc_info=False)
            deconvolved_sigs[:, i] = signal_roi
    logger.info(f"S {subject_id}: HRF deconvolution attempt finished.")
    return deconvolved_sigs

def _preprocess_time_series(
    sigs: np.ndarray, target_len_ts_val: int, n_rois_expected_val: int,
    subject_id: str, lasso_var_max_lag_val: int,
    tr_seconds_val: float, low_cut_val: float, high_cut_val: float, filter_order_val: int,
    apply_hrf_deconv_val: bool, hrf_model_type_val: str
) -> Optional[np.ndarray]:
    original_length = sigs.shape[0]
    fs = 1.0 / tr_seconds_val
    logger.info(f"S {subject_id}: Preprocessing. Original TPs: {original_length}, TR: {tr_seconds_val}s, Target TPs: {target_len_ts_val}.")
    sigs_processed = _bandpass_filter_signals(sigs, low_cut_val, high_cut_val, fs, filter_order_val, subject_id)
    if apply_hrf_deconv_val:
        sigs_processed = _hrf_deconvolution(sigs_processed, tr_seconds_val, hrf_model_type_val, subject_id)
        if np.isnan(sigs_processed).any() or np.isinf(sigs_processed).any():
            logger.warning(f"S {subject_id}: NaNs/Infs after HRF deconvolution. Cleaning.")
            sigs_processed = np.nan_to_num(sigs_processed, nan=0.0, posinf=0.0, neginf=0.0)
    min_len_for_var = lasso_var_max_lag_val + 10
    min_len_for_dfc = DFC_WIN_POINTS if DFC_WIN_POINTS > 0 else 5
    min_overall_len = max(5, min_len_for_var, min_len_for_dfc)
    if sigs_processed.shape[0] < min_overall_len:
        logger.warning(f"S {subject_id}: TPs after processing ({sigs_processed.shape[0]}) < min required ({min_overall_len}). Skipping.")
        return None
    if np.isnan(sigs_processed).any():
        logger.warning(f"S {subject_id}: NaNs before scaling. Filling with 0.0.")
        sigs_processed = np.nan_to_num(sigs_processed, nan=0.0)
    try:
        scaler = StandardScaler()
        sigs_normalized = scaler.fit_transform(sigs_processed)
        if np.isnan(sigs_normalized).any():
            logger.warning(f"S {subject_id}: NaNs after StandardScaler. Filling with 0.0.")
            sigs_normalized = np.nan_to_num(sigs_normalized, nan=0.0, posinf=0.0, neginf=0.0)
    except ValueError: 
        logger.warning(f"S {subject_id}: StandardScaler failed. Attempting column-wise scaling.")
        sigs_normalized = np.zeros_like(sigs_processed, dtype=np.float32)
        for i in range(sigs_processed.shape[1]):
            col_data = sigs_processed[:, i].reshape(-1,1)
            if np.std(col_data) > 1e-9:
                try: sigs_normalized[:, i] = StandardScaler().fit_transform(col_data).flatten()
                except: sigs_normalized[:, i] = 0.0
            else: sigs_normalized[:, i] = 0.0
        if np.isnan(sigs_normalized).any():
            sigs_normalized = np.nan_to_num(sigs_normalized, nan=0.0, posinf=0.0, neginf=0.0)

    current_length = sigs_normalized.shape[0]
    if current_length != target_len_ts_val:
        logger.info(f"S {subject_id}: Homogenizing length from {current_length} to {target_len_ts_val}.")
        if current_length < target_len_ts_val:
            padding = np.zeros((target_len_ts_val - current_length, n_rois_expected_val), dtype=np.float32)
            sigs_homogenized = np.vstack((sigs_normalized, padding))
        else:
            sigs_homogenized = sigs_normalized[:target_len_ts_val, :]
    else:
        sigs_homogenized = sigs_normalized
    return sigs_homogenized.astype(np.float32)

def load_and_preprocess_single_subject_series(
    subject_id: str, n_rois_expected_val: int, target_len_ts_val: int,
    current_roi_signals_dir_path: Path, current_roi_filename_template: str,
    possible_roi_keys_list: List[str], lasso_var_max_lag_val: int,
    tr_seconds_val: float, low_cut_val: float, high_cut_val: float, filter_order_val: int,
    apply_hrf_deconv_val: bool, hrf_model_type_val: str
) -> Tuple[Optional[np.ndarray], str, bool]:
    mat_path = current_roi_signals_dir_path / current_roi_filename_template.format(subject_id=subject_id)
    if not mat_path.exists(): return None, f"MAT file not found: {mat_path}", False
    try:
        loaded_sigs_raw = _load_signals_from_mat(mat_path, possible_roi_keys_list)
        if loaded_sigs_raw is None: return None, f"No valid signal keys or load error in {mat_path}", False
        loaded_sigs_raw = np.asarray(loaded_sigs_raw, dtype=np.float64)
        sigs_oriented = _orient_and_validate_signals(loaded_sigs_raw, n_rois_expected_val, subject_id)
        del loaded_sigs_raw; gc.collect()
        if sigs_oriented is None: return None, f"Incorrect shape or validation failed for S {subject_id}.", False
        original_tp_count = sigs_oriented.shape[0]
        sigs_processed = _preprocess_time_series(
            sigs_oriented, target_len_ts_val, n_rois_expected_val, subject_id, lasso_var_max_lag_val,
            tr_seconds_val, low_cut_val, high_cut_val, filter_order_val,
            apply_hrf_deconv_val, hrf_model_type_val
        )
        del sigs_oriented; gc.collect()
        if sigs_processed is None: return None, f"Preprocessing failed for S {subject_id}. Original TPs: {original_tp_count}", False
        final_shape_str = f"({sigs_processed.shape[0]}, {sigs_processed.shape[1]})"
        return sigs_processed, f"OK. Original TPs: {original_tp_count}, final shape: {final_shape_str}", True
    except Exception as e:
        logger.error(f"Unhandled exception during load_and_preprocess for S {subject_id} ({mat_path}): {e}", exc_info=True)
        return None, f"Exception processing {mat_path}: {str(e)}", False

# --- 3. Connectivity Calculation Functions ---

def fisher_r_to_z(r_matrix: np.ndarray, eps: float = 1e-7) -> np.ndarray:
    r_clean = np.nan_to_num(r_matrix.astype(np.float32), nan=0.0)
    r_clipped = np.clip(r_clean, -1.0 + eps, 1.0 - eps)
    z_matrix = np.arctanh(r_clipped)
    np.fill_diagonal(z_matrix, 0.0)
    return z_matrix.astype(np.float32)

def calculate_pearson_omst(ts_subject: np.ndarray, sid: str) -> Optional[np.ndarray]:
    """
    Calculates Pearson correlation, applies Fisher-Z, takes absolute values for weights,
    and then applies OMST filtering. Returns a weighted (N_ROIS x N_ROIS) matrix.
    """
    if orthogonal_minimum_spanning_tree is None:
        logger.error(f"Pearson_OMST (S {sid}): OMST package not available. Cannot calculate this modality.")
        return None
    if ts_subject.shape[0] < 2:
        logger.warning(f"Pearson_OMST (S {sid}): Insufficient timepoints ({ts_subject.shape[0]} < 2).")
        return None
    try:
        corr_matrix = np.corrcoef(ts_subject, rowvar=False).astype(np.float32)
        if corr_matrix.ndim == 0: 
            logger.warning(f"Pearson_OMST (S {sid}): Correlation resulted in a scalar. Input shape: {ts_subject.shape}. Returning zero matrix.")
            return np.zeros((ts_subject.shape[1], ts_subject.shape[1]), dtype=np.float32) if ts_subject.shape[1] > 0 else None
        np.fill_diagonal(corr_matrix, 0.0)
        z_transformed_matrix = fisher_r_to_z(corr_matrix)
        weights_for_omst = np.abs(z_transformed_matrix)
        np.fill_diagonal(weights_for_omst, 0.0) 
        if np.all(np.isclose(weights_for_omst, 0)):
             logger.warning(f"Pearson_OMST (S {sid}): All weights for OMST are zero after abs(Fisher-Z). Returning zero matrix.")
             return weights_for_omst.astype(np.float32)
        w_omst = orthogonal_minimum_spanning_tree(weights_for_omst, cost_function='efficiency_cost')
        np.fill_diagonal(w_omst, 0.0) # Ensure diagonal is zero after OMST
        logger.info(f"Pearson_OMST (S {sid}): Successfully calculated. Matrix density: {np.count_nonzero(w_omst) / w_omst.size:.4f}")
        return w_omst.astype(np.float32)
    except Exception as e:
        logger.error(f"Error calculating Pearson_OMST connectivity for S {sid}: {e}", exc_info=True)
        return None

def calculate_mi_knn_connectivity(ts_subject: np.ndarray, n_neighbors_val: int, sid: str) -> Optional[np.ndarray]:
    n_tp, n_rois = ts_subject.shape
    if n_tp == 0: logger.warning(f"MI_KNN (S {sid}): 0 TPs."); return None
    if n_tp <= n_neighbors_val:
        logger.warning(f"MI_KNN (S {sid}): TPs ({n_tp}) <= n_neighbors ({n_neighbors_val}). Skipping.")
        return None
    mi_matrix = np.zeros((n_rois, n_rois), dtype=np.float32)
    for i in range(n_rois):
        for j in range(i + 1, n_rois):
            try:
                X_i_reshaped = ts_subject[:, i].reshape(-1, 1)
                y_j = ts_subject[:, j]
                mi_val = mutual_info_regression(X_i_reshaped, y_j, n_neighbors=n_neighbors_val, random_state=42, discrete_features=False)[0]
                mi_matrix[i, j] = mi_matrix[j, i] = mi_val
            except Exception as e:
                logger.error(f"MI_KNN (S {sid}), pair ({i},{j}): {e}", exc_info=False)
                mi_matrix[i, j] = mi_matrix[j, i] = 0.0
    return mi_matrix

def calculate_custom_dfc_abs_diff_mean(ts_subject: np.ndarray, win_points_val: int, step_val: int, sid: str) -> Optional[np.ndarray]:
    n_tp, n_rois = ts_subject.shape
    if n_tp < win_points_val: logger.warning(f"dFC (S {sid}): TPs < window. Skipping."); return None
    num_windows = (n_tp - win_points_val) // step_val + 1
    if num_windows < 2: logger.warning(f"dFC (S {sid}): <2 windows. Skipping."); return None
    sum_abs_diff_matrix = np.zeros((n_rois, n_rois), dtype=np.float64)
    n_diffs_calculated = 0
    prev_corr_matrix_abs: Optional[np.ndarray] = None
    for idx in range(num_windows):
        start_idx, end_idx = idx * step_val, start_idx + win_points_val
        window_ts = ts_subject[start_idx:end_idx, :]
        if window_ts.shape[0] < 2: continue
        try:
            corr_matrix_window = np.corrcoef(window_ts, rowvar=False)
            if corr_matrix_window.ndim < 2 or corr_matrix_window.shape != (n_rois, n_rois):
                corr_matrix_window = np.full((n_rois, n_rois), 0.0, dtype=np.float32)
            else:
                corr_matrix_window = np.nan_to_num(corr_matrix_window.astype(np.float32), nan=0.0)
            current_corr_matrix_abs = np.abs(corr_matrix_window)
            np.fill_diagonal(current_corr_matrix_abs, 0)
            if prev_corr_matrix_abs is not None:
                sum_abs_diff_matrix += np.abs(current_corr_matrix_abs - prev_corr_matrix_abs)
                n_diffs_calculated += 1
            prev_corr_matrix_abs = current_corr_matrix_abs
        except Exception as e: logger.error(f"dFC (S {sid}), Window {idx}: Error: {e}")
    if n_diffs_calculated == 0: logger.warning(f"dFC (S {sid}): No valid diffs. Returning None."); return None
    mean_abs_diff_matrix = (sum_abs_diff_matrix / n_diffs_calculated).astype(np.float32)
    np.fill_diagonal(mean_abs_diff_matrix, 0)
    return mean_abs_diff_matrix

def calculate_lasso_var_connectivity_matrix(ts: np.ndarray, max_lag_val: int, sid: str) -> Optional[np.ndarray]:
    n_tp, n_rois = ts.shape; p = max_lag_val
    if n_tp <= p: logger.warning(f"LassoVAR (S {sid}): TPs <= lag. Skipping."); return None
    if n_tp < p * n_rois / 5 : logger.warning(f"LassoVAR (S {sid}): TPs very low for ROIs*lag. Model might be unstable.")
    elif n_tp < p * n_rois: logger.info(f"LassoVAR (S {sid}): TPs < ROIs*lag. Relying on Lasso.")
    Y_data = ts[p:, :].astype(np.float32)
    X_data = np.hstack([ts[p-lag_idx : n_tp-lag_idx, :] for lag_idx in range(1, p + 1)]).astype(np.float32)
    scaler_X, scaler_Y = StandardScaler(), StandardScaler()
    try:
        X_scaled = scaler_X.fit_transform(X_data)
        Y_scaled = scaler_Y.fit_transform(Y_data)
    except ValueError as e: logger.error(f"LassoVAR (S {sid}): Scaling failed - {e}. Skipping."); return None
    try:
        model = MultiTaskLassoCV(cv=5, n_jobs=-1, max_iter=LASSO_MAX_ITER, tol=LASSO_TOL, random_state=42, selection='random', verbose=False)
        model.fit(X_scaled, Y_scaled)
    except Exception as e: logger.error(f"LassoVAR (S {sid}): Model fitting failed – {e}", exc_info=False); return None
    var_coeffs_scaled = model.coef_
    influence_matrix = np.zeros((n_rois, n_rois), dtype=np.float32)
    for target_roi_idx in range(n_rois):
        for source_roi_idx in range(n_rois):
            coeffs_for_pair = [var_coeffs_scaled[target_roi_idx, source_roi_idx + (lag_k * n_rois)] for lag_k in range(p)]
            influence_matrix[target_roi_idx, source_roi_idx] = np.sum(np.abs(coeffs_for_pair))
    np.fill_diagonal(influence_matrix, 0.0)
    return influence_matrix

# --- 4. Function to Calculate All Connectivity Modalities ---
def calculate_all_connectivity_modalities_for_subject(
    subject_id: str, subject_ts_data: np.ndarray,
    n_neighbors_mi_param: int,
    dfc_win_points_param: int, dfc_step_param: int,
    lasso_lag_param: int
) -> Dict[str, Any]:
    matrices: Dict[str, Optional[np.ndarray]] = {name: None for name in CONNECTIVITY_CHANNEL_NAMES}
    errors_in_calculation: Dict[str, str] = {}
    timings: Dict[str, float] = {}

    for channel_name in CONNECTIVITY_CHANNEL_NAMES:
        logger.info(f"Calculating {channel_name} for S {subject_id}...")
        start_time_channel = time.time()
        matrix_result: Optional[np.ndarray] = None
        error_msg: Optional[str] = None

        try:
            if channel_name == PEARSON_OMST_CHANNEL_NAME:
                matrix_result = calculate_pearson_omst(subject_ts_data, subject_id)
            elif channel_name == "MI_KNN":
                matrix_result = calculate_mi_knn_connectivity(subject_ts_data, n_neighbors_mi_param, subject_id)
            elif channel_name == "dFC_AbsDiffMean":
                matrix_result = calculate_custom_dfc_abs_diff_mean(subject_ts_data, dfc_win_points_param, dfc_step_param, subject_id)
            elif channel_name.startswith("Lasso_VAR"):
                matrix_result = calculate_lasso_var_connectivity_matrix(subject_ts_data, lasso_lag_param, subject_id)
            else:
                error_msg = f"Unknown connectivity channel: {channel_name}"
                logger.error(error_msg)

            if matrix_result is None and error_msg is None: # If function returned None without raising an exception
                error_msg = f"{channel_name} calculation returned None."
        
        except Exception as e:
            error_msg = str(e)
            logger.error(f"Error calculating {channel_name} for S {subject_id}: {e}", exc_info=True if not isinstance(e, (ValueError, TypeError)) else False)
        
        matrices[channel_name] = matrix_result
        if error_msg:
            errors_in_calculation[channel_name] = error_msg
        
        timings[f"{channel_name}_time_sec"] = time.time() - start_time_channel
        if matrix_result is not None:
            logger.info(f"{channel_name} for S {subject_id} took {timings[f'{channel_name}_time_sec']:.2f}s.")
        else:
            logger.warning(f"{channel_name} for S {subject_id} failed or returned None. Took {timings[f'{channel_name}_time_sec']:.2f}s. Error: {errors_in_calculation.get(channel_name, 'N/A')}")
            
    num_modalities_expected = len(CONNECTIVITY_CHANNEL_NAMES)
    num_successful = sum(1 for mat_name in CONNECTIVITY_CHANNEL_NAMES if matrices.get(mat_name) is not None)
    if num_successful < num_modalities_expected:
        logger.warning(f"Connectivity for S {subject_id}: {num_successful}/{num_modalities_expected} modalities computed. Errors: {errors_in_calculation}")
    else:
        logger.info(f"Connectivity for S {subject_id}: All {num_successful}/{num_modalities_expected} modalities computed successfully.")

    return {"matrices": matrices, "errors_conn_calc": errors_in_calculation, "timings_conn_calc": timings}


# --- 5. Per-Subject Processing Pipeline (for Multiprocessing) ---
def process_single_subject_pipeline(subject_row_tuple: Tuple[int, pd.Series]) -> Dict[str, Any]:
    idx, subject_row = subject_row_tuple
    subject_id = str(subject_row['SubjectID']).strip()
    process = psutil.Process(os.getpid())
    ram_initial_mb = process.memory_info().rss / (1024**2)
    result: Dict[str, Any] = {
        "id": subject_id, "status_preprocessing": "PENDING", "detail_preprocessing": "",
        "status_connectivity_calc": "NOT_ATTEMPTED", "errors_connectivity_calc": {},
        "timings_connectivity_calc_sec": {}, "path_saved_tensor": None,
        "status_overall": "PENDING",
        "ram_usage_mb_initial": ram_initial_mb, "ram_usage_mb_final": -1.0
    }
    series_data: Optional[np.ndarray] = None
    try:
        series_data, detail_msg_preproc, success_preproc = load_and_preprocess_single_subject_series(
            subject_id, N_ROIS_EXPECTED, TARGET_LEN_TS,
            ROI_SIGNALS_DIR_PATH_AAL3, ROI_FILENAME_TEMPLATE, POSSIBLE_ROI_KEYS,
            LASSO_VAR_MAX_LAG, TR_SECONDS, LOW_CUT_HZ, HIGH_CUT_HZ, FILTER_ORDER,
            APPLY_HRF_DECONVOLUTION, HRF_MODEL
        )
        result["status_preprocessing"] = "SUCCESS" if success_preproc else "FAILED"
        result["detail_preprocessing"] = detail_msg_preproc
        if not success_preproc or series_data is None:
            result["status_overall"] = "PREPROCESSING_FAILED"; return result

        connectivity_results = calculate_all_connectivity_modalities_for_subject(
            subject_id, series_data, N_NEIGHBORS_MI,
            DFC_WIN_POINTS, DFC_STEP, LASSO_VAR_MAX_LAG
        )
        del series_data; series_data = None; gc.collect()
        calculated_matrices_dict = connectivity_results["matrices"]
        result["errors_connectivity_calc"] = connectivity_results["errors_conn_calc"]
        result["timings_connectivity_calc_sec"] = connectivity_results.get("timings_conn_calc", {})

        all_modalities_valid = True
        final_matrices_to_stack_list = []
        expected_matrix_shape = (N_ROIS_EXPECTED, N_ROIS_EXPECTED)
        for channel_name in CONNECTIVITY_CHANNEL_NAMES: # Use the global, dynamically built list
            matrix = calculated_matrices_dict.get(channel_name)
            if matrix is None:
                all_modalities_valid = False; err_msg = f"Modality '{channel_name}' None."
                logger.error(f"S {subject_id}: {err_msg}")
                result["errors_connectivity_calc"][channel_name] = result["errors_connectivity_calc"].get(channel_name, "") + " | " + err_msg
                break
            elif matrix.shape != expected_matrix_shape:
                all_modalities_valid = False; err_msg = f"Modality '{channel_name}' shape {matrix.shape} != {expected_matrix_shape}."
                logger.error(f"S {subject_id}: {err_msg}")
                result["errors_connectivity_calc"][channel_name] = result["errors_connectivity_calc"].get(channel_name, "") + " | " + err_msg
                break
            else:
                final_matrices_to_stack_list.append(matrix)
        
        if all_modalities_valid and len(final_matrices_to_stack_list) == N_CHANNELS: # N_CHANNELS is global
            result["status_connectivity_calc"] = "SUCCESS_ALL_MODALITIES_VALID"
            try:
                subject_tensor = np.stack(final_matrices_to_stack_list, axis=0).astype(np.float32)
                del final_matrices_to_stack_list, calculated_matrices_dict; gc.collect()
                output_dir_individual_tensors = BASE_PATH_AAL3 / OUTPUT_CONNECTIVITY_DIR_NAME / "individual_subject_tensors"
                output_dir_individual_tensors.mkdir(parents=True, exist_ok=True)
                output_path = output_dir_individual_tensors / f"tensor_{N_CHANNELS}ch_{subject_id}.npz"
                np.savez_compressed(output_path, tensor_data=subject_tensor, subject_id=subject_id, channel_names=np.array(CONNECTIVITY_CHANNEL_NAMES, dtype=str))
                result["path_saved_tensor"] = str(output_path)
                result["status_overall"] = "SUCCESS_ALL_PROCESSED_AND_SAVED"
                del subject_tensor; gc.collect()
            except Exception as e_save:
                logger.error(f"Error saving tensor S {subject_id}: {e_save}", exc_info=True)
                result["errors_connectivity_calc"]["save_error"] = str(e_save)
                result["status_overall"] = "FAILURE_DURING_TENSOR_SAVING"
                result["status_connectivity_calc"] = "FAILURE_DURING_SAVING"
        else:
            result["status_overall"] = "FAILURE_IN_CONNECTIVITY_CALC_OR_VALIDATION"
            result["status_connectivity_calc"] = "FAILURE_MISSING_OR_INVALID_MODALITIES"
            if not all_modalities_valid: logger.error(f"S {subject_id}: Not all modalities valid. Tensor not saved. Errors: {result['errors_connectivity_calc']}")
    finally:
        if series_data is not None: del series_data; gc.collect()
        result["ram_usage_mb_final"] = process.memory_info().rss / (1024**2)
    return result

# --- 6. Main Script Execution Flow ---
def main():
    script_start_time = time.time()
    main_process_info = psutil.Process(os.getpid())
    logger.info(f"Main process RAM at start: {main_process_info.memory_info().rss / (1024**2):.2f} MB")
    logger.info(f"--- Starting fMRI Connectivity Pipeline (Luppi et al. 2024 informed) ---")
    logger.info(f"--- Output Directory: {OUTPUT_CONNECTIVITY_DIR_NAME} ---")
    logger.info(f"--- N_ROIS_EXPECTED (final for connectivity): {N_ROIS_EXPECTED} ---")
    logger.info(f"--- TARGET_LEN_TS (homogenized time series length): {TARGET_LEN_TS} ---")

    if not BASE_PATH_AAL3.exists() or not ROI_SIGNALS_DIR_PATH_AAL3.exists():
        logger.critical("CRITICAL: Base AAL3 path or ROI signals directory not found. Aborting.")
        return

    subject_metadata_df = load_metadata(SUBJECT_METADATA_CSV_PATH, QC_REPORT_CSV_PATH)
    if subject_metadata_df is None or subject_metadata_df.empty:
        logger.critical("Metadata loading failed or no subjects passed QC. Aborting.")
        return

    output_main_directory = BASE_PATH_AAL3 / OUTPUT_CONNECTIVITY_DIR_NAME
    output_individual_tensors_dir = output_main_directory / "individual_subject_tensors"
    try:
        output_main_directory.mkdir(parents=True, exist_ok=True)
        output_individual_tensors_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Main output directory: {output_main_directory}")
    except OSError as e:
        logger.critical(f"Could not create output directories: {e}. Aborting."); return

    total_cpu_cores = multiprocessing.cpu_count()
    MAX_WORKERS = max(1, total_cpu_cores // 2) 
    logger.info(f"Total CPU cores: {total_cpu_cores}. Using MAX_WORKERS = {MAX_WORKERS}.")
    available_ram_gb = psutil.virtual_memory().available / (1024**3)
    logger.warning(f"Available system RAM: {available_ram_gb:.2f} GB. Monitor usage.")

    subject_rows_to_process = list(subject_metadata_df.iterrows())
    num_subjects_to_process = len(subject_rows_to_process)
    if num_subjects_to_process == 0: logger.critical("No subjects to process. Aborting."); return
    logger.info(f"Starting parallel processing for {num_subjects_to_process} subjects.")
    
    all_subject_results_list = []
    with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_subject_id_map = {
            executor.submit(process_single_subject_pipeline, subject_tuple): str(subject_tuple[1]['SubjectID']).strip()
            for subject_tuple in subject_rows_to_process
        }
        for future in tqdm(as_completed(future_to_subject_id_map), total=num_subjects_to_process, desc="Processing Subjects"):
            subject_id_for_log = future_to_subject_id_map[future]
            try:
                subject_result = future.result()
                all_subject_results_list.append(subject_result)
            except Exception as exc:
                logger.critical(f"CRITICAL WORKER EXCEPTION S {subject_id_for_log}: {exc}", exc_info=True)
                all_subject_results_list.append({
                    "id": subject_id_for_log, "status_overall": "CRITICAL_WORKER_EXCEPTION",
                    "detail_preprocessing": str(exc), "errors_connectivity_calc": {"worker_exception": str(exc)}
                })

    processing_log_df = pd.DataFrame(all_subject_results_list)
    log_file_path = output_main_directory / f"pipeline_log_{output_main_directory.name}.csv"
    try: processing_log_df.to_csv(log_file_path, index=False); logger.info(f"Processing log saved: {log_file_path}")
    except Exception as e_log_save: logger.error(f"Failed to save processing log: {e_log_save}")

    successful_subject_entries_list = [
        res for res in all_subject_results_list
        if res.get("status_overall") == "SUCCESS_ALL_PROCESSED_AND_SAVED" and \
           res.get("path_saved_tensor") and Path(res["path_saved_tensor"]).exists()
    ]
    num_successful_subjects = len(successful_subject_entries_list)
    logger.info(f"--- Processing Summary ---")
    logger.info(f"Total subjects attempted: {num_subjects_to_process}")
    logger.info(f"Successfully processed and individual tensors saved: {num_successful_subjects}")
    if num_successful_subjects < num_subjects_to_process:
        logger.warning(f"{num_subjects_to_process - num_successful_subjects} subjects failed. Check log.")

    if num_successful_subjects > 0:
        logger.info(f"Assembling global tensor for {num_successful_subjects} subjects.")
        global_conn_tensor_list = []
        final_subject_ids_in_tensor = []
        try:
            for s_entry in tqdm(successful_subject_entries_list, desc="Assembling Global Tensor"):
                s_id, tensor_path_str = s_entry["id"], s_entry["path_saved_tensor"]
                try:
                    with np.load(tensor_path_str) as loaded_npz:
                        s_tensor_data = loaded_npz['tensor_data']
                        if s_tensor_data.shape == (N_CHANNELS, N_ROIS_EXPECTED, N_ROIS_EXPECTED):
                            global_conn_tensor_list.append(s_tensor_data)
                            final_subject_ids_in_tensor.append(s_id)
                        else: logger.error(f"Tensor S {s_id} shape mismatch. Skipping.")
                    del s_tensor_data; gc.collect()
                except Exception as e_load_ind_tensor: logger.error(f"Error loading S {s_id} tensor: {e_load_ind_tensor}. Skipping.")
            
            if global_conn_tensor_list:
                global_conn_tensor = np.stack(global_conn_tensor_list, axis=0).astype(np.float32)
                del global_conn_tensor_list; gc.collect()
                global_tensor_fname = f"GLOBAL_TENSOR_AAL3_{N_ROIS_EXPECTED}_{len(final_subject_ids_in_tensor)}subs_{N_CHANNELS}ch_{LASSO_VAR_MAX_LAG}lag{deconv_str}.npz"
                global_tensor_path = output_main_directory / global_tensor_fname
                np.savez_compressed(global_tensor_path, global_tensor_data=global_conn_tensor, subject_ids=np.array(final_subject_ids_in_tensor, dtype=str), channel_names=np.array(CONNECTIVITY_CHANNEL_NAMES, dtype=str))
                logger.info(f"Global tensor saved: {global_tensor_path} (Shape: {global_conn_tensor.shape})")
                del global_conn_tensor; gc.collect()
            else: logger.warning("No valid individual tensors for global assembly.")
        except MemoryError: logger.critical("MEMORY ERROR during global tensor assembly.")
        except Exception as e_global: logger.critical(f"Error during global tensor assembly: {e_global}", exc_info=True)

    total_time_min = (time.time() - script_start_time) / 60
    logger.info(f"--- fMRI Connectivity Pipeline Finished ---")
    logger.info(f"Total execution time: {total_time_min:.2f} minutes.")
    logger.info(f"Final main process RAM: {main_process_info.memory_info().rss / (1024**2):.2f} MB")
    logger.info(f"Outputs in: {output_main_directory}")
    logger.info("THESIS REMINDER: Document parameters, QC, subject selection, AAL3 ROI definitions.")

if __name__ == "__main__":
    multiprocessing.freeze_support() 
    main()


2025-05-29 20:40:46,293 - INFO - 3047307419.py:115 - Connectivity channels to be computed: ['Pearson_OMST_Weighted', 'dFC_AbsDiffMean', 'Lasso_VAR_Influence']
2025-05-29 20:40:46,294 - INFO - 3047307419.py:116 - Total number of channels: 3
2025-05-29 20:40:46,298 - INFO - 3047307419.py:624 - Main process RAM at start: 239.98 MB
2025-05-29 20:40:46,298 - INFO - 3047307419.py:625 - --- Starting fMRI Connectivity Pipeline (Luppi et al. 2024 informed) ---
2025-05-29 20:40:46,299 - INFO - 3047307419.py:626 - --- Output Directory: AAL3_131_fmri_tensor_LuppiOMST_1lag_NeuroEnhanced_v6.1 ---
2025-05-29 20:40:46,299 - INFO - 3047307419.py:627 - --- N_ROIS_EXPECTED (final for connectivity): 131 ---
2025-05-29 20:40:46,299 - INFO - 3047307419.py:628 - --- TARGET_LEN_TS (homogenized time series length): 140 ---
2025-05-29 20:40:46,300 - INFO - 3047307419.py:125 - --- Starting Subject Metadata Loading and QC Integration ---
2025-05-29 20:40:46,308 - INFO - 3047307419.py:136 - Loaded main metadata fr

ERROR: The 'omst' package (topological_filtering_networks) is not installed. Please install it: pip install git+https://github.com/stdimitr/topological_filtering_networks.git
