# Base functions and imports

In [None]:
from typing import Iterator, List, Optional, Union
import pyarrow.parquet as pq
import pyarrow as pa
import pandas as pd
import numpy as np
import tempfile
import mygene
import os
import gc
from IPython.display import display

def load_csv_in_chunks(file_path, chunk_size=8000, **kwargs):
    """
    Loads a CSV file in chunks to avoid memory issues.

    Args:
        file_path (str): The path to the CSV file.
        chunk_size (int): The number of rows per chunk.
        **kwargs: Additional keyword arguments to pass to pd.read_csv()
                  (e.g., sep=',', header=0, index_col=None, usecols=None).

    Returns:
        A pandas TextFileReader object (iterator) that yields DataFrame chunks
        if the file exists, otherwise None.
    """

    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None

    print(f"Preparing to load {file_path} in chunks of size {chunk_size}...")
    try:
        chunk_iterator = pd.read_csv(file_path, chunksize=chunk_size, **kwargs)
        return chunk_iterator
    except Exception as e:
        print(f"Error loading CSV file {file_path}: {e}")
        return None

    # Clean up
    finally:
        if 'chunk_iterator' in locals():
            del chunk_iterator
        gc.collect()


def load_parquet_in_chunks(file_path, chunk_size=8000):
    """
    Loads a parquet file in true chunks to avoid memory issues.

    Args:
        file_path (str): The path to the parquet file.
        chunk_size (int): The number of rows per chunk.

    Yields:
        pandas.DataFrame: DataFrame chunks of the specified size.
    """
    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None

    print(f"Preparing to load {file_path} in chunks of size {chunk_size}...")

    try:
        parquet_file = pq.ParquetFile(file_path)
        total_rows = parquet_file.metadata.num_rows
        print(f"Total rows in file: {total_rows}")

        for batch in parquet_file.iter_batches(batch_size=chunk_size):
            chunk_df = batch.to_pandas()
            yield chunk_df

            del chunk_df

    except Exception as e:
        print(f"Error loading parquet file {file_path}: {e}")
        return None

    # Clean up
    finally:
        if 'parquet_file' in locals():
            del parquet_file
        if 'batch' in locals():
            del batch
        if 'chunk_df' in locals():
            del chunk_df
        gc.collect()


def save_data_as_parquet(chunk_iterator, output_parquet_path, preserve_index=True):
    """
    Process DataFrame chunks and save as parquet using PyArrow for maximum efficiency.
    Handles index preservation properly across chunks.

    Args:
        chunk_iterator: Iterator yielding DataFrame chunks
        output_parquet_path (str): Path to save the parquet file
        preserve_index (bool): Whether to preserve the DataFrame index
    """

    writer = None
    total_rows = 0
    all_index_values = set()

    try:
        for chunk_idx, chunk in enumerate(chunk_iterator):
            if chunk.empty:
                continue

            if preserve_index and chunk.index.name is not None:
                chunk_index_values = set(chunk.index)
                duplicates = all_index_values.intersection(chunk_index_values)
                if duplicates:
                    print(f"Warning: Found {len(duplicates)} duplicate index values in chunk {chunk_idx}")
                    print(f"First few duplicates: {list(duplicates)[:5]}")
                all_index_values.update(chunk_index_values)

            table = pa.Table.from_pandas(chunk, preserve_index=preserve_index)

            if writer is None:
                writer = pq.ParquetWriter(output_parquet_path, table.schema)
                if preserve_index and chunk.index.name:
                    print(f"Preserving index: '{chunk.index.name}' (dtype: {chunk.index.dtype})")

            writer.write_table(table)
            total_rows += chunk.shape[0]

            del table

    finally:
        if writer:
            writer.close()

        # Clean up
        if 'writer' in locals():
            del writer
        if 'chunk_iterator' in locals():
            del chunk_iterator
        gc.collect()

    if total_rows > 0:
        print(f"Successfully saved {total_rows} rows to parquet")
        if preserve_index:
            print(f"Total unique index values: {len(all_index_values)}")
    else:
        print("No data to save!")


def drop_dataframe_chunks(
    chunk_generator: Iterator[pd.DataFrame],
    drop_rows: Optional[Union[List[Union[int, str]], List[int], List[str]]] = None,
    drop_columns: Optional[Union[List[Union[int, str]], List[int], List[str]]] = None
) -> Iterator[pd.DataFrame]:
    """
    Generator function that drops specified rows and/or columns from DataFrame chunks
    in a memory-efficient way. Designed to be chained with other generators.

    Args:
        chunk_generator: An iterable that yields pandas DataFrame chunks
        drop_rows: List of row indices/names to drop, or None to keep all rows
        drop_columns: List of column indices/names to drop, or None to keep all columns

    Yields:
        pd.DataFrame: Chunks with specified rows/columns dropped
    """
    for chunk in chunk_generator:
        if chunk.empty:
            yield chunk
            continue

        processed_chunk = chunk.copy()

        if drop_columns is not None:
            columns_to_drop = [col for col in drop_columns if col in processed_chunk.columns]
            if columns_to_drop:
                processed_chunk = processed_chunk.drop(columns=columns_to_drop)

        if drop_rows is not None:
            if drop_rows and not all(isinstance(row, int) for row in drop_rows):
                chunk_index_str = processed_chunk.index.astype(str)
                drop_rows_str = [str(row) for row in drop_rows]

                rows_to_drop = processed_chunk.index[chunk_index_str.isin(drop_rows_str)]
                if len(rows_to_drop) > 0:
                    processed_chunk = processed_chunk.drop(index=rows_to_drop)
            else:
                valid_indices = [idx for idx in drop_rows if 0 <= idx < len(processed_chunk)]
                if valid_indices:
                    processed_chunk = processed_chunk.drop(processed_chunk.index[valid_indices])

        if not processed_chunk.empty:
            yield processed_chunk


def keep_dataframe_chunks(
    chunk_generator: Iterator[pd.DataFrame],
    keep_rows: Optional[Union[List[Union[int, str]], List[int], List[str]]] = None,
    keep_columns: Optional[Union[List[Union[int, str]], List[int], List[str]]] = None
) -> Iterator[pd.DataFrame]:
    """
    Generator function that keeps only specified rows and/or columns from DataFrame chunks
    in a memory-efficient way. Designed to be chained with other generators.

    Args:
        chunk_generator: An iterable that yields pandas DataFrame chunks
        keep_rows: List of row indices/names to keep, or None to keep all rows
        keep_columns: List of column indices/names to keep, or None to keep all columns

    Yields:
        pd.DataFrame: Chunks with only specified rows/columns kept
    """
    for chunk in chunk_generator:
        if chunk.empty:
            yield chunk
            continue

        processed_chunk = chunk.copy()

        if keep_columns is not None:
            columns_to_keep = [col for col in keep_columns if col in processed_chunk.columns]
            if columns_to_keep:
                processed_chunk = processed_chunk[columns_to_keep]
            else:
                processed_chunk = processed_chunk.iloc[:0]

        if keep_rows is not None:
            if keep_rows and not all(isinstance(row, int) for row in keep_rows):
                chunk_index_str = processed_chunk.index.astype(str)
                keep_rows_str = [str(row) for row in keep_rows]

                rows_to_keep = processed_chunk.index[chunk_index_str.isin(keep_rows_str)]
                if len(rows_to_keep) > 0:
                    processed_chunk = processed_chunk.loc[rows_to_keep]
                else:
                    processed_chunk = processed_chunk.iloc[:0]
            else:
                valid_indices = [idx for idx in keep_rows if 0 <= idx < len(processed_chunk)]
                if valid_indices:
                    processed_chunk = processed_chunk.iloc[valid_indices]
                else:
                    processed_chunk = processed_chunk.iloc[:0]

        if not processed_chunk.empty:
            yield processed_chunk


def transpose_dataframe_chunks(
    chunk_generator,
    skip_rows=None,
    skip_columns=None,
    output_batch_size=1000,
    temp_dir="/tmp",
    dtype='uint32'
):
    """
    Generator function that collects DataFrame chunks, transposes the complete dataset,
    and yields the transposed result in batches. Designed to be chained with other generators.

    Args:
        chunk_generator: An iterable that yields pandas DataFrame chunks
        skip_rows: List of row indices/names to skip, or None
        skip_columns: List of column indices/names to skip, or None
        output_batch_size: Number of rows to yield in each output batch
        temp_dir: Directory for temporary memory-mapped file
        dtype: Data type for memory-mapped array (default: 'uint32')

    Yields:
        pd.DataFrame: Batches of the transposed DataFrame
    """

    chunk_list = list(chunk_generator)
    if not chunk_list:
        print("No chunks received from generator. Exiting.")
        return

    first_chunk = chunk_list[0].copy()

    if skip_columns is not None:
        columns_to_keep = [col for col in first_chunk.columns if col not in skip_columns]
        first_chunk = first_chunk[columns_to_keep]
        print(f"Skipping columns: {skip_columns}")

    if skip_rows is not None:
        if isinstance(skip_rows[0], str):
            rows_to_keep = [idx for idx in first_chunk.index if idx not in skip_rows]
        else:
            rows_to_keep = [idx for i, idx in enumerate(first_chunk.index) if i not in skip_rows]
        first_chunk = first_chunk.loc[rows_to_keep]
        print(f"Skipping rows: {skip_rows}")

    original_rows = sum(len(chunk) for chunk in chunk_list)
    original_cols = len(first_chunk.columns)

    filtered_cols = first_chunk.columns.tolist()
    n_output_rows = len(filtered_cols)

    all_row_indices = []
    for chunk in chunk_list:
        chunk_filtered = chunk.copy()

        if skip_columns is not None:
            chunk_filtered = chunk_filtered[filtered_cols]
        if skip_rows is not None:
            if isinstance(skip_rows[0], str):
                chunk_rows_to_keep = [idx for idx in chunk_filtered.index if idx not in skip_rows]
            else:
                chunk_rows_to_keep = [idx for i, idx in enumerate(chunk_filtered.index) if i not in skip_rows]
            chunk_filtered = chunk_filtered.loc[chunk_rows_to_keep]

        all_row_indices.extend(chunk_filtered.index.tolist())

    n_output_cols = len(all_row_indices)

    print(f"Dataset dimensions:")
    print(f"    Original: {original_rows} rows x {original_cols} columns")
    print(f"    After filtering: {len(all_row_indices)} rows x {len(filtered_cols)} columns")
    print(f"    After transpose: {n_output_rows} rows x {n_output_cols} columns")

    # Create memory-mapped file
    try:
        temp_file = tempfile.NamedTemporaryFile(
            dir=temp_dir,
            delete=False,
            suffix='.mmap'
        )
        temp_filename = temp_file.name
        temp_file.close()

        # Create memory-mapped array in transposed orientation: [samples, features]
        mmap_array = np.memmap(
            temp_filename,
            dtype=dtype,
            mode='w+',
            shape=(n_output_rows, n_output_cols)
        )
        print(f"Memory map created successfully")

        print(f"Filling memory map with data...")
        current_feature_idx = 0

        for chunk_idx, chunk in enumerate(chunk_list):
            chunk_filtered = chunk.copy()
            if skip_columns is not None:
                chunk_filtered = chunk_filtered[filtered_cols]
            if skip_rows is not None:
                if isinstance(skip_rows[0], str):
                    chunk_rows_to_keep = [idx for idx in chunk_filtered.index if idx not in skip_rows]
                else:
                    chunk_rows_to_keep = [idx for i, idx in enumerate(chunk_filtered.index) if i not in skip_rows]
                chunk_filtered = chunk_filtered.loc[chunk_rows_to_keep]

            chunk_data = chunk_filtered.values.T.astype(dtype)
            chunk_feature_count = chunk_filtered.shape[0]

            mmap_array[:, current_feature_idx:current_feature_idx + chunk_feature_count] = chunk_data
            current_feature_idx += chunk_feature_count

            mmap_array.flush()
            del chunk_data, chunk_filtered

            print(f"    Processed chunk {chunk_idx + 1}/{len(chunk_list)}")

        del chunk_list
        gc.collect()

        print(f"Memory map filled successfully")

        print(f"Yielding transposed batches (size: {output_batch_size})...")
        total_batches = (n_output_rows + output_batch_size - 1) // output_batch_size

        for batch_idx in range(total_batches):
            start_row = batch_idx * output_batch_size
            end_row = min(start_row + output_batch_size, n_output_rows)

            batch_data = mmap_array[start_row:end_row, :].copy()
            batch_sample_ids = filtered_cols[start_row:end_row]

            batch_df = pd.DataFrame(
                data=batch_data,
                index=batch_sample_ids,
                columns=all_row_indices
            )

            batch_df = batch_df.reset_index().rename(columns={'index': 'sample_id'})

            yield batch_df

            del batch_data, batch_df
            gc.collect()

            print(f"    Yielded batch {batch_idx + 1}/{total_batches}")

        print(f"Transposition completed successfully!")

    except Exception as e:
        print(f"Error during memory-mapped transposition: {e}")
        raise

    finally:
        print(f"Cleaning up temporary files...")
        try:
            if 'mmap_array' in locals():
                del mmap_array
            if 'temp_filename' in locals() and os.path.exists(temp_filename):
                os.unlink(temp_filename)
                print(f"Temporary file removed: {temp_filename}")
        except Exception as cleanup_error:
            print(f"Warning: Could not clean up temp file: {cleanup_error}")


def convert_gene_ids_to_symbols(dataset_chunk_iterator, mygene_client, gene_column_prefix="ENSG"):
    """
    Convert gene IDs to gene symbols using MyGene API in a memory-efficient way.
    Only calls MyGene API once for all gene columns from the first chunk.

    Args:
        dataset_chunk_iterator: Iterator yielding DataFrame chunks
        mygene_client: MyGene client instance
        gene_column_prefix (str): Prefix to identify gene columns

    Returns:
        pd.DataFrame: Complete dataset with gene symbols as column names
    """
    print("Converting gene IDs to symbols (chunked processing)...")

    # Get first chunk to determine gene columns and create mapping
    try:
        first_chunk = next(dataset_chunk_iterator)
        print(f"Processing first chunk: {first_chunk.shape}")
    except StopIteration:
        print("❌ Error: Dataset is empty")
        return pd.DataFrame()

    # Identify gene columns from first chunk
    gene_columns = [col for col in first_chunk.columns
                   if isinstance(col, str) and col.startswith(gene_column_prefix)]

    if not gene_columns:
        print("No gene columns found - returning original data")
        # Concatenate all chunks and return
        all_chunks = [first_chunk]
        for chunk in dataset_chunk_iterator:
            all_chunks.append(chunk)
        return pd.concat(all_chunks, axis=0, ignore_index=False)

    print(f"Found {len(gene_columns)} gene columns")

    # Single MyGene API call for all gene IDs
    gene_id_to_symbol = {}
    symbols_found = 0

    try:
        print("Making single MyGene API call...")
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            results = mygene_client.querymany(
                gene_columns,
                scopes='ensembl.gene',
                fields='symbol',
                species='human',
                verbose=False,
                silent=True
            )

        # Process API results
        for result in results:
            gene_id = result['query']
            if 'symbol' in result and result['symbol']:
                gene_id_to_symbol[gene_id] = result['symbol']
                symbols_found += 1
            else:
                gene_id_to_symbol[gene_id] = gene_id

        print(f"✅ Successfully converted {symbols_found}/{len(gene_columns)} genes to symbols")

    except Exception as e:
        print(f"⚠️ MyGene API error: {e}. Using original gene IDs")
        gene_id_to_symbol = {gene_id: gene_id for gene_id in gene_columns}

    # Create final column mapping (preserve non-gene columns like 'condition')
    final_column_mapping = {}
    for col in first_chunk.columns:
        if col == 'condition':
            final_column_mapping[col] = col
        elif col in gene_id_to_symbol:
            final_column_mapping[col] = gene_id_to_symbol[col]
        else:
            final_column_mapping[col] = col

    print("Applying gene symbol mapping to all chunks...")

    # Process first chunk
    renamed_first_chunk = first_chunk.rename(columns=final_column_mapping)
    processed_chunks = [renamed_first_chunk]
    chunk_count = 1

    # Process remaining chunks with same mapping
    for chunk in dataset_chunk_iterator:
        renamed_chunk = chunk.rename(columns=final_column_mapping)
        processed_chunks.append(renamed_chunk)
        chunk_count += 1

        if chunk_count % 3 == 0:
            print(f"  ✓ Processed {chunk_count} chunks...")

    print(f"Concatenating {chunk_count} processed chunks...")
    final_dataset = pd.concat(processed_chunks, axis=0, ignore_index=False)

    print(f"✅ Final dataset shape: {final_dataset.shape}")
    print(f"   Symbols converted: {symbols_found}/{len(gene_columns)}")

    return final_dataset


def add_condition_labels_to_chunks(chunk_iterator, condition_label, dataset_name):
    """
    Add condition labels to dataset chunks.

    Args:
        chunk_iterator: Iterator yielding DataFrame chunks
        condition_label (int): Binary label (0 for healthy, 1 for unhealthy)
        dataset_name (str): Name for logging purposes

    Returns:
        list: List of labeled DataFrame chunks
    """
    print(f"Adding label '{condition_label}' to {dataset_name} dataset...")

    labeled_chunks = []
    for chunk in chunk_iterator:
        chunk['condition'] = condition_label
        labeled_chunks.append(chunk)

    print(f"✅ Completed {len(labeled_chunks)} {dataset_name} chunks")
    return labeled_chunks


def merge_labeled_datasets(healthy_chunks, unhealthy_chunks):
    """
    Merge healthy and unhealthy dataset chunks into one DataFrame.

    Args:
        healthy_chunks (list): List of healthy DataFrame chunks
        unhealthy_chunks (list): List of unhealthy DataFrame chunks

    Returns:
        pd.DataFrame: Merged dataset
    """
    print("Merging datasets...")

    all_chunks = healthy_chunks + unhealthy_chunks
    merged_dataset = pd.concat(all_chunks, axis=0, ignore_index=False)

    print(f"✅ Merged dataset: {len(merged_dataset)} samples, {len(merged_dataset.columns)} features")
    print(f"   Healthy (0): {(merged_dataset['condition'] == 0).sum()}")
    print(f"   Unhealthy (1): {(merged_dataset['condition'] == 1).sum()}")

    return merged_dataset


def clean_duplicate_nans(chunk_iterator):
    """
    Processes an iterator of DataFrame chunks to drop duplicates and NaNs.
    Yields cleaned DataFrame chunks one at a time.
    """

    chunks_processed = 0
    for i, chunk_df in enumerate(chunk_iterator):
        original_rows = len(chunk_df)

        chunk_df = chunk_df.dropna()

        if len(chunk_df) < original_rows:
            print(f"    Chunk {i+1}: Dropped {original_rows - len(chunk_df)} rows with null values.")

        if chunk_df.empty:
            print(f"    Chunk {i+1}: Empty after dropping NaNs, skipping...", flush=True)
            continue

        rows_before_dedup = len(chunk_df)
        chunk_df = chunk_df.drop_duplicates()

        if rows_before_dedup > len(chunk_df):
            print(f"    Chunk {i+1}: Dropped {rows_before_dedup - len(chunk_df)} duplicate rows.")

        if not chunk_df.empty:
            chunks_processed += 1
            print(f"    Chunk {i+1}: Yielded chunk with shape {chunk_df.shape}")
            yield chunk_df
        else:
            print(f"    Chunk {i+1}: Empty after deduplication, skipping...")

    print(f"Finished processing. {chunks_processed} non-empty chunks processed.")


def rename_index(chunk_iterator, index_name):
    """
    Rename the index of DataFrame chunks.
    """
    for chunk in chunk_iterator:
        chunk.index.name = index_name
        yield chunk


def filter_rows(chunk_iterator):
    """
    Generator function that filters out rows where the sum of all numeric values equals 0.
    Excludes the sample_id column from sum calculation.
    """

    total_rows_processed = 0
    total_rows_removed = 0
    chunk_count = 0

    for chunk_df in chunk_iterator:
        if chunk_df.empty:
            continue

        chunk_count += 1
        original_rows = len(chunk_df)

        numeric_columns = [col for col in chunk_df.columns if col != 'sample_id']

        if len(numeric_columns) == 0:
            print(f"    Warning: Chunk {chunk_count} has no numeric columns, keeping all rows")
            yield chunk_df
            continue

        try:
            row_sums = chunk_df[numeric_columns].sum(axis=1)
            filtered_chunk = chunk_df[row_sums != 0]
        except Exception as e:
            print(f"    Error in chunk {chunk_count}: {e}")
            print(f"    Column types: {chunk_df.dtypes}")
            raise

        rows_removed = original_rows - len(filtered_chunk)
        total_rows_processed += original_rows
        total_rows_removed += rows_removed

        if rows_removed > 0:
            print(f"    Chunk {chunk_count}: Removed {rows_removed}/{original_rows} rows with zero sum")

        yield filtered_chunk

    print(f"Filtering complete: {total_rows_removed}/{total_rows_processed} rows removed")


def get_healthy_whole_blood_samples(metadata_path: str) -> list[str]:
    """
    Get healthy whole blood samples from the metadata.

    Args:
        metadata_path (str): Path to the metadata CSV file
        gtex_blood_ids (list[str]): List of GTEx whole blood sample IDs

    Returns:
        list[str]: List of healthy whole blood sample IDs
    """
    metadata_df = pd.read_csv(metadata_path, index_col=0)

    # Get all samples with 'Whole Blood' SMTSD
    whole_blood_samples = metadata_df[metadata_df['SMTSD'] == 'Whole Blood']
    # Filter on RNA integrity (SMRIN) to remove low quality samples
    healthy_samples = whole_blood_samples["SMRIN"] >= 7.0

    # Return SAMPIDs of healthy whole blood samples
    healthy_whole_blood_samples = healthy_samples.index.tolist()
    return healthy_whole_blood_samples

def align_gene_columns_generator(healthy_chunk_generator, unhealthy_chunk_generator,
                                gene_column_prefix="ENSG"):
    """
    Generator that aligns gene columns between healthy and unhealthy datasets.
    Handles datasets of different sizes by yielding None for exhausted iterators.

    Args:
        healthy_chunk_generator: Generator yielding healthy dataset chunks
        unhealthy_chunk_generator: Generator yielding unhealthy dataset chunks
        gene_column_prefix (str): Prefix to identify gene columns (default: "ENSG")

    Yields:
        tuple: (aligned_healthy_chunk, aligned_unhealthy_chunk) for each chunk pair
               None is yielded for exhausted datasets
    """

    print("Starting gene column alignment...")

    healthy_iter = iter(healthy_chunk_generator)
    unhealthy_iter = iter(unhealthy_chunk_generator)

    try:
        first_healthy_chunk = next(healthy_iter)
        first_unhealthy_chunk = next(unhealthy_iter)
    except StopIteration:
        print("ERROR: One or both datasets are empty!")
        return

    print("\nProcessing gene columns and stripping version suffixes...")
    def extract_gene_info(chunk, dataset_name):
        """Extract gene columns and create mapping from original_column to base_id"""
        print(f"Processing {dataset_name} dataset...")

        original_to_base = {}
        base_gene_ids = set()
        total_gene_columns = 0
        duplicate_count = 0

        for column in chunk.columns:
            if isinstance(column, str) and column.startswith(gene_column_prefix):
                total_gene_columns += 1
                base_gene_id = column.split('.')[0]

                # Handle duplicates (keep first occurrence)
                if base_gene_id not in base_gene_ids:
                    original_to_base[column] = base_gene_id
                    base_gene_ids.add(base_gene_id)
                else:
                    duplicate_count += 1
                    print(f"Warning: Duplicate base gene {base_gene_id} found, skipping {column}")

        print(f"Total gene columns found: {total_gene_columns}")
        print(f"Unique base gene IDs: {len(base_gene_ids)}")
        if duplicate_count > 0:
            print(f"Duplicates removed: {duplicate_count}")

        return original_to_base, base_gene_ids

    healthy_rename_mapping, healthy_base_genes = extract_gene_info(first_healthy_chunk, "HEALTHY")
    unhealthy_rename_mapping, unhealthy_base_genes = extract_gene_info(first_unhealthy_chunk, "UNHEALTHY")

    common_base_genes = healthy_base_genes & unhealthy_base_genes

    if not common_base_genes:
        print("ERROR: No common genes found between datasets!")
        return

    print(f"Common genes: {len(common_base_genes)}")
    print(f"Genes exclusive to healthy dataset: {len(healthy_base_genes - common_base_genes)}")
    print(f"Genes exclusive to unhealthy dataset: {len(unhealthy_base_genes - common_base_genes)}")

    renamed_first_healthy = first_healthy_chunk.rename(columns=healthy_rename_mapping)
    renamed_first_unhealthy = first_unhealthy_chunk.rename(columns=unhealthy_rename_mapping)

    healthy_non_gene_cols = [col for col in renamed_first_healthy.columns
                           if not (isinstance(col, str) and col.startswith(gene_column_prefix))]
    unhealthy_non_gene_cols = [col for col in renamed_first_unhealthy.columns
                             if not (isinstance(col, str) and col.startswith(gene_column_prefix))]

    # Validate that non-gene columns match
    if set(healthy_non_gene_cols) != set(unhealthy_non_gene_cols):
        print("Warning: Non-gene columns differ between datasets!")
        print(f"Healthy only: {set(healthy_non_gene_cols) - set(unhealthy_non_gene_cols)}")
        print(f"Unhealthy only: {set(unhealthy_non_gene_cols) - set(healthy_non_gene_cols)}")

    non_gene_cols = healthy_non_gene_cols
    common_gene_base_ids = sorted(common_base_genes)
    final_columns = non_gene_cols + common_gene_base_ids

    print("\nProcessing and yielding aligned chunks...")
    def process_chunk(chunk, rename_mapping):
        """Process a single chunk: rename and select columns"""
        if chunk is None:
            return None

        renamed_chunk = chunk.rename(columns=rename_mapping)
        aligned_chunk = renamed_chunk[final_columns].copy()

        return aligned_chunk

    aligned_first_healthy = process_chunk(first_healthy_chunk, healthy_rename_mapping)
    aligned_first_unhealthy = process_chunk(first_unhealthy_chunk, unhealthy_rename_mapping)

    yield aligned_first_healthy, aligned_first_unhealthy

    chunk_count = 1
    healthy_exhausted = False
    unhealthy_exhausted = False

    while not (healthy_exhausted and unhealthy_exhausted):
        chunk_count += 1

        healthy_chunk = None
        if not healthy_exhausted:
            try:
                healthy_chunk = next(healthy_iter)
            except StopIteration:
                healthy_exhausted = True
                print(f"Healthy dataset exhausted after {chunk_count-1} chunks")

        unhealthy_chunk = None
        if not unhealthy_exhausted:
            try:
                unhealthy_chunk = next(unhealthy_iter)
            except StopIteration:
                unhealthy_exhausted = True
                print(f"Unhealthy dataset exhausted after {chunk_count-1} chunks")

        if healthy_chunk is None and unhealthy_chunk is None:
            break

        aligned_healthy = process_chunk(healthy_chunk, healthy_rename_mapping)
        aligned_unhealthy = process_chunk(unhealthy_chunk, unhealthy_rename_mapping)

        yield aligned_healthy, aligned_unhealthy

    print(f"\nAlignment complete!")

    # Clean up
    if 'healthy_iter' in locals():
        del healthy_iter
    if 'unhealthy_iter' in locals():
        del unhealthy_iter
    if 'first_healthy_chunk' in locals():
        del first_healthy_chunk
    if 'first_unhealthy_chunk' in locals():
        del first_unhealthy_chunk
    gc.collect()


def merge_datasets_generator(aligned_pair_iterator):
    """
    Generator that merges healthy and unhealthy dataset chunks into single merged chunks.

    Args:
        aligned_chunk_pairs_generator: Generator yielding (healthy_chunk, unhealthy_chunk) tuples
        chunk_size (int, optional): Target size for merged chunks. If None, merges each pair directly.

    Yields:
        pandas.DataFrame: Merged chunks containing both healthy and unhealthy data
    """

    for healthy_chunk, unhealthy_chunk in aligned_pair_iterator:
        chunks_to_merge = []

        if healthy_chunk is not None:
            chunks_to_merge.append(healthy_chunk)

        if unhealthy_chunk is not None:
            chunks_to_merge.append(unhealthy_chunk)

        if chunks_to_merge:
            # Concatenate available chunks
            merged_chunk = pd.concat(chunks_to_merge)
            yield merged_chunk
            del merged_chunk

    # Clean up
    if 'aligned_pair_iterator' in locals():
        del aligned_pair_iterator
    if 'healthy_chunk' in locals():
        del healthy_chunk
    if 'unhealthy_chunk' in locals():
        del unhealthy_chunk
    gc.collect()


def set_index_column(chunk_generator, column_name, drop=True):
    """
    Generator function that sets a specified column as the index for each DataFrame chunk.

    Args:
        chunk_generator: An iterable that yields pandas DataFrame chunks
        column_name: Name of the column to set as index
        drop: Whether to drop the column after setting it as index (default: True)

    Yields:
        pd.DataFrame: Chunks with the specified column set as index
    """
    for chunk in chunk_generator:
        if chunk.empty:
            yield chunk
            continue

        if column_name not in chunk.columns:
            yield chunk
            continue

        chunk_with_index = chunk.set_index(column_name, drop=drop)

        yield chunk_with_index


def prepare_metadata_generator(aligned_pair_iterator, output_metadata_path="data/merged_metadata.pq"):
    """
    Generator that creates metadata from aligned chunk pairs and saves to parquet.

    Args:
        aligned_chunk_pairs_generator: Generator yielding (healthy_chunk, unhealthy_chunk) tuples
        output_metadata_path: Path to save metadata parquet file

    Yields:
        tuple: (healthy_chunk, unhealthy_chunk) - passes through the original chunks
    """

    all_metadata = []

    for healthy_chunk, unhealthy_chunk in aligned_pair_iterator:
        chunk_metadata = []

        if healthy_chunk is not None:
            healthy_sample_ids = [str(sid) for sid in healthy_chunk.index.tolist()]
            # healthy_metadata = [{'sample_id': sid, 'condition': 0, 'batch': 'GTEx'} for sid in healthy_sample_ids]
            healthy_metadata = [{'sample_id': sid, 'condition': 0} for sid in healthy_sample_ids]
            chunk_metadata.extend(healthy_metadata)

        if unhealthy_chunk is not None:
            unhealthy_sample_ids = [str(sid) for sid in unhealthy_chunk.index.tolist()]
            # unhealthy_metadata = [{'sample_id': sid, 'condition': 1, 'batch': 'GDC'} for sid in unhealthy_sample_ids]
            unhealthy_metadata = [{'sample_id': sid, 'condition': 1} for sid in unhealthy_sample_ids]
            chunk_metadata.extend(unhealthy_metadata)

        all_metadata.extend(chunk_metadata)

        # Yield the chunks unchanged (pass-through)
        yield healthy_chunk, unhealthy_chunk
        del healthy_chunk, unhealthy_chunk

    if all_metadata:
        metadata_df = pd.DataFrame(all_metadata)
        metadata_df.to_parquet(output_metadata_path, index=False)
        print(f"Metadata saved to {output_metadata_path} with {len(metadata_df)} records")
    else:
        print("No metadata to save")

    # Clean up
    if 'all_metadata' in locals() or 'all_metadata' in globals():
        del all_metadata
    if 'metadata_df' in locals() or 'metadata_df' in globals():
        del metadata_df
    gc.collect()


def check_metadata(healthy_generator, unhealthy_generator, merged_metadata_generator):
    """
    Comprehensive validation of metadata against healthy and unhealthy datasets.

    Args:
        healthy_generator: Generator yielding healthy dataset chunks
        unhealthy_generator: Generator yielding unhealthy dataset chunks
        merged_metadata_generator: Generator yielding metadata chunks

    Raises:
        Exception: If validation fails
    """

    # Collect all sample IDs from datasets
    healthy_sample_ids = set()
    unhealthy_sample_ids = set()

    healthy_dataset_line_count = 0
    unhealthy_dataset_line_count = 0

    print("Collecting sample IDs from healthy dataset...")
    for chunk in healthy_generator:
        chunk_sample_ids = set(str(sid) for sid in chunk.index.tolist())
        healthy_sample_ids.update(chunk_sample_ids)
        healthy_dataset_line_count += len(chunk)

    print("Collecting sample IDs from unhealthy dataset...")
    for chunk in unhealthy_generator:
        chunk_sample_ids = set(str(sid) for sid in chunk.index.tolist())
        unhealthy_sample_ids.update(chunk_sample_ids)
        unhealthy_dataset_line_count += len(chunk)

    # Collect metadata information
    metadata_sample_ids = set()
    healthy_metadata_ids = set()
    unhealthy_metadata_ids = set()
    # gtex_batch_ids = set()
    # gdc_batch_ids = set()
    condition_0_ids = set()
    condition_1_ids = set()

    metadata_line_count = 0

    print("Validating metadata...")
    for chunk in merged_metadata_generator:
        metadata_line_count += len(chunk)

        # Process each row in metadata chunk
        for _, row in chunk.iterrows():
            sample_id = str(row['sample_id'])
            condition = row['condition']
            # batch = row['batch']

            metadata_sample_ids.add(sample_id)

            # Categorize by condition
            if condition == 0:
                condition_0_ids.add(sample_id)
                healthy_metadata_ids.add(sample_id)
            elif condition == 1:
                condition_1_ids.add(sample_id)
                unhealthy_metadata_ids.add(sample_id)
            else:
                raise Exception(f"Invalid condition value: {condition} for sample {sample_id}")

            # # Categorize by batch
            # if batch == 'GTEx':
            #     gtex_batch_ids.add(sample_id)
            # elif batch == 'GDC':
            #     gdc_batch_ids.add(sample_id)
            # else:
            #     raise Exception(f"Invalid batch value: {batch} for sample {sample_id}")

    # Validation checks
    print("\n" + "="*60)
    print("METADATA VALIDATION RESULTS")
    print("="*60)

    # 1. Check row counts
    total_dataset_count = healthy_dataset_line_count + unhealthy_dataset_line_count
    print(f"Dataset row counts:")
    print(f"  Healthy: {healthy_dataset_line_count}")
    print(f"  Unhealthy: {unhealthy_dataset_line_count}")
    print(f"  Total: {total_dataset_count}")
    print(f"  Metadata: {metadata_line_count}")

    if total_dataset_count != metadata_line_count:
        raise Exception(f"Row count mismatch! Dataset total: {total_dataset_count}, Metadata: {metadata_line_count}")

    # 2. Check sample ID completeness
    print(f"\nSample ID validation:")
    print(f"  Unique healthy samples in dataset: {len(healthy_sample_ids)}")
    print(f"  Unique unhealthy samples in dataset: {len(unhealthy_sample_ids)}")
    print(f"  Total unique samples in datasets: {len(healthy_sample_ids) + len(unhealthy_sample_ids)}")
    print(f"  Unique samples in metadata: {len(metadata_sample_ids)}")

    # Check if all dataset sample IDs are in metadata
    missing_healthy_in_metadata = healthy_sample_ids - metadata_sample_ids
    missing_unhealthy_in_metadata = unhealthy_sample_ids - metadata_sample_ids
    extra_metadata_samples = metadata_sample_ids - (healthy_sample_ids | unhealthy_sample_ids)

    if missing_healthy_in_metadata:
        raise Exception(f"Missing healthy samples in metadata: {list(missing_healthy_in_metadata)[:10]}...")

    if missing_unhealthy_in_metadata:
        raise Exception(f"Missing unhealthy samples in metadata: {list(missing_unhealthy_in_metadata)[:10]}...")

    if extra_metadata_samples:
        raise Exception(f"Extra samples in metadata not in datasets: {list(extra_metadata_samples)[:10]}...")

    # 3. Check condition mapping
    print(f"\nCondition validation:")
    print(f"  Condition 0 (healthy) samples: {len(condition_0_ids)}")
    print(f"  Condition 1 (unhealthy) samples: {len(condition_1_ids)}")

    # Verify condition 0 matches healthy dataset
    healthy_condition_mismatch = healthy_sample_ids - condition_0_ids
    condition_0_mismatch = condition_0_ids - healthy_sample_ids

    if healthy_condition_mismatch:
        raise Exception(f"Healthy samples with wrong condition in metadata: {list(healthy_condition_mismatch)[:10]}...")

    if condition_0_mismatch:
        raise Exception(f"Condition 0 samples not in healthy dataset: {list(condition_0_mismatch)[:10]}...")

    # Verify condition 1 matches unhealthy dataset
    unhealthy_condition_mismatch = unhealthy_sample_ids - condition_1_ids
    condition_1_mismatch = condition_1_ids - unhealthy_sample_ids

    if unhealthy_condition_mismatch:
        raise Exception(f"Unhealthy samples with wrong condition in metadata: {list(unhealthy_condition_mismatch)[:10]}...")

    if condition_1_mismatch:
        raise Exception(f"Condition 1 samples not in unhealthy dataset: {list(condition_1_mismatch)[:10]}...")

    # # 4. Check batch mapping
    # print(f"\nBatch validation:")
    # print(f"  GTEx batch samples: {len(gtex_batch_ids)}")
    # print(f"  GDC batch samples: {len(gdc_batch_ids)}")

    # Check if condition and batch alignment is correct
    # gtex_condition_mismatch = gtex_batch_ids - condition_0_ids
    # gdc_condition_mismatch = gdc_batch_ids - condition_1_ids
    # condition_0_batch_mismatch = condition_0_ids - gtex_batch_ids
    # condition_1_batch_mismatch = condition_1_ids - gdc_batch_ids

    # if gtex_condition_mismatch:
    #     raise Exception(f"GTEx samples with wrong condition (should be 0): {list(gtex_condition_mismatch)[:10]}...")

    # if gdc_condition_mismatch:
    #     raise Exception(f"GDC samples with wrong condition (should be 1): {list(gdc_condition_mismatch)[:10]}...")

    # if condition_0_batch_mismatch:
    #     raise Exception(f"Condition 0 samples with wrong batch (should be GTEx): {list(condition_0_batch_mismatch)[:10]}...")

    # if condition_1_batch_mismatch:
    #     raise Exception(f"Condition 1 samples with wrong batch (should be GDC): {list(condition_1_batch_mismatch)[:10]}...")

    # 5. Check for duplicates in metadata
    if len(metadata_sample_ids) != metadata_line_count:
        print(f"\nWARNING: Duplicate sample IDs detected in metadata!")
        print(f"  Unique sample IDs: {len(metadata_sample_ids)}")
        print(f"  Total metadata rows: {metadata_line_count}")
        print(f"  Duplicates: {metadata_line_count - len(metadata_sample_ids)}")

    # Summary
    print(f"\n" + "="*60)
    print("✅ VALIDATION PASSED!")
    print("="*60)
    print(f"✅ Row counts match: {metadata_line_count} total")
    print(f"✅ All sample IDs present and correct")
    print(f"✅ Conditions correctly mapped: {len(condition_0_ids)} healthy, {len(condition_1_ids)} unhealthy")
    # print(f"✅ Batches correctly mapped: {len(gtex_batch_ids)} GTEx, {len(gdc_batch_ids)} GDC")
    print(f"✅ No sample ID mismatches detected")


def get_healthy_whole_blood_samples(metadata_path: str) -> list[str]:
    """
    Get healthy whole blood samples from the metadata.

    Args:
        metadata_path (str): Path to the metadata CSV file
        gtex_blood_ids (list[str]): List of GTEx whole blood sample IDs

    Returns:
        list[str]: List of healthy whole blood sample IDs
    """
    metadata_df = pd.read_csv(metadata_path, index_col=0)

    # Get all samples with 'Whole Blood' SMTSD
    whole_blood_samples = metadata_df[metadata_df['SMTSD'] == 'Whole Blood']
    # Filter on RNA integrity (SMRIN) to remove low quality samples
    healthy_samples = whole_blood_samples["SMRIN"] >= 7.0

    # Return SAMPIDs of healthy whole blood samples
    healthy_whole_blood_samples = healthy_samples.index.tolist()
    return healthy_whole_blood_samples

# GTEx and GDC

## Unhealthy Preprocessing

In [None]:
unhealthy_dataset_file = 'data/unhealthy_data.pq'
unhealthy_output_file = 'data/unhealthy_data_preprocessed.pq'
chunk_size = 8000

chunk_iterator = load_parquet_in_chunks(unhealthy_dataset_file)
print(f"Starting preprocessing for Unhealthy Dataset: {unhealthy_dataset_file}...")

chunk_iterator = filter_rows(chunk_iterator)
chunk_iterator = transpose_dataframe_chunks(chunk_iterator, output_batch_size=chunk_size, dtype='float32')
chunk_iterator = rename_index(chunk_iterator, 'sample_id')
chunk_iterator = set_index_column(chunk_iterator, 'sample_id')
chunk_iterator = clean_duplicate_nans(chunk_iterator)

save_data_as_parquet(chunk_iterator, unhealthy_output_file)

# Clean up
if 'chunk_iterator' in locals():
    del chunk_iterator

gc.collect()

In [None]:
unhealthy_iterator = load_parquet_in_chunks('data/unhealthy_data_preprocessed.pq')

if unhealthy_iterator:
    first_chunk = next(unhealthy_iterator)

    # Basic info
    print(f"First chunk shape: {first_chunk.shape}")
    print(f"Columns: {list(first_chunk.columns)}")
    print(f"Data types:\n{first_chunk.dtypes}")
    print(f"Memory usage: {first_chunk.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print("\nFirst 5 rows:")
    display(first_chunk.head())

## Healthy Preprocessing

In [None]:
dataset_file = 'data/gene_tpm_2022-06-06_v10_breast_mammary_tissue.gct'
output_file = 'data/healthy_data.pq'

gct_chunk_iterator = load_csv_in_chunks(
    file_path=dataset_file,
    sep='\t',
    skiprows=2,
    header=0,
    index_col=0,
)

if gct_chunk_iterator:
    print("Saving GCT dataset as pickle...")

    save_data_as_parquet(
        chunk_iterator=gct_chunk_iterator,
        output_parquet_path=output_file
    )
else:
    print(f"Failed to load {dataset_file} using load_csv_in_chunks.")

# Clean up
if 'gct_chunk_iterator' in locals() or 'gct_chunk_iterator' in globals():
    del gct_chunk_iterator
if 'dataset_file' in locals() or 'dataset_file' in globals():
    del dataset_file
if 'output_file' in locals() or 'output_file' in globals():
    del output_file

gc.collect()

In [None]:
unhealthy_iterator = load_parquet_in_chunks('data/healthy_data.pq')

if unhealthy_iterator:
    first_chunk = next(unhealthy_iterator)

    # Basic info
    print(f"First chunk shape: {first_chunk.shape}")
    print(f"Columns: {list(first_chunk.columns)}")
    print(f"Data types:\n{first_chunk.dtypes}")
    print(f"Memory usage: {first_chunk.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print("\nFirst 5 rows:")
    display(first_chunk.head())

In [None]:
healthy_data_input = 'data/healthy_data.pq'
healthy_data_output = 'data/healthy_data_preprocessed.pq'
exclude_cols = ['Description']

chunk_iterator = load_parquet_in_chunks(
    file_path=healthy_data_input
)

chunk_iterator = drop_dataframe_chunks(chunk_generator=chunk_iterator, drop_columns=['Description'])
chunk_iterator = filter_rows(chunk_iterator)
chunk_iterator = transpose_dataframe_chunks(chunk_generator=chunk_iterator)
chunk_iterator = set_index_column(chunk_iterator, 'sample_id')
chunk_iterator = clean_duplicate_nans(chunk_iterator)

save_data_as_parquet(
    chunk_iterator=chunk_iterator,
    output_parquet_path=healthy_data_output
)

# Clean up
if 'chunk_iterator' in locals() or 'chunk_iterator' in globals():
    del chunk_iterator
if 'healthy_data_input' in locals() or 'healthy_data_input' in globals():
    del healthy_data_input
if 'healthy_data_output' in locals() or 'healthy_data_output' in globals():
    del healthy_data_output
if 'exclude_cols' in locals() or 'exclude_cols' in globals():
    del exclude_cols
if 'chunk_size' in locals() or 'chunk_size' in globals():
    del chunk_size
gc.collect()

In [None]:
unhealthy_iterator = load_parquet_in_chunks('data/healthy_data_preprocessed.pq')

if unhealthy_iterator:
    first_chunk = next(unhealthy_iterator)

    # Basic info
    print(f"First chunk shape: {first_chunk.shape}")
    print(f"Columns: {list(first_chunk.columns)}")
    print(f"Data types:\n{first_chunk.dtypes}")
    print(f"Memory usage: {first_chunk.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print("\nFirst 5 rows:")
    display(first_chunk.head())

## Merge datasets

In [None]:
healthy_data_path = 'data/healthy_data_preprocessed.pq'
unhealthy_data_path = 'data/unhealthy_data_preprocessed.pq'


healthy_iterator = load_parquet_in_chunks(healthy_data_path)
unhealthy_iterator = load_parquet_in_chunks(unhealthy_data_path)

aligned_pair_iterator = align_gene_columns_generator(healthy_iterator, unhealthy_iterator)
aligned_pair_iterator = prepare_metadata_generator(aligned_pair_iterator, output_metadata_path="data/merged_metadata.pq")
merged_chunk_iterator = merge_datasets_generator(aligned_pair_iterator)

save_data_as_parquet(chunk_iterator=merged_chunk_iterator, output_parquet_path='data/merged_dataset.pq')

# Clean up
if 'healthy_iterator' in locals() or 'healthy_iterator' in globals():
    del healthy_iterator
if 'unhealthy_iterator' in locals() or 'unhealthy_iterator' in globals():
    del unhealthy_iterator
if 'aligned_pair_iterator' in locals() or 'aligned_pair_iterator' in globals():
    del aligned_pair_iterator
if 'merged_chunk_iterator' in locals() or 'merged_chunk_iterator' in globals():
    del merged_chunk_iterator

gc.collect()

In [None]:
merged_iterator = load_parquet_in_chunks('data/merged_dataset.pq')

if merged_iterator:
    first_chunk = next(merged_iterator)

    # Basic info
    print(f"First chunk shape: {first_chunk.shape}")
    print(f"Columns: {list(first_chunk.columns)}")
    print(f"Data types:\n{first_chunk.dtypes}")
    print(f"Memory usage: {first_chunk.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print("\nFirst 5 rows:")
    display(first_chunk.head())

# Clean up
if 'merged_iterator' in locals() or 'merged_iterator' in globals():
    del merged_iterator
if 'first_chunk' in locals() or 'first_chunk' in globals():
    del first_chunk

gc.collect()

In [None]:
merged_iterator = load_parquet_in_chunks('data/merged_metadata.pq')

if merged_iterator:
    first_chunk = next(merged_iterator)

    # Basic info
    print(f"First chunk shape: {first_chunk.shape}")
    print(f"Columns: {list(first_chunk.columns)}")
    print(f"Data types:\n{first_chunk.dtypes}")
    print(f"Memory usage: {first_chunk.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print("\nFirst 5 rows:")
    display(first_chunk.head())

# Clean up
if 'merged_iterator' in locals() or 'merged_iterator' in globals():
    del merged_iterator
if 'first_chunk' in locals() or 'first_chunk' in globals():
    del first_chunk

gc.collect()

In [None]:
healthy_iterator = load_parquet_in_chunks('data/healthy_data_preprocessed.pq')
unhealthy_iterator = load_parquet_in_chunks('data/unhealthy_data_preprocessed.pq')
metadata_iterator = load_parquet_in_chunks('data/merged_metadata.pq')

check_metadata(healthy_generator=healthy_iterator, unhealthy_generator=unhealthy_iterator, merged_metadata_generator=metadata_iterator)

## Simple PCA plot

In [None]:
from matplotlib import pyplot as plt
import sklearn
import seaborn as sns
import pandas as pd
import numpy as np

# Load data
df = pd.read_parquet("data/merged_dataset.pq")
metadata = pd.read_parquet("data/merged_metadata.pq")

# Set sample_id as index for metadata to match with df
metadata = metadata.set_index('sample_id')

# Align metadata with df samples
y = metadata.loc[df.index, "condition"]

print(f"Dataset shape: {df.shape}")
print(f"Samples: {df.shape[0]}, Genes: {df.shape[1]}")
print(f"Condition distribution: {y.value_counts()}")

# Calculate mean differences between conditions
gene_columns = df.columns
mean_healthy = df[y == 0][gene_columns].mean()  # Condition 0 = healthy
mean_unhealthy = df[y == 1][gene_columns].mean()  # Condition 1 = unhealthy
mean_diff = (mean_unhealthy - mean_healthy).abs()

print("\nHead of Mean Differences (for top 5 genes):")
print(mean_diff.head())

# Select top k genes by mean difference
k_genes = 50_000
top_k_genes = mean_diff.nlargest(k_genes).index
x_selected = df[top_k_genes]

print(f"\nOriginal number of genes: {df.shape[1]}")
print(f"Number of genes after selection (top {k_genes} by mean difference): {x_selected.shape[1]}")

# Scale the data
scaler = sklearn.preprocessing.StandardScaler()
x_scaled = scaler.fit_transform(x_selected)

print("Shape of x_scaled:", x_scaled.shape)

# Perform PCA
pca = sklearn.decomposition.PCA(n_components=2)
pca_result = pca.fit_transform(x_scaled)

print("Shape of principal components:", pca_result.shape)

# Create PCA DataFrame
pca_df = pd.DataFrame(
    pca_result,
    columns=["PC1", "PC2"],
    index=x_selected.index
)
pca_df["condition"] = y

print(f"Shape of PCA DataFrame: {pca_df.shape}")
print("\nExplained Variance Ratio:")
print(f"PC1: {pca.explained_variance_ratio_[0]:.4f}")
print(f"PC2: {pca.explained_variance_ratio_[1]:.4f}")
print(f"Total Explained Variance (PC1 + PC2): {pca.explained_variance_ratio_.sum():.4f}")

# Generate PCA plot
print("Generating PCA plot...")
plt.figure(figsize=(12, 8))

# Create condition labels for better visualization
condition_labels = {0: 'Healthy', 1: 'Unhealthy'}
pca_df['condition_label'] = pca_df['condition'].map(condition_labels)

sns.scatterplot(
    data=pca_df,
    x="PC1",
    y="PC2",
    hue="condition_label",
    palette=['#2E8B57', '#DC143C'],  # Green for healthy, red for unhealthy
    alpha=0.7,
    s=50
)

plt.title(f'PCA of Gene Expression Data (Top {k_genes} Most Discriminative Genes)')
plt.xlabel(f'Principal Component 1 ({pca.explained_variance_ratio_[0]*100:.2f}% Variance Explained)')
plt.ylabel(f'Principal Component 2 ({pca.explained_variance_ratio_[1]*100:.2f}% Variance Explained)')
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend(title='Condition')
plt.tight_layout()
plt.show()

# Print some summary statistics
print(f"\nSummary Statistics:")
print(f"Healthy samples (condition 0): {(y == 0).sum()}")
print(f"Unhealthy samples (condition 1): {(y == 1).sum()}")
print(f"Total samples: {len(y)}")

## Log Transformation

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os
import pyarrow.parquet as pq # Import for your generator function
import gc # For garbage collection

In [None]:
MERGED_DATA_PATH = "data/merged_dataset.pq"

print("Starting log transformation and sampling for plots...")

raw_sample_data = None
transformed_sample_data = None

for i, chunk in enumerate(load_parquet_in_chunks(MERGED_DATA_PATH, chunk_size=8000)):
    print(f"Processing chunk {i+1}...")
    if raw_sample_data is None:
        raw_sample_data = chunk.values.flatten()
        if len(raw_sample_data) > 1_000_000:
             raw_sample_data = np.random.choice(raw_sample_data, size=1_000_000, replace=False)
        print(f"Captured {len(raw_sample_data)} raw expression values for plotting.")

    transformed_chunk = np.log2(chunk + 1)

    if transformed_sample_data is None:
        transformed_sample_data = transformed_chunk.values.flatten()
        if len(transformed_sample_data) > 1_000_000:
             transformed_sample_data = np.random.choice(transformed_sample_data, size=1_000_000, replace=False)
        print(f"Captured {len(transformed_sample_data)} transformed expression values for plotting.")

    if raw_sample_data is not None and transformed_sample_data is not None:
        break

print("\nLog transformation and sampling complete. Generating plots...")

plt.style.use('seaborn-v0_8-darkgrid')

plt.figure(figsize=(10, 6))
sns.histplot(raw_sample_data, bins=50, kde=True, color='skyblue', edgecolor='black', stat='density', log_scale=True)
plt.title('Distribution of Raw TPM Values (Sampled Data)')
plt.xlabel('Raw TPM Value (log scale)')
plt.ylabel('Density')
plt.grid(True, which="both", ls="--", c='0.7')
plt.show()

plt.figure(figsize=(10, 6))
sns.histplot(transformed_sample_data, bins=50, kde=True, color='lightcoral', edgecolor='black', stat='density')
plt.title('Distribution of Log2(TPM + 1) Values (Sampled Data)')
plt.xlabel('Log2(TPM + 1) Value')
plt.ylabel('Density')
plt.grid(True, which="both", ls="--", c='0.7')
plt.show()

print("\nPlots generated. Review them to confirm the transformation's effect.")

## Batch correction

In [None]:
import pandas as pd
import numpy as np
from combat.pycombat import pycombat
import pyarrow.parquet as pq
import gc
from inmoose.pycombat import pycombat_norm
from inmoose.cohort_qc.cohort_metric import CohortMetric
from inmoose.cohort_qc.qc_report import QCReport

In [None]:
data_log_transformed = "data/merged_dataset_log_transformed.pq"
merged_metadata = "data/merged_metadata.pq"
output_path = "data/merged_dataset_batch_corrected.pq"

df = pd.read_parquet(data_log_transformed)

metadata = pd.read_parquet(merged_metadata)
metadata = metadata.set_index('sample_id')
batches_series = metadata['batch']

common_samples = df.index.intersection(batches_series.index)
if len(common_samples) == 0:
    print("Error: No common samples across data and metadata. Check indices.")
    exit()

df_aligned = df.loc[common_samples]
batches_aligned = batches_series.loc[common_samples]
metadata_aligned = metadata.loc[common_samples]

print(f"Aligned data for ComBat. Data shape: {df_aligned.shape}, Batches series length: {len(batches_aligned)}")
print(f"Condition distribution:\n{metadata_aligned['condition'].value_counts()}")
print("Warning: Running ComBat without covariates due to perfect confounding")

print("Transposing data for ComBat (Genes x Samples)...")
data_for_combat = df_aligned.T
print(f"Data for ComBat shape: {data_for_combat.shape}")

# Apply batch correction without covariates
corrected_data_df_t = pycombat_norm(
    data_for_combat,
    batches_aligned
)
print("Batch correction without covariates complete.")

corrected_data_df = corrected_data_df_t.T
print(f"Corrected data transposed back to Samples x Genes. Shape: {corrected_data_df.shape}")

corrected_data_df.to_parquet(output_path, index=True)
print("Batch-corrected dataset saved!")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np

def create_batch_correction_plots(original_data_path, corrected_data_path, metadata_path):
    """
    Create comprehensive plots to validate batch correction.
    """
    print("Loading data for plotting...")

    # Load data
    original_df = pd.read_parquet(original_data_path)
    corrected_df = pd.read_parquet(corrected_data_path)
    metadata_df = pd.read_parquet(metadata_path).set_index('sample_id')

    # Align data
    common_samples = original_df.index.intersection(corrected_df.index).intersection(metadata_df.index)
    original_aligned = original_df.loc[common_samples]
    corrected_aligned = corrected_df.loc[common_samples]
    metadata_aligned = metadata_df.loc[common_samples]

    print(f"Plot data shape: {original_aligned.shape}")

    # Select top variable genes for PCA (to speed up computation)
    print("Selecting top 5000 most variable genes...")
    gene_vars = original_aligned.var(axis=0)
    top_genes = gene_vars.nlargest(5000).index

    original_subset = original_aligned[top_genes]
    corrected_subset = corrected_aligned[top_genes]

    # Standardize and perform PCA
    print("Performing PCA...")
    scaler = StandardScaler()

    # Original data PCA
    original_scaled = scaler.fit_transform(original_subset)
    pca_orig = PCA(n_components=2)
    original_pca = pca_orig.fit_transform(original_scaled)

    # Corrected data PCA
    corrected_scaled = scaler.fit_transform(corrected_subset)
    pca_corr = PCA(n_components=2)
    corrected_pca = pca_corr.fit_transform(corrected_scaled)

    # Create plotting dataframes
    plot_df_orig = pd.DataFrame({
        'PC1': original_pca[:, 0],
        'PC2': original_pca[:, 1],
        'batch': metadata_aligned['batch'],
        'condition': metadata_aligned['condition'].map({0: 'Healthy', 1: 'Unhealthy'}),
        'dataset': 'Original'
    })

    plot_df_corr = pd.DataFrame({
        'PC1': corrected_pca[:, 0],
        'PC2': corrected_pca[:, 1],
        'batch': metadata_aligned['batch'],
        'condition': metadata_aligned['condition'].map({0: 'Healthy', 1: 'Unhealthy'}),
        'dataset': 'Batch Corrected'
    })

    # Create plots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Batch Correction Validation', fontsize=16, fontweight='bold')

    # Plot 1: Original data colored by batch
    sns.scatterplot(data=plot_df_orig, x='PC1', y='PC2', hue='batch',
                   alpha=0.7, s=50, ax=axes[0,0])
    axes[0,0].set_title(f'Original Data - Colored by Batch\nPC1: {pca_orig.explained_variance_ratio_[0]:.2%}, PC2: {pca_orig.explained_variance_ratio_[1]:.2%}')
    axes[0,0].grid(True, alpha=0.3)

    # Plot 2: Corrected data colored by batch
    sns.scatterplot(data=plot_df_corr, x='PC1', y='PC2', hue='batch',
                   alpha=0.7, s=50, ax=axes[0,1])
    axes[0,1].set_title(f'Batch Corrected Data - Colored by Batch\nPC1: {pca_corr.explained_variance_ratio_[0]:.2%}, PC2: {pca_corr.explained_variance_ratio_[1]:.2%}')
    axes[0,1].grid(True, alpha=0.3)

    # Plot 3: Original data colored by condition
    sns.scatterplot(data=plot_df_orig, x='PC1', y='PC2', hue='condition',
                   palette=['#2E8B57', '#DC143C'], alpha=0.7, s=50, ax=axes[1,0])
    axes[1,0].set_title('Original Data - Colored by Condition')
    axes[1,0].grid(True, alpha=0.3)

    # Plot 4: Corrected data colored by condition
    sns.scatterplot(data=plot_df_corr, x='PC1', y='PC2', hue='condition',
                   palette=['#2E8B57', '#DC143C'], alpha=0.7, s=50, ax=axes[1,1])
    axes[1,1].set_title('Batch Corrected Data - Colored by Condition')
    axes[1,1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Summary statistics
    print("\n" + "="*50)
    print("VISUAL VALIDATION SUMMARY")
    print("="*50)
    print("Expected results after good batch correction:")
    print("• Top row: Batch separation should be REDUCED")
    print("• Bottom row: Condition separation should be PRESERVED")
    print("\nIf batch effects were successfully removed:")
    print("• GTEx and GDC samples should mix better in corrected data")
    print("• Healthy vs Unhealthy separation should remain clear")

    return plot_df_orig, plot_df_corr

In [None]:
# print("\n1. QUANTITATIVE METRICS (inmoose)")
# print("-" * 40)
# qc_report, cohort_qc = validate_batch_correction_with_inmoose(
#     original_data_path="data/merged_dataset_log_transformed.pq",
#     corrected_data_path="data/merged_dataset_batch_corrected.pq",
#     metadata_path="data/merged_metadata.pq"
# )

print("\n2. VISUAL VALIDATION (Custom Plots)")
print("-" * 40)
orig_plot_df, corr_plot_df = create_batch_correction_plots(
    original_data_path="data/merged_dataset_log_transformed.pq",
    corrected_data_path="data/merged_dataset_batch_corrected.pq",
    metadata_path="data/merged_metadata.pq"
)

## Feature Selection

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os
import gc

In [None]:
def perform_feature_selection(data_path, metadata_path, output_path,
                            variance_threshold=0.1, top_k_genes=5000,
                            create_plots=True):
    """
    Perform variance-based feature selection on batch-corrected data.

    Args:
        data_path (str): Path to batch-corrected dataset
        metadata_path (str): Path to metadata
        output_path (str): Path to save selected features
        variance_threshold (float): Minimum variance threshold (removes low-variance genes)
        top_k_genes (int): Number of top variable genes to select
        create_plots (bool): Whether to generate before/after plots

    Returns:
        tuple: (original_df, selected_df, selected_genes, metadata)
    """
    print("Loading batch-corrected data for feature selection...")

    # Load data
    df = pd.read_parquet(data_path)
    metadata = pd.read_parquet(metadata_path).set_index('sample_id')

    # Align data
    common_samples = df.index.intersection(metadata.index)
    df_aligned = df.loc[common_samples]
    metadata_aligned = metadata.loc[common_samples]

    print(f"Original dataset shape: {df_aligned.shape}")
    print(f"Total genes before selection: {df_aligned.shape[1]}")

    # Step 1: Remove low-variance genes
    print(f"\nStep 1: Removing genes with variance < {variance_threshold}")
    gene_variances = df_aligned.var(axis=0)
    high_var_genes = gene_variances[gene_variances >= variance_threshold].index

    df_high_var = df_aligned[high_var_genes]
    print(f"Genes after variance filtering: {len(high_var_genes)}")
    print(f"Genes removed (low variance): {df_aligned.shape[1] - len(high_var_genes)}")

    # Step 2: Select top k most variable genes
    print(f"\nStep 2: Selecting top {top_k_genes} most variable genes")
    top_var_genes = gene_variances.nlargest(min(top_k_genes, len(high_var_genes))).index

    df_selected = df_aligned[top_var_genes]
    print(f"Final dataset shape: {df_selected.shape}")
    print(f"Feature reduction: {df_aligned.shape[1]} → {df_selected.shape[1]} ({df_selected.shape[1]/df_aligned.shape[1]*100:.1f}%)")

    # Calculate some statistics
    original_var_stats = {
        'mean': gene_variances.mean(),
        'median': gene_variances.median(),
        'min': gene_variances.min(),
        'max': gene_variances.max()
    }

    selected_var_stats = {
        'mean': gene_variances[top_var_genes].mean(),
        'median': gene_variances[top_var_genes].median(),
        'min': gene_variances[top_var_genes].min(),
        'max': gene_variances[top_var_genes].max()
    }

    print(f"\nVariance Statistics:")
    print(f"Original genes - Mean: {original_var_stats['mean']:.3f}, Median: {original_var_stats['median']:.3f}")
    print(f"Selected genes - Mean: {selected_var_stats['mean']:.3f}, Median: {selected_var_stats['median']:.3f}")

    # Create plots if requested
    if create_plots:
        create_feature_selection_plots(df_aligned, df_selected, gene_variances,
                                     top_var_genes, metadata_aligned, variance_threshold)

    # Save selected features
    df_selected.to_parquet(output_path, index=True)
    print(f"\nFeature-selected dataset saved to: {output_path}")

    return df_aligned, df_selected, top_var_genes, metadata_aligned


def create_feature_selection_plots(original_df, selected_df, gene_variances,
                                 selected_genes, metadata, variance_threshold):
    """
    Create visualization plots for feature selection validation.
    """
    print("\nGenerating feature selection plots...")

    # Set up the plotting style
    plt.style.use('default')
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Feature Selection Analysis', fontsize=16, fontweight='bold')

    # Plot 1: Variance distribution
    ax1 = axes[0, 0]
    sns.histplot(gene_variances, bins=50, kde=True, alpha=0.7, ax=ax1)
    ax1.axvline(variance_threshold, color='red', linestyle='--',
                label=f'Threshold: {variance_threshold}')
    ax1.axvline(gene_variances[selected_genes].min(), color='green', linestyle='--',
                label=f'Selected min: {gene_variances[selected_genes].min():.3f}')
    ax1.set_xlabel('Gene Variance')
    ax1.set_ylabel('Count')
    ax1.set_title('Distribution of Gene Variances')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Cumulative variance explained
    ax2 = axes[0, 1]
    sorted_variances = gene_variances.sort_values(ascending=False)
    cumsum_var = sorted_variances.cumsum()
    total_var = sorted_variances.sum()

    x_range = range(1, len(sorted_variances) + 1)
    ax2.plot(x_range, cumsum_var / total_var * 100, 'b-', alpha=0.7)

    # Mark where selected genes end
    n_selected = len(selected_genes)
    ax2.axvline(n_selected, color='red', linestyle='--',
                label=f'Selected genes: {n_selected}')
    var_explained = cumsum_var.iloc[n_selected-1] / total_var * 100
    ax2.axhline(var_explained, color='red', linestyle=':', alpha=0.7,
                label=f'Variance captured: {var_explained:.1f}%')

    ax2.set_xlabel('Number of Genes (ranked by variance)')
    ax2.set_ylabel('Cumulative Variance Explained (%)')
    ax2.set_title('Cumulative Variance Explained')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(0, min(10000, len(sorted_variances)))  # Zoom in for better view

    # Plot 3 & 4: PCA before and after feature selection
    from sklearn.decomposition import PCA
    from sklearn.preprocessing import StandardScaler

    # Sample data if too large for plotting
    n_samples_plot = min(2000, original_df.shape[0])
    if original_df.shape[0] > n_samples_plot:
        plot_indices = np.random.choice(original_df.index, n_samples_plot, replace=False)
        orig_plot = original_df.loc[plot_indices]
        sel_plot = selected_df.loc[plot_indices]
        meta_plot = metadata.loc[plot_indices]
    else:
        orig_plot = original_df
        sel_plot = selected_df
        meta_plot = metadata

    # Sample genes if too many for PCA
    n_genes_pca = min(5000, orig_plot.shape[1])
    if orig_plot.shape[1] > n_genes_pca:
        pca_genes = gene_variances.nlargest(n_genes_pca).index
        orig_plot_pca = orig_plot[pca_genes]
    else:
        orig_plot_pca = orig_plot

    # PCA on original data
    scaler = StandardScaler()
    orig_scaled = scaler.fit_transform(orig_plot_pca)
    pca_orig = PCA(n_components=2)
    orig_pca_result = pca_orig.fit_transform(orig_scaled)

    # PCA on selected data
    sel_scaled = scaler.fit_transform(sel_plot)
    pca_sel = PCA(n_components=2)
    sel_pca_result = pca_sel.fit_transform(sel_scaled)

    # Plot original data PCA
    ax3 = axes[1, 0]
    condition_labels = meta_plot['condition'].map({0: 'Healthy', 1: 'Unhealthy'})
    scatter3 = ax3.scatter(orig_pca_result[:, 0], orig_pca_result[:, 1],
                          c=meta_plot['condition'], cmap='RdYlBu', alpha=0.6, s=30)
    ax3.set_xlabel(f'PC1 ({pca_orig.explained_variance_ratio_[0]:.2%} variance)')
    ax3.set_ylabel(f'PC2 ({pca_orig.explained_variance_ratio_[1]:.2%} variance)')
    ax3.set_title(f'PCA - Before Selection\n({orig_plot_pca.shape[1]} genes)')
    ax3.grid(True, alpha=0.3)
    plt.colorbar(scatter3, ax=ax3, label='Condition')

    # Plot selected data PCA
    ax4 = axes[1, 1]
    scatter4 = ax4.scatter(sel_pca_result[:, 0], sel_pca_result[:, 1],
                          c=meta_plot['condition'], cmap='RdYlBu', alpha=0.6, s=30)
    ax4.set_xlabel(f'PC1 ({pca_sel.explained_variance_ratio_[0]:.2%} variance)')
    ax4.set_ylabel(f'PC2 ({pca_sel.explained_variance_ratio_[1]:.2%} variance)')
    ax4.set_title(f'PCA - After Selection\n({sel_plot.shape[1]} genes)')
    ax4.grid(True, alpha=0.3)
    plt.colorbar(scatter4, ax=ax4, label='Condition')

    plt.tight_layout()
    plt.show()

    # Summary information
    print(f"\n{'='*60}")
    print("FEATURE SELECTION SUMMARY")
    print(f"{'='*60}")
    print(f"Original genes: {original_df.shape[1]:,}")
    print(f"Selected genes: {selected_df.shape[1]:,}")
    print(f"Reduction: {(1 - selected_df.shape[1]/original_df.shape[1])*100:.1f}%")
    print(f"Variance explained by selected genes: {var_explained:.1f}%")
    print(f"PC1+PC2 variance (before): {(pca_orig.explained_variance_ratio_[0] + pca_orig.explained_variance_ratio_[1])*100:.1f}%")
    print(f"PC1+PC2 variance (after): {(pca_sel.explained_variance_ratio_[0] + pca_sel.explained_variance_ratio_[1])*100:.1f}%")

    return orig_pca_result, sel_pca_result

In [None]:
original_data, selected_data, selected_gene_list, metadata = perform_feature_selection(
    data_path="data/merged_dataset_batch_corrected.pq",
    metadata_path="data/merged_metadata.pq",
    output_path="data/merged_dataset_feature_selected.pq",
    variance_threshold=0.1,
    top_k_genes=5000,
    create_plots=True
)

print(f"\nFeature selection complete!")
print(f"Ready for model training with {selected_data.shape[1]} selected features")

In [None]:
def perform_feature_selection_safe(data_path, metadata_path, output_path,
                                 variance_threshold=0.1, top_k_genes=3000,
                                 create_plots=True, exclude_perfect_separation=True):
    """
    Perform variance-based feature selection with optional perfect separation removal.
    """
    print("Loading batch-corrected data for SAFE feature selection...")

    # Load data
    df = pd.read_parquet(data_path)
    metadata = pd.read_parquet(metadata_path).set_index('sample_id')

    # Align data
    common_samples = df.index.intersection(metadata.index)
    df_aligned = df.loc[common_samples]
    metadata_aligned = metadata.loc[common_samples]
    y = metadata_aligned['condition']

    print(f"Original dataset shape: {df_aligned.shape}")
    print(f"Total genes before selection: {df_aligned.shape[1]}")

    excluded_features = {'perfect_separation': [], 'low_variance': []}

    # Step 0: Identify and exclude perfectly separating features
    if exclude_perfect_separation:
        print(f"\nStep 0: Identifying perfectly separating features...")

        perfect_features = []
        for i, col in enumerate(df_aligned.columns):
            if i % 1000 == 0 and i > 0:
                print(f"  Checked {i:,} features...")

            # Get unique values for each class
            healthy_values = set(df_aligned.loc[y == 0, col].dropna().unique())
            unhealthy_values = set(df_aligned.loc[y == 1, col].dropna().unique())

            # Check if completely non-overlapping
            overlap = len(healthy_values.intersection(unhealthy_values))

            if overlap == 0 and len(healthy_values) > 0 and len(unhealthy_values) > 0:
                perfect_features.append(col)

        print(f"Found {len(perfect_features)} perfectly separating features")
        if len(perfect_features) > 0:
            print(f"Examples: {perfect_features[:5]}")

        # Remove perfect features
        df_cleaned = df_aligned.drop(columns=perfect_features)
        excluded_features['perfect_separation'] = perfect_features

        print(f"After removing perfect features: {df_cleaned.shape[1]} genes remaining")
    else:
        df_cleaned = df_aligned
        print("Skipping perfect separation check...")

    # Step 1: Remove low-variance genes
    print(f"\nStep 1: Removing genes with variance < {variance_threshold}")
    gene_variances = df_cleaned.var(axis=0)

    # Identify low variance genes
    low_var_genes = gene_variances[gene_variances < variance_threshold].index
    high_var_genes = gene_variances[gene_variances >= variance_threshold].index

    df_high_var = df_cleaned[high_var_genes]
    excluded_features['low_variance'] = low_var_genes.tolist()

    print(f"Genes after variance filtering: {len(high_var_genes)}")
    print(f"Genes removed (low variance): {len(low_var_genes)}")

    # Step 2: Select top k most variable genes
    print(f"\nStep 2: Selecting top {top_k_genes} most variable genes")

    # Ensure we don't select more genes than available
    k_actual = min(top_k_genes, len(high_var_genes))
    top_var_genes = gene_variances.nlargest(k_actual).index

    df_selected = df_high_var[top_var_genes]
    print(f"Final dataset shape: {df_selected.shape}")
    print(f"Feature reduction: {df_aligned.shape[1]} → {df_selected.shape[1]} ({df_selected.shape[1]/df_aligned.shape[1]*100:.1f}%)")

    # Final check: ensure no perfect separation remains
    if exclude_perfect_separation:
        print(f"\nStep 3: Final validation - checking for remaining perfect separation...")
        remaining_perfect = []
        for col in df_selected.columns:
            healthy_values = set(df_selected.loc[y == 0, col].dropna().unique())
            unhealthy_values = set(df_selected.loc[y == 1, col].dropna().unique())

            if len(healthy_values.intersection(unhealthy_values)) == 0 and len(healthy_values) > 0 and len(unhealthy_values) > 0:
                remaining_perfect.append(col)

        if remaining_perfect:
            print(f"⚠️  WARNING: {len(remaining_perfect)} perfect features still remain!")
            print(f"Examples: {remaining_perfect[:5]}")
            print("Consider more aggressive filtering...")
        else:
            print("✅ No perfect separation detected in final feature set")

    # Create plots if requested (with error handling)
    if create_plots:
        try:
            create_safe_feature_selection_plots(df_aligned, df_selected, gene_variances,
                                               top_var_genes, metadata_aligned, variance_threshold,
                                               excluded_features)
        except Exception as e:
            print(f"⚠️  Plotting failed: {str(e)}")
            print("Continuing without plots...")

    # Save selected features
    df_selected.to_parquet(output_path, index=True)
    print(f"\nSAFE feature-selected dataset saved to: {output_path}")

    # Print summary
    print(f"\n{'='*60}")
    print("SAFE FEATURE SELECTION SUMMARY")
    print(f"{'='*60}")
    print(f"Original features: {df_aligned.shape[1]:,}")
    print(f"Perfect separation features removed: {len(excluded_features['perfect_separation']):,}")
    print(f"Low variance features removed: {len(excluded_features['low_variance']):,}")
    print(f"Final selected features: {df_selected.shape[1]:,}")
    print(f"Total reduction: {(1 - df_selected.shape[1]/df_aligned.shape[1])*100:.1f}%")
    print(f"Features-to-samples ratio: {df_selected.shape[1]/df_selected.shape[0]:.2f}")

    return df_aligned, df_selected, top_var_genes, metadata_aligned, excluded_features


def create_safe_feature_selection_plots(original_df, selected_df, gene_variances,
                                       selected_genes, metadata, variance_threshold,
                                       excluded_features):
    """
    Create visualization plots for safe feature selection with robust error handling.
    """
    print("\nGenerating safe feature selection plots...")

    try:
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('Safe Feature Selection Analysis', fontsize=16, fontweight='bold')

        # Plot 1: Feature removal breakdown
        ax1 = axes[0, 0]
        categories = ['Original', 'Perfect\nRemoved', 'Low Var\nRemoved', 'Final\nSelected']
        counts = [
            original_df.shape[1],
            original_df.shape[1] - len(excluded_features['perfect_separation']),
            original_df.shape[1] - len(excluded_features['perfect_separation']) - len(excluded_features['low_variance']),
            selected_df.shape[1]
        ]

        bars = ax1.bar(categories, counts, color=['red', 'orange', 'yellow', 'green'], alpha=0.7)
        ax1.set_ylabel('Number of Features')
        ax1.set_title('Feature Selection Pipeline')

        # Add count labels on bars
        for bar, count in zip(bars, counts):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
                    str(count), ha='center', va='bottom', fontweight='bold')

        # Plot 2: Variance distribution with error checking
        ax2 = axes[0, 1]

        # Check if gene_variances is valid
        if len(gene_variances) > 0 and not gene_variances.isna().all():
            # Remove any NaN values for plotting
            valid_variances = gene_variances.dropna()

            if len(valid_variances) > 0:
                sns.histplot(valid_variances, bins=min(50, len(valid_variances)//10),
                           kde=True, alpha=0.7, ax=ax2)

                ax2.axvline(variance_threshold, color='red', linestyle='--',
                           label=f'Threshold: {variance_threshold}')

                # Check if selected_genes exist in gene_variances
                selected_variances = gene_variances[gene_variances.index.isin(selected_genes)]
                if len(selected_variances) > 0:
                    ax2.axvline(selected_variances.min(), color='green', linestyle='--',
                               label=f'Selected min: {selected_variances.min():.3f}')

                ax2.set_xlabel('Gene Variance')
                ax2.set_ylabel('Count')
                ax2.set_title('Gene Variance Distribution')
                ax2.legend()
                ax2.grid(True, alpha=0.3)
            else:
                ax2.text(0.5, 0.5, 'No valid variance data',
                        ha='center', va='center', transform=ax2.transAxes)
                ax2.set_title('Gene Variance Distribution - No Data')
        else:
            ax2.text(0.5, 0.5, 'No variance data available',
                    ha='center', va='center', transform=ax2.transAxes)
            ax2.set_title('Gene Variance Distribution - No Data')

        # Plot 3 & 4: PCA before and after with error checking
        try:
            from sklearn.decomposition import PCA
            from sklearn.preprocessing import StandardScaler

            # Sample data for plotting if too large
            n_samples_plot = min(1000, original_df.shape[0])
            if original_df.shape[0] > n_samples_plot:
                plot_indices = np.random.choice(original_df.index, n_samples_plot, replace=False)
                orig_plot = original_df.loc[plot_indices]
                sel_plot = selected_df.loc[plot_indices]
                meta_plot = metadata.loc[plot_indices]
            else:
                orig_plot = original_df
                sel_plot = selected_df
                meta_plot = metadata

            # Check if we have enough data for PCA
            if orig_plot.shape[1] > 1 and orig_plot.shape[0] > 1:
                # Sample genes for PCA if too many
                n_genes_pca = min(1000, orig_plot.shape[1])  # Reduced from 3000
                if orig_plot.shape[1] > n_genes_pca:
                    # Use the most variable genes that exist in the data
                    available_genes = gene_variances.index.intersection(orig_plot.columns)
                    if len(available_genes) > 0:
                        pca_genes = gene_variances[available_genes].nlargest(min(n_genes_pca, len(available_genes))).index
                        orig_plot_pca = orig_plot[pca_genes]
                    else:
                        orig_plot_pca = orig_plot.iloc[:, :n_genes_pca]
                else:
                    orig_plot_pca = orig_plot

                # Remove any columns with NaN or infinite values
                orig_plot_pca = orig_plot_pca.select_dtypes(include=[np.number])
                orig_plot_pca = orig_plot_pca.dropna(axis=1)

                if orig_plot_pca.shape[1] > 1:
                    # PCA on original data
                    scaler = StandardScaler()
                    orig_scaled = scaler.fit_transform(orig_plot_pca)
                    pca_orig = PCA(n_components=2)
                    orig_pca_result = pca_orig.fit_transform(orig_scaled)

                    # Plot original data PCA
                    ax3 = axes[1, 0]
                    scatter3 = ax3.scatter(orig_pca_result[:, 0], orig_pca_result[:, 1],
                                          c=meta_plot['condition'], cmap='RdYlBu', alpha=0.6, s=30)
                    ax3.set_xlabel(f'PC1 ({pca_orig.explained_variance_ratio_[0]:.2%})')
                    ax3.set_ylabel(f'PC2 ({pca_orig.explained_variance_ratio_[1]:.2%})')
                    ax3.set_title(f'PCA - Before Safe Selection\n({orig_plot_pca.shape[1]} genes)')
                    ax3.grid(True, alpha=0.3)
                else:
                    ax3 = axes[1, 0]
                    ax3.text(0.5, 0.5, 'Insufficient features for PCA',
                            ha='center', va='center', transform=ax3.transAxes)
                    ax3.set_title('PCA - Before Selection (No Data)')
            else:
                ax3 = axes[1, 0]
                ax3.text(0.5, 0.5, 'Insufficient data for PCA',
                        ha='center', va='center', transform=ax3.transAxes)
                ax3.set_title('PCA - Before Selection (No Data)')

            # PCA on selected data
            if sel_plot.shape[1] > 1 and sel_plot.shape[0] > 1:
                # Remove any columns with NaN or infinite values
                sel_plot_clean = sel_plot.select_dtypes(include=[np.number])
                sel_plot_clean = sel_plot_clean.dropna(axis=1)

                if sel_plot_clean.shape[1] > 1:
                    sel_scaled = scaler.fit_transform(sel_plot_clean)
                    pca_sel = PCA(n_components=2)
                    sel_pca_result = pca_sel.fit_transform(sel_scaled)

                    # Plot selected data PCA
                    ax4 = axes[1, 1]
                    scatter4 = ax4.scatter(sel_pca_result[:, 0], sel_pca_result[:, 1],
                                          c=meta_plot['condition'], cmap='RdYlBu', alpha=0.6, s=30)
                    ax4.set_xlabel(f'PC1 ({pca_sel.explained_variance_ratio_[0]:.2%})')
                    ax4.set_ylabel(f'PC2 ({pca_sel.explained_variance_ratio_[1]:.2%})')
                    ax4.set_title(f'PCA - After Safe Selection\n({sel_plot_clean.shape[1]} genes)')
                    ax4.grid(True, alpha=0.3)
                else:
                    ax4 = axes[1, 1]
                    ax4.text(0.5, 0.5, 'Insufficient features for PCA',
                            ha='center', va='center', transform=ax4.transAxes)
                    ax4.set_title('PCA - After Selection (No Data)')
            else:
                ax4 = axes[1, 1]
                ax4.text(0.5, 0.5, 'Insufficient data for PCA',
                        ha='center', va='center', transform=ax4.transAxes)
                ax4.set_title('PCA - After Selection (No Data)')

        except Exception as e:
            print(f"⚠️  PCA plotting failed: {str(e)}")
            # Fill remaining plots with error message
            for i, ax in enumerate([axes[1, 0], axes[1, 1]]):
                ax.text(0.5, 0.5, f'PCA Error:\n{str(e)[:50]}...',
                       ha='center', va='center', transform=ax.transAxes)
                ax.set_title(f'PCA Plot {i+1} - Error')

        plt.tight_layout()
        plt.show()

        print("✅ Safe feature selection plots generated!")

    except Exception as e:
        print(f"❌ Plotting failed with error: {str(e)}")
        print("Continuing without plots...")

        # Print basic statistics instead
        print(f"\n📊 FEATURE SELECTION SUMMARY:")
        print(f"Original features: {original_df.shape[1]:,}")
        print(f"Perfect separation removed: {len(excluded_features.get('perfect_separation', [])):,}")
        print(f"Low variance removed: {len(excluded_features.get('low_variance', [])):,}")
        print(f"Final selected features: {selected_df.shape[1]:,}")

In [None]:
original_data, selected_data, selected_gene_list, metadata, excluded_features = perform_feature_selection_safe(
    data_path="data/merged_dataset_batch_corrected.pq",
    metadata_path="data/merged_metadata.pq",
    output_path="data/merged_dataset_feature_selected_safe.pq",
    variance_threshold=0.1,
    top_k_genes=3000,
    create_plots=True,  # Will fall back gracefully if fails
    exclude_perfect_separation=True
)

print(f"\n🎯 SAFE feature selection complete!")
print(f"Perfect separating features excluded: {len(excluded_features['perfect_separation'])}")
print(f"Ready for training with {selected_data.shape[1]} safe features")

## Train-Test Split

In [None]:
from sklearn.model_selection import train_test_split
import os
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

In [None]:
def create_train_test_split(feature_selected_path, metadata_path, output_dir="data/splits",
                           test_size=0.2, random_state=42, create_validation=False):
    """
    Create train-test splits for the feature-selected dataset.

    Args:
        feature_selected_path (str): Path to feature-selected dataset
        metadata_path (str): Path to metadata
        output_dir (str): Directory to save split datasets
        test_size (float): Proportion of data for test set (default: 0.2)
        random_state (int): Random state for reproducibility
        create_validation (bool): Whether to create a validation set from training data

    Returns:
        dict: Paths to saved split files
    """
    print("Loading feature-selected data for splitting...")

    # Load data
    df = pd.read_parquet(feature_selected_path)
    metadata = pd.read_parquet(metadata_path).set_index('sample_id')

    # Align data
    common_samples = df.index.intersection(metadata.index)
    df_aligned = df.loc[common_samples]
    metadata_aligned = metadata.loc[common_samples]

    print(f"Dataset shape: {df_aligned.shape}")
    print(f"Condition distribution: {metadata_aligned['condition'].value_counts()}")

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Prepare features and labels
    X = df_aligned
    y = metadata_aligned['condition']

    # Initial train-test split (stratified)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y,
        test_size=test_size,
        random_state=random_state,
        stratify=y
    )

    print(f"\nTrain-Test Split Results:")
    print(f"Training set: {X_train.shape[0]} samples ({X_train.shape[0]/len(X)*100:.1f}%)")
    print(f"Test set: {X_test.shape[0]} samples ({X_test.shape[0]/len(X)*100:.1f}%)")
    print(f"Training condition distribution: {y_train.value_counts()}")
    print(f"Test condition distribution: {y_test.value_counts()}")

    # Create validation set if requested
    if create_validation:
        X_train_final, X_val, y_train_final, y_val = train_test_split(
            X_train, y_train,
            test_size=0.25,  # 25% of training = 20% of total (if test_size=0.2)
            random_state=random_state,
            stratify=y_train
        )

        print(f"\nWith Validation Set:")
        print(f"Final training: {X_train_final.shape[0]} samples ({X_train_final.shape[0]/len(X)*100:.1f}%)")
        print(f"Validation: {X_val.shape[0]} samples ({X_val.shape[0]/len(X)*100:.1f}%)")
        print(f"Test: {X_test.shape[0]} samples ({X_test.shape[0]/len(X)*100:.1f}%)")

    # Save the splits
    file_paths = {}

    if create_validation:
        # Save with validation split
        X_train_final.to_parquet(f"{output_dir}/X_train.pq", index=True)
        X_val.to_parquet(f"{output_dir}/X_val.pq", index=True)
        X_test.to_parquet(f"{output_dir}/X_test.pq", index=True)

        y_train_final.to_frame('condition').to_parquet(f"{output_dir}/y_train.pq", index=True)
        y_val.to_frame('condition').to_parquet(f"{output_dir}/y_val.pq", index=True)
        y_test.to_frame('condition').to_parquet(f"{output_dir}/y_test.pq", index=True)

        file_paths = {
            'X_train': f"{output_dir}/X_train.pq",
            'X_val': f"{output_dir}/X_val.pq",
            'X_test': f"{output_dir}/X_test.pq",
            'y_train': f"{output_dir}/y_train.pq",
            'y_val': f"{output_dir}/y_val.pq",
            'y_test': f"{output_dir}/y_test.pq"
        }

        print(f"\nSaved train/validation/test splits to {output_dir}/")

    else:
        # Save without validation split
        X_train.to_parquet(f"{output_dir}/X_train.pq", index=True)
        X_test.to_parquet(f"{output_dir}/X_test.pq", index=True)

        y_train.to_frame('condition').to_parquet(f"{output_dir}/y_train.pq", index=True)
        y_test.to_frame('condition').to_parquet(f"{output_dir}/y_test.pq", index=True)

        file_paths = {
            'X_train': f"{output_dir}/X_train.pq",
            'X_test': f"{output_dir}/X_test.pq",
            'y_train': f"{output_dir}/y_train.pq",
            'y_test': f"{output_dir}/y_test.pq"
        }

        print(f"\nSaved train/test splits to {output_dir}/")

    create_split_visualization(X_train, X_test, y_train, y_test,
                              X_val if create_validation else None,
                              y_val if create_validation else None)

    return file_paths


def create_split_visualization(X_train, X_test, y_train, y_test, X_val=None, y_val=None):
    """
    Create visualization of the train-test split.
    """

    print("\nCreating split visualization...")

    # Combine data for consistent PCA
    if X_val is not None:
        X_combined = pd.concat([X_train, X_val, X_test])
        y_combined = pd.concat([y_train, y_val, y_test])
        split_labels = (['Train'] * len(X_train) +
                       ['Validation'] * len(X_val) +
                       ['Test'] * len(X_test))
    else:
        X_combined = pd.concat([X_train, X_test])
        y_combined = pd.concat([y_train, y_test])
        split_labels = ['Train'] * len(X_train) + ['Test'] * len(X_test)

    # Perform PCA
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_combined)
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(X_scaled)

    # Create plot
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    # Plot 1: Colored by split
    ax1 = axes[0]
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] if X_val is not None else ['#1f77b4', '#ff7f0e']
    for i, split_type in enumerate(set(split_labels)):
        mask = [label == split_type for label in split_labels]
        ax1.scatter(pca_result[mask, 0], pca_result[mask, 1],
                   c=colors[i], label=split_type, alpha=0.7, s=50)

    ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    ax1.set_title('Data Split Visualization')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Colored by condition
    ax2 = axes[1]
    condition_colors = {0: '#2E8B57', 1: '#DC143C'}  # Green for healthy, red for unhealthy
    for condition in [0, 1]:
        mask = y_combined == condition
        label = 'Healthy' if condition == 0 else 'Unhealthy'
        ax2.scatter(pca_result[mask, 0], pca_result[mask, 1],
                   c=condition_colors[condition], label=label, alpha=0.7, s=50)

    ax2.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    ax2.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    ax2.set_title('Condition Distribution Across Splits')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print("Split visualization complete!")

In [None]:
# For simple train-test split (recommended for initial testing)
split_paths = create_train_test_split(
    feature_selected_path="data/merged_dataset_feature_selected.pq",
    metadata_path="data/merged_metadata.pq",
    output_dir="data/splits",
    test_size=0.2,
    random_state=42,
    create_validation=False
)

print("Split files created:")
for split_name, path in split_paths.items():
    print(f"  {split_name}: {path}")

## Handle class imbalance

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
import lightgbm as lgb
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import shap

In [None]:
def analyze_class_imbalance_lightgbm(y_train_path, y_test_path, create_plots=True):
    """
    Analyze class imbalance specifically for LightGBM training.

    Args:
        y_train_path (str): Path to training labels
        y_test_path (str): Path to test labels
        create_plots (bool): Whether to create visualization plots

    Returns:
        dict: Class distribution statistics and LightGBM scale_pos_weight
    """
    print("Analyzing class imbalance for LightGBM...")

    # Load labels
    y_train = pd.read_parquet(y_train_path)['condition']
    y_test = pd.read_parquet(y_test_path)['condition']

    # Calculate class distributions
    train_counts = y_train.value_counts().sort_index()
    test_counts = y_test.value_counts().sort_index()

    total_train = len(y_train)
    total_test = len(y_test)

    healthy_train = train_counts[0]
    unhealthy_train = train_counts[1]

    # Calculate scale_pos_weight for LightGBM
    scale_pos_weight = healthy_train / unhealthy_train

    train_ratio = unhealthy_train / healthy_train
    test_ratio = test_counts[1] / test_counts[0]

    stats = {
        'train_healthy': healthy_train,
        'train_unhealthy': unhealthy_train,
        'test_healthy': test_counts[0],
        'test_unhealthy': test_counts[1],
        'train_total': total_train,
        'test_total': total_test,
        'scale_pos_weight': scale_pos_weight,
        'train_imbalance_ratio': train_ratio,
        'test_imbalance_ratio': test_ratio,
        'train_minority_percentage': (unhealthy_train / total_train) * 100,
        'test_minority_percentage': (test_counts[1] / total_test) * 100
    }

    # Print summary
    print(f"\n{'='*50}")
    print("LIGHTGBM CLASS IMBALANCE ANALYSIS")
    print(f"{'='*50}")
    print(f"Training Set:")
    print(f"  Healthy (0): {healthy_train:,} samples ({healthy_train/total_train*100:.1f}%)")
    print(f"  Unhealthy (1): {unhealthy_train:,} samples ({unhealthy_train/total_train*100:.1f}%)")
    print(f"  Imbalance ratio: 1:{train_ratio:.2f}")

    print(f"\nTest Set:")
    print(f"  Healthy (0): {test_counts[0]:,} samples ({test_counts[0]/total_test*100:.1f}%)")
    print(f"  Unhealthy (1): {test_counts[1]:,} samples ({test_counts[1]/total_test*100:.1f}%)")
    print(f"  Imbalance ratio: 1:{test_ratio:.2f}")

    print(f"\nLightGBM Parameter:")
    print(f"  scale_pos_weight = {scale_pos_weight:.3f}")

    # Determine severity
    if train_ratio < 0.2:
        severity = "SEVERE"
        recommendation = "Consider using scale_pos_weight + focal loss"
    elif train_ratio < 0.5:
        severity = "MODERATE"
        recommendation = "scale_pos_weight should handle this well"
    else:
        severity = "MILD"
        recommendation = "scale_pos_weight may not be necessary"

    print(f"  Imbalance severity: {severity}")
    print(f"  Recommendation: {recommendation}")

    if create_plots:
        create_lightgbm_imbalance_plots(stats)

    return stats


def create_lightgbm_imbalance_plots(stats):
    """
    Create focused visualization plots for LightGBM class imbalance.
    """
    print("\nGenerating LightGBM-focused imbalance plots...")

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('LightGBM Class Imbalance Analysis', fontsize=16, fontweight='bold')

    # Plot 1: Training set distribution with scale_pos_weight annotation
    ax1 = axes[0, 0]
    train_labels = ['Healthy', 'Unhealthy']
    train_counts = [stats['train_healthy'], stats['train_unhealthy']]
    colors = ['#2E8B57', '#DC143C']

    bars1 = ax1.bar(train_labels, train_counts, color=colors, alpha=0.7, edgecolor='black')
    ax1.set_title(f'Training Set Distribution\nscale_pos_weight = {stats["scale_pos_weight"]:.3f}')
    ax1.set_ylabel('Number of Samples')
    ax1.grid(True, alpha=0.3)

    for bar, count in zip(bars1, train_counts):
        ax1.text(bar.get_x() + bar.get_width()/2., bar.get_height() + max(train_counts)*0.01,
                f'{count:,}\n({count/stats["train_total"]*100:.1f}%)',
                ha='center', va='bottom', fontweight='bold')

    # Plot 2: Test set distribution
    ax2 = axes[0, 1]
    test_counts = [stats['test_healthy'], stats['test_unhealthy']]

    bars2 = ax2.bar(train_labels, test_counts, color=colors, alpha=0.7, edgecolor='black')
    ax2.set_title('Test Set Distribution')
    ax2.set_ylabel('Number of Samples')
    ax2.grid(True, alpha=0.3)

    for bar, count in zip(bars2, test_counts):
        ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + max(test_counts)*0.01,
                f'{count:,}\n({count/stats["test_total"]*100:.1f}%)',
                ha='center', va='bottom', fontweight='bold')

    # Plot 3: Imbalance ratio visualization
    ax3 = axes[1, 0]
    ratio_data = [1, stats['train_imbalance_ratio']]
    ratio_labels = ['Healthy\n(Reference)', f'Unhealthy\n(1:{stats["train_imbalance_ratio"]:.2f})']

    bars3 = ax3.bar(ratio_labels, ratio_data, color=colors, alpha=0.7, edgecolor='black')
    ax3.set_title('Class Imbalance Ratio\n(Training Set)')
    ax3.set_ylabel('Relative Frequency')
    ax3.grid(True, alpha=0.3)

    for bar, ratio in zip(bars3, ratio_data):
        ax3.text(bar.get_x() + bar.get_width()/2., bar.get_height() + max(ratio_data)*0.01,
                f'{ratio:.2f}',
                ha='center', va='bottom', fontweight='bold')

    # Plot 4: Train vs Test comparison
    ax4 = axes[1, 1]
    x = np.arange(2)
    width = 0.35

    train_props = [stats['train_healthy']/stats['train_total']*100,
                   stats['train_unhealthy']/stats['train_total']*100]
    test_props = [stats['test_healthy']/stats['test_total']*100,
                  stats['test_unhealthy']/stats['test_total']*100]

    bars1 = ax4.bar(x - width/2, train_props, width, label='Training',
                   color=colors, alpha=0.7, edgecolor='black')
    bars2 = ax4.bar(x + width/2, test_props, width, label='Test',
                   color=colors, alpha=0.5, edgecolor='black')

    ax4.set_xlabel('Class')
    ax4.set_ylabel('Percentage (%)')
    ax4.set_title('Distribution Consistency Check')
    ax4.set_xticks(x)
    ax4.set_xticklabels(['Healthy', 'Unhealthy'])
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax4.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                   f'{height:.1f}%', ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    plt.show()

    print("LightGBM imbalance analysis plots generated!")
    print(f"\n✅ Ready for LightGBM training with scale_pos_weight = {stats['scale_pos_weight']:.3f}")

In [None]:
# Analyze class imbalance specifically for LightGBM
lightgbm_stats = analyze_class_imbalance_lightgbm(
    y_train_path="data/splits/y_train.pq",
    y_test_path="data/splits/y_test.pq",
    create_plots=True
)

print(f"\n🚀 LightGBM scale_pos_weight ready: {lightgbm_stats['scale_pos_weight']:.3f}")

## Hyper param tuning

In [None]:
import optuna
import lightgbm as lgb
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.metrics import roc_auc_score
import pandas as pd
import numpy as np

In [None]:
def optimize_lightgbm_hyperparameters(X_train_path, y_train_path, scale_pos_weight,
                                    n_trials=100, cv_folds=5, random_state=42):
    """
    Optimize LightGBM hyperparameters using Optuna for your breast cancer dataset.

    Args:
        X_train_path (str): Path to training features
        y_train_path (str): Path to training labels
        scale_pos_weight (float): Class imbalance weight from previous analysis
        n_trials (int): Number of optimization trials
        cv_folds (int): Number of cross-validation folds
        random_state (int): Random seed

    Returns:
        dict: Best hyperparameters and study results
    """
    print("Loading training data for hyperparameter optimization...")

    # Load training data
    X_train = pd.read_parquet(X_train_path)
    y_train = pd.read_parquet(y_train_path)['condition']

    print(f"Training data shape: {X_train.shape}")
    print(f"Class distribution: {y_train.value_counts()}")

    # Define the objective function for Optuna
    def objective(trial):
        # Suggest hyperparameters to optimize
        params = {
            'objective': 'binary',
            'metric': 'binary_logloss',
            'boosting_type': 'gbdt',
            'scale_pos_weight': scale_pos_weight,  # Fixed from class imbalance analysis
            'random_state': random_state,
            'verbose': -1,

            # Hyperparameters to optimize
            'num_leaves': trial.suggest_int('num_leaves', 10, 300),
            'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
            'feature_fraction': trial.suggest_float('feature_fraction', 0.4, 1.0),
            'bagging_fraction': trial.suggest_float('bagging_fraction', 0.4, 1.0),
            'bagging_freq': trial.suggest_int('bagging_freq', 1, 7),
            'min_child_samples': trial.suggest_int('min_child_samples', 5, 100),
            'reg_alpha': trial.suggest_float('reg_alpha', 1e-8, 10.0, log=True),
            'reg_lambda': trial.suggest_float('reg_lambda', 1e-8, 10.0, log=True),
        }

        # Perform cross-validation
        cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)

        cv_scores = []
        for train_idx, val_idx in cv.split(X_train, y_train):
            X_fold_train, X_fold_val = X_train.iloc[train_idx], X_train.iloc[val_idx]
            y_fold_train, y_fold_val = y_train.iloc[train_idx], y_train.iloc[val_idx]

            # Create LightGBM datasets
            train_data = lgb.Dataset(X_fold_train, label=y_fold_train)
            val_data = lgb.Dataset(X_fold_val, label=y_fold_val, reference=train_data)

            # Train model with proper validation dataset for early stopping
            model = lgb.train(
                params,
                train_data,
                valid_sets=[val_data],  # Provide validation dataset
                valid_names=['eval'],
                num_boost_round=1000,  # Increased for early stopping
                callbacks=[
                    lgb.early_stopping(stopping_rounds=50),  # Early stopping
                    lgb.log_evaluation(0)  # Silent logging
                ]
            )

            # Predict and score
            val_pred = model.predict(X_fold_val, num_iteration=model.best_iteration)
            auc_score = roc_auc_score(y_fold_val, val_pred)
            cv_scores.append(auc_score)

        return np.mean(cv_scores)

    # Create and run optimization study
    print(f"\nStarting hyperparameter optimization with {n_trials} trials...")
    print("This may take several minutes...")

    # Set random seed for reproducibility using optuna sampler
    sampler = optuna.samplers.TPESampler(seed=random_state)
    study = optuna.create_study(direction='maximize', sampler=sampler)
    study.optimize(objective, n_trials=n_trials)

    # Get best parameters
    best_params = study.best_params.copy()
    best_params.update({
        'objective': 'binary',
        'metric': 'binary_logloss',
        'boosting_type': 'gbdt',
        'scale_pos_weight': scale_pos_weight,
        'random_state': random_state,
        'verbose': -1
    })

    print(f"\n{'='*60}")
    print("HYPERPARAMETER OPTIMIZATION RESULTS")
    print(f"{'='*60}")
    print(f"Best CV AUC Score: {study.best_value:.4f}")
    print(f"Best Parameters:")
    for param, value in study.best_params.items():
        print(f"  {param}: {value}")

    # Plot optimization history
    create_optimization_plots(study)

    return {
        'best_params': best_params,
        'best_score': study.best_value,
        'study': study,
        'n_trials': n_trials
    }

def create_optimization_plots(study):
    """
    Create plots to visualize the hyperparameter optimization process.
    """
    import matplotlib.pyplot as plt

    print("\nGenerating optimization plots...")

    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Hyperparameter Optimization Analysis', fontsize=16, fontweight='bold')

    # Plot 1: Optimization history
    ax1 = axes[0, 0]
    trials = study.trials
    trial_numbers = [t.number for t in trials]
    trial_values = [t.value for t in trials if t.value is not None]
    trial_nums_valid = [t.number for t in trials if t.value is not None]

    ax1.plot(trial_nums_valid, trial_values, 'b-', alpha=0.6)
    ax1.axhline(study.best_value, color='red', linestyle='--',
                label=f'Best: {study.best_value:.4f}')
    ax1.set_xlabel('Trial Number')
    ax1.set_ylabel('CV AUC Score')
    ax1.set_title('Optimization History')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Parameter importance
    ax2 = axes[0, 1]
    try:
        importance = optuna.importance.get_param_importances(study)
        params = list(importance.keys())
        values = list(importance.values())

        ax2.barh(params, values, alpha=0.7, color='skyblue', edgecolor='black')
        ax2.set_xlabel('Importance')
        ax2.set_title('Parameter Importance')
        ax2.grid(True, alpha=0.3)
    except:
        ax2.text(0.5, 0.5, 'Parameter importance\nnot available',
                ha='center', va='center', transform=ax2.transAxes)
        ax2.set_title('Parameter Importance')

    # Plot 3: Learning rate vs performance
    ax3 = axes[1, 0]
    lr_values = [t.params.get('learning_rate') for t in trials if t.value is not None and 'learning_rate' in t.params]
    auc_values = [t.value for t in trials if t.value is not None and 'learning_rate' in t.params]

    if lr_values and auc_values:
        ax3.scatter(lr_values, auc_values, alpha=0.6, color='green', s=30)
        ax3.set_xlabel('Learning Rate')
        ax3.set_ylabel('CV AUC Score')
        ax3.set_title('Learning Rate vs Performance')
        ax3.set_xscale('log')
        ax3.grid(True, alpha=0.3)
    else:
        ax3.text(0.5, 0.5, 'No learning rate data', ha='center', va='center', transform=ax3.transAxes)

    # Plot 4: Num leaves vs performance
    ax4 = axes[1, 1]
    leaves_values = [t.params.get('num_leaves') for t in trials if t.value is not None and 'num_leaves' in t.params]
    auc_values_leaves = [t.value for t in trials if t.value is not None and 'num_leaves' in t.params]

    if leaves_values and auc_values_leaves:
        ax4.scatter(leaves_values, auc_values_leaves, alpha=0.6, color='orange', s=30)
        ax4.set_xlabel('Number of Leaves')
        ax4.set_ylabel('CV AUC Score')
        ax4.set_title('Num Leaves vs Performance')
        ax4.grid(True, alpha=0.3)
    else:
        ax4.text(0.5, 0.5, 'No num_leaves data', ha='center', va='center', transform=ax4.transAxes)

    plt.tight_layout()
    plt.show()

    print("Optimization plots generated!")

In [None]:
optimization_results = optimize_lightgbm_hyperparameters(
    X_train_path="data/splits/X_train.pq",
    y_train_path="data/splits/y_train.pq",
    scale_pos_weight=lightgbm_stats['scale_pos_weight'],
    n_trials=50,
    cv_folds=10,
    random_state=42
)

print("\n🎯 Hyperparameter optimization complete!")
print(f"Best parameters ready for final model training!")