# Base methods and imports

In [None]:
import pyarrow.parquet as pq
import pyarrow as pa
import pandas as pd
import numpy as np
import os
import gc
from IPython.display import display

In [None]:
def load_parquet_as_df(file_path):
    """
    Loads a parquet file in true chunks to avoid memory issues.

    Args:
        file_path (str): The path to the parquet file.
        chunk_size (int): The number of rows per chunk.

    Yields:
        pandas.DataFrame: DataFrame chunks of the specified size.
    """

    try:
        if not os.path.exists(file_path):
            print(f"Error: File not found at {file_path}")
            return None

        return pd.read_parquet(file_path, engine='pyarrow')

    except Exception as e:
        print(f"Error loading parquet file {file_path}: {e}")
        return None

    finally:
        gc.collect()


def save_df_as_parquet(dataframe, output_parquet_path, preserve_index=True):
    """
    Save a DataFrame as parquet using PyArrow for maximum efficiency.
    Handles index preservation and duplicate detection.

    Args:
        dataframe (pd.DataFrame): DataFrame to save
        output_parquet_path (str): Path to save the parquet file
        preserve_index (bool): Whether to preserve the DataFrame index
    """

    if dataframe.empty:
        print("No data to save!")
        return

    try:
        if preserve_index and dataframe.index.name is not None:
            duplicate_indices = dataframe.index.duplicated()
            if duplicate_indices.any():
                num_duplicates = duplicate_indices.sum()
                print(f"Warning: Found {num_duplicates} duplicate index values")
                duplicate_values = dataframe.index[duplicate_indices].unique()[:5]
                print(f"First few duplicates: {list(duplicate_values)}")

        table = pa.Table.from_pandas(dataframe, preserve_index=preserve_index)

        if preserve_index and dataframe.index.name:
            print(f"Preserving index: '{dataframe.index.name}' (dtype: {dataframe.index.dtype})")

        pq.write_table(table, output_parquet_path)

        print(f"Successfully saved {len(dataframe)} rows to parquet")
        if preserve_index:
            unique_index_values = len(dataframe.index.unique())
            print(f"Total unique index values: {unique_index_values}")

    except Exception as e:
        print(f"Error saving DataFrame to parquet: {e}")
        raise

    finally:
        # Clean up
        if 'table' in locals():
            del table
        gc.collect()


def transpose_df(dataframe, skip_rows=None, skip_columns=None, dtype='float32'):
    """
    Transpose a DataFrame with optional row/column filtering.

    Args:
        dataframe (pd.DataFrame): DataFrame to transpose
        skip_rows: List of row indices/names to skip, or None
        skip_columns: List of column indices/names to skip, or None
        dtype: Data type for the transposed data (default: 'float32')

    Returns:
        pd.DataFrame: Transposed DataFrame with sample_id column
    """

    if dataframe.empty:
        print("No data to transpose!")
        return pd.DataFrame()

    print("Starting DataFrame transpose...")

    df_filtered = dataframe.copy()

    original_rows, original_cols = df_filtered.shape

    if skip_columns is not None:
        columns_to_keep = [col for col in df_filtered.columns if col not in skip_columns]
        df_filtered = df_filtered[columns_to_keep]
        print(f"Skipping columns: {skip_columns}")

    if skip_rows is not None:
        if skip_rows and isinstance(skip_rows[0], str):
            rows_to_keep = [idx for idx in df_filtered.index if idx not in skip_rows]
        else:
            rows_to_keep = [idx for i, idx in enumerate(df_filtered.index) if i not in skip_rows]

        df_filtered = df_filtered.loc[rows_to_keep]
        print(f"Skipping rows: {skip_rows}")

    filtered_rows, filtered_cols = df_filtered.shape

    print(f"Dataset dimensions:")
    print(f"    Original: {original_rows} rows x {original_cols} columns")
    print(f"    After filtering: {filtered_rows} rows x {filtered_cols} columns")
    print(f"    After transpose: {filtered_cols} rows x {filtered_rows} columns")

    try:
        print("Converting data types and transposing...")
        df_numeric = df_filtered.astype(dtype)
        transposed_df = df_numeric.T

        final_df = pd.DataFrame(
            data=transposed_df.values,
            index=transposed_df.index,
            columns=transposed_df.columns
        )

        # Reset index to create sample_id column
        final_df = final_df.reset_index().rename(columns={'index': 'sample_id'})

        print(f"Transposition completed successfully!")
        print(f"Final DataFrame shape: {final_df.shape}")

        return final_df

    except Exception as e:
        print(f"Error during transposition: {e}")
        raise

    finally:
        # Clean up
        if 'df_filtered' in locals():
            del df_filtered
        if 'df_numeric' in locals():
            del df_numeric
        if 'transposed_df' in locals():
            del transposed_df
        gc.collect()


# def convert_gene_ids_to_symbols(dataset_chunk_iterator, mygene_client, gene_column_prefix="ENSG"):
#     """
#     Convert gene IDs to gene symbols using MyGene API in a memory-efficient way.
#     Only calls MyGene API once for all gene columns from the first chunk.

#     Args:
#         dataset_chunk_iterator: Iterator yielding DataFrame chunks
#         mygene_client: MyGene client instance
#         gene_column_prefix (str): Prefix to identify gene columns

#     Returns:
#         pd.DataFrame: Complete dataset with gene symbols as column names
#     """
#     print("Converting gene IDs to symbols (chunked processing)...")

#     # Get first chunk to determine gene columns and create mapping
#     try:
#         first_chunk = next(dataset_chunk_iterator)
#         print(f"Processing first chunk: {first_chunk.shape}")
#     except StopIteration:
#         print("❌ Error: Dataset is empty")
#         return pd.DataFrame()

#     # Identify gene columns from first chunk
#     gene_columns = [col for col in first_chunk.columns
#                    if isinstance(col, str) and col.startswith(gene_column_prefix)]

#     if not gene_columns:
#         print("No gene columns found - returning original data")
#         # Concatenate all chunks and return
#         all_chunks = [first_chunk]
#         for chunk in dataset_chunk_iterator:
#             all_chunks.append(chunk)
#         return pd.concat(all_chunks, axis=0, ignore_index=False)

#     print(f"Found {len(gene_columns)} gene columns")

#     # Single MyGene API call for all gene IDs
#     gene_id_to_symbol = {}
#     symbols_found = 0

#     try:
#         print("Making single MyGene API call...")
#         import warnings
#         with warnings.catch_warnings():
#             warnings.simplefilter("ignore")
#             results = mygene_client.querymany(
#                 gene_columns,
#                 scopes='ensembl.gene',
#                 fields='symbol',
#                 species='human',
#                 verbose=False,
#                 silent=True
#             )

#         # Process API results
#         for result in results:
#             gene_id = result['query']
#             if 'symbol' in result and result['symbol']:
#                 gene_id_to_symbol[gene_id] = result['symbol']
#                 symbols_found += 1
#             else:
#                 gene_id_to_symbol[gene_id] = gene_id

#         print(f"✅ Successfully converted {symbols_found}/{len(gene_columns)} genes to symbols")

#     except Exception as e:
#         print(f"⚠️ MyGene API error: {e}. Using original gene IDs")
#         gene_id_to_symbol = {gene_id: gene_id for gene_id in gene_columns}

#     # Create final column mapping (preserve non-gene columns like 'condition')
#     final_column_mapping = {}
#     for col in first_chunk.columns:
#         if col == 'condition':
#             final_column_mapping[col] = col
#         elif col in gene_id_to_symbol:
#             final_column_mapping[col] = gene_id_to_symbol[col]
#         else:
#             final_column_mapping[col] = col

#     print("Applying gene symbol mapping to all chunks...")

#     # Process first chunk
#     renamed_first_chunk = first_chunk.rename(columns=final_column_mapping)
#     processed_chunks = [renamed_first_chunk]
#     chunk_count = 1

#     # Process remaining chunks with same mapping
#     for chunk in dataset_chunk_iterator:
#         renamed_chunk = chunk.rename(columns=final_column_mapping)
#         processed_chunks.append(renamed_chunk)
#         chunk_count += 1

#         if chunk_count % 3 == 0:
#             print(f"  ✓ Processed {chunk_count} chunks...")

#     print(f"Concatenating {chunk_count} processed chunks...")
#     final_dataset = pd.concat(processed_chunks, axis=0, ignore_index=False)

#     print(f"✅ Final dataset shape: {final_dataset.shape}")
#     print(f"   Symbols converted: {symbols_found}/{len(gene_columns)}")

#     return final_dataset


def add_condition_labels(dataframe, condition_label, dataset_name):
    """
    Add condition labels to a DataFrame.

    Args:
        dataframe (pd.DataFrame): DataFrame to add labels to
        condition_label (int): Binary label (0 for healthy, 1 for unhealthy)
        dataset_name (str): Name for logging purposes

    Returns:
        pd.DataFrame: DataFrame with condition column added
    """
    if dataframe.empty:
        print(f"Warning: {dataset_name} dataset is empty!")
        return dataframe

    labeled_df = dataframe.copy()

    labeled_df['condition'] = condition_label
    labeled_df['condition'] = labeled_df['condition'].astype(np.int8)

    print(f"✅ Successfully added condition label to {len(labeled_df)} samples in {dataset_name} dataset")

    return labeled_df


def merge_labeled_datasets(healthy_dataframe, unhealthy_dataframe):
    """
    Merge healthy and unhealthy dataset chunks into one DataFrame.

    Args:
        healthy_chunks (list): List of healthy DataFrame chunks
        unhealthy_chunks (list): List of unhealthy DataFrame chunks

    Returns:
        pd.DataFrame: Merged dataset
    """
    if healthy_dataframe.empty and unhealthy_dataframe.empty:
        print("Warning: Both datasets are empty!")
        return

    if healthy_dataframe.empty:
        print("Warning: Healthy dataset is empty, returning only unhealthy data")
        return
    elif unhealthy_dataframe.empty:
        print("Warning: Unhealthy dataset is empty, returning only healthy data")
        return
    else:
        merged_dataset = pd.concat([healthy_dataframe, unhealthy_dataframe], axis=0, ignore_index=False)

    print(f"✅ Merged dataset: {len(merged_dataset)} samples, {len(merged_dataset.columns)} features")

    if 'condition' in merged_dataset.columns:
        print(f"   Healthy (0): {(merged_dataset['condition'] == 0).sum()}")
        print(f"   Unhealthy (1): {(merged_dataset['condition'] == 1).sum()}")
    else:
        print("   Note: No 'condition' column found for condition counting")

    return merged_dataset


def clean_duplicate_nans(dataframe):
    """
    Process a DataFrame to drop duplicates and NaNs.

    Args:
        dataframe (pd.DataFrame): DataFrame to clean

    Returns:
        pd.DataFrame: Cleaned DataFrame with duplicates and NaNs removed
    """
    print("Starting data cleaning (removing NaNs and duplicates)...")

    if dataframe.empty:
        print("Warning: DataFrame is empty!")
        return dataframe

    original_rows = len(dataframe)
    print(f"Original DataFrame shape: {dataframe.shape}")

    cleaned_df = dataframe.copy()

    cleaned_df = cleaned_df.dropna()
    rows_after_nan_drop = len(cleaned_df)

    if rows_after_nan_drop < original_rows:
        nan_dropped = original_rows - rows_after_nan_drop
        print(f"Dropped {nan_dropped} rows with null values")
    else:
        print("No null values found")

    if cleaned_df.empty:
        print("Warning: DataFrame is empty after dropping NaNs!")
        return cleaned_df

    cleaned_df = cleaned_df.drop_duplicates()
    final_rows = len(cleaned_df)

    if final_rows < rows_after_nan_drop:
        duplicates_dropped = rows_after_nan_drop - final_rows
        print(f"Dropped {duplicates_dropped} duplicate rows")
    else:
        print("No duplicate rows found")

    if cleaned_df.empty:
        print("Warning: DataFrame is empty after removing duplicates!")
        return cleaned_df

    total_removed = original_rows - final_rows
    print(f"✅ Cleaning complete: {final_rows} rows remaining ({total_removed} total rows removed)")
    print(f"Final DataFrame shape: {cleaned_df.shape}")

    return cleaned_df


def align_gene_columns(healthy_dataframe, unhealthy_dataframe, gene_column_prefix="ENSG"):
    """
    Align gene columns between healthy and unhealthy datasets.

    Args:
        healthy_dataframe (pd.DataFrame): Healthy dataset
        unhealthy_dataframe (pd.DataFrame): Unhealthy dataset
        gene_column_prefix (str): Prefix to identify gene columns (default: "ENSG")

    Returns:
        tuple: (aligned_healthy_df, aligned_unhealthy_df) with matching columns
    """
    print("Starting gene column alignment...")

    if healthy_dataframe.empty or unhealthy_dataframe.empty:
        print("ERROR: One or both datasets are empty!")
        return pd.DataFrame(), pd.DataFrame()

    print("\nProcessing gene columns and stripping version suffixes...")

    def extract_gene_info(dataframe, dataset_name):
        """Extract gene columns and create mapping from original_column to base_id"""
        print(f"Processing {dataset_name} dataset...")

        original_to_base = {}
        base_gene_ids = set()
        total_gene_columns = 0
        duplicate_count = 0

        for column in dataframe.columns:
            if isinstance(column, str) and column.startswith(gene_column_prefix):
                total_gene_columns += 1
                base_gene_id = column.split('.')[0]

                if base_gene_id not in base_gene_ids:
                    original_to_base[column] = base_gene_id
                    base_gene_ids.add(base_gene_id)
                else:
                    duplicate_count += 1
                    print(f"Warning: Duplicate base gene {base_gene_id} found, skipping {column}")

        print(f"Total gene columns found: {total_gene_columns}")
        print(f"Unique base gene IDs: {len(base_gene_ids)}")
        if duplicate_count > 0:
            print(f"Duplicates removed: {duplicate_count}")

        return original_to_base, base_gene_ids

    healthy_rename_mapping, healthy_base_genes = extract_gene_info(healthy_dataframe, "HEALTHY")
    unhealthy_rename_mapping, unhealthy_base_genes = extract_gene_info(unhealthy_dataframe, "UNHEALTHY")

    common_base_genes = healthy_base_genes & unhealthy_base_genes

    if not common_base_genes:
        print("ERROR: No common genes found between datasets!")
        return pd.DataFrame(), pd.DataFrame()

    print(f"Common genes: {len(common_base_genes)}")
    print(f"Genes exclusive to healthy dataset: {len(healthy_base_genes - common_base_genes)}")
    print(f"Genes exclusive to unhealthy dataset: {len(unhealthy_base_genes - common_base_genes)}")

    renamed_healthy = healthy_dataframe.rename(columns=healthy_rename_mapping)
    renamed_unhealthy = unhealthy_dataframe.rename(columns=unhealthy_rename_mapping)

    healthy_non_gene_cols = [col for col in renamed_healthy.columns
                            if not (isinstance(col, str) and col.startswith(gene_column_prefix))]
    unhealthy_non_gene_cols = [col for col in renamed_unhealthy.columns
                              if not (isinstance(col, str) and col.startswith(gene_column_prefix))]

    if set(healthy_non_gene_cols) != set(unhealthy_non_gene_cols):
        print("Warning: Non-gene columns differ between datasets!")
        print(f"Healthy only: {set(healthy_non_gene_cols) - set(unhealthy_non_gene_cols)}")
        print(f"Unhealthy only: {set(unhealthy_non_gene_cols) - set(healthy_non_gene_cols)}")

    non_gene_cols = healthy_non_gene_cols
    common_gene_base_ids = sorted(common_base_genes)
    final_columns = non_gene_cols + common_gene_base_ids

    print("\nAligning datasets to common columns...")

    try:
        aligned_healthy = renamed_healthy[final_columns].copy()
        aligned_unhealthy = renamed_unhealthy[final_columns].copy()
    except KeyError as e:
        print(f"Error selecting columns: {e}")
        return pd.DataFrame(), pd.DataFrame()

    print(f"✅ Alignment complete!")
    print(f"   Final shape - Healthy: {aligned_healthy.shape}, Unhealthy: {aligned_unhealthy.shape}")
    print(f"   Final columns: {len(final_columns)} ({len(non_gene_cols)} non-gene + {len(common_gene_base_ids)} gene)")

    return aligned_healthy, aligned_unhealthy


# def prepare_metadata_generator(aligned_pair_iterator, output_metadata_path="data/merged_metadata.pq"):
#     """
#     Generator that creates metadata from aligned chunk pairs and saves to parquet.

#     Args:
#         aligned_chunk_pairs_generator: Generator yielding (healthy_chunk, unhealthy_chunk) tuples
#         output_metadata_path: Path to save metadata parquet file

#     Yields:
#         tuple: (healthy_chunk, unhealthy_chunk) - passes through the original chunks
#     """

#     all_metadata = []

#     for healthy_chunk, unhealthy_chunk in aligned_pair_iterator:
#         chunk_metadata = []

#         if healthy_chunk is not None:
#             healthy_sample_ids = [str(sid) for sid in healthy_chunk.index.tolist()]
#             healthy_metadata = [{'sample_id': sid, 'condition': 0} for sid in healthy_sample_ids]
#             chunk_metadata.extend(healthy_metadata)

#         if unhealthy_chunk is not None:
#             unhealthy_sample_ids = [str(sid) for sid in unhealthy_chunk.index.tolist()]
#             unhealthy_metadata = [{'sample_id': sid, 'condition': 1} for sid in unhealthy_sample_ids]
#             chunk_metadata.extend(unhealthy_metadata)

#         all_metadata.extend(chunk_metadata)

#         # Yield the chunks unchanged (pass-through)
#         yield healthy_chunk, unhealthy_chunk
#         del healthy_chunk, unhealthy_chunk

#     if all_metadata:
#         metadata_df = pd.DataFrame(all_metadata)
#         metadata_df.to_parquet(output_metadata_path, index=False)
#         print(f"Metadata saved to {output_metadata_path} with {len(metadata_df)} records")
#     else:
#         print("No metadata to save")

#     # Clean up
#     if 'all_metadata' in locals() or 'all_metadata' in globals():
#         del all_metadata
#     if 'metadata_df' in locals() or 'metadata_df' in globals():
#         del metadata_df
#     gc.collect()

# Unhealthy Preprocessing

In [None]:
unhealthy_dataset_file = 'data/unhealthy_data.pq'
unhealthy_output_file = 'data/unhealthy_data_preprocessed.pq'

unhealthy_df = load_parquet_as_df(unhealthy_dataset_file)
print(f"Starting preprocessing for Unhealthy Dataset: {unhealthy_dataset_file}...")

unhealthy_df = transpose_df(unhealthy_df, dtype='float32')
unhealthy_df.index.name = 'sample_id'
unhealthy_df.set_index('sample_id', drop=True, inplace=True)
unhealthy_df = clean_duplicate_nans(unhealthy_df)
unhealthy_df = add_condition_labels(unhealthy_df, condition_label=1, dataset_name='Unhealthy')

save_df_as_parquet(unhealthy_df, unhealthy_output_file)

# Clean up
if 'unhealthy_df' in locals():
    del unhealthy_df

gc.collect()

## Display unhealthy dataset info

In [None]:
unhealthy_df = load_parquet_as_df('data/unhealthy_data_preprocessed.pq')

print(f"Shape: {unhealthy_df.shape}")
print(f"Columns: {list(unhealthy_df.columns)}")
print(f"Data types:\n{unhealthy_df.dtypes}")
print(f"Memory usage: {unhealthy_df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

print(f"\nFirst 5 rows:")
display(unhealthy_df.head())

# Clean up
if 'unhealthy_df' in locals():
    del unhealthy_df

gc.collect()

# Healthy Preprocessing

In [None]:
healthy_dataset_file = 'data/healthy_data.pq'
healthy_output_file = 'data/healthy_data_preprocessed.pq'
chunk_size = 1000

healthy_df = load_parquet_as_df(healthy_dataset_file)
print(f"Starting preprocessing for Unhealthy Dataset: {healthy_dataset_file}...")

healthy_df = transpose_df(healthy_df, dtype='float32')
healthy_df.index.name = 'sample_id'
healthy_df.set_index('sample_id', drop=True, inplace=True)
healthy_df = clean_duplicate_nans(healthy_df)
healthy_df = add_condition_labels(healthy_df, condition_label=0, dataset_name='Healthy')

save_df_as_parquet(healthy_df, healthy_output_file)

# Clean up
if 'healthy_df' in locals():
    del healthy_df

gc.collect()

## Display healthy dataset info

In [None]:
healthy_df = load_parquet_as_df('data/healthy_data_preprocessed.pq')

print(f"Shape: {healthy_df.shape}")
print(f"Columns: {list(healthy_df.columns)}")
print(f"Data types:\n{healthy_df.dtypes}")
print(f"Memory usage: {healthy_df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

print(f"\nFirst 5 rows:")
display(healthy_df.head())

# Clean up
if 'healthy_df' in locals():
    del healthy_df

gc.collect()

# Merge datasets

In [None]:
healthy_data_path = 'data/healthy_data_preprocessed.pq'
unhealthy_data_path = 'data/unhealthy_data_preprocessed.pq'
merged_data_path = 'data/merged_dataset.pq'

healthy_df = load_parquet_as_df(healthy_data_path)
unhealthy_df = load_parquet_as_df(unhealthy_data_path)

healthy_df, unhealthy_df = align_gene_columns(healthy_df, unhealthy_df)
merged_df = merge_labeled_datasets(healthy_df, unhealthy_df)

save_df_as_parquet(merged_df, merged_data_path)

# Clean up
if 'healthy_df' in locals() or 'healthy_df' in globals():
    del healthy_df
if 'unhealthy_df' in locals() or 'unhealthy_df' in globals():
    del unhealthy_df
if 'merged_df' in locals() or 'merged_df' in globals():
    del merged_df

gc.collect()

## Display basic merged dataset info

In [None]:
merged_df = load_parquet_as_df('data/merged_dataset.pq')

print(f"Shape: {merged_df.shape}")
print(f"Columns: {list(merged_df.columns)}")
print(f"Data types:\n{merged_df.dtypes}")
print(f"Memory usage: {merged_df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

print(f"\nFirst and Last 5 rows:")
display(merged_df)

# Clean up
if 'merged_df' in locals():
    del merged_df

gc.collect()

# Train/Test Split

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os

In [None]:
def create_train_test_split(merged_dataset_path, output_dir="data/splits",
                                 test_size=0.2, random_state=42, create_validation=False):
    """
    Create train-test splits from the merged dataset (before feature selection).

    Args:
        merged_dataset_path (str): Path to merged dataset (with condition column)
        output_dir (str): Directory to save split datasets
        test_size (float): Proportion of data for test set (default: 0.2)
        random_state (int): Random state for reproducibility
        create_validation (bool): Whether to create a validation set from training data

    Returns:
        dict: Paths to saved split files
    """
    print("Loading merged dataset for early train-test splitting...")

    # Load merged dataset
    df = pd.read_parquet(merged_dataset_path)

    print(f"Dataset shape: {df.shape}")

    # Check if condition column exists
    if 'condition' not in df.columns:
        raise ValueError("Condition column not found in merged dataset!")

    print(f"Condition distribution: {df['condition'].value_counts()}")

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Separate features and labels
    X = df.drop('condition', axis=1)  # Remove condition column for features
    y = df['condition']

    print(f"Features shape: {X.shape}")
    print(f"Labels shape: {y.shape}")

    # Initial train-test split (stratified)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y,
        test_size=test_size,
        random_state=random_state,
        stratify=y
    )

    print(f"\nTrain-Test Split Results:")
    print(f"Training set: {X_train.shape[0]} samples ({X_train.shape[0]/len(X)*100:.1f}%)")
    print(f"Test set: {X_test.shape[0]} samples ({X_test.shape[0]/len(X)*100:.1f}%)")
    print(f"Training condition distribution: {y_train.value_counts()}")
    print(f"Test condition distribution: {y_test.value_counts()}")

    # Create validation set if requested
    if create_validation:
        X_train_final, X_val, y_train_final, y_val = train_test_split(
            X_train, y_train,
            test_size=0.25,  # 25% of training = 20% of total (if test_size=0.2)
            random_state=random_state,
            stratify=y_train
        )

        print(f"\nWith Validation Set:")
        print(f"Final training: {X_train_final.shape[0]} samples ({X_train_final.shape[0]/len(X)*100:.1f}%)")
        print(f"Validation: {X_val.shape[0]} samples ({X_val.shape[0]/len(X)*100:.1f}%)")
        print(f"Test: {X_test.shape[0]} samples ({X_test.shape[0]/len(X)*100:.1f}%)")

    # Save the splits
    file_paths = {}

    if create_validation:
        # Save with validation split
        X_train_final.to_parquet(f"{output_dir}/X_train_raw.pq", index=True)
        X_val.to_parquet(f"{output_dir}/X_val_raw.pq", index=True)
        X_test.to_parquet(f"{output_dir}/X_test_raw.pq", index=True)

        y_train_final.to_frame('condition').to_parquet(f"{output_dir}/y_train.pq", index=True)
        y_val.to_frame('condition').to_parquet(f"{output_dir}/y_val.pq", index=True)
        y_test.to_frame('condition').to_parquet(f"{output_dir}/y_test.pq", index=True)

        file_paths = {
            'X_train': f"{output_dir}/X_train_raw.pq",
            'X_val': f"{output_dir}/X_val_raw.pq",
            'X_test': f"{output_dir}/X_test_raw.pq",
            'y_train': f"{output_dir}/y_train.pq",
            'y_val': f"{output_dir}/y_val.pq",
            'y_test': f"{output_dir}/y_test.pq"
        }

        print(f"\nSaved train/validation/test splits to {output_dir}/")

        # Create visualization with validation set
        create_split_visualization(X_train_final, X_test, y_train_final, y_test, X_val, y_val)

    else:
        # Save without validation split
        X_train.to_parquet(f"{output_dir}/X_train_raw.pq", index=True)
        X_test.to_parquet(f"{output_dir}/X_test_raw.pq", index=True)

        y_train.to_frame('condition').to_parquet(f"{output_dir}/y_train.pq", index=True)
        y_test.to_frame('condition').to_parquet(f"{output_dir}/y_test.pq", index=True)

        file_paths = {
            'X_train': f"{output_dir}/X_train_raw.pq",
            'X_test': f"{output_dir}/X_test_raw.pq",
            'y_train': f"{output_dir}/y_train.pq",
            'y_test': f"{output_dir}/y_test.pq"
        }

        print(f"\nSaved train/test splits to {output_dir}/")

        # Create visualization without validation set
        create_split_visualization(X_train, X_test, y_train, y_test)

    return file_paths


def create_split_visualization(X_train, X_test, y_train, y_test, X_val=None, y_val=None):
    """
    Create visualization of the train-test split.
    """

    print("\nCreating split visualization...")

    # Combine data for consistent PCA
    if X_val is not None:
        X_combined = pd.concat([X_train, X_val, X_test])
        y_combined = pd.concat([y_train, y_val, y_test])
        split_labels = (['Train'] * len(X_train) +
                       ['Validation'] * len(X_val) +
                       ['Test'] * len(X_test))
    else:
        X_combined = pd.concat([X_train, X_test])
        y_combined = pd.concat([y_train, y_test])
        split_labels = ['Train'] * len(X_train) + ['Test'] * len(X_test)

    # Perform PCA
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_combined)
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(X_scaled)

    # Create plot
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    # Plot 1: Colored by split
    ax1 = axes[0]
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] if X_val is not None else ['#1f77b4', '#ff7f0e']
    for i, split_type in enumerate(set(split_labels)):
        mask = [label == split_type for label in split_labels]
        ax1.scatter(pca_result[mask, 0], pca_result[mask, 1],
                   c=colors[i], label=split_type, alpha=0.7, s=50)

    ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    ax1.set_title('Data Split Visualization')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Colored by condition
    ax2 = axes[1]
    condition_colors = {0: '#2E8B57', 1: '#DC143C'}  # Green for healthy, red for unhealthy
    for condition in [0, 1]:
        mask = y_combined == condition
        label = 'Healthy' if condition == 0 else 'Unhealthy'
        ax2.scatter(pca_result[mask, 0], pca_result[mask, 1],
                   c=condition_colors[condition], label=label, alpha=0.7, s=50)

    ax2.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    ax2.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    ax2.set_title('Condition Distribution Across Splits')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print("Split visualization complete!")

In [None]:
split_paths = create_train_test_split(
    merged_dataset_path="data/merged_dataset_tpm_normalized.pq",
    output_dir="data/splits",
    test_size=0.2,
    random_state=42,
    create_validation=False
)

print("Split files created:")
for split_name, path in split_paths.items():
    print(f"  {split_name}: {path}")

# Log2 Tansformation

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
def apply_log_transformation_to_splits(split_paths, output_dir="data/splits"):
    """
    Apply log2(x + 1) transformation to train/test splits.

    Args:
        split_paths (dict): Paths to raw split files
        output_dir (str): Directory to save transformed splits

    Returns:
        dict: Paths to transformed split files
    """
    os.makedirs(output_dir, exist_ok=True)

    transformed_paths = {}

    for split_name, file_path in split_paths.items():
        if split_name.startswith('X_'):
            print(f"Transforming {split_name}...")

            df = pd.read_parquet(file_path)

            if 'condition' in df.columns:
                condition_col = df['condition']
                feature_cols = df.drop('condition', axis=1)

                transformed_features = np.log2(feature_cols + 1)

                transformed_df = pd.DataFrame(
                    transformed_features,
                    index=df.index,
                    columns=feature_cols.columns
                )
                transformed_df['condition'] = condition_col
            else:
                transformed_df = pd.DataFrame(
                    np.log2(df + 1),
                    index=df.index,
                    columns=df.columns
                )

            filename = os.path.basename(file_path).replace('_raw.pq', '_log2.pq')
            output_path = os.path.join(output_dir, filename)

            transformed_df.to_parquet(output_path, index=True)
            transformed_paths[split_name] = output_path

            print(f"  Saved to: {output_path}")

            del df, transformed_df

        else:
            transformed_paths[split_name] = file_path

    print("✅ Log transformation complete!")
    return transformed_paths


def visualize_log_transformation(raw_split_paths, sample_size=100000):
    """
    Create visualization comparing raw vs log-transformed data using training data.

    Args:
        raw_split_paths (dict): Paths to raw split files
        sample_size (int): Number of values to sample for plotting
    """
    print("Creating log transformation visualization...")

    # Load training data for visualization
    X_train_raw = pd.read_parquet(raw_split_paths['X_train'])

    print(f"Loaded training data: {X_train_raw.shape}")

    # Sample data for plotting
    raw_values = X_train_raw.values.flatten()
    if len(raw_values) > sample_size:
        raw_sample = np.random.choice(raw_values, size=sample_size, replace=False)
    else:
        raw_sample = raw_values

    # Apply log transformation
    log_sample = np.log2(raw_sample + 1)

    print(f"Sampled {len(raw_sample)} values for visualization")

    # Create plots
    plt.style.use('default')
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    # Raw data distribution
    axes[0].hist(raw_sample, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0].set_title('Distribution of Raw TPM Values (Training Data)')
    axes[0].set_xlabel('Raw TPM Value')
    axes[0].set_ylabel('Frequency')
    axes[0].set_yscale('log')
    axes[0].grid(True, alpha=0.3)

    # Log-transformed data distribution
    axes[1].hist(log_sample, bins=50, alpha=0.7, color='lightcoral', edgecolor='black')
    axes[1].set_title('Distribution of Log2(TPM + 1) Values (Training Data)')
    axes[1].set_xlabel('Log2(TPM + 1) Value')
    axes[1].set_ylabel('Frequency')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print summary statistics
    print(f"\nSummary Statistics:")
    print(f"Raw data - Mean: {np.mean(raw_sample):.2f}, Std: {np.std(raw_sample):.2f}")
    print(f"Log-transformed - Mean: {np.mean(log_sample):.2f}, Std: {np.std(log_sample):.2f}")

    # Clean up
    del X_train_raw, raw_values, raw_sample, log_sample
    gc.collect()

In [None]:
visualize_log_transformation(split_paths, sample_size=50000)

transformed_paths = apply_log_transformation_to_splits(split_paths)

for split_name, path in transformed_paths.items():
    print(f"  {split_name}: {path}")

# Batch effect correction

In [None]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from inmoose.pycombat import pycombat_norm
import pandas as pd
import numpy as np
import os

In [None]:
def apply_batch_correction_to_splits(transformed_paths, output_dir="data/splits"):
    """
    Apply batch effect correction using pycombat_norm to train/test splits.
    Assumes batch effects are present but no specific batch information is available.

    Args:
        transformed_paths (dict): Paths to log-transformed split files
        output_dir (str): Directory to save batch-corrected splits

    Returns:
        dict: Paths to batch-corrected split files
    """
    print("=== BATCH EFFECT CORRECTION ===")
    print("Using pycombat_norm without covariates")

    os.makedirs(output_dir, exist_ok=True)

    # Load training and test data
    X_train = pd.read_parquet(transformed_paths['X_train'])
    X_test = pd.read_parquet(transformed_paths['X_test'])

    print(f"Training data shape: {X_train.shape}")
    print(f"Test data shape: {X_test.shape}")

    # Check if we should apply batch correction
    # For genomics data, batch correction is often beneficial even without explicit batch info

    # Combine train and test for batch correction
    # This ensures consistent correction across splits
    combined_data = pd.concat([X_train, X_test], axis=0)
    print(f"Combined data shape: {combined_data.shape}")

    # Create simple batch labels based on dataset source
    # Alternative: you could create batches based on sample ID patterns or other metadata
    batch_labels = (['train'] * len(X_train) + ['test'] * len(X_test))

    try:
        print("\n🔄 Applying pycombat batch correction...")
        print("   Method: Combat without covariates")
        print("   This may take a few minutes for large datasets...")

        # Prepare data for pycombat (features as rows, samples as columns)
        # pycombat expects: genes/features as rows, samples as columns
        data_for_combat = combined_data.T  # Transpose: features as rows, samples as columns

        # Apply batch correction
        corrected_data = pycombat_norm(
            data_for_combat,
            batch=batch_labels,
            mod=None  # No covariates
        )

        # Transpose back to original format (samples as rows, features as columns)
        corrected_combined = pd.DataFrame(
            corrected_data.T,
            index=combined_data.index,
            columns=combined_data.columns
        )

        print("✅ Batch correction completed successfully!")

        # Split back into train and test
        n_train = len(X_train)
        corrected_X_train = corrected_combined.iloc[:n_train]
        corrected_X_test = corrected_combined.iloc[n_train:]

        print(f"   Corrected training shape: {corrected_X_train.shape}")
        print(f"   Corrected test shape: {corrected_X_test.shape}")

        # Save corrected splits
        batch_corrected_paths = {}

        for split_name, file_path in transformed_paths.items():
            if split_name == 'X_train':
                filename = os.path.basename(file_path).replace('_log2.pq', '_batch_corrected.pq')
                output_path = os.path.join(output_dir, filename)
                corrected_X_train.to_parquet(output_path, index=True)
                batch_corrected_paths[split_name] = output_path
                print(f"   Saved corrected training data to: {output_path}")

            elif split_name == 'X_test':
                filename = os.path.basename(file_path).replace('_log2.pq', '_batch_corrected.pq')
                output_path = os.path.join(output_dir, filename)
                corrected_X_test.to_parquet(output_path, index=True)
                batch_corrected_paths[split_name] = output_path
                print(f"   Saved corrected test data to: {output_path}")

            else:
                # Keep y_train, y_test unchanged
                batch_corrected_paths[split_name] = file_path

        # Save batch correction info
        save_batch_correction_info(output_dir, method="pycombat_norm", covariates=None)

        return batch_corrected_paths

    except Exception as e:
        print(f"❌ Batch correction failed: {e}")
        print("   Returning original paths without batch correction")
        return transformed_paths

    finally:
        # Clean up memory
        if 'combined_data' in locals():
            del combined_data
        if 'data_for_combat' in locals():
            del data_for_combat
        if 'corrected_data' in locals():
            del corrected_data
        if 'corrected_combined' in locals():
            del corrected_combined
        gc.collect()


def save_batch_correction_info(output_dir, method="pycombat_norm", covariates=None):
    """
    Save information about batch correction applied.
    """
    info_path = os.path.join(output_dir, 'batch_correction_info.txt')

    with open(info_path, 'w') as f:
        f.write("Batch Effect Correction Summary\n")
        f.write("==============================\n\n")
        f.write(f"Method: {method}\n")
        f.write(f"Covariates: {covariates if covariates else 'None'}\n")
        f.write(f"Batch definition: Simple train/test split\n")
        f.write("Note: Applied to remove potential technical batch effects\n")
        f.write("      between training and test data\n")

    print(f"   📄 Batch correction info saved to: {info_path}")


def visualize_batch_effects(before_paths, after_paths, sample_size=1000):
    """
    Create simple visualization to compare data before and after batch correction.
    """
    print("📊 Creating batch correction visualization...")

    # Load data before and after correction
    X_train_before = pd.read_parquet(before_paths['X_train'])
    X_train_after = pd.read_parquet(after_paths['X_train'])

    # Sample data for visualization
    if len(X_train_before) > sample_size:
        sample_idx = np.random.choice(len(X_train_before), sample_size, replace=False)
        train_before_sample = X_train_before.iloc[sample_idx]
        train_after_sample = X_train_after.iloc[sample_idx]
    else:
        train_before_sample = X_train_before
        train_after_sample = X_train_after

    # Create comparison plots
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    scaler = StandardScaler()
    pca = PCA(n_components=2)

    # Before correction
    before_scaled = scaler.fit_transform(train_before_sample)
    before_pca = pca.fit_transform(before_scaled)

    axes[0].scatter(before_pca[:, 0], before_pca[:, 1], alpha=0.6, s=30)
    axes[0].set_title('Before Batch Correction')
    axes[0].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    axes[0].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    axes[0].grid(True, alpha=0.3)

    # After correction
    after_scaled = scaler.fit_transform(train_after_sample)
    after_pca = pca.fit_transform(after_scaled)

    axes[1].scatter(after_pca[:, 0], after_pca[:, 1], alpha=0.6, s=30, color='orange')
    axes[1].set_title('After Batch Correction')
    axes[1].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    axes[1].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print summary statistics
    print(f"\nBatch Correction Summary:")
    print(f"   Data points visualized: {len(train_before_sample):,}")
    print(f"   Features: {train_before_sample.shape[1]:,}")

    # Clean up
    del X_train_before, X_train_after, train_before_sample, train_after_sample
    gc.collect()

In [None]:
# Apply batch correction
print("=== APPLYING BATCH EFFECT CORRECTION ===")

batch_corrected_paths = apply_batch_correction_to_splits(transformed_paths)

# Visualize batch effects after correction (if correction was applied)
if batch_corrected_paths != transformed_paths:
    print("\n📊 Batch correction was applied - creating visualization...")
    visualize_batch_effects(transformed_paths, batch_corrected_paths)
else:
    print("\n⚠️  Batch correction was skipped - using original transformed data")

print("\nBatch-corrected (or original) file paths:")
for split_name, path in batch_corrected_paths.items():
    print(f"  {split_name}: {path}")

# Feature Selection

In [None]:
from sklearn.feature_selection import SelectKBest, f_classif, VarianceThreshold
import numpy as np
import pandas as pd
import joblib

In [None]:
def apply_comprehensive_feature_selection(transformed_paths, variance_threshold=0.01,
                                        k_best=5000, output_dir="data/splits"):
    """
    Apply both variance-based and k-best feature selection in sequence.

    Args:
        transformed_paths (dict): Paths to log-transformed split files
        variance_threshold (float): Minimum variance threshold (default: 0.01)
        k_best (int): Number of top features to keep after variance filtering
        output_dir (str): Directory to save feature-selected splits

    Returns:
        tuple: (selected_paths, comprehensive_feature_info)
    """
    from sklearn.feature_selection import SelectKBest, f_classif

    os.makedirs(output_dir, exist_ok=True)

    print("=== COMPREHENSIVE FEATURE SELECTION ===")
    print("Step 1: Variance Threshold (Unsupervised)")
    print("Step 2: SelectKBest (Supervised)")
    print()

    # Load training data and labels
    X_train = pd.read_parquet(transformed_paths['X_train'])
    y_train = pd.read_parquet(transformed_paths['y_train'])['condition']

    print(f"Original features: {X_train.shape[1]:,}")

    # Step 1: Variance thresholding (unsupervised)
    print(f"\n🔄 Step 1: Applying variance threshold ({variance_threshold})...")
    variance_selector = VarianceThreshold(threshold=variance_threshold)
    variance_selector.fit(X_train)

    # Get features that pass variance threshold
    variance_mask = variance_selector.get_support()
    features_after_variance = X_train.columns[variance_mask].tolist()
    variance_removed = X_train.columns[~variance_mask].tolist()

    print(f"✅ Variance filtering complete:")
    print(f"   Kept: {len(features_after_variance):,} features")
    print(f"   Removed: {len(variance_removed):,} features")

    # Step 2: SelectKBest (supervised) - only if we have more than k_best features
    if len(features_after_variance) > k_best:
        print(f"\n🔄 Step 2: Applying SelectKBest (k={k_best:,})...")

        # Apply variance selection to get intermediate dataset
        X_train_var = X_train[features_after_variance]

        # Apply SelectKBest
        k_best_selector = SelectKBest(f_classif, k=k_best)
        k_best_selector.fit(X_train_var, y_train)

        # Get final selected features
        kbest_mask = k_best_selector.get_support()
        final_features = np.array(features_after_variance)[kbest_mask].tolist()
        kbest_removed = np.array(features_after_variance)[~kbest_mask].tolist()

        print(f"✅ SelectKBest complete:")
        print(f"   Final features: {len(final_features):,}")
        print(f"   Additional removed: {len(kbest_removed):,}")

        # Save both selectors
        joblib.dump(variance_selector, os.path.join(output_dir, 'variance_selector.pkl'))
        joblib.dump(k_best_selector, os.path.join(output_dir, 'kbest_selector.pkl'))

    else:
        print(f"\n⚠️  Skipping SelectKBest: Only {len(features_after_variance):,} features after variance filtering (less than k={k_best:,})")
        final_features = features_after_variance
        k_best_selector = None
        kbest_removed = []

        # Save only variance selector
        joblib.dump(variance_selector, os.path.join(output_dir, 'variance_selector.pkl'))

    print(f"\n📊 FINAL RESULTS:")
    print(f"   Original features: {X_train.shape[1]:,}")
    print(f"   Final features: {len(final_features):,}")
    print(f"   Total reduction: {((X_train.shape[1] - len(final_features)) / X_train.shape[1] * 100):.1f}%")

    # Apply final feature selection to all splits
    print(f"\n🔄 Applying feature selection to all splits...")
    selected_paths = {}

    for split_name, file_path in transformed_paths.items():
        if split_name.startswith('X_'):
            print(f"   Processing {split_name}...")

            df = pd.read_parquet(file_path)
            df_selected = df[final_features]  # Apply same selection to all splits

            filename = os.path.basename(file_path).replace('_log2.pq', '_selected.pq')
            output_path = os.path.join(output_dir, filename)

            df_selected.to_parquet(output_path, index=True)
            selected_paths[split_name] = output_path

            print(f"     {df.shape} -> {df_selected.shape}")

            del df, df_selected
        else:
            selected_paths[split_name] = file_path

    # Create comprehensive feature info
    comprehensive_feature_info = {
        'final_selected_features': final_features,
        'variance_removed_features': variance_removed,
        'kbest_removed_features': kbest_removed,
        'variance_threshold': variance_threshold,
        'k_best': k_best if k_best_selector else None,
        'n_features_original': X_train.shape[1],
        'n_features_after_variance': len(features_after_variance),
        'n_features_final': len(final_features),
        'total_reduction_pct': ((X_train.shape[1] - len(final_features)) / X_train.shape[1] * 100),
        'used_kbest': k_best_selector is not None
    }

    # Save detailed results
    with open(os.path.join(output_dir, 'comprehensive_feature_selection_info.txt'), 'w') as f:
        f.write("Comprehensive Feature Selection Results\n")
        f.write("======================================\n\n")
        f.write(f"Original features: {comprehensive_feature_info['n_features_original']:,}\n")
        f.write(f"After variance threshold ({variance_threshold}): {comprehensive_feature_info['n_features_after_variance']:,}\n")
        if comprehensive_feature_info['used_kbest']:
            f.write(f"After SelectKBest ({k_best:,}): {comprehensive_feature_info['n_features_final']:,}\n")
        f.write(f"Final features: {comprehensive_feature_info['n_features_final']:,}\n")
        f.write(f"Total reduction: {comprehensive_feature_info['total_reduction_pct']:.1f}%\n\n")
        f.write("Feature Selection Steps:\n")
        f.write("1. Variance Threshold (Unsupervised)\n")
        if comprehensive_feature_info['used_kbest']:
            f.write("2. SelectKBest with F-classification (Supervised)\n")
        else:
            f.write("2. SelectKBest skipped (insufficient features)\n")

    print("✅ Comprehensive feature selection complete!")
    print(f"📁 Results saved to: {output_dir}/")

    return selected_paths, comprehensive_feature_info


def visualize_comprehensive_feature_selection(comprehensive_feature_info, X_train_path):
    """
    Create comprehensive visualizations for the two-step feature selection process.

    Args:
        comprehensive_feature_info (dict): Results from comprehensive feature selection
        X_train_path (str): Path to training data before selection
    """
    print("Creating comprehensive feature selection visualizations...")

    # Load training data to analyze variances
    X_train = pd.read_parquet(X_train_path)
    variances = X_train.var()

    # Create comprehensive visualization
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # Plot 1: Variance distribution with threshold
    axes[0, 0].hist(variances, bins=100, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0, 0].axvline(comprehensive_feature_info['variance_threshold'],
                      color='red', linestyle='--', linewidth=2,
                      label=f"Threshold = {comprehensive_feature_info['variance_threshold']}")
    axes[0, 0].set_xlabel('Feature Variance')
    axes[0, 0].set_ylabel('Number of Features')
    axes[0, 0].set_title('Step 1: Variance Distribution')
    axes[0, 0].legend()
    axes[0, 0].set_yscale('log')
    axes[0, 0].grid(True, alpha=0.3)

    # Plot 2: Feature selection pipeline
    if comprehensive_feature_info['used_kbest']:
        stages = ['Original', 'After\nVariance', 'Final\n(K-Best)']
        counts = [
            comprehensive_feature_info['n_features_original'],
            comprehensive_feature_info['n_features_after_variance'],
            comprehensive_feature_info['n_features_final']
        ]
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    else:
        stages = ['Original', 'Final\n(Variance Only)']
        counts = [
            comprehensive_feature_info['n_features_original'],
            comprehensive_feature_info['n_features_final']
        ]
        colors = ['#FF6B6B', '#45B7D1']

    bars = axes[0, 1].bar(stages, counts, color=colors, alpha=0.8, edgecolor='black')
    axes[0, 1].set_ylabel('Number of Features')
    axes[0, 1].set_title('Feature Selection Pipeline')
    axes[0, 1].grid(True, alpha=0.3, axis='y')

    # Add count labels on bars
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        axes[0, 1].text(bar.get_x() + bar.get_width()/2., height + height*0.01,
                       f'{count:,}', ha='center', va='bottom', fontweight='bold')

    # Plot 3: Reduction percentages
    if comprehensive_feature_info['used_kbest']:
        variance_reduction = ((comprehensive_feature_info['n_features_original'] -
                             comprehensive_feature_info['n_features_after_variance']) /
                            comprehensive_feature_info['n_features_original'] * 100)
        kbest_reduction = ((comprehensive_feature_info['n_features_after_variance'] -
                          comprehensive_feature_info['n_features_final']) /
                         comprehensive_feature_info['n_features_after_variance'] * 100)

        reduction_stages = ['Variance\nFiltering', 'K-Best\nSelection']
        reduction_pcts = [variance_reduction, kbest_reduction]
        reduction_colors = ['#FF9F43', '#10AC84']
    else:
        reduction_stages = ['Variance\nFiltering']
        reduction_pcts = [comprehensive_feature_info['total_reduction_pct']]
        reduction_colors = ['#10AC84']

    bars = axes[1, 0].bar(reduction_stages, reduction_pcts,
                         color=reduction_colors, alpha=0.8, edgecolor='black')
    axes[1, 0].set_ylabel('Reduction Percentage (%)')
    axes[1, 0].set_title('Feature Reduction by Stage')
    axes[1, 0].grid(True, alpha=0.3, axis='y')

    # Add percentage labels
    for bar, pct in zip(bars, reduction_pcts):
        height = bar.get_height()
        axes[1, 0].text(bar.get_x() + bar.get_width()/2., height + height*0.01,
                       f'{pct:.1f}%', ha='center', va='bottom', fontweight='bold')

    # Plot 4: Final summary pie chart
    if comprehensive_feature_info['used_kbest']:
        labels = ['Selected', 'Removed by\nVariance', 'Removed by\nK-Best']
        sizes = [
            comprehensive_feature_info['n_features_final'],
            len(comprehensive_feature_info['variance_removed_features']),
            len(comprehensive_feature_info['kbest_removed_features'])
        ]
        colors = ['#2E8B57', '#DC143C', '#FF8C00']
        explode = (0.05, 0, 0)
    else:
        labels = ['Selected', 'Removed by\nVariance']
        sizes = [
            comprehensive_feature_info['n_features_final'],
            len(comprehensive_feature_info['variance_removed_features'])
        ]
        colors = ['#2E8B57', '#DC143C']
        explode = (0.05, 0)

    wedges, texts, autotexts = axes[1, 1].pie(sizes, labels=labels, colors=colors,
                                             autopct='%1.1f%%', explode=explode,
                                             shadow=True, startangle=90)
    axes[1, 1].set_title('Final Feature Distribution')

    # Make percentage text bold
    for autotext in autotexts:
        autotext.set_color('white')
        autotext.set_fontweight('bold')

    plt.tight_layout()
    plt.show()

    # Print detailed summary
    print(f"\n{'='*60}")
    print("COMPREHENSIVE FEATURE SELECTION SUMMARY")
    print(f"{'='*60}")
    print(f"📊 Original features: {comprehensive_feature_info['n_features_original']:,}")
    print(f"🔄 After variance filtering: {comprehensive_feature_info['n_features_after_variance']:,}")
    if comprehensive_feature_info['used_kbest']:
        print(f"🎯 After K-Best selection: {comprehensive_feature_info['n_features_final']:,}")
    print(f"✅ Final features: {comprehensive_feature_info['n_features_final']:,}")
    print(f"📉 Total reduction: {comprehensive_feature_info['total_reduction_pct']:.1f}%")
    print(f"🧬 Features per sample ratio: {comprehensive_feature_info['n_features_final']:,} features")

    if comprehensive_feature_info['used_kbest']:
        print(f"\n🔬 Selection Methods Used:")
        print(f"   1. Variance Threshold ({comprehensive_feature_info['variance_threshold']})")
        print(f"   2. SelectKBest (k={comprehensive_feature_info['k_best']:,})")
    else:
        print(f"\n🔬 Selection Method Used:")
        print(f"   • Variance Threshold only ({comprehensive_feature_info['variance_threshold']})")

    # Clean up
    del X_train, variances
    gc.collect()

In [None]:
selected_paths_conservative, feature_info_conservative = apply_comprehensive_feature_selection(
    transformed_paths,
    variance_threshold=0.005,
    k_best=50,
    output_dir="data/splits"
)

# Visualize results
visualize_comprehensive_feature_selection(feature_info_conservative, transformed_paths['X_train'])

print("\n" + "="*50)
print("FINAL PROCESSED FILES:")
for split_name, path in selected_paths_conservative.items():
    print(f"  {split_name}: {path}")

# Handle Class Imbalance

In [None]:
from sklearn.utils.class_weight import compute_class_weight
import pickle

In [None]:
def analyze_basic_class_imbalance(selected_paths):
    """
    Simple class imbalance analysis with basic class weight computation.

    Args:
        selected_paths (dict): Paths to processed split files

    Returns:
        dict: Basic class distribution and weights
    """
    print("=== BASIC CLASS IMBALANCE ANALYSIS ===")

    # Load training labels
    y_train = pd.read_parquet(selected_paths['y_train'])['condition']
    y_test = pd.read_parquet(selected_paths['y_test'])['condition']

    # Calculate distributions
    train_counts = y_train.value_counts().sort_index()
    test_counts = y_test.value_counts().sort_index()

    healthy_train = train_counts[0]
    unhealthy_train = train_counts[1]

    # Calculate basic metrics
    imbalance_ratio = healthy_train / unhealthy_train
    minority_pct = (unhealthy_train / len(y_train)) * 100

    print(f"📊 Training Set:")
    print(f"   Healthy: {healthy_train:,} ({healthy_train/len(y_train)*100:.1f}%)")
    print(f"   Unhealthy: {unhealthy_train:,} ({unhealthy_train/len(y_train)*100:.1f}%)")
    print(f"   Imbalance ratio: {imbalance_ratio:.2f}:1")

    print(f"\n📊 Test Set:")
    print(f"   Healthy: {test_counts[0]:,} ({test_counts[0]/len(y_test)*100:.1f}%)")
    print(f"   Unhealthy: {test_counts[1]:,} ({test_counts[1]/len(y_test)*100:.1f}%)")

    # Compute class weights (sklearn balanced method)
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(y_train),
        y=y_train
    )
    class_weight_dict = {0: class_weights[0], 1: class_weights[1]}

    print(f"\n  Class Weights (sklearn 'balanced'):")
    print(f"   Healthy (0): {class_weight_dict[0]:.3f}")
    print(f"   Unhealthy (1): {class_weight_dict[1]:.3f}")
    print(f"   Weight ratio: {class_weight_dict[1]/class_weight_dict[0]:.3f}")

    # LightGBM scale_pos_weight (from your existing code)
    scale_pos_weight = healthy_train / unhealthy_train
    print(f"\n🔧 LightGBM scale_pos_weight: {scale_pos_weight:.3f}")

    # Simple recommendation
    if imbalance_ratio > 2.0:
        print(f"   ⚠️  Significant imbalance detected!")
        print(f"   💡 Recommendation: Use class weights")
    else:
        print(f"   ✅ Relatively balanced dataset")
        print(f"   💡 Class weights optional but can help")

    # Create simple visualization
    create_simple_class_plots(train_counts, test_counts, class_weight_dict)

    return {
        'class_weight_dict': class_weight_dict,
        'scale_pos_weight': scale_pos_weight,
        'imbalance_ratio': imbalance_ratio,
        'minority_percentage': minority_pct,
        'train_counts': train_counts,
        'test_counts': test_counts
    }


def create_simple_class_plots(train_counts, test_counts, class_weights):
    """
    Create simple bar plots for class distribution.
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Plot 1: Class distribution
    classes = ['Healthy', 'Unhealthy']
    train_values = [train_counts[0], train_counts[1]]
    test_values = [test_counts[0], test_counts[1]]

    x = range(len(classes))
    width = 0.35

    bars1 = axes[0].bar([i - width/2 for i in x], train_values, width,
                       label='Training', color='#4CAF50', alpha=0.8)
    bars2 = axes[0].bar([i + width/2 for i in x], test_values, width,
                       label='Test', color='#FF9800', alpha=0.8)

    axes[0].set_xlabel('Class')
    axes[0].set_ylabel('Number of Samples')
    axes[0].set_title('Class Distribution')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(classes)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3, axis='y')

    # Add count labels
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            axes[0].text(bar.get_x() + bar.get_width()/2., height + height*0.01,
                        f'{int(height):,}', ha='center', va='bottom', fontweight='bold')

    # Plot 2: Class weights
    weight_values = [class_weights[0], class_weights[1]]
    bars3 = axes[1].bar(classes, weight_values, color=['#2E8B57', '#DC143C'], alpha=0.8)

    axes[1].set_xlabel('Class')
    axes[1].set_ylabel('Class Weight')
    axes[1].set_title('Computed Class Weights')
    axes[1].grid(True, alpha=0.3, axis='y')

    # Add weight labels
    for bar, weight in zip(bars3, weight_values):
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height + height*0.01,
                    f'{weight:.3f}', ha='center', va='bottom', fontweight='bold')

    plt.tight_layout()
    plt.show()


def save_simple_class_weights(class_info, output_dir="data/splits"):
    """
    Save class weights in a simple format for easy loading.

    Args:
        class_info (dict): Class information from analyze_basic_class_imbalance
        output_dir (str): Directory to save weights

    Returns:
        str: Path to saved weights file
    """
    weights_path = os.path.join(output_dir, 'class_weights.pkl')

    # Save just the essential info
    weights_data = {
        'class_weight_dict': class_info['class_weight_dict'],
        'scale_pos_weight': class_info['scale_pos_weight'],
        'imbalance_ratio': class_info['imbalance_ratio']
    }

    with open(weights_path, 'wb') as f:
        pickle.dump(weights_data, f)

    print(f"✅ Class weights saved to: {weights_path}")

    # Also save as simple text for reference
    txt_path = os.path.join(output_dir, 'class_weights.txt')
    with open(txt_path, 'w') as f:
        f.write("Class Weights Summary\n")
        f.write("===================\n\n")
        f.write(f"sklearn class_weight='balanced':\n")
        f.write(f"  Healthy (0): {class_info['class_weight_dict'][0]:.6f}\n")
        f.write(f"  Unhealthy (1): {class_info['class_weight_dict'][1]:.6f}\n\n")
        f.write(f"LightGBM scale_pos_weight: {class_info['scale_pos_weight']:.6f}\n")
        f.write(f"Imbalance ratio: {class_info['imbalance_ratio']:.2f}:1\n")

    return weights_path


def load_simple_class_weights(weights_path="data/splits/class_weights.pkl"):
    """
    Load class weights for model training.

    Args:
        weights_path (str): Path to saved class weights

    Returns:
        dict: Class weights ready for sklearn models
    """
    try:
        with open(weights_path, 'rb') as f:
            weights_data = pickle.load(f)

        print(f"✅ Loaded class weights:")
        print(f"   Healthy (0): {weights_data['class_weight_dict'][0]:.3f}")
        print(f"   Unhealthy (1): {weights_data['class_weight_dict'][1]:.3f}")
        print(f"   LightGBM scale_pos_weight: {weights_data['scale_pos_weight']:.3f}")

        return weights_data['class_weight_dict']

    except FileNotFoundError:
        print(f"❌ Weights file not found: {weights_path}")
        return None

In [None]:
# Analyze class imbalance
class_info = analyze_basic_class_imbalance(selected_paths_conservative)

# Save weights for later use
weights_path = save_simple_class_weights(class_info)

# Hyper Param Tuning

In [None]:
import optuna
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
import lightgbm as lgb
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def optimize_lightgbm_with_workflow(selected_paths, class_weights_path="data/splits/class_weights.pkl",
                                  n_trials=100, cv_folds=5, random_state=42,
                                  output_dir="data/splits"):
    """
    Optimize LightGBM hyperparameters integrated with the full preprocessing workflow.

    Args:
        selected_paths (dict): Paths to processed split files
        class_weights_path (str): Path to saved class weights
        n_trials (int): Number of optimization trials
        cv_folds (int): Number of cross-validation folds
        random_state (int): Random seed
        output_dir (str): Directory to save optimization results

    Returns:
        dict: Best hyperparameters, study results, and evaluation metrics
    """
    print("=== LIGHTGBM HYPERPARAMETER OPTIMIZATION ===")
    print("Integrated with class weights and feature selection")

    # Load training data
    print("\n📊 Loading processed training data...")
    X_train = pd.read_parquet(selected_paths['X_train'])
    y_train = pd.read_parquet(selected_paths['y_train'])['condition']

    print(f"   Training features shape: {X_train.shape}")
    print(f"   Training labels shape: {y_train.shape}")
    print(f"   Class distribution: {y_train.value_counts().to_dict()}")

    # Load class weights
    print("\n⚖️  Loading class imbalance information...")
    try:
        with open(class_weights_path, 'rb') as f:
            weights_data = pickle.load(f)
        scale_pos_weight = weights_data['scale_pos_weight']
        imbalance_ratio = weights_data['imbalance_ratio']
        print(f"   Scale_pos_weight: {scale_pos_weight:.3f}")
        print(f"   Imbalance ratio: {imbalance_ratio:.2f}:1")
    except FileNotFoundError:
        print("   ⚠️  Class weights not found, using default (1.0)")
        scale_pos_weight = 1.0
        imbalance_ratio = 1.0

    # Define optimization objective
    def objective(trial):
        """Optuna objective function for LightGBM optimization."""

        # Sample hyperparameters
        params = {
            'objective': 'binary',
            'metric': 'binary_logloss',
            'boosting_type': 'gbdt',
            'scale_pos_weight': scale_pos_weight,  # Use computed class weight
            'random_state': random_state,
            'verbose': -1,
            'force_col_wise': True,  # Better for many features

            # Core hyperparameters to optimize
            'num_leaves': trial.suggest_int('num_leaves', 10, 300),
            'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
            'feature_fraction': trial.suggest_float('feature_fraction', 0.4, 1.0),
            'bagging_fraction': trial.suggest_float('bagging_fraction', 0.4, 1.0),
            'bagging_freq': trial.suggest_int('bagging_freq', 1, 7),
            'min_child_samples': trial.suggest_int('min_child_samples', 5, 100),
            'reg_alpha': trial.suggest_float('reg_alpha', 1e-8, 10.0, log=True),
            'reg_lambda': trial.suggest_float('reg_lambda', 1e-8, 10.0, log=True),
            'min_data_in_leaf': trial.suggest_int('min_data_in_leaf', 1, 50),
            'max_depth': trial.suggest_int('max_depth', 3, 15),
        }

        # Cross-validation with stratification
        cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
        cv_scores = []

        for fold, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
            X_fold_train, X_fold_val = X_train.iloc[train_idx], X_train.iloc[val_idx]
            y_fold_train, y_fold_val = y_train.iloc[train_idx], y_train.iloc[val_idx]

            # Create LightGBM datasets
            train_data = lgb.Dataset(X_fold_train, label=y_fold_train)
            val_data = lgb.Dataset(X_fold_val, label=y_fold_val, reference=train_data)

            # Train with early stopping
            model = lgb.train(
                params,
                train_data,
                valid_sets=[val_data],
                valid_names=['eval'],
                num_boost_round=1000,
                callbacks=[
                    lgb.early_stopping(stopping_rounds=50, verbose=False),
                    lgb.log_evaluation(0)  # Silent
                ]
            )

            # Predict and evaluate
            val_pred = model.predict(X_fold_val, num_iteration=model.best_iteration)
            auc_score = roc_auc_score(y_fold_val, val_pred)
            cv_scores.append(auc_score)

        return np.mean(cv_scores)

    # Run optimization
    print(f"\n🔧 Starting optimization with {n_trials} trials and {cv_folds}-fold CV...")
    print("This may take several minutes depending on data size...")

    # Create study with reproducible sampler
    sampler = optuna.samplers.TPESampler(seed=random_state)
    study = optuna.create_study(direction='maximize', sampler=sampler)

    # Add progress callback
    def progress_callback(study, trial):
        if trial.number % 10 == 0:
            print(f"   Trial {trial.number:3d}: Current best AUC = {study.best_value:.4f}")

    study.optimize(objective, n_trials=n_trials, callbacks=[progress_callback])

    # Prepare final best parameters
    best_params = study.best_params.copy()
    best_params.update({
        'objective': 'binary',
        'metric': 'binary_logloss',
        'boosting_type': 'gbdt',
        'scale_pos_weight': scale_pos_weight,
        'random_state': random_state,
        'verbose': -1,
        'force_col_wise': True
    })

    # Print results
    print(f"\n{'='*60}")
    print("OPTIMIZATION RESULTS")
    print(f"{'='*60}")
    print(f"🎯 Best CV AUC Score: {study.best_value:.4f}")
    print(f"📊 Number of trials: {len(study.trials)}")
    print(f"⚖️  Used scale_pos_weight: {scale_pos_weight:.3f}")

    print(f"\n🔧 Best Hyperparameters:")
    for param, value in study.best_params.items():
        if isinstance(value, float):
            print(f"   {param}: {value:.6f}")
        else:
            print(f"   {param}: {value}")

    # Create visualizations
    create_optimization_visualizations(study, output_dir)

    # Evaluate final model on test set
    print(f"\n📈 Evaluating final model on test set...")
    test_results = evaluate_final_model(
        best_params, selected_paths, X_train, y_train, output_dir
    )

    # Save optimization results
    optimization_results = {
        'best_params': best_params,
        'best_cv_score': study.best_value,
        'study': study,
        'n_trials': n_trials,
        'cv_folds': cv_folds,
        'scale_pos_weight': scale_pos_weight,
        'imbalance_ratio': imbalance_ratio,
        'test_results': test_results,
        'feature_count': X_train.shape[1]
    }

    # Save results as pickle
    results_path = os.path.join(output_dir, 'lightgbm_optimization_results.pkl')
    with open(results_path, 'wb') as f:
        pickle.dump(optimization_results, f)

    # Save readable summary
    save_optimization_summary(optimization_results, output_dir)

    print(f"\n✅ Optimization complete!")
    print(f"📁 Results saved to: {output_dir}/")

    return optimization_results


def evaluate_final_model(best_params, selected_paths, X_train, y_train, output_dir):
    """
    Train final model with best parameters and evaluate on test set.
    """
    # Load test data
    X_test = pd.read_parquet(selected_paths['X_test'])
    y_test = pd.read_parquet(selected_paths['y_test'])['condition']

    print(f"   Test data shape: {X_test.shape}")

    # Train final model on full training set
    train_data = lgb.Dataset(X_train, label=y_train)

    final_model = lgb.train(
        best_params,
        train_data,
        num_boost_round=1000,
        callbacks=[lgb.log_evaluation(0)]
    )

    # Predict on test set
    test_pred_proba = final_model.predict(X_test)
    test_pred_binary = (test_pred_proba > 0.5).astype(int)

    # Calculate metrics
    test_metrics = {
        'auc': roc_auc_score(y_test, test_pred_proba),
        'accuracy': accuracy_score(y_test, test_pred_binary),
        'precision': precision_score(y_test, test_pred_binary),
        'recall': recall_score(y_test, test_pred_binary),
        'f1': f1_score(y_test, test_pred_binary)
    }

    # Print test results
    print(f"   🎯 Test AUC: {test_metrics['auc']:.4f}")
    print(f"   📊 Test Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"   🎯 Test Precision: {test_metrics['precision']:.4f}")
    print(f"   📊 Test Recall: {test_metrics['recall']:.4f}")
    print(f"   🎯 Test F1-Score: {test_metrics['f1']:.4f}")

    # Save model
    model_path = os.path.join(output_dir, 'best_lightgbm_model.txt')
    final_model.save_model(model_path)
    print(f"   💾 Model saved to: {model_path}")

    test_metrics['model_path'] = model_path
    test_metrics['predictions'] = {
        'y_true': y_test.tolist(),
        'y_pred_proba': test_pred_proba.tolist(),
        'y_pred_binary': test_pred_binary.tolist()
    }

    return test_metrics


def create_optimization_visualizations(study, output_dir):
    """
    Create comprehensive visualizations for the optimization process.
    """
    print(f"   📊 Creating optimization visualizations...")

    fig, axes = plt.subplots(2, 2, figsize=(15, 12))

    # Plot 1: Optimization history
    trials = study.trials
    trial_numbers = [t.number for t in trials]
    trial_values = [t.value for t in trials if t.value is not None]

    axes[0, 0].plot(trial_numbers[:len(trial_values)], trial_values, alpha=0.7)
    axes[0, 0].axhline(y=study.best_value, color='red', linestyle='--',
                      label=f'Best: {study.best_value:.4f}')
    axes[0, 0].set_xlabel('Trial Number')
    axes[0, 0].set_ylabel('CV AUC Score')
    axes[0, 0].set_title('Optimization History')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Plot 2: Parameter importance
    try:
        importance = optuna.importance.get_param_importances(study)
        if importance:
            params = list(importance.keys())
            values = list(importance.values())

            axes[0, 1].barh(params, values, alpha=0.8, color='skyblue')
            axes[0, 1].set_xlabel('Importance')
            axes[0, 1].set_title('Parameter Importance')
            axes[0, 1].grid(True, alpha=0.3)
        else:
            axes[0, 1].text(0.5, 0.5, 'Not enough trials\nfor importance',
                           ha='center', va='center', transform=axes[0, 1].transAxes)
    except:
        axes[0, 1].text(0.5, 0.5, 'Parameter importance\nnot available',
                       ha='center', va='center', transform=axes[0, 1].transAxes)

    # Plot 3: Learning rate vs num_leaves (if both exist)
    if 'learning_rate' in study.best_params and 'num_leaves' in study.best_params:
        lr_values = []
        nl_values = []
        scores = []

        for trial in trials:
            if trial.value is not None and 'learning_rate' in trial.params and 'num_leaves' in trial.params:
                lr_values.append(trial.params['learning_rate'])
                nl_values.append(trial.params['num_leaves'])
                scores.append(trial.value)

        if len(scores) > 0:
            scatter = axes[1, 0].scatter(lr_values, nl_values, c=scores,
                                       cmap='viridis', alpha=0.7)
            axes[1, 0].set_xlabel('Learning Rate')
            axes[1, 0].set_ylabel('Num Leaves')
            axes[1, 0].set_title('Learning Rate vs Num Leaves')
            axes[1, 0].set_xscale('log')
            plt.colorbar(scatter, ax=axes[1, 0], label='CV AUC')

    # Plot 4: Score distribution
    valid_scores = [t.value for t in trials if t.value is not None]
    if len(valid_scores) > 0:
        axes[1, 1].hist(valid_scores, bins=min(20, len(valid_scores)//2),
                       alpha=0.7, color='lightcoral', edgecolor='black')
        axes[1, 1].axvline(study.best_value, color='red', linestyle='--',
                          label=f'Best: {study.best_value:.4f}')
        axes[1, 1].set_xlabel('CV AUC Score')
        axes[1, 1].set_ylabel('Frequency')
        axes[1, 1].set_title('Score Distribution')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'lightgbm_optimization_plots.png'),
                dpi=300, bbox_inches='tight')
    plt.show()


def save_optimization_summary(results, output_dir):
    """
    Save a human-readable summary of optimization results.
    """
    summary_path = os.path.join(output_dir, 'lightgbm_optimization_summary.txt')

    with open(summary_path, 'w') as f:
        f.write("LightGBM Hyperparameter Optimization Summary\n")
        f.write("=" * 50 + "\n\n")

        f.write(f"Optimization Configuration:\n")
        f.write(f"  Number of trials: {results['n_trials']}\n")
        f.write(f"  CV folds: {results['cv_folds']}\n")
        f.write(f"  Features used: {results['feature_count']:,}\n")
        f.write(f"  Scale pos weight: {results['scale_pos_weight']:.6f}\n")
        f.write(f"  Imbalance ratio: {results['imbalance_ratio']:.2f}:1\n\n")

        f.write(f"Best Results:\n")
        f.write(f"  CV AUC Score: {results['best_cv_score']:.6f}\n\n")

        f.write(f"Test Set Performance:\n")
        for metric, value in results['test_results'].items():
            if metric not in ['model_path', 'predictions']:
                f.write(f"  {metric.upper()}: {value:.6f}\n")
        f.write(f"\nModel saved to: {results['test_results']['model_path']}\n\n")

        f.write(f"Best Hyperparameters:\n")
        for param, value in results['best_params'].items():
            if isinstance(value, float):
                f.write(f"  {param}: {value:.6f}\n")
            else:
                f.write(f"  {param}: {value}\n")

    print(f"   📄 Summary saved to: {summary_path}")


def load_optimization_results(results_path="data/splits/lightgbm_optimization_results.pkl"):
    """
    Load previously saved optimization results.

    Args:
        results_path (str): Path to saved results

    Returns:
        dict: Optimization results
    """
    try:
        with open(results_path, 'rb') as f:
            results = pickle.load(f)

        print(f"✅ Loaded optimization results from: {results_path}")
        print(f"   Best CV AUC: {results['best_cv_score']:.4f}")
        print(f"   Test AUC: {results['test_results']['auc']:.4f}")

        return results

    except FileNotFoundError:
        print(f"❌ Results file not found: {results_path}")
        return None

In [None]:
optimization_results = optimize_lightgbm_with_workflow(
    selected_paths_conservative,
    class_weights_path="data/splits/class_weights.pkl",
    n_trials=50,
    cv_folds=5,
    random_state=42,
    output_dir="data/splits"
)

print(f"\n🎉 Optimization completed!")
print(f"📊 Best CV AUC: {optimization_results['best_cv_score']:.4f}")
print(f"🎯 Test AUC: {optimization_results['test_results']['auc']:.4f}")
print(f"📁 All results saved to: data/splits/")

In [None]:
def investigate_suspicious_results(optimization_results, selected_paths_conservative):
    """
    Investigate potential data leakage, batch effects, or other issues causing unrealistic performance.
    """
    print("=== INVESTIGATING SUSPICIOUS RESULTS ===")
    print("Checking for potential data leakage, batch effects, or other issues...")

    # Load the data to investigate
    X_train = pd.read_parquet(selected_paths_conservative['X_train'])
    X_test = pd.read_parquet(selected_paths_conservative['X_test'])
    y_train = pd.read_parquet(selected_paths_conservative['y_train'])['condition']
    y_test = pd.read_parquet(selected_paths_conservative['y_test'])['condition']

    print(f"Training set: {X_train.shape[0]} samples, {X_train.shape[1]} features")
    print(f"Test set: {X_test.shape[0]} samples, {X_test.shape[1]} features")

    # 1. Check for identical samples between train/test
    print("\n🔍 1. CHECKING FOR IDENTICAL SAMPLES BETWEEN TRAIN/TEST")

    # Check if any training samples are identical to test samples
    identical_samples = 0
    sample_correlations = []

    for i, test_sample in enumerate(X_test.values[:10]):  # Check first 10 test samples
        for j, train_sample in enumerate(X_train.values):
            correlation = np.corrcoef(test_sample, train_sample)[0, 1]
            if correlation > 0.99:  # Very high correlation
                identical_samples += 1
                sample_correlations.append(correlation)
                if identical_samples <= 3:  # Show first few
                    print(f"   High correlation ({correlation:.6f}) between test sample {i} and train sample {j}")

    if identical_samples > 0:
        print(f"   🚨 FOUND {identical_samples} highly correlated sample pairs!")
        print(f"   Average correlation: {np.mean(sample_correlations):.6f}")
    else:
        print("   ✅ No identical samples found between train/test")

    # 2. Check feature distributions
    print("\n🔍 2. CHECKING FEATURE DISTRIBUTIONS")

    # Compare feature means between train/test
    train_means = X_train.mean()
    test_means = X_test.mean()
    mean_differences = np.abs(train_means - test_means)

    print(f"   Mean absolute difference in feature means: {mean_differences.mean():.6f}")
    print(f"   Max difference in feature means: {mean_differences.max():.6f}")
    print(f"   Features with >0.1 difference: {(mean_differences > 0.1).sum()}")

    # 3. Check class distributions
    print("\n🔍 3. CHECKING CLASS DISTRIBUTIONS")

    train_class_dist = y_train.value_counts(normalize=True).sort_index()
    test_class_dist = y_test.value_counts(normalize=True).sort_index()

    print(f"   Train class distribution: {train_class_dist.to_dict()}")
    print(f"   Test class distribution: {test_class_dist.to_dict()}")

    class_diff = np.abs(train_class_dist - test_class_dist)
    print(f"   Class distribution difference: {class_diff.to_dict()}")

    # 4. Check for perfect separation features
    print("\n🔍 4. CHECKING FOR PERFECT SEPARATION FEATURES")

    # Find features that perfectly separate classes
    perfect_features = []
    for col in X_train.columns[:100]:  # Check first 100 features
        train_healthy = X_train[y_train == 0][col]
        train_unhealthy = X_train[y_train == 1][col]

        # Check if ranges don't overlap
        healthy_max = train_healthy.max()
        healthy_min = train_healthy.min()
        unhealthy_max = train_unhealthy.max()
        unhealthy_min = train_unhealthy.min()

        # Perfect separation if no overlap
        if (healthy_max < unhealthy_min) or (unhealthy_max < healthy_min):
            perfect_features.append(col)

    if perfect_features:
        print(f"   🚨 FOUND {len(perfect_features)} features with perfect class separation!")
        print(f"   First few: {perfect_features[:5]}")
    else:
        print("   ✅ No features with perfect class separation found")

    # 5. Simple baseline check
    print("\n🔍 5. SIMPLE BASELINE PERFORMANCE CHECK")

    from sklearn.dummy import DummyClassifier
    from sklearn.metrics import roc_auc_score

    # Majority class baseline
    dummy_maj = DummyClassifier(strategy='most_frequent')
    dummy_maj.fit(X_train, y_train)
    dummy_pred = dummy_maj.predict(X_test)
    dummy_acc = (dummy_pred == y_test).mean()

    print(f"   Majority class baseline accuracy: {dummy_acc:.4f}")
    print(f"   Our model accuracy: {optimization_results['test_results']['accuracy']:.4f}")
    print(f"   Improvement over baseline: {optimization_results['test_results']['accuracy'] - dummy_acc:.4f}")

    # 6. Check sample sizes vs feature count
    print("\n🔍 6. SAMPLE SIZE vs FEATURE COUNT")

    n_samples = len(X_train)
    n_features = X_train.shape[1]
    ratio = n_samples / n_features

    print(f"   Training samples: {n_samples}")
    print(f"   Features: {n_features}")
    print(f"   Sample-to-feature ratio: {ratio:.2f}")

    if ratio < 2:
        print("   🚨 WARNING: Very low sample-to-feature ratio! Risk of overfitting!")
    elif ratio < 5:
        print("   ⚠️  Low sample-to-feature ratio. Consider more regularization.")
    else:
        print("   ✅ Reasonable sample-to-feature ratio")

    # 7. Cross-validation consistency check
    print("\n🔍 7. CV vs TEST PERFORMANCE GAP")

    cv_auc = optimization_results['best_cv_score']
    test_auc = optimization_results['test_results']['auc']
    performance_gap = abs(cv_auc - test_auc)

    print(f"   CV AUC: {cv_auc:.6f}")
    print(f"   Test AUC: {test_auc:.6f}")
    print(f"   Performance gap: {performance_gap:.6f}")

    if performance_gap > 0.05:
        print("   🚨 Large performance gap! Possible overfitting or data leakage!")
    elif performance_gap > 0.02:
        print("   ⚠️  Moderate performance gap. Monitor for overfitting.")
    else:
        print("   ✅ Good CV-Test consistency")

    # Summary and recommendations
    print("\n" + "="*60)
    print("INVESTIGATION SUMMARY")
    print("="*60)

    issues_found = []

    if identical_samples > 0:
        issues_found.append("Identical/highly similar samples between train/test")

    if len(perfect_features) > 0:
        issues_found.append("Features with perfect class separation")

    if ratio < 2:
        issues_found.append("Extremely low sample-to-feature ratio")

    if performance_gap > 0.05:
        issues_found.append("Large CV-Test performance gap")

    if not issues_found:
        print("✅ No obvious data leakage issues detected")
        print("   The high performance might be legitimate for this dataset")
        print("   Consider:")
        print("   • Validating on external datasets")
        print("   • Checking biological plausibility of selected features")
        print("   • Running additional cross-validation schemes")
    else:
        print("🚨 POTENTIAL ISSUES DETECTED:")
        for issue in issues_found:
            print(f"   • {issue}")
        print("\n💡 RECOMMENDATIONS:")
        print("   • Re-examine data preprocessing pipeline")
        print("   • Check for sample ID overlaps in original data")
        print("   • Consider more aggressive feature selection")
        print("   • Validate with external datasets")
        print("   • Use nested cross-validation")

    return {
        'identical_samples': identical_samples,
        'perfect_features': perfect_features,
        'sample_feature_ratio': ratio,
        'performance_gap': performance_gap,
        'issues_detected': len(issues_found) > 0
    }


def check_original_sample_ids():
    """
    Check for potential sample ID overlaps in the original datasets.
    """
    print("\n🔍 CHECKING ORIGINAL SAMPLE IDs")

    try:
        # Load original processed data to check sample IDs
        healthy_df = pd.read_parquet('data/healthy_data_preprocessed.pq')
        unhealthy_df = pd.read_parquet('data/unhealthy_data_preprocessed.pq')

        print(f"Healthy samples: {len(healthy_df)}")
        print(f"Unhealthy samples: {len(unhealthy_df)}")

        # Check for overlapping sample IDs
        healthy_ids = set(healthy_df.index)
        unhealthy_ids = set(unhealthy_df.index)

        overlap = healthy_ids & unhealthy_ids

        if overlap:
            print(f"🚨 FOUND {len(overlap)} overlapping sample IDs!")
            print(f"First few overlapping IDs: {list(overlap)[:5]}")
            return True
        else:
            print("✅ No overlapping sample IDs found")
            return False

    except Exception as e:
        print(f"❌ Could not check sample IDs: {e}")
        return None


def simple_sanity_check():
    """
    Perform a simple sanity check with a basic model.
    """
    print("\n🔍 SIMPLE SANITY CHECK WITH BASIC MODEL")

    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import roc_auc_score

    # Load the data
    X_train = pd.read_parquet(selected_paths_conservative['X_train'])
    X_test = pd.read_parquet(selected_paths_conservative['X_test'])
    y_train = pd.read_parquet(selected_paths_conservative['y_train'])['condition']
    y_test = pd.read_parquet(selected_paths_conservative['y_test'])['condition']

    # Simple logistic regression
    lr = LogisticRegression(random_state=42, max_iter=1000)
    lr.fit(X_train, y_train)

    lr_pred_proba = lr.predict_proba(X_test)[:, 1]
    lr_auc = roc_auc_score(y_test, lr_pred_proba)

    print(f"   Simple Logistic Regression AUC: {lr_auc:.4f}")
    print(f"   LightGBM AUC: {optimization_results['test_results']['auc']:.4f}")

    if lr_auc > 0.95:
        print("   🚨 Even simple model achieves very high performance!")
        print("   This suggests the problem might be too easy (possible data issues)")
    else:
        print("   ✅ Simple model shows more realistic performance")

In [None]:
# Run comprehensive investigation
investigation_results = investigate_suspicious_results(optimization_results, selected_paths_conservative)

# Check sample ID overlaps
sample_overlap = check_original_sample_ids()

# Simple sanity check
simple_sanity_check()