<a href="https://colab.research.google.com/github/JinYuTong03/Wids-2025/blob/main/Wids2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# First mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Save uploaded files to Google Drive
import shutil
import os

# Create directory to store data files
!mkdir -p /content/drive/MyDrive/wids_datathon

# Copy uploaded files to Google Drive
input_files = {
    'train_cat': 'TRAIN_CATEGORICAL_METADATA_new.xlsx',
    'train_quant': 'TRAIN_QUANTITATIVE_METADATA_new.xlsx',
    'solutions': 'TRAINING_SOLUTIONS.xlsx',
    'test_cat': 'TEST_CATEGORICAL.xlsx',
    'test_quant': 'TEST_QUANTITATIVE_METADATA.xlsx',
    'train_connectome': 'TRAIN_FUNCTIONAL_CONNECTOME_MATRICES_new_36P_Pearson.csv',
    'test_connectome': 'TEST_FUNCTIONAL_CONNECTOME_MATRICES.csv'
}

for key, filename in input_files.items():
    if os.path.exists(filename):
        shutil.copy(filename, f'/content/drive/MyDrive/wids_datathon/{filename}')
        print(f"File {filename} has been saved to Google Drive")

In [None]:
# ====================================
# WiDS Datathon 2025 - ADHD and Gender Prediction with Multimodal Data
# Final Version (Simplified Connectome + No Extra Interaction Features + Hyperparameter Tuning + Single Best Model)
# ====================================

!pip install neuroCombat
!pip install neuroHarmonize
!pip install optuna
from neuroHarmonize import harmonizationLearn

# === 1. ML Setup and Library Imports ===
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns # Optional
import os
import shutil
import math
from time import time
import warnings
import optuna

# Preprocessing & Feature Engineering
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, IterativeImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder # Keep OneHotEncoder in case of future categorical features
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import VarianceThreshold
from scipy import stats
import networkx as nx
from joblib import Parallel, delayed
from sklearn.linear_model import LassoCV

# Modeling & Evaluation
from sklearn.model_selection import StratifiedKFold, RandomizedSearchCV, RepeatedStratifiedKFold
from sklearn.metrics import roc_auc_score, f1_score, precision_recall_curve, auc, confusion_matrix, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb
import lightgbm as lgb
from sklearn.base import clone
from scipy.stats import uniform, randint

warnings.filterwarnings('ignore')

SEED = 42

# ====================================
# 2. Data Loading
# ====================================
def load_data(categorical_path, quantitative_path, solutions_path=None,
              test_categorical_path=None, test_quantitative_path=None,
              train_connectome_path=None, test_connectome_path=None):
    """Load training and test data"""
    # [Function body same as before]
    print("Loading data files...")
    train_categorical, train_quantitative, train_connectome, solutions = None, None, None, None
    test_categorical, test_quantitative, test_connectome = None, None, None
    try:
        if categorical_path.endswith('.xlsx'): train_categorical = pd.read_excel(categorical_path)
        else: train_categorical = pd.read_csv(categorical_path)
        print(f"Training categorical data shape: {train_categorical.shape}")
    except Exception as e: print(f"Error loading {categorical_path}: {e}")
    try:
        if quantitative_path.endswith('.xlsx'): train_quantitative = pd.read_excel(quantitative_path)
        else: train_quantitative = pd.read_csv(quantitative_path)
        print(f"Training quantitative data shape: {train_quantitative.shape}")
    except Exception as e: print(f"Error loading {quantitative_path}: {e}")
    if train_connectome_path and os.path.exists(train_connectome_path):
        try: train_connectome = pd.read_csv(train_connectome_path); print(f"Training connectome shape: {train_connectome.shape}")
        except Exception as e: print(f"Error loading {train_connectome_path}: {e}")
    else: print(f"Warning: Training connectome file not found: {train_connectome_path}")
    if solutions_path and os.path.exists(solutions_path):
        try:
            if solutions_path.endswith('.xlsx'): solutions = pd.read_excel(solutions_path)
            else: solutions = pd.read_csv(solutions_path)
            print(f"Target variable data shape: {solutions.shape}")
        except Exception as e: print(f"Error loading {solutions_path}: {e}")
    else: print(f"Warning: Training solutions file not found: {solutions_path}")
    if test_categorical_path and test_quantitative_path and os.path.exists(test_categorical_path) and os.path.exists(test_quantitative_path):
        try:
            if test_categorical_path.endswith('.xlsx'): test_categorical = pd.read_excel(test_categorical_path)
            else: test_categorical = pd.read_csv(test_categorical_path)
            print(f"Test categorical data shape: {test_categorical.shape}")
        except Exception as e: print(f"Error loading {test_categorical_path}: {e}")
        try:
            if test_quantitative_path.endswith('.xlsx'): test_quantitative = pd.read_excel(test_quantitative_path)
            else: test_quantitative = pd.read_csv(test_quantitative_path)
            print(f"Test quantitative data shape: {test_quantitative.shape}")
        except Exception as e: print(f"Error loading {test_quantitative_path}: {e}")
        if test_connectome_path and os.path.exists(test_connectome_path):
            try: test_connectome = pd.read_csv(test_connectome_path); print(f"Test connectome shape: {test_connectome.shape}")
            except Exception as e: print(f"Error loading {test_connectome_path}: {e}")
        else: print(f"Warning: Test connectome file not found: {test_connectome_path}")
    else: print("Warning: Test file paths are incomplete or files do not exist.")
    return (train_categorical, train_quantitative, train_connectome, solutions, test_categorical, test_quantitative, test_connectome)


# ====================================
# 3. Data Preprocessing and Merging
# ====================================
def preprocess_and_merge_data(train_categorical, train_quantitative, solutions=None,
                              test_categorical=None, test_quantitative=None):
    """Preprocess and merge data from different sources"""
    # [Function body same as before]
    print("\nPreprocessing and merging data...")
    train_merged, test_merged = None, None
    if train_categorical is None or train_quantitative is None: print("Error: Training data is missing"); return None, None
    try:
        train_cat = train_categorical.copy(); train_quant = train_quantitative.copy()
        common_cols = list(set(train_cat.columns) & set(train_quant.columns) - {'participant_id'})
        if common_cols: print(f"Warning: Common columns: {common_cols}")
        train_merged = train_cat.merge(train_quant, on='participant_id', how='inner', suffixes=('_cat', '_quant'))
        print(f"Merged training data shape: {train_merged.shape}")
        if solutions is not None: train_merged = train_merged.merge(solutions, on='participant_id', how='inner'); print(f"Training data with targets shape: {train_merged.shape}")
        else: print("Warning: No target variables provided")
    except Exception as e: print(f"Error merging training data: {e}"); return None, None
    if test_categorical is not None and test_quantitative is not None:
        try: test_cat = test_categorical.copy(); test_quant = test_quantitative.copy(); test_merged = test_cat.merge(test_quant, on='participant_id', how='inner', suffixes=('_cat', '_quant')); print(f"Merged test data shape: {test_merged.shape}")
        except Exception as e: print(f"Error merging test data: {e}")
    return train_merged, test_merged


# ====================================
# 2.5. Improved Neuroimaging Data Processing
# ====================================

# --- Necessary Imports ---
import pandas as pd
import numpy as np
import networkx as nx
from scipy import stats
from joblib import Parallel, delayed
from sklearn.feature_selection import VarianceThreshold
import math # For sqrt and isclose
from time import time

# --- ConnectomeProcessor Class Definition (including all discussed metrics) ---
class ConnectomeProcessor:
    """Processes functional connectome matrices to extract statistics and richer graph theory metrics"""

    def __init__(self, n_jobs=-1):
        self.n_jobs = n_jobs
        self.selector = None
        self.feature_columns = None
        # --- Default values including all metrics ---
        self.network_metric_defaults = {
            'avg_degree': 0, 'max_degree': 0, 'min_degree': 0,
            'degree_variance': 0,
            'avg_betweenness': 0,      # Re-added
            'avg_closeness': 0,        # Re-added
            'pos_density': 0, 'neg_density': 0,
            'avg_clustering': 0,       # Re-added
            'global_efficiency': 0,    # Re-added
            'components_0.3': 0, 'components_0.5': 0,
            'largest_cc_0.3': 0, 'largest_cc_0.5': 0
        }

    def extract_participant_features(self, participant_id, corr_values):
        """
        Extract features (including basic statistics and all graph metrics).
        """
        metrics = {'participant_id': participant_id}
        try:
            corr_values = np.array(corr_values, dtype=float).flatten()
        except ValueError:
            print(f"Error: P:{participant_id}'s corr_values contain invalid values. Skipping.")
            return None

        # --- 1. Basic Connectivity Statistics [Logic unchanged] ---
        if len(corr_values) > 0 and not np.all(np.isnan(corr_values)):
            metrics['mean_connectivity'] = np.nanmean(corr_values)
            metrics['std_connectivity'] = np.nanstd(corr_values)
            valid_corr = corr_values[~np.isnan(corr_values)]
            if len(valid_corr) > 0:
                metrics['positive_connectivity_ratio'] = np.mean(valid_corr > 0)
                metrics['strong_pos_conn_ratio'] = np.mean(valid_corr > 0.5)
                metrics['strong_neg_conn_ratio'] = np.mean(valid_corr < -0.5)
                if len(valid_corr) > 2:
                    metrics['skewness'] = stats.skew(valid_corr)
                    metrics['kurtosis'] = stats.kurtosis(valid_corr)
                else: metrics['skewness'], metrics['kurtosis'] = 0, 0
            else: metrics.update({'positive_connectivity_ratio': 0, 'strong_pos_conn_ratio': 0, 'strong_neg_conn_ratio': 0, 'skewness': 0, 'kurtosis': 0})
        else:
            metrics.update({'mean_connectivity': 0, 'std_connectivity': 0, 'positive_connectivity_ratio': 0, 'strong_pos_conn_ratio': 0, 'strong_neg_conn_ratio': 0, 'skewness': 0, 'kurtosis': 0})
            metrics.update(self.network_metric_defaults); return metrics

        # --- 2. Network Feature Extraction ---
        n_regions = 0
        matrix_valid = False
        # [Matrix dimension calculation and validation logic unchanged]
        M = len(corr_values)
        if M > 0:
            delta = 1 + 8 * M
            if delta >= 0:
                sqrt_delta = math.sqrt(delta)
                N_float = (1 + sqrt_delta) / 2
                if math.isclose(N_float, round(N_float)):
                    n_regions_calc = int(round(N_float))
                    if n_regions_calc > 0 and n_regions_calc * (n_regions_calc - 1) // 2 == M:
                        n_regions = n_regions_calc; matrix_valid = True
        # Initialize all network metrics to default values
        metrics.update(self.network_metric_defaults)
        metrics['components_0.3'] = n_regions if matrix_valid else 0
        metrics['components_0.5'] = n_regions if matrix_valid else 0
        metrics['largest_cc_0.3'] = 0; metrics['largest_cc_0.5'] = 0
        if not matrix_valid: return metrics

        # --- If the matrix is valid, try to calculate all metrics ---
        try:
            matrix = np.zeros((n_regions, n_regions))
            triu_indices = np.triu_indices(n_regions, k=1)
            if len(triu_indices[0]) == M:
                corr_values_filled = np.nan_to_num(corr_values, nan=0.0)
                matrix[triu_indices] = corr_values_filled
            else: print(f"Error: P:{participant_id} internal mismatch of index/value counts"); return metrics
            matrix = matrix + matrix.T; matrix_abs = np.abs(matrix)

            G_full = nx.from_numpy_array(matrix_abs)
            G_unweighted = nx.from_numpy_array((matrix_abs > 1e-6).astype(int))

            # --- Extract basic and new network metrics ---
            if G_full.number_of_nodes() > 0:
                degrees = [d for _, d in G_full.degree(weight='weight')]
                if degrees:
                    metrics['avg_degree'] = np.mean(degrees)
                    metrics['max_degree'] = np.max(degrees)
                    metrics['min_degree'] = np.min(degrees)
                    metrics['degree_variance'] = np.var(degrees)

                # ===> Re-adding complex metric calculations <===
                try:
                    betweenness = nx.betweenness_centrality(G_unweighted, normalized=True)
                    metrics['avg_betweenness'] = np.mean(list(betweenness.values()))
                except Exception as e_bw: print(f"  P:{participant_id} Betweenness failed: {e_bw}")

                try:
                    closeness = nx.closeness_centrality(G_unweighted)
                    metrics['avg_closeness'] = np.mean(list(closeness.values()))
                except Exception as e_cl: print(f"  P:{participant_id} Closeness failed: {e_cl}")

                try:
                    metrics['avg_clustering'] = nx.average_clustering(G_full, weight='weight')
                except Exception as e_cc: print(f"  P:{participant_id} Clustering failed: {e_cc}")

                try:
                    metrics['global_efficiency'] = nx.global_efficiency(G_unweighted)
                except Exception as e_ge: print(f"  P:{participant_id} Efficiency failed: {e_ge}")
                # ===> End of complex metric calculations <===

            # --- Density Calculation ---
            matrix_filled = np.nan_to_num(matrix, nan=0.0)
            G_pos = nx.from_numpy_array(np.where(matrix_filled > 0, matrix_filled, 0))
            G_neg = nx.from_numpy_array(np.where(matrix_filled < 0, -matrix_filled, 0))
            metrics['pos_density'] = nx.density(G_pos); metrics['neg_density'] = nx.density(G_neg)

            # --- Threshold Analysis (including largest connected component) ---
            for threshold in [0.3, 0.5]:
                G_thresh = nx.Graph(); G_thresh.add_nodes_from(range(n_regions))
                for u, v, d in G_full.edges(data=True):
                    if d.get('weight', 0) > threshold: G_thresh.add_edge(u, v, weight=d['weight'])
                metrics[f'components_{threshold}'] = nx.number_connected_components(G_thresh)
                try:
                    components = list(nx.connected_components(G_thresh))
                    if components: largest_cc = max(components, key=len); metrics[f'largest_cc_{threshold}'] = len(largest_cc) / n_regions
                except Exception as e_lcc: print(f"  P:{participant_id} LargestCC failed (thr={threshold}): {e_lcc}")

        except Exception as e: print(f"Error calculating network metrics for P:{participant_id}: {e}")
        return metrics

    # --- Methods _get_corr_cols, fit_transform, transform remain unchanged ---
    # [Code same as the previous version]
    def _get_corr_cols(self, df):
        corr_cols = [col for col in df.columns if 'throw_' in col]
        if not corr_cols:
            numeric_cols = df.select_dtypes(include=np.number).columns
            if len(numeric_cols) > 0.8 * len(df.columns):
                potential_cols = numeric_cols.drop('participant_id', errors='ignore').tolist()
                if potential_cols: print("Warning: 'throw_' columns not found..."); return potential_cols
                else: raise ValueError("Connectome data columns not found (auto-detection failed).")
            else: raise ValueError("Connectome data columns not found ('throw_' or majority numeric columns).")
        return corr_cols

    def fit_transform(self, connectome_df):
        print("  Executing connectome fit_transform...")
        if 'participant_id' not in connectome_df.columns: raise ValueError("'participant_id' column is missing")
        corr_cols = self._get_corr_cols(connectome_df)
        participant_ids = connectome_df['participant_id'].tolist()
        corr_values_list = connectome_df[corr_cols].values.tolist()
        features = Parallel(n_jobs=self.n_jobs)(delayed(self.extract_participant_features)(pid, vals) for pid, vals in zip(participant_ids, corr_values_list))
        valid_features = [f for f in features if f is not None]
        if not valid_features: print("Warning: No valid features"); return pd.DataFrame(columns=['participant_id'])
        features_df = pd.DataFrame(valid_features).set_index('participant_id')
        self.feature_columns = features_df.columns.tolist()
        if not self.feature_columns: print("Warning: No feature columns extracted"); return features_df.reset_index()[['participant_id']]
        self.medians_ = features_df.median()
        feature_matrix = features_df.fillna(self.medians_)
        self.selector = VarianceThreshold(threshold=0.01)
        try:
            self.selector.fit(feature_matrix)
            selected_features_mask = self.selector.get_support()
            selected_features = feature_matrix.columns[selected_features_mask]
            processed_matrix = self.selector.transform(feature_matrix)
            print(f"  Number of features remaining after VarianceThreshold: {len(selected_features)}")
            if len(selected_features) == 0: print("Warning: No features remaining"); selected_features = feature_matrix.columns; processed_matrix = feature_matrix.values; self.selector = None
        except ValueError as ve: print(f"Warning: VarianceThreshold failed: {ve}"); selected_features = feature_matrix.columns; processed_matrix = feature_matrix.values; self.selector = None
        processed_df = pd.DataFrame(processed_matrix, columns=selected_features, index=features_df.index).reset_index()
        self.final_feature_columns_ = selected_features.tolist()
        return processed_df

    def transform(self, connectome_df):
        print("  Executing connectome transform...")
        if self.feature_columns is None: raise ValueError("Processor has not been fitted")
        if 'participant_id' not in connectome_df.columns: raise ValueError("'participant_id' column is missing")
        corr_cols = self._get_corr_cols(connectome_df)
        participant_ids = connectome_df['participant_id'].tolist()
        corr_values_list = connectome_df[corr_cols].values.tolist()
        features = Parallel(n_jobs=self.n_jobs)(delayed(self.extract_participant_features)(pid, vals) for pid, vals in zip(participant_ids, corr_values_list))
        valid_features = [f for f in features if f is not None]
        if not valid_features: print("Warning: No valid features in test data"); final_cols = self.final_feature_columns_ if hasattr(self, 'final_feature_columns_') else []; return pd.DataFrame(columns=['participant_id'] + final_cols)
        features_df = pd.DataFrame(valid_features).set_index('participant_id')
        if not hasattr(self, 'medians_'): print("Warning: Training medians not stored"); self.medians_ = 0
        feature_matrix = features_df.reindex(columns=self.feature_columns).fillna(self.medians_)
        if self.selector is not None and hasattr(self, 'final_feature_columns_'):
            try: processed_matrix = self.selector.transform(feature_matrix); selected_features = self.final_feature_columns_
            except ValueError as ve: print(f"Warning: VT transform failed: {ve}."); processed_matrix = feature_matrix.values; selected_features = self.feature_columns
            except Exception as e: print(f"Error in VT transform: {e}"); raise
        else: processed_matrix = feature_matrix.values; selected_features = self.feature_columns
        processed_df = pd.DataFrame(processed_matrix, columns=selected_features, index=features_df.index).reset_index()
        return processed_df


# --- process_connectome_data function (no changes needed) ---
def process_connectome_data(train_connectome, test_connectome=None):
    """Processes functional connectome matrices using the improved class-based method."""
    # [Code same as above]
    if train_connectome is None: print("Error: Training connectome data not provided."); return None, None
    print("\nProcessing neuroimaging data (functional connectome)...")
    start_time = time(); processor = ConnectomeProcessor(n_jobs=-1)
    train_connectome_features = processor.fit_transform(train_connectome)
    if train_connectome_features is not None and 'participant_id' in train_connectome_features.columns: print(f"Extracted {train_connectome_features.shape[1] - 1} connectome features for training")
    else: print("Warning: Training feature extraction failed or did not return participant_id.")
    test_connectome_features = None
    if test_connectome is not None:
        try:
            test_connectome_features = processor.transform(test_connectome)
            if test_connectome_features is not None and 'participant_id' in test_connectome_features.columns: print(f"Extracted {test_connectome_features.shape[1] - 1} connectome features for testing")
            else: print("Warning: Test feature extraction failed or did not return participant_id.")
        except Exception as e: print(f"Error: Failed to process test connectome data: {e}"); final_cols = processor.final_feature_columns_ if hasattr(processor, 'final_feature_columns_') else []; test_connectome_features = pd.DataFrame(columns=['participant_id'] + final_cols)
    elapsed_time = time() - start_time; print(f"Connectome processing completed in {elapsed_time:.2f} seconds")
    return train_connectome_features, test_connectome_features



# ====================================
# 5. Multimodal Data Fusion
# ====================================
def multimodal_fusion(train_merged, train_connectome_features, test_merged=None, test_connectome_features=None):
    """Combine behavioral/clinical data with neuroimaging features"""
    # [Function body same as before]
    print("\nPerforming multimodal data fusion...")
    train_multimodal, test_multimodal = None, None
    if train_merged is None or train_connectome_features is None or train_connectome_features.empty: print("Error: Cannot perform training data fusion")
    else:
        conn_feature_cols = [col for col in train_connectome_features.columns if col != 'participant_id']
        if not conn_feature_cols: print("Warning: No data in connectome features, skipping fusion."); train_multimodal = train_merged.copy()
        else:
            train_multimodal = train_merged.merge(train_connectome_features, on='participant_id', how='inner')
            print(f"Shape after training data fusion: {train_multimodal.shape}"); print(f"Added {len(conn_feature_cols)} connectome features")
            n_lost = len(train_merged) - len(train_multimodal);
            if n_lost > 0: print(f"Warning: {n_lost} participants lost during training fusion")
    if test_merged is not None and test_connectome_features is not None and not test_connectome_features.empty:
        conn_feature_cols_test = [col for col in test_connectome_features.columns if col != 'participant_id']
        if not conn_feature_cols_test: print("Warning: No data in test connectome features, skipping fusion."); test_multimodal = test_merged.copy()
        else:
            test_multimodal = test_merged.merge(test_connectome_features, on='participant_id', how='inner')
            print(f"Shape after test data fusion: {test_multimodal.shape}")
            n_lost = len(test_merged) - len(test_multimodal);
            if n_lost > 0: print(f"Warning: {n_lost} participants lost during test fusion")
    elif test_merged is not None: print("Warning: Failed to fuse test data."); test_multimodal = test_merged.copy()
    return train_multimodal, test_multimodal

# ====================================
# 7. Feature Engineering and Preprocessing
# ====================================
def feature_engineering(train_merged, test_merged=None):
    """Performs feature engineering and preprocessing, properly handling numerical and categorical features."""
    print("\nPerforming feature engineering...")
    X_train_processed, y_train, X_test_processed, preprocessor, feature_names_out = None, None, None, None, None
    if train_merged is None or train_merged.empty:
        print("Error: Training data is missing")
        return X_train_processed, y_train, X_test_processed, preprocessor, feature_names_out

    # --- Operate directly on the passed train_merged ---
    target_cols = [col for col in ['ADHD_Outcome', 'Sex_F'] if col in train_merged.columns]
    if not target_cols:
        print("Warning: Target variables not found")
        y_train = None
    else:
        y_train = train_merged[target_cols]
        print(f"Training targets shape: {y_train.shape}")

    drop_cols = ['participant_id'] + target_cols
    # Keep all non-dropped columns as the initial feature pool
    X_train = train_merged.drop(columns=drop_cols, errors='ignore')

    # --- Process test set ---
    X_test = None
    if test_merged is not None:
        # Check and align columns (based on X_train's columns) - keep all columns from X_train
        train_cols = X_train.columns.tolist()
        test_merged_aligned = test_merged.copy()
        # Drop extra columns in the test set that are not in the training set (except for participant_id)
        extra_cols_in_test = [col for col in test_merged_aligned.columns if col not in train_cols and col != 'participant_id']
        if extra_cols_in_test:
            print(f"Warning: Extra columns in test set will be dropped: {extra_cols_in_test}")
            test_merged_aligned.drop(columns=extra_cols_in_test, errors='ignore', inplace=True)

        # Check if the test set is missing columns from the training set
        missing_cols_in_test = [col for col in train_cols if col not in test_merged_aligned.columns]
        if missing_cols_in_test:
            print(f"Warning: Test set is missing columns from the training set: {missing_cols_in_test}. These columns will be created and filled with NaN.")
            for col in missing_cols_in_test:
                test_merged_aligned[col] = np.nan  # Fill with NaN, IterativeImputer will handle it

        # Ensure the column order of the test set is consistent with the training set
        X_test = test_merged_aligned[train_cols]
        print(f"Test features shape after alignment: {X_test.shape}")

    # --- Identify Feature Types (Updated) ---
    # Explicitly list features to be treated as categorical
    # Assume these are columns from the original categorical metadata, retaining their original names after the merge
    categorical_features_list = [
        'Basic_Demos_Gender',  # May need adjustment based on the actual column name if it's not Sex_F
        'Barratt_Barratt_P1_Edu',
        'Barratt_Barratt_P1_Occ',
        'Barratt_Barratt_P2_Edu',
        'Barratt_Barratt_P2_Occ',
        # Add other columns that need to be treated as categorical, check the columns of train_categorical
    ]

    # Filter out columns from the original feature pool that actually exist and belong to the categorical list
    categorical_features = [col for col in categorical_features_list if col in X_train.columns]
    # The rest are considered numerical features (excluding those identified as categorical)
    numeric_features = [col for col in X_train.columns if col not in categorical_features]

    print(f"Identified numerical features: {len(numeric_features)}, Identified categorical features: {len(categorical_features)}")

    if not numeric_features and not categorical_features:
        print("Error: No features found")
        return None, y_train, None, None, None
    # Note: IterativeImputer is mainly for numerical features. If there are only categorical features, this pipeline needs adjustment.
    # If there are indeed no numerical features but there are categorical ones, they need to be handled separately. Currently, it's assumed there are at least some numerical features.
    if not numeric_features:
        print("Warning: No numerical features found, IterativeImputer may not work. Please check the feature list and data.")  # Continue, but may error out

    # --- Create Preprocessing Pipeline (Updated) ---
    # Numerical feature pipeline: IterativeImputer + StandardScaler
    numeric_transformer = Pipeline(steps=[
        ('imputer', IterativeImputer(estimator=LassoCV(random_state=SEED), max_iter=5, random_state=SEED)),
        ('scaler', StandardScaler())
    ])

    # Categorical feature pipeline: SimpleImputer (for categories) + OneHotEncoder
    # Use SimpleImputer to fill missing values in categorical features (with a constant or the most frequent value)
    # handle_unknown='ignore' to handle categories in the test set that were not seen in the training set
    categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),  # Use most frequent value instead
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
    ])
    # Use ColumnTransformer to combine different types of processing
    # remainder='passthrough' to keep columns not listed in transformers (if needed, though they are usually all listed here)
    # Or remainder='drop' to drop unlisted columns (the original behavior)
    transformers_list = []
    if numeric_features:
        transformers_list.append(('num', numeric_transformer, numeric_features))
    if categorical_features:
        transformers_list.append(('cat', categorical_transformer, categorical_features))

    if not transformers_list:
        print("Error: No features were selected in the preprocessor.")
        return None, y_train, None, None, None

    preprocessor = ColumnTransformer(
        transformers=transformers_list,
        remainder='passthrough'  # Or set to 'drop' as needed
    )

    # --- Apply Preprocessing ---
    print("Applying preprocessing pipeline...")
    try:
        # Ensure fit_transform only fits on the training data
        X_train_preprocessed = preprocessor.fit_transform(X_train)

        # Get feature names after processing
        # Need to ensure the preprocessor can get feature names; for OneHotEncoder, this usually requires specifying categories='auto' or providing a list of categories
        try:
            feature_names_out = preprocessor.get_feature_names_out()
        except Exception as name_e:
            print(f"Warning: Could not get feature names: {name_e}. Feature names will be the default names generated by ColumnTransformer.")
            # fallback to default names if get_feature_names_out fails
            if hasattr(preprocessor, 'transformers_'):
                current_feature_names = []
                for name, trans, cols in preprocessor.transformers_:
                    if trans == 'passthrough':
                        current_feature_names.extend(cols)
                    elif hasattr(trans, 'get_feature_names_out'):
                        # Handle pipelines within ColumnTransformer
                        if isinstance(trans, Pipeline):
                            # Try to get names from the last step or the encoder if present
                            last_step = trans.steps[-1][1]
                            if hasattr(last_step, 'get_feature_names_out'):
                                # get_feature_names_out requires input_features for OneHotEncoder in Pipeline
                                # This is tricky within ColumnTransformer. Simple fallback might be needed.
                                try:
                                    step_names = last_step.get_feature_names_out(cols)
                                    current_feature_names.extend(step_names)
                                except:  # If get_feature_names_out needs input_features, fallback
                                    print("   Fallback: Using ColumnTransformer default names.")
                                    feature_names_out = preprocessor.get_feature_names_out()  # Try the main method
                                    break  # Stop manual attempt

                        else:  # For other transformers directly in ColumnTransformer
                            step_names = trans.get_feature_names_out(cols)
                            current_feature_names.extend(step_names)

                if 'feature_names_out' not in locals() or feature_names_out is None:
                    feature_names_out = current_feature_names  # Use collected names if main method failed

        # If still no feature names (e.g. only VarianceThreshold) or mismatch
        if feature_names_out is None or len(feature_names_out) != X_train_preprocessed.shape[1]:
            print("Warning: Could not reliably get feature names, using default generated names.")
            feature_names_out = [f'feature_{i}' for i in range(X_train_preprocessed.shape[1])]

        X_train_processed = pd.DataFrame(X_train_preprocessed, index=X_train.index, columns=feature_names_out)
        print(f"Training features shape: {X_train_processed.shape}")

        if X_test is not None:
            # Ensure transform uses the fitted preprocessor
            X_test_preprocessed = preprocessor.transform(X_test)
            # Test set uses the same feature names as the training set
            X_test_processed = pd.DataFrame(X_test_preprocessed, index=X_test.index, columns=feature_names_out)
            print(f"Test features shape: {X_test_processed.shape}")
        else:
            X_test_processed = None

    except Exception as e:
        print(f"Error in feature engineering/preprocessing: {e}")
        # Fallback: Return original data if processing fails (might cause later errors but allows pipeline to continue)
        # Or return None/empty as before if failure is critical
        return None, y_train, None, None, None  # Keep original failure handling

    return X_train_processed, y_train, X_test_processed, preprocessor, feature_names_out


# ====================================
# 4. Feature Selection
# ====================================
def select_features(X_train, y_train, n_features=30): # Default changed to 20
    """Performs feature selection using Random Forest feature importance"""
    print("\nPerforming feature selection...")
    if X_train is None or y_train is None or y_train.shape[1] == 0:
        print("Warning: Missing training data or target variables, cannot perform feature selection.")
        return None, {} if X_train is None else X_train.columns.tolist(), {}

    selected_features_dict = {}
    all_selected_set = set()

    for target in y_train.columns:
        print(f"\nSelecting features for {target}:")
        if X_train.empty: print(f"Warning: X_train is empty, cannot select features for {target}."); continue

        try:
            rf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
            rf.fit(X_train, y_train[target])

            importances = pd.DataFrame({
                'feature': X_train.columns,
                'importance': rf.feature_importances_
            }).sort_values('importance', ascending=False)

            # Ensure n_features is not greater than the number of available features
            actual_n_features = min(n_features, len(X_train.columns))
            top_features = importances.head(actual_n_features)['feature'].tolist()
            selected_features_dict[target] = top_features
            all_selected_set.update(top_features)

            print(f"Top {len(top_features)} features selected for {target}")
            print("Top 5 important features:")
            for i, feat in enumerate(top_features[:5], 1):
                imp = importances.loc[importances['feature'] == feat, 'importance'].values[0]
                print(f"{i}. {feat}: {imp:.4f}")
        except Exception as e:
            print(f"Error selecting features for {target}: {e}")
            selected_features_dict[target] = X_train.columns.tolist() # Keep all features if an error occurs
            all_selected_set.update(X_train.columns.tolist())


    all_selected = list(all_selected_set)
    print(f"\nSelected a total of {len(all_selected)} unique features")

    return all_selected, selected_features_dict

# ====================================
# 8. Inverse Probability Weighting
# ====================================
def calculate_optimized_weights(y_train,
                                female_adhd_boost=2.0,
                                female_non_adhd_boost=1.5,
                                male_non_adhd_boost=1.2,
                                male_adhd_boost=1.0):
    """Optimizes sample weights using IPW, allowing adjustment of boost factors."""
    print("\nCalculating inverse probability weighting sample weights (adjustable factors)...")

    if y_train is None or 'ADHD_Outcome' not in y_train.columns or 'Sex_F' not in y_train.columns:
        print("Warning: Cannot calculate weights")
        return None

    try:
        ADHD_Outcome = y_train['ADHD_Outcome'].values
        Sex_F = y_train['Sex_F'].values
    except KeyError as e:
        print(f"Warning: Target column {e} is missing")
        return None

    combined_target = ADHD_Outcome * 2 + Sex_F
    class_counts = np.bincount(combined_target, minlength=4)
    total_samples = len(combined_target)

    if total_samples == 0:
        print("Warning: Number of samples is 0")
        return None

    print("Class distribution:")
    class_desc_map = {
        0: "Non-ADHD Male",
        1: "Non-ADHD Female",
        2: "ADHD Male",
        3: "ADHD Female"
    }

    boost_map = {
        0: male_non_adhd_boost,
        1: female_non_adhd_boost,
        2: male_adhd_boost,
        3: female_adhd_boost
    }

    for class_val, count in enumerate(class_counts):
        print(f"  Class {class_val} ({class_desc_map.get(class_val, '?')}): {count} samples")

    weights = np.ones(total_samples)
    base_weights = {}

    for class_val, count in enumerate(class_counts):
        if count == 0:
            base_weights[class_val] = 0
            continue

        boost_factor = boost_map.get(class_val, 1.0)
        class_weight = boost_factor * total_samples / count
        base_weights[class_val] = class_weight
        weights[combined_target == class_val] = class_weight

    mean_weight = np.mean(weights)

    if mean_weight > 0:
        weights = weights / mean_weight
        print(f"\nWeights have been normalized")
    else:
        print("\nWarning: Cannot normalize weights")

    print(f"Total samples: {total_samples}")
    print(f"Number of female ADHD samples: {class_counts[3]}")

    print(f"Using weight coefficients:")
    print(f"  - ADHD Female(3): {female_adhd_boost:.2f}x")
    print(f"  - Non-ADHD Female(1): {female_non_adhd_boost:.2f}x")
    print(f"  - Non-ADHD Male(0): {male_non_adhd_boost:.2f}x")
    print(f"  - ADHD Male(2): {male_adhd_boost:.2f}x")

    print("\nFinal average weights:")
    for class_val in range(len(class_counts)):
        mask = (combined_target == class_val)
        if np.sum(mask) > 0:
            avg_w = np.mean(weights[mask])
            print(f"  Class {class_val} ({class_desc_map.get(class_val, '?')}): {avg_w:.4f}")
        else:
            print(f"  Class {class_val} ({class_desc_map.get(class_val, '?')}): N/A")

    return weights



# --- Optuna Objective Functions ---
from sklearn.model_selection import StratifiedKFold # Already imported, but reiterated here to ensure visibility
from sklearn.model_selection import cross_val_score # cross_val_score is no longer used directly in the function, but the import can be kept

def objective_lr(trial, X, y, sample_weights):
    """Optuna objective function for Logistic Regression (Revised)."""
    # Fix: Suggest penalty and solver parameters independently to avoid dynamic space errors
    penalty = trial.suggest_categorical('penalty', ['l1', 'l2'])
    solver = trial.suggest_categorical('solver', ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'])
    C = trial.suggest_float('C', 0.01, 10, log=True)

    # Check for incompatible combinations and return a low score
    # Common compatibility issues:
    # 'lbfgs', 'newton-cg', 'sag', 'saga' only support 'l2' or None penalty
    # 'liblinear' supports 'l1' and 'l2'
    # 'saga' and 'liblinear' support 'l1'
    # Note: Your code uses class_weight='balanced', which is often mutually exclusive with or requires careful use alongside sample_weight.
    # If used simultaneously, sample_weight will be multiplied by class_weight. The original setting is kept here, but be aware.
    incompatible = False
    if penalty == 'l1' and solver not in ['liblinear', 'saga']:
        incompatible = True
    if penalty == 'l2' and solver == 'liblinear': # liblinear supports l2, this might be fine based on original param_dist
        # If you want to exclude liblinear for l2 based on preference, set incompatible=True
        pass # Keep liblinear for l2 as it was in original param_dist

    if incompatible:
        # print(f"Info: Skipping incompatible LR params: penalty={penalty}, solver={solver}")
        return 0.0 # Return a low score for incompatible combinations

    # Instantiate model
    model = LogisticRegression(
        C=C,
        penalty=penalty,
        solver=solver,
        random_state=42,
        max_iter=5000, # Increase max_iter, especially for sag/saga solvers, to ensure convergence
        class_weight='balanced' # Continue to use, but sample_weight will further adjust the total weight
    )

    # --- Manual Cross-Validation for Evaluation ---
    # Replace cross_val_score to manually pass sample_weight
    cv_scores = []
    # Use the same number of splits and random_state as the outer CV, but ensure it's an independent KFold instance
    inner_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42) # The 3 folds here are the same as in the original RandomizedSearchCV

    try:
        for fold, (train_idx, val_idx) in enumerate(inner_cv.split(X, y, groups=y)): # Use y for stratification
            X_train_fold, X_val_fold = X.iloc[train_idx], X.iloc[val_idx]
            y_train_fold, y_val_fold = y.iloc[train_idx], y.iloc[val_idx]

            fold_weights = sample_weights[train_idx] if sample_weights is not None else None

            # Clone the model to ensure independent training for each fold, and set the random state
            fold_model = clone(model)
            if hasattr(fold_model, 'random_state'):
                fold_model.random_state = 42 + fold # Set a different random state for each fold

            # Train the model, manually passing sample_weight
            # Note: LogisticRegression's fit method directly accepts sample_weight
            fit_params_fold = {}
            if fold_weights is not None:
                fit_params_fold['sample_weight'] = fold_weights

            fold_model.fit(X_train_fold, y_train_fold, **fit_params_fold)

            # Evaluate model (using AUC)
            y_pred_proba = fold_model.predict_proba(X_val_fold)[:, 1]
            if len(np.unique(y_val_fold)) < 2:
                # If the validation set has only one class, AUC cannot be calculated, skip this fold or return a low score
                # print(f"Info: Skipping LR fold {fold} due to single class in validation set.")
                continue # Skip the current fold

            score = roc_auc_score(y_val_fold, y_pred_proba)
            cv_scores.append(score)

        if not cv_scores: # If all folds were skipped (e.g., too little data or stratification issues)
            # print("Warning: LR tuning failed, no valid CV scores.")
            return 0.0 # Return a low score

        return np.mean(cv_scores) # Return the average AUC

    except Exception as e:
        # If other errors occur during training or evaluation
        # print(f"Warning: LR trial failed during CV with params {trial.params}: {e}")
        return 0.0 # Return a low score

def objective_rf(trial, X, y, sample_weights):
    """Optuna objective function for RandomForestClassifier (Revised)."""
    n_estimators = trial.suggest_int('n_estimators', 100, 500)
    max_depth = trial.suggest_int('max_depth', 3, 20) # None can be indirectly searched by setting a large range
    min_samples_split = trial.suggest_int('min_samples_split', 2, 20)
    min_samples_leaf = trial.suggest_int('min_samples_leaf', 1, 20)
    max_features = trial.suggest_categorical('max_features', ['sqrt', 'log2', None, 0.6, 0.8, 1.0])

    model = RandomForestClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        min_samples_split=min_samples_split,
        min_samples_leaf=min_samples_leaf,
        max_features=max_features,
        random_state=42,
        n_jobs=-1,
        class_weight='balanced' # RF's class_weight is usually used in conjunction with sample weights
    )

    # --- Manual Cross-Validation for Evaluation ---
    cv_scores = []
    inner_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

    try:
        for fold, (train_idx, val_idx) in enumerate(inner_cv.split(X, y, groups=y)):
            X_train_fold, X_val_fold = X.iloc[train_idx], X.iloc[val_idx]
            y_train_fold, y_val_fold = y.iloc[train_idx], y.iloc[val_idx]

            fold_weights = sample_weights[train_idx] if sample_weights is not None else None

            fold_model = clone(model)
            if hasattr(fold_model, 'random_state'):
                fold_model.random_state = 42 + fold

            # Train the model, manually passing sample_weight
            # Note: RandomForestClassifier's fit method directly accepts sample_weight
            fit_params_fold = {}
            if fold_weights is not None:
                fit_params_fold['sample_weight'] = fold_weights

            fold_model.fit(X_train_fold, y_train_fold, **fit_params_fold)

            # Evaluate model (using AUC)
            y_pred_proba = fold_model.predict_proba(X_val_fold)[:, 1]
            if len(np.unique(y_val_fold)) < 2:
                # print(f"Info: Skipping RF fold {fold} due to single class in validation set.")
                continue

            score = roc_auc_score(y_val_fold, y_pred_proba)
            cv_scores.append(score)

        if not cv_scores:
            # print("Warning: RF tuning failed, no valid CV scores.")
            return 0.0

        return np.mean(cv_scores)

    except Exception as e:
        # print(f"Warning: RF trial failed during CV with params {trial.params}: {e}")
        return 0.0


def objective_xgb(trial, X, y, sample_weights):
    """Optuna objective function for XGBoost (Revised)."""
    # The parameter search range can be adjusted as needed
    params = {
        'n_estimators': trial.suggest_int('n_estimators', 100, 1500),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.2, log=True),
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        'gamma': trial.suggest_float('gamma', 0.0, 0.5),
        'reg_alpha': trial.suggest_float('reg_alpha', 0.0, 1.0),
        'reg_lambda': trial.suggest_float('reg_lambda', 0.0, 1.0),
        # Imbalance handling parameter, very important
        # scale_pos_weight = count(negative class) / count(positive class)
        # Or you can manually calculate or adjust it based on the value suggested by the Trial
        # If you want Optuna to optimize this value itself, use suggest_float
        'scale_pos_weight': trial.suggest_float('scale_pos_weight', 1.0, 10.0, log=True)
    }

    model = xgb.XGBClassifier(
        objective='binary:logistic',
        eval_metric='auc',
        use_label_encoder=False,
        random_state=42,
        n_jobs=-1,
        **params
    )

    # --- Manual Cross-Validation for Evaluation ---
    cv_scores = []
    inner_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

    try:
        for fold, (train_idx, val_idx) in enumerate(inner_cv.split(X, y, groups=y)):
            X_train_fold, X_val_fold = X.iloc[train_idx], X.iloc[val_idx]
            y_train_fold, y_val_fold = y.iloc[train_idx], y.iloc[val_idx]

            fold_weights = sample_weights[train_idx] if sample_weights is not None else None

            fold_model = clone(model)
            if hasattr(fold_model, 'random_state'):
                fold_model.random_state = 42 + fold

            # Train the model, manually passing sample_weight
            # Note: XGBoost's fit method directly accepts sample_weight
            fit_params_fold = {}
            if fold_weights is not None:
                fit_params_fold['sample_weight'] = fold_weights

            # For XGBoost, eval_set and early_stopping_rounds can be provided in fit, but are not supported by cross_val_score
            # Early stopping can be implemented in manual CV, but it would make the code more complex. Here, for simplicity, only sample_weight is kept as a fit parameter
            fold_model.fit(X_train_fold, y_train_fold, **fit_params_fold)


            # Evaluate model (using AUC)
            y_pred_proba = fold_model.predict_proba(X_val_fold)[:, 1]
            if len(np.unique(y_val_fold)) < 2:
                # print(f"Info: Skipping XGB fold {fold} due to single class in validation set.")
                continue

            score = roc_auc_score(y_val_fold, y_pred_proba)
            cv_scores.append(score)

        if not cv_scores:
            # print("Warning: XGB tuning failed, no valid CV scores.")
            return 0.0

        return np.mean(cv_scores)

    except Exception as e:
        # print(f"Warning: XGB trial failed during CV with params {trial.params}: {e}")
        return 0.0

def objective_lgbm(trial, X, y, sample_weights):
    """Optuna objective function for LightGBM."""
    params = {
        'n_estimators': trial.suggest_int('n_estimators', 100, 1500),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'num_leaves': trial.suggest_int('num_leaves', 20, 150),
        'max_depth': trial.suggest_int('max_depth', 3, 12),
        'min_child_samples': trial.suggest_int('min_child_samples', 5, 100),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        # Modify these two lines, change the lower bound from 0.0 to 0.001 or remove log=True
        'reg_alpha': trial.suggest_float('reg_alpha', 0.001, 10.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 0.001, 10.0, log=True),
        'min_split_gain': trial.suggest_float('min_split_gain', 0.0, 0.5),
        'is_unbalance': True  # Handle imbalanced datasets
    }

    model = lgb.LGBMClassifier(
        objective='binary',
        metric='auc',
        random_state=42,
        n_jobs=-1,
        verbose=-1,
        **params
    )

    # --- Manual Cross-Validation for Evaluation ---
    cv_scores = []
    inner_cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

    try:
        for fold, (train_idx, val_idx) in enumerate(inner_cv.split(X, y, groups=y)):
            X_train_fold, X_val_fold = X.iloc[train_idx], X.iloc[val_idx]
            y_train_fold, y_val_fold = y.iloc[train_idx], y.iloc[val_idx]

            fold_weights = sample_weights[train_idx] if sample_weights is not None else None

            fold_model = clone(model)
            if hasattr(fold_model, 'random_state'):
                fold_model.random_state = 42 + fold

            # Train model
            fit_params_fold = {}
            if fold_weights is not None:
                fit_params_fold['sample_weight'] = fold_weights

            fold_model.fit(X_train_fold, y_train_fold, **fit_params_fold)

            # Evaluate model
            y_pred_proba = fold_model.predict_proba(X_val_fold)[:, 1]
            if len(np.unique(y_val_fold)) < 2:
                continue

            score = roc_auc_score(y_val_fold, y_pred_proba)
            cv_scores.append(score)

        if not cv_scores:
            return 0.0

        return np.mean(cv_scores)

    except Exception as e:
        print(f"LGBM error: {e}")
        return 0.0

# Map objective functions to model names for easy use in train_models_with_cv
objective_map = {
    'logistic': objective_lr,
    'rf': objective_rf,
    'xgb': objective_xgb,
    'lgbm': objective_lgbm # <-- Ensure this line exists
}


# ====================================
# 9. Model Training and Evaluation
# ====================================

# --- Helper Function: Find Optimal Threshold [Unchanged] ---
def find_optimal_threshold(y_true, y_pred_proba, metric='f1'):
    # [Function body same as before]
    thresholds = np.linspace(0.01, 0.99, 100); best_score = -1; best_threshold = 0.5
    if metric == 'f1':
        y_true_int = np.array(y_true).astype(int)
        with np.errstate(divide='ignore', invalid='ignore'): scores = [f1_score(y_true_int, (y_pred_proba >= t).astype(int), zero_division=0) for t in thresholds]
        if np.all(np.isnan(scores)) or not np.any(scores): print(f"  Warning: Could not calculate a valid F1 score"); return 0.5
        best_idx = np.nanargmax(scores); best_score = scores[best_idx]; best_threshold = thresholds[best_idx]
    print(f"  Optimal threshold ({metric.upper()}): {best_threshold:.4f} (Score: {best_score:.4f})")
    return best_threshold

# --- Main Training Function with Tuning and Repeated CV for OOF ---
def train_models_with_cv(X_train, y_train, selected_features_dict,
                         target_to_process=None,
                         sample_weights=None,
                         n_splits=5,
                         n_repeats=5,
                         n_tuning_iter=50,
                         inner_cv_folds=3):
    """
    Trains and evaluates models using hyperparameter tuning and repeated stratified K-fold cross-validation.
    """
    print(f"\nStarting model training with {n_repeats} repeats of {n_splits}-Fold CV and Tuning ({n_tuning_iter} iterations)...")
    if target_to_process: print(f"  Processing specific targets: {target_to_process}")

    final_models = {}
    tuning_results = {}
    oof_predictions = {} # Store the final averaged OOF predictions

    # --- Model and Parameter Definitions ---
    rf_model = RandomForestClassifier(random_state=42, class_weight='balanced', n_jobs=-1)
    rf_param_dist = {'n_estimators': randint(100, 500), 'max_depth': [5, 10, 15, 20, None], 'min_samples_split': randint(2, 11), 'min_samples_leaf': randint(1, 11), 'max_features': ['sqrt', 'log2', None]}

    lr_model = LogisticRegression(random_state=42, solver='liblinear', class_weight='balanced', max_iter=1000)
    lr_param_dist = { 'C': uniform(0.01, 10), 'penalty': ['l1', 'l2'] }

    xgb_model = xgb.XGBClassifier(objective='binary:logistic', eval_metric='auc', use_label_encoder=False, random_state=42, n_jobs=-1)
    xgb_param_dist = {'n_estimators': randint(100, 1000), 'learning_rate': uniform(0.01, 0.19), 'max_depth': randint(3, 8), 'subsample': uniform(0.6, 0.4), 'colsample_bytree': uniform(0.6, 0.4), 'gamma': uniform(0, 0.5), 'reg_alpha': uniform(0, 1), 'reg_lambda': uniform(0, 1)}

    lgbm_model = lgb.LGBMClassifier(random_state=SEED, n_jobs=-1)

    base_models = {'rf': rf_model, 'logistic': lr_model, 'xgb': xgb_model, 'lgbm': lgbm_model}

    # --- Determine Targets to Process ---
    targets_in_data = y_train.columns.tolist()
    if target_to_process:
        targets_to_run = [t for t in target_to_process if t in targets_in_data]
        if not targets_to_run: print("Error: Specified targets are not in the data."); return {}, {}, {}
    else:
        targets_to_run = targets_in_data

    # --- Process Each Target ---
    for target in targets_to_run:
        print(f"\n--- Processing Target: {target} ---")

        # --- Feature Selection ---
        if target not in selected_features_dict or not selected_features_dict[target]:
            features_for_target = X_train.columns.tolist()
            print(f"Warning: No features for {target}, using all {len(features_for_target)}.")
        else:
            features_for_target = [f for f in selected_features_dict[target] if f in X_train.columns]

        if not features_for_target:
            print(f"Error: No valid features for {target}.")
            continue

        missing_in_train = [f for f in features_for_target if f not in X_train.columns]
        if missing_in_train:
            print(f"Error! X_train is missing features: {missing_in_train}.")
            continue

        X_target_full = X_train[features_for_target]
        y_target = y_train[target]
        print(f"  Using {len(features_for_target)} features for {target}.")

        # --- Hyperparameter Tuning - Using Optuna ---
        print(f"  Tuning hyperparameters using Optuna ({n_tuning_iter} trials)...")
        best_overall_auc = -1  # To track the best AUC among different models
        best_model_name = None
        best_params = None  # Store the parameters of the best model
        target_tuning_results = {}  # Store the best result for each model

        models_to_tune = list(base_models.keys())
        print(f"    Models to tune: {models_to_tune}")

        for name in models_to_tune:
            print(f"    Tuning {name}...")
            objective_func = objective_map.get(name)  # Get the corresponding objective function

            if objective_func is None:
                print(f"    Error: No Optuna objective function defined for {name}. Skipping.")
                target_tuning_results[name] = {'best_score': 0, 'best_params': None}
                continue

            try:
                # Create Optuna Study, maximizing AUC
                study = optuna.create_study(
                    direction='maximize',
                    study_name=f"{target}_{name}_tuning",
                    load_if_exists=False
                )

                # Run optimization, passing X, y, and sample_weights to the objective function
                study.optimize(
                    lambda trial: objective_func(trial, X_target_full, y_target, sample_weights),
                    n_trials=n_tuning_iter,
                    show_progress_bar=True
                )

                # Get the best result for the current model
                current_best_auc = study.best_value
                current_best_params = study.best_params

                print(f"      Best AUC for {name}: {current_best_auc:.4f}")

                target_tuning_results[name] = {'best_score': current_best_auc, 'best_params': current_best_params}

                # Update the overall best model and parameters
                if current_best_auc > best_overall_auc:
                    best_overall_auc = current_best_auc
                    best_model_name = name
                    best_params = current_best_params

            except Exception as e:
                print(f"    Error during Optuna tuning for {name}: {e}")
                target_tuning_results[name] = {'best_score': 0.0, 'best_params': None}

        tuning_results[target] = target_tuning_results  # Store the tuning results for all models

        # Create the best model instance based on the stored best model name and parameters
        if best_model_name is None or best_params is None:
            print(f"Error: Optuna tuning failed for {target}. Could not find a best model.")
            continue  # Skip subsequent steps for the current target

        print(f"\n  Selected best model: {best_model_name} (Tuning AUC: {best_overall_auc:.4f})")

        # Instantiate the best model for subsequent OOF and final training
        base_model_instance = base_models[best_model_name]

        if isinstance(base_model_instance, LogisticRegression):
            best_estimator = LogisticRegression(random_state=SEED, max_iter=1000, class_weight='balanced', **best_params)
        elif isinstance(base_model_instance, RandomForestClassifier):
            best_estimator = RandomForestClassifier(random_state=SEED, n_jobs=-1, class_weight='balanced', **best_params)
        elif isinstance(base_model_instance, xgb.XGBClassifier):
            # XGBoost usually includes all parameters in params
            best_estimator = xgb.XGBClassifier(
                objective='binary:logistic',
                eval_metric='auc',
                use_label_encoder=False,
                random_state=SEED,
                n_jobs=-1,
                **best_params
            )
        # Add instantiation branch for LightGBM
        elif isinstance(base_model_instance, lgb.LGBMClassifier): # <-- Ensure this line exists
            best_estimator = lgb.LGBMClassifier( # <-- Ensure this line exists
                random_state=SEED, # <-- Ensure this line exists
                n_jobs=-1, # <-- Ensure this line exists
                objective='binary', # <-- Ensure this line exists
                metric='auc', # <-- Ensure this line exists
                verbose=-1, # <-- Ensure this line exists
                **best_params # <-- Ensure this line exists
            ) # <-- Ensure this line exists
        else:
            print(f"Error: Unknown best model type {best_model_name}. Cannot instantiate.")
            final_models[target] = {'model': None, 'features': features_for_target, 'threshold': optimal_threshold, 'auc': final_cv_mean_auc, 'model_name': best_model_name}
            continue  # Skip subsequent steps for the current target

        # --- Outer CV to get OOF predictions (using RepeatedStratifiedKFold) ---
        print(f"  Generating OOF predictions ({n_repeats} repeats of {n_splits}-fold CV)...")
        oof_preds_sum = np.zeros(len(X_target_full))
        oof_counts = np.zeros(len(X_target_full))

        outer_cv = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=42)
        fold_val_auc_scores = []

        for fold, (train_idx, val_idx) in enumerate(outer_cv.split(X_target_full, y_target)):
            X_train_fold, X_val_fold = X_target_full.iloc[train_idx], X_target_full.iloc[val_idx]
            y_train_fold, y_val_fold = y_target.iloc[train_idx], y_target.iloc[val_idx]

            fold_weights = sample_weights[train_idx] if sample_weights is not None else None
            fold_model = clone(best_estimator)

            if hasattr(fold_model, 'random_state'):
                fold_model.random_state = 42 + fold

            try:
                fit_params_outer = {}
                # Modify the condition, add handling for LGBMClassifier
                if fold_weights is not None and isinstance(fold_model, (LogisticRegression, xgb.XGBClassifier, lgb.LGBMClassifier)):
                    fit_params_outer['sample_weight'] = fold_weights

                fold_model.fit(X_train_fold, y_train_fold, **fit_params_outer)
                fold_pred_proba = fold_model.predict_proba(X_val_fold)[:, 1]

                # Accumulate predictions and counts
                oof_preds_sum[val_idx] += fold_pred_proba
                oof_counts[val_idx] += 1

                fold_val_auc = roc_auc_score(y_val_fold, fold_pred_proba)
                fold_val_auc_scores.append(fold_val_auc)
            except Exception as e:
                print(f"    Error OOF fold {fold+1}: {e}")
                fold_val_auc_scores.append(0.0)

        # Calculate the final average OOF predictions
        valid_oof_mask = oof_counts > 0
        oof_preds_final = np.zeros(len(X_target_full))
        oof_preds_final[valid_oof_mask] = oof_preds_sum[valid_oof_mask] / oof_counts[valid_oof_mask]

        if not np.all(valid_oof_mask):
            print(f"Warning: {np.sum(~valid_oof_mask)} samples have 0 OOF predictions.")

        oof_predictions[target] = oof_preds_final

        # Print the AUC results of the repeated CV
        final_cv_mean_auc = np.mean([s for s in fold_val_auc_scores if s > 0]) if any(s > 0 for s in fold_val_auc_scores) else 0.0
        final_cv_std_auc = np.std([s for s in fold_val_auc_scores if s > 0]) if any(s > 0 for s in fold_val_auc_scores) else 0.0
        print(f"  Outer {n_repeats}x{n_splits}-Fold CV Mean AUC for {best_model_name}: {final_cv_mean_auc:.4f} (+/- {final_cv_std_auc:.4f})")

        # Find the optimal threshold - based on the averaged oof_preds_final
        print(f"  Finding optimal threshold using averaged OOF predictions...")
        optimal_threshold = find_optimal_threshold(y_target, oof_preds_final, metric='f1')

        # Print OOF evaluation metrics - based on the averaged oof_preds_final
        print(f"\n  Calculating OOF metrics using threshold: {optimal_threshold:.4f}")
        oof_preds_binary = (oof_preds_final >= optimal_threshold).astype(int)

        try:
            cm = confusion_matrix(y_target, oof_preds_binary)
            print("  OOF Confusion Matrix:\n", cm)
            tn, fp, fn, tp = cm.ravel()
            print(f"  TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}")
        except Exception as e_cm:
            print(f"  Error OOF CM: {e_cm}")

        try:
            report = classification_report(y_target, oof_preds_binary)
            print("\n  OOF Classification Report:\n", report)
        except Exception as e_report:
            print(f"  Error OOF Report: {e_report}")

        # Train the final model
        print(f"\n  Training final {best_model_name} model on full data...")
        final_model_instance = clone(best_estimator)

        if hasattr(final_model_instance, 'random_state'):
            final_model_instance.random_state = 42

        fit_params_full = {}
        # Modify the condition, add handling for LGBMClassifier
        if sample_weights is not None and isinstance(final_model_instance, (LogisticRegression, xgb.XGBClassifier, lgb.LGBMClassifier)):
            fit_params_full['sample_weight'] = sample_weights

        try:
            final_model_instance.fit(X_target_full, y_target, **fit_params_full)
        except Exception as e:
            print(f"Error training final model {best_model_name}: {e}")
            final_models[target] = {'model': None, 'features': features_for_target, 'threshold': optimal_threshold, 'auc': final_cv_mean_auc, 'model_name': best_model_name}
            continue

        # Store results
        final_models[target] = {'model': final_model_instance, 'features': features_for_target, 'threshold': optimal_threshold, 'auc': final_cv_mean_auc, 'model_name': best_model_name}
        print(f"  Stored final model for {target}.")

    print("\nFinished training and tuning.")
    # Return value includes OOF
    return tuning_results, final_models, oof_predictions

# ====================================
# 11. Feature Importance Visualization
# ====================================
def plot_feature_importance(model, features, target, top_n=20):
    """Plots a feature importance chart"""
    # [Function body same as before]
    if model is None: print(f"Cannot plot importance for {target}, model is None."); return
    if not hasattr(model, 'feature_importances_'): print(f"Warning: Model {model.__class__.__name__} for {target} does not support feature importance."); return
    try:
        importances = model.feature_importances_
        if not features or len(features) != len(importances):
            print(f"Warning: Feature list length mismatch for {target} ({len(features)} vs {len(importances)}). Plot may be incorrect.")
            internal_names = None
            if hasattr(model, 'feature_names_in_'): internal_names = model.feature_names_in_.tolist()
            elif hasattr(model, 'feature_name_'): internal_names = model.feature_name_()
            if internal_names and len(internal_names) == len(importances): print("  Using model's internal feature names."); features = internal_names
            else: print("  Skipping plot."); return
        indices = np.argsort(importances)[::-1]; plot_n = min(top_n, len(features)); indices = indices[:plot_n]
        plt.figure(figsize=(10, max(6, plot_n * 0.3))); plt.title(f'{target} - Top {plot_n} Feature Importance ({model.__class__.__name__})')
        plt.barh(range(len(indices)), importances[indices][::-1], align='center'); plt.yticks(range(len(indices)), np.array(features)[indices][::-1])
        plt.xlabel('Relative Importance'); plt.tight_layout(); plt.savefig(f'{target}_feature_importance.png')
    except Exception as e_plot: print(f"Error plotting importance for {target}: {e_plot}")
    finally: plt.close()


# ====================================
# 12. Test Set Prediction
# ====================================

# --- Helper Function: Calibrate Threshold ---
def calibrate_threshold(y_proba, target_ratio=0.685):
    """ Calibrates threshold to match a target ratio (more robust version) """
    # [Function body same as before]
    if not isinstance(y_proba, np.ndarray): y_proba = np.array(y_proba)
    n_samples = len(y_proba)
    if n_samples == 0: print("Warning (calibrate): Input is empty"); return 0.5
    if np.all(np.isclose(y_proba, y_proba[0])): return y_proba[0]
    n_positives_target = int(round(n_samples * target_ratio))
    if n_positives_target <= 0: return np.max(y_proba) + 1e-7
    if n_positives_target >= n_samples: return np.min(y_proba) - 1e-7
    sorted_probs = np.sort(y_proba)[::-1]
    threshold_lower = sorted_probs[n_positives_target]
    if n_positives_target > 0:
        threshold_upper = sorted_probs[n_positives_target - 1]
        if np.isclose(threshold_upper, threshold_lower): calibrated_threshold = threshold_lower - 1e-7
        else: calibrated_threshold = (threshold_upper + threshold_lower) / 2.0
    else: calibrated_threshold = threshold_lower + 1e-7
    return np.clip(calibrated_threshold, 0.0, 1.0)

# --- Main Prediction Function ---
def make_predictions(X_test, final_models, test_multimodal):
    """Makes predictions on the test set using the trained models. Sorts features as expected by the model."""
    # [Function body same as before]
    print("\nGenerating test set predictions...")
    if X_test is None or not isinstance(X_test, pd.DataFrame):
        print("Error: Valid test data needed.");
        if test_multimodal is not None and 'participant_id' in test_multimodal.columns:
            submission = pd.DataFrame({'participant_id': test_multimodal['participant_id']}); expected_targets = list(final_models.keys()) if final_models else ['ADHD_Outcome', 'Sex_F']
            for target in expected_targets: submission[target] = 0; print("Returning default 0 submission."); return submission, {}
        else: print("Cannot create submission."); return None, None
    if test_multimodal is None or 'participant_id' not in test_multimodal.columns: print("Error: test_multimodal required."); return None, None
    predictions = {}
    for target, model_info in final_models.items():
        print(f"\n--- Predicting for Target: {target} ---")
        model = model_info.get('model'); features_subset_list = model_info.get('features'); model_name = model_info.get('model_name', 'N/A')
        if model is None or features_subset_list is None or not features_subset_list: print(f"Warning: Invalid model/features for {target}. Skipping."); predictions[target] = {'probability': np.zeros(len(X_test)), 'prediction': np.zeros(len(X_test), dtype=int)}; continue
        features_expected_by_model = [];
        try:
            if hasattr(model, 'feature_name_'): features_expected_by_model = model.feature_name_()
            elif hasattr(model, 'feature_names_in_'): features_expected_by_model = model.feature_names_in_.tolist()
            else: print(f"Warning: Cannot get expected features for {target}. Using stored list."); features_expected_by_model = features_subset_list
            if not features_expected_by_model: print(f"Error: Model {target} has no expected features."); features_expected_by_model = features_subset_list
        except Exception as debug_e: print(f"Warning: Error getting expected features: {debug_e}."); features_expected_by_model = features_subset_list
        missing_features = [f for f in features_expected_by_model if f not in X_test.columns]
        if missing_features: print(f"Error: Features missing in X_test: {missing_features}. Skipping."); predictions[target] = {'probability': np.zeros(len(X_test)), 'prediction': np.zeros(len(X_test), dtype=int)}; continue
        try: X_test_features = X_test[features_expected_by_model]; print(f"  Prepared test data {X_test_features.shape} in model order.")
        except Exception as e_reorder: print(f"Error selecting/reordering cols: {e_reorder}. Skipping."); predictions[target] = {'probability': np.zeros(len(X_test)), 'prediction': np.zeros(len(X_test), dtype=int)}; continue
        try:
            y_pred_proba = model.predict_proba(X_test_features)[:, 1]; final_threshold = 0.5
            if target == 'ADHD_Outcome':
                try: final_threshold = calibrate_threshold(y_pred_proba, target_ratio=0.685); print(f"  Using calibrated threshold: {final_threshold:.4f}")
                except Exception as cal_e: print(f"  Warning: Calibration error: {cal_e}. Using 0.5.")
            else: final_threshold = model_info.get('threshold', 0.5); print(f"  Using optimal threshold: {final_threshold:.4f}")
            y_pred = (y_pred_proba >= final_threshold).astype(int); predictions[target] = { 'probability': y_pred_proba, 'prediction': y_pred }
            print(f"  Prediction successful. Positive ratio: {np.mean(y_pred):.2f}")
        except Exception as e: print(f"Error during predict_proba for {target}: {e}"); print(f"    Data shape: {X_test_features.shape}"); predictions[target] = {'probability': np.zeros(len(X_test)), 'prediction': np.zeros(len(X_test), dtype=int)}
    print("\nCreating submission file...")
    submission = pd.DataFrame({'participant_id': test_multimodal['participant_id']})
    expected_targets = list(final_models.keys()) if final_models else ['ADHD_Outcome', 'Sex_F']
    for target in expected_targets:
        pred_val = predictions.get(target, {}).get('prediction', np.zeros(len(submission), dtype=int))
        if len(pred_val) != len(submission): print(f"Warning: Prediction length mismatch for target {target}!"); submission[target] = 0
        else: submission[target] = pred_val
    return submission, predictions



# ====================================
def main():
    """Runs the complete WiDS Datathon 2025 ADHD and Gender Prediction pipeline (two-stage prediction version)"""
    print("WiDS Datathon 2025 - ADHD and Gender Prediction Model using Multimodal Data (Two-Stage Prediction Version)")
    print("======================================================================")

    # --- [File Path Setup and Check] ---
    drive_dir = '/content/drive/MyDrive/wids_datathon'
    if IN_COLAB and not os.path.exists(drive_dir): print("Mounting Google Drive..."); drive.mount('/content/drive')
    elif not IN_COLAB and not os.path.exists(drive_dir): print(f"Error: Drive directory does not exist: {drive_dir}"); return None
    train_cat_path = os.path.join(drive_dir, "TRAIN_CATEGORICAL_METADATA_new.xlsx"); train_quant_path = os.path.join(drive_dir, "TRAIN_QUANTITATIVE_METADATA_new.xlsx")
    solutions_path = os.path.join(drive_dir, "TRAINING_SOLUTIONS.xlsx"); test_cat_path = os.path.join(drive_dir, "TEST_CATEGORICAL.xlsx")
    test_quant_path = os.path.join(drive_dir, "TEST_QUANTITATIVE_METADATA.xlsx"); train_connectome_path = os.path.join(drive_dir, "TRAIN_FUNCTIONAL_CONNECTOME_MATRICES_new_36P_Pearson.csv")
    test_connectome_path = os.path.join(drive_dir, "TEST_FUNCTIONAL_CONNECTOME_MATRICES.csv")
    file_paths = [train_cat_path, train_quant_path, solutions_path, test_cat_path, test_quant_path, train_connectome_path, test_connectome_path]
    if not all(os.path.exists(path) for path in file_paths): print("Error: Some files do not exist."); missing = [p for p in file_paths if not os.path.exists(p)]; print(f"Missing: {missing}"); return None
    print("All files found in Google Drive. Starting analysis...")

    # --- 1. Load Data ---
    train_cat, train_quant, train_con, solutions, test_cat, test_quant, test_con = load_data(train_cat_path, train_quant_path, solutions_path, test_cat_path, test_quant_path, train_connectome_path, test_connectome_path)
    if train_cat is None or train_quant is None or solutions is None: print("Error: Core training data failed to load."); return None

    # --- 2. Merge ---
    train_merged, test_merged = preprocess_and_merge_data(train_cat, train_quant, solutions, test_cat, test_quant)
    if train_merged is None: print("Error: Training data merge failed."); return None

    # --- 3. Connectome Features (with Harmonization) ---
    print("\n--- Preparing Data for Harmonization (if applicable) ---")
    # Use raw data by default
    train_con_to_process = train_con
    test_con_to_process = test_con
    apply_harmonization = True # Or set to False to skip

    if apply_harmonization and train_con is not None and train_cat is not None:
        try:
            from neuroHarmonize import harmonizationLearn
            print("  Applying Harmonization data harmonization...")
            # --- Harmonization Logic ---
            batch_col = 'MRI_Track_Scan_Location' # Select batch variable
            if batch_col not in train_cat.columns or (test_cat is not None and batch_col not in test_cat.columns):
                print(f"Warning: Cannot find batch variable '{batch_col}' in categorical data, skipping Harmonization.")
            else:
                # Prepare covariates
                train_covars = train_cat[['participant_id', batch_col]].copy()
                test_covars = test_cat[['participant_id', batch_col]].copy() if test_cat is not None else pd.DataFrame(columns=['participant_id', batch_col])
                all_covars = pd.concat([train_covars, test_covars], ignore_index=True).drop_duplicates(subset=['participant_id']).set_index('participant_id')

                # Merge connectome data
                all_conn = train_con.copy().set_index('participant_id')
                if test_con is not None:
                    all_conn = pd.concat([all_conn, test_con.copy().set_index('participant_id')], axis=0)

                # Align indices
                common_ids = all_covars.index.intersection(all_conn.index)
                print(f"  Found {len(common_ids)} common participants for Harmonization.")
                all_covars = all_covars.loc[common_ids]
                all_conn_aligned = all_conn.loc[common_ids]
                connectome_features = all_conn_aligned.columns.tolist() # Get feature column names

                # Handle missing covariate values
                if all_covars[batch_col].isnull().any():
                    mode_site = all_covars[batch_col].mode()[0]
                    print(f"  Filling missing values in '{batch_col}' with mode: {mode_site}")
                    all_covars[batch_col].fillna(mode_site, inplace=True)
                all_covars[batch_col] = all_covars[batch_col].astype(str)

                # Execute Harmonization
                print(f"  Executing Harmonization using '{batch_col}' as the batch variable...")
                connectome_data_array = all_conn_aligned.values
                covars_df = all_covars[[batch_col]] # Only pass the batch variable
                # --- Fix: Ensure column names meet neuroHarmonize requirements (e.g., 'SITE') ---
                covars_df.rename(columns={batch_col: 'SITE'}, inplace=True) # <--- Add renaming

                harm_model, data_harmonized = harmonizationLearn(connectome_data_array, covars_df, eb=True)
                print(f"  Harmonization complete.")

                # Convert back to DataFrame and separate
                harmonized_df = pd.DataFrame(data_harmonized, index=common_ids, columns=connectome_features)
                harmonized_df.reset_index(inplace=True) # participant_id becomes a column again
                train_ids = train_con['participant_id'].unique().tolist()
                train_con_harmonized = harmonized_df[harmonized_df['participant_id'].isin(train_ids)]
                if test_con is not None:
                    test_ids = test_con['participant_id'].unique().tolist()
                    test_con_harmonized = harmonized_df[harmonized_df['participant_id'].isin(test_ids)]
                else:
                    test_con_harmonized = None

                # --- Replace original variables with harmonized data ---
                train_con_to_process = train_con_harmonized
                test_con_to_process = test_con_harmonized

        except ImportError: print("Warning: neuroHarmonize not installed, skipping Harmonization")
        except Exception as e_harm: print(f"Warning: Harmonization failed: {e_harm}. Continuing with raw data...")
    else:
        if apply_harmonization: print("\nInfo: Harmonization not performed (due to missing necessary data or apply_harmonization=False).")

    # --- Call process_connectome_data (using raw or harmonized data) ---
    # ===> Corrected call: only pass connectome data <===
    train_connectome_features, test_connectome_features = process_connectome_data(
        train_con_to_process,
        test_con_to_process
    )
    if train_connectome_features is None: print("Warning: Failed to extract training connectome features.")

    # --- 4. Multimodal Fusion ---
    train_multimodal, test_multimodal = multimodal_fusion(train_merged, train_connectome_features, test_merged, test_connectome_features)
    if train_multimodal is None: print("Error: Training data fusion failed."); return None
    if test_multimodal is None and test_merged is not None: test_multimodal = test_merged[['participant_id']].copy()
    if test_multimodal is None: print("Error: Could not get test IDs"); return None


    # --- 5. Initial Feature Engineering (Restored version - no extra interaction features) ---
    # Consider removing the location feature here (if Harmonization was successful)
    batch_col = 'MRI_Track_Scan_Location' # Assume this is the column used for Harmonization
    if 'harm_model' in locals() or apply_harmonization: # If harmonization was attempted
        if batch_col in train_multimodal.columns: train_multimodal = train_multimodal.drop(columns=[batch_col], errors='ignore')
        if test_multimodal is not None and batch_col in test_multimodal.columns: test_multimodal = test_multimodal.drop(columns=[batch_col], errors='ignore')
        print(f"Info: Removed location feature '{batch_col}' (assuming it was handled by Harmonization).")

    X_train_base, y_train, X_test_base, preprocessor, feature_names = feature_engineering(train_multimodal, test_multimodal)
    if X_train_base is None or y_train is None: print("Error: Feature engineering failed."); return None

    # --- 6. Initial Feature Selection ---
    print("\n--- Initial Feature Selection (using base features) ---")
    _, selected_features_dict = select_features(X_train_base, y_train, n_features=30)
    if not selected_features_dict: print("Warning: Initial feature selection failed."); selected_features_dict = {t: X_train_base.columns.tolist() for t in y_train.columns}

    # --- 7. Calculate Sample Weights ---
    sample_weights = calculate_optimized_weights(y_train, female_adhd_boost=2.0, female_non_adhd_boost=1.2, male_non_adhd_boost=1.5, male_adhd_boost=1.5)
    if sample_weights is None: print("Warning: Sample weight calculation failed.")

    # --- 8. Stage 1: Train Sex_F Model and Get Predictions ---
    print("\n===== Stage 1: Training Sex_F Model =====")
    sex_features_list = selected_features_dict.get('Sex_F', X_train_base.columns.tolist())
    sex_feature_dict_for_train = {'Sex_F': sex_features_list}
    tuning_results_sex, final_models_sex, oof_predictions_sex = train_models_with_cv(
        X_train_base, y_train[['Sex_F']], sex_feature_dict_for_train,
        target_to_process=['Sex_F'], sample_weights=sample_weights,
        n_splits=5, n_repeats=5, n_tuning_iter=50, inner_cv_folds=5
    )
    if 'Sex_F' not in final_models_sex or final_models_sex['Sex_F']['model'] is None: print("Error: Sex model training failed."); return None
    oof_sex_proba = oof_predictions_sex.get('Sex_F');
    if oof_sex_proba is None: print("Error: Failed to get OOF sex predictions."); return None
    test_sex_proba = None
    if X_test_base is not None:
        final_sex_model = final_models_sex['Sex_F']['model']; sex_model_features = final_models_sex['Sex_F']['features']
        features_expected_sex = [];
        try: # Get expected feature order from the model
            if hasattr(final_sex_model, 'feature_names_in_'): features_expected_sex = final_sex_model.feature_names_in_.tolist()
            else: features_expected_sex = sex_model_features
            if not features_expected_sex: features_expected_sex = sex_model_features
        except: features_expected_sex = sex_model_features
        missing_test_sex_feats = [f for f in features_expected_sex if f not in X_test_base.columns]
        if not missing_test_sex_feats:
            try: test_sex_proba = final_sex_model.predict_proba(X_test_base[features_expected_sex])[:, 1]; print("  Successfully predicted Sex_F probabilities for the test set.")
            except Exception as e_pred_sex: print(f"  Error predicting Sex_F probabilities for test set: {e_pred_sex}")
        else: print(f"  Test set is missing features required by the Sex model: {missing_test_sex_feats}")
    if test_sex_proba is None: print("Warning: Failed to generate Sex_F prediction probabilities for the test set.")

    # --- 9. Augment Feature Set (add sex_proba and interaction terms - revised version) ---
    print("\n--- Augmenting Features with Sex Predictions ---")
    X_train_aug = X_train_base.copy(); X_train_aug['sex_proba_oof'] = oof_sex_proba
    X_test_aug = None
    if X_test_base is not None:
        X_test_aug = X_test_base.copy()
        if test_sex_proba is not None: X_test_aug['sex_proba_pred'] = test_sex_proba
        else: X_test_aug['sex_proba_pred'] = np.mean(oof_sex_proba); print("  Warning: Filling test set's sex_proba_pred with the training set's average probability")

    # --- Use the correct prefixed feature names ---
    base_feature_map = {
        'MRI_Track_Age_at_Scan': [c for c in X_train_base.columns if 'MRI_Track_Age_at_Scan' in c],
        'SDQ_SDQ_Hyperactivity': [c for c in X_train_base.columns if 'SDQ_SDQ_Hyperactivity' in c],
        'skewness': [c for c in X_train_base.columns if 'skewness' in c],
        'min_degree': [c for c in X_train_base.columns if 'min_degree' in c]
    }
    interaction_base_features_prefixed = []
    for base_name, found_names in base_feature_map.items():
        if found_names: interaction_base_features_prefixed.append(found_names[0])
        else: print(f"Warning: Could not find base feature '{base_name}' for interaction.")
    print(f"  The following features will be used to create interaction terms with sex probability: {interaction_base_features_prefixed}")

    # --- Create Interaction Features ---
    added_train_interactions = []; added_test_interactions = []
    for feat_prefixed in interaction_base_features_prefixed:
        interaction_name = f"{feat_prefixed}_x_sexproba"
        X_train_aug[interaction_name] = X_train_aug[feat_prefixed] * X_train_aug['sex_proba_oof']; added_train_interactions.append(interaction_name)
        if X_test_aug is not None and feat_prefixed in X_test_aug.columns and 'sex_proba_pred' in X_test_aug.columns:
            X_test_aug[interaction_name] = X_test_aug[feat_prefixed] * X_test_aug['sex_proba_pred']; added_test_interactions.append(interaction_name)
    print(f"  Added {len(added_train_interactions)} interaction features to the training set.")
    if X_test_aug is not None: print(f"  Added {len(added_test_interactions)} interaction features to the test set.")

    # --- Column Alignment ---
    if X_test_aug is not None:
        expected_aug_cols = X_train_aug.columns.tolist()
        rename_map = {'sex_proba_pred': 'sex_proba_oof'}
        rename_map.update({f"{f}_x_sexproba": f"{f}_x_sexproba" for f in interaction_base_features_prefixed}) # Use prefixed base names
        X_test_aug_renamed = X_test_aug.rename(columns=rename_map)
        final_test_cols = [col for col in expected_aug_cols if col in X_test_aug_renamed.columns]
        missing_final = list(set(expected_aug_cols) - set(final_test_cols))
        if missing_final: print(f"Warning: Test set is still missing columns after final alignment: {missing_final}")
        X_test_aug = X_test_aug_renamed[final_test_cols]
        if X_test_aug.empty and not X_test_base.empty : print("Warning: Test set became empty after alignment!") # Added check

    # --- 10. Feature Selection Again (only for ADHD, using augmented feature set) ---
    print("\n--- Feature Selection for ADHD (using augmented features) ---")
    _, selected_features_dict_adhd = select_features(X_train_aug, y_train[['ADHD_Outcome']], n_features=20)
    if not selected_features_dict_adhd: print("Warning: ADHD feature selection failed."); selected_features_dict_adhd = {'ADHD_Outcome': X_train_aug.columns.tolist()}

    # --- 11. Stage 2: Train ADHD Model ---
    print("\n===== Stage 2: Training ADHD_Outcome Model =====")
    adhd_feature_dict_for_train = {'ADHD_Outcome': selected_features_dict_adhd.get('ADHD_Outcome', X_train_aug.columns.tolist())}
    tuning_results_adhd, final_models_adhd, oof_predictions_adhd = train_models_with_cv(
        X_train_aug, y_train[['ADHD_Outcome']], adhd_feature_dict_for_train,
        target_to_process=['ADHD_Outcome'], sample_weights=sample_weights,
        n_splits=5, n_repeats=5, n_tuning_iter=50, inner_cv_folds=5 # <--- Increased CV and tuning intensity
    )
    if 'ADHD_Outcome' not in final_models_adhd or final_models_adhd['ADHD_Outcome']['model'] is None: print("Error: ADHD model training failed."); return None

    # --- 12. Generate Final Predictions ---
    print("\n--- Generating Final Predictions (Two-Stage) ---")
    final_predictions = {}
    # Sex_F prediction
    if test_sex_proba is not None: sex_threshold = final_models_sex['Sex_F'].get('threshold', 0.5); final_predictions['Sex_F'] = (test_sex_proba >= sex_threshold).astype(int); print(f"Sex_F prediction generated (Thr: {sex_threshold:.4f}). Ratio: {np.mean(final_predictions['Sex_F']):.2f}")
    else: final_predictions['Sex_F'] = np.zeros(len(test_multimodal), dtype=int); print("Warning: Sex_F prediction used defaults.")
    # ADHD prediction
    if X_test_aug is not None:
        final_adhd_model = final_models_adhd['ADHD_Outcome']['model']; adhd_model_features = final_models_adhd['ADHD_Outcome']['features']
        features_expected_adhd = [];
        try: # Get expected feature order from the model
            if hasattr(final_adhd_model, 'feature_names_in_'): features_expected_adhd = final_adhd_model.feature_names_in_.tolist()
            else: features_expected_adhd = adhd_model_features
        except: features_expected_adhd = adhd_model_features
        missing_test_adhd_feats = [f for f in features_expected_adhd if f not in X_test_aug.columns]
        if not missing_test_adhd_feats:
            try:
                test_adhd_proba = final_adhd_model.predict_proba(X_test_aug[features_expected_adhd])[:, 1]
                adhd_threshold = calibrate_threshold(test_adhd_proba, target_ratio=0.685)
                final_predictions['ADHD_Outcome'] = (test_adhd_proba >= adhd_threshold).astype(int)
                print(f"ADHD prediction generated (Thr: {adhd_threshold:.4f}). Ratio: {np.mean(final_predictions['ADHD_Outcome']):.2f}")
            except Exception as e_pred_adhd: print(f"  Error predicting ADHD: {e_pred_adhd}")
        else: print(f"  Error: Test set missing features for ADHD: {missing_test_adhd_feats}")
    if 'ADHD_Outcome' not in final_predictions: print("Warning: ADHD prediction used defaults."); final_predictions['ADHD_Outcome'] = np.zeros(len(test_multimodal), dtype=int)

    # --- 13. Create and Save Submission File ---
    print("\nCreating final submission file (Two-Stage)...")
    submission = pd.DataFrame({'participant_id': test_multimodal['participant_id']})
    submission['ADHD_Outcome'] = final_predictions.get('ADHD_Outcome', 0) # Using .get() is safer
    submission['Sex_F'] = final_predictions.get('Sex_F', 0)
    submission_path = os.path.join(drive_dir, "wids_datathon_submission_two_stage_final.csv") # Update filename
    try: submission.to_csv(submission_path, index=False); print(f"\nSubmission file saved: {submission_path}"); print("\nSubmission file preview:"); print(submission.head())
    except Exception as e_save: print(f"\nError: Failed to save submission file: {e_save}")

    # --- 14. Visualization and Saving Charts ---
    print("\nGenerating feature importance plots...")
    charts_dir = os.path.join(drive_dir, "charts")

    # Create charts directory (if it doesn't exist)
    if not os.path.exists(charts_dir):
        os.makedirs(charts_dir)
        print(f"Created charts directory: {charts_dir}")

    # Merge dictionaries for convenient looping
    final_models_for_plot = {**final_models_sex, **final_models_adhd}

    # Plot feature importance charts
    for target, model_info in final_models_for_plot.items():
        model = model_info.get('model')
        features_list_for_model = model_info.get('features')

        if model is not None and features_list_for_model:
            print(f"  Generating plot for {target} (Model: {model_info.get('model_name', '?')})...")
            plot_feature_importance(model, features_list_for_model, f"{target}_final", top_n=20)

    # Copy charts to Google Drive
    if charts_dir:
        print(f"\nCopying charts to Google Drive ({charts_dir})...")
        for target in final_models_for_plot.keys():
            try:
                plot_file = f"{target}_final_feature_importance.png"
                if os.path.exists(plot_file):
                    dest_path = os.path.join(charts_dir, plot_file)
                    shutil.copy(plot_file, dest_path)
                    print(f"  Copied {plot_file} to charts directory")
                else:
                    print(f"  Warning: {plot_file} not found")
            except Exception as e:
                print(f"  Error copying {plot_file}: {e}")

    # --- 15. Download Option ---
    if IN_COLAB:
        try:
            if 'submission_path' in locals() and os.path.exists(submission_path): print("\nDownload submission file? (y/n):"); download_response = input();
            if download_response.lower() == 'y': files.download(submission_path)
        except Exception as e_download: print(f"Download prompt error: {e_download}")
    else: print("\nNot in Colab. Skipping download prompt.")

    # --- 16. Return Results ---
    print("\nPipeline finished (Two-Stage).")
    return {'final_models': {**final_models_sex, **final_models_adhd}, 'submission': submission}


# ====================================
# Script Execution Entry Point
# ====================================
if __name__ == "__main__":
    result_objects = main()