# Imports and modular functions

*Always run this if changes are made to these functions or if **python kernel crashes*** 😁😄😃🙂😐🙁☹️😢😭💀

In [None]:
from typing import Iterator, List, Optional, Union
import pyarrow.parquet as pq
import pyarrow as pa
import pandas as pd
import numpy as np
import tempfile
import mygene
import os
import gc

def load_csv_in_chunks(file_path, chunk_size=8000, **kwargs):
    """
    Loads a CSV file in chunks to avoid memory issues.

    Args:
        file_path (str): The path to the CSV file.
        chunk_size (int): The number of rows per chunk.
        **kwargs: Additional keyword arguments to pass to pd.read_csv()
                  (e.g., sep=',', header=0, index_col=None, usecols=None).

    Returns:
        A pandas TextFileReader object (iterator) that yields DataFrame chunks
        if the file exists, otherwise None.
    """

    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None

    print(f"Preparing to load {file_path} in chunks of size {chunk_size}...")
    try:
        chunk_iterator = pd.read_csv(file_path, chunksize=chunk_size, **kwargs)
        return chunk_iterator
    except Exception as e:
        print(f"Error loading CSV file {file_path}: {e}")
        return None

    # Clean up
    finally:
        if 'chunk_iterator' in locals():
            del chunk_iterator
        gc.collect()


def load_parquet_in_chunks(file_path, chunk_size=8000):
    """
    Loads a parquet file in true chunks to avoid memory issues.

    Args:
        file_path (str): The path to the parquet file.
        chunk_size (int): The number of rows per chunk.

    Yields:
        pandas.DataFrame: DataFrame chunks of the specified size.
    """
    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None

    print(f"Preparing to load {file_path} in chunks of size {chunk_size}...")

    try:
        parquet_file = pq.ParquetFile(file_path)
        total_rows = parquet_file.metadata.num_rows
        print(f"Total rows in file: {total_rows}")

        for batch in parquet_file.iter_batches(batch_size=chunk_size):
            chunk_df = batch.to_pandas()
            yield chunk_df

            del chunk_df

    except Exception as e:
        print(f"Error loading parquet file {file_path}: {e}")
        return None

    # Clean up
    finally:
        if 'parquet_file' in locals():
            del parquet_file
        if 'batch' in locals():
            del batch
        if 'chunk_df' in locals():
            del chunk_df
        gc.collect()


def save_data_as_parquet(chunk_iterator, output_parquet_path, preserve_index=True):
    """
    Process DataFrame chunks and save as parquet using PyArrow for maximum efficiency.
    Handles index preservation properly across chunks.

    Args:
        chunk_iterator: Iterator yielding DataFrame chunks
        output_parquet_path (str): Path to save the parquet file
        preserve_index (bool): Whether to preserve the DataFrame index
    """
    print(f"Processing chunks and saving to {output_parquet_path}...")

    writer = None
    total_rows = 0
    all_index_values = set()

    try:
        for chunk_idx, chunk in enumerate(chunk_iterator):
            if chunk.empty:
                continue

            if preserve_index and chunk.index.name is not None:
                chunk_index_values = set(chunk.index)
                duplicates = all_index_values.intersection(chunk_index_values)
                if duplicates:
                    print(f"Warning: Found {len(duplicates)} duplicate index values in chunk {chunk_idx}")
                    print(f"First few duplicates: {list(duplicates)[:5]}")
                all_index_values.update(chunk_index_values)

            table = pa.Table.from_pandas(chunk, preserve_index=preserve_index)

            if writer is None:
                writer = pq.ParquetWriter(output_parquet_path, table.schema)
                if preserve_index and chunk.index.name:
                    print(f"Preserving index: '{chunk.index.name}' (dtype: {chunk.index.dtype})")

            writer.write_table(table)
            total_rows += chunk.shape[0]

            del table

    finally:
        if writer:
            writer.close()

        # Clean up
        if 'writer' in locals():
            del writer
        if 'chunk_iterator' in locals():
            del chunk_iterator
        gc.collect()

    if total_rows > 0:
        print(f"Successfully saved {total_rows} rows to parquet")
        if preserve_index:
            print(f"Total unique index values: {len(all_index_values)}")
    else:
        print("No data to save!")


def drop_dataframe_chunks(
    chunk_generator: Iterator[pd.DataFrame],
    drop_rows: Optional[Union[List[Union[int, str]], List[int], List[str]]] = None,
    drop_columns: Optional[Union[List[Union[int, str]], List[int], List[str]]] = None
) -> Iterator[pd.DataFrame]:
    """
    Generator function that drops specified rows and/or columns from DataFrame chunks
    in a memory-efficient way. Designed to be chained with other generators.

    Args:
        chunk_generator: An iterable that yields pandas DataFrame chunks
        drop_rows: List of row indices/names to drop, or None to keep all rows
        drop_columns: List of column indices/names to drop, or None to keep all columns

    Yields:
        pd.DataFrame: Chunks with specified rows/columns dropped
    """
    for chunk in chunk_generator:
        if chunk.empty:
            yield chunk
            continue

        processed_chunk = chunk.copy()

        if drop_columns is not None:
            columns_to_drop = [col for col in drop_columns if col in processed_chunk.columns]
            if columns_to_drop:
                processed_chunk = processed_chunk.drop(columns=columns_to_drop)

        if drop_rows is not None:
            if drop_rows and not all(isinstance(row, int) for row in drop_rows):
                chunk_index_str = processed_chunk.index.astype(str)
                drop_rows_str = [str(row) for row in drop_rows]

                rows_to_drop = processed_chunk.index[chunk_index_str.isin(drop_rows_str)]
                if len(rows_to_drop) > 0:
                    processed_chunk = processed_chunk.drop(index=rows_to_drop)
            else:
                valid_indices = [idx for idx in drop_rows if 0 <= idx < len(processed_chunk)]
                if valid_indices:
                    processed_chunk = processed_chunk.drop(processed_chunk.index[valid_indices])

        if not processed_chunk.empty:
            yield processed_chunk


def keep_dataframe_chunks(
    chunk_generator: Iterator[pd.DataFrame],
    keep_rows: Optional[Union[List[Union[int, str]], List[int], List[str]]] = None,
    keep_columns: Optional[Union[List[Union[int, str]], List[int], List[str]]] = None
) -> Iterator[pd.DataFrame]:
    """
    Generator function that keeps only specified rows and/or columns from DataFrame chunks
    in a memory-efficient way. Designed to be chained with other generators.

    Args:
        chunk_generator: An iterable that yields pandas DataFrame chunks
        keep_rows: List of row indices/names to keep, or None to keep all rows
        keep_columns: List of column indices/names to keep, or None to keep all columns

    Yields:
        pd.DataFrame: Chunks with only specified rows/columns kept
    """
    for chunk in chunk_generator:
        if chunk.empty:
            yield chunk
            continue

        processed_chunk = chunk.copy()

        if keep_columns is not None:
            columns_to_keep = [col for col in keep_columns if col in processed_chunk.columns]
            if columns_to_keep:
                processed_chunk = processed_chunk[columns_to_keep]
            else:
                processed_chunk = processed_chunk.iloc[:0]

        if keep_rows is not None:
            if keep_rows and not all(isinstance(row, int) for row in keep_rows):
                chunk_index_str = processed_chunk.index.astype(str)
                keep_rows_str = [str(row) for row in keep_rows]

                rows_to_keep = processed_chunk.index[chunk_index_str.isin(keep_rows_str)]
                if len(rows_to_keep) > 0:
                    processed_chunk = processed_chunk.loc[rows_to_keep]
                else:
                    processed_chunk = processed_chunk.iloc[:0]
            else:
                valid_indices = [idx for idx in keep_rows if 0 <= idx < len(processed_chunk)]
                if valid_indices:
                    processed_chunk = processed_chunk.iloc[valid_indices]
                else:
                    processed_chunk = processed_chunk.iloc[:0]

        if not processed_chunk.empty:
            yield processed_chunk


def transpose_dataframe_chunks(
    chunk_generator,
    skip_rows=None,
    skip_columns=None,
    output_batch_size=1000,
    temp_dir="/tmp",
    dtype='uint32'
):
    """
    Generator function that collects DataFrame chunks, transposes the complete dataset,
    and yields the transposed result in batches. Designed to be chained with other generators.

    Args:
        chunk_generator: An iterable that yields pandas DataFrame chunks
        skip_rows: List of row indices/names to skip, or None
        skip_columns: List of column indices/names to skip, or None
        output_batch_size: Number of rows to yield in each output batch
        temp_dir: Directory for temporary memory-mapped file
        dtype: Data type for memory-mapped array (default: 'uint32')

    Yields:
        pd.DataFrame: Batches of the transposed DataFrame
    """

    print(f"Starting memory-mapped transposition...")

    chunk_list = list(chunk_generator)
    if not chunk_list:
        print("No chunks received from generator. Exiting.")
        return

    first_chunk = chunk_list[0].copy()

    if skip_columns is not None:
        columns_to_keep = [col for col in first_chunk.columns if col not in skip_columns]
        first_chunk = first_chunk[columns_to_keep]
        print(f"Skipping columns: {skip_columns}")

    if skip_rows is not None:
        if isinstance(skip_rows[0], str):
            rows_to_keep = [idx for idx in first_chunk.index if idx not in skip_rows]
        else:
            rows_to_keep = [idx for i, idx in enumerate(first_chunk.index) if i not in skip_rows]
        first_chunk = first_chunk.loc[rows_to_keep]
        print(f"Skipping rows: {skip_rows}")

    original_rows = sum(len(chunk) for chunk in chunk_list)
    original_cols = len(first_chunk.columns)

    filtered_cols = first_chunk.columns.tolist()
    n_output_rows = len(filtered_cols)

    all_row_indices = []
    for chunk in chunk_list:
        chunk_filtered = chunk.copy()

        if skip_columns is not None:
            chunk_filtered = chunk_filtered[filtered_cols]
        if skip_rows is not None:
            if isinstance(skip_rows[0], str):
                chunk_rows_to_keep = [idx for idx in chunk_filtered.index if idx not in skip_rows]
            else:
                chunk_rows_to_keep = [idx for i, idx in enumerate(chunk_filtered.index) if i not in skip_rows]
            chunk_filtered = chunk_filtered.loc[chunk_rows_to_keep]

        all_row_indices.extend(chunk_filtered.index.tolist())

    n_output_cols = len(all_row_indices)

    print(f"Dataset dimensions:")
    print(f"    Original: {original_rows} rows x {original_cols} columns")
    print(f"    After filtering: {len(all_row_indices)} rows x {len(filtered_cols)} columns")
    print(f"    After transpose: {n_output_rows} rows x {n_output_cols} columns")

    # Create memory-mapped file
    try:
        temp_file = tempfile.NamedTemporaryFile(
            dir=temp_dir,
            delete=False,
            suffix='.mmap'
        )
        temp_filename = temp_file.name
        temp_file.close()

        # Create memory-mapped array in transposed orientation: [samples, features]
        mmap_array = np.memmap(
            temp_filename,
            dtype=dtype,
            mode='w+',
            shape=(n_output_rows, n_output_cols)
        )
        print(f"Memory map created successfully")

        print(f"Filling memory map with data...")
        current_feature_idx = 0

        for chunk_idx, chunk in enumerate(chunk_list):
            chunk_filtered = chunk.copy()
            if skip_columns is not None:
                chunk_filtered = chunk_filtered[filtered_cols]
            if skip_rows is not None:
                if isinstance(skip_rows[0], str):
                    chunk_rows_to_keep = [idx for idx in chunk_filtered.index if idx not in skip_rows]
                else:
                    chunk_rows_to_keep = [idx for i, idx in enumerate(chunk_filtered.index) if i not in skip_rows]
                chunk_filtered = chunk_filtered.loc[chunk_rows_to_keep]

            chunk_data = chunk_filtered.values.T.astype(dtype)
            chunk_feature_count = chunk_filtered.shape[0]

            mmap_array[:, current_feature_idx:current_feature_idx + chunk_feature_count] = chunk_data
            current_feature_idx += chunk_feature_count

            mmap_array.flush()
            del chunk_data, chunk_filtered

            print(f"    Processed chunk {chunk_idx + 1}/{len(chunk_list)}")

        del chunk_list
        gc.collect()

        print(f"Memory map filled successfully")

        print(f"Yielding transposed batches (size: {output_batch_size})...")
        total_batches = (n_output_rows + output_batch_size - 1) // output_batch_size

        for batch_idx in range(total_batches):
            start_row = batch_idx * output_batch_size
            end_row = min(start_row + output_batch_size, n_output_rows)

            batch_data = mmap_array[start_row:end_row, :].copy()
            batch_sample_ids = filtered_cols[start_row:end_row]

            batch_df = pd.DataFrame(
                data=batch_data,
                index=batch_sample_ids,
                columns=all_row_indices
            )

            batch_df = batch_df.reset_index().rename(columns={'index': 'sample_id'})

            yield batch_df

            del batch_data, batch_df
            gc.collect()

            print(f"    Yielded batch {batch_idx + 1}/{total_batches}")

        print(f"Transposition completed successfully!")

    except Exception as e:
        print(f"Error during memory-mapped transposition: {e}")
        raise

    finally:
        print(f"Cleaning up temporary files...")
        try:
            if 'mmap_array' in locals():
                del mmap_array
            if 'temp_filename' in locals() and os.path.exists(temp_filename):
                os.unlink(temp_filename)
                print(f"Temporary file removed: {temp_filename}")
        except Exception as cleanup_error:
            print(f"Warning: Could not clean up temp file: {cleanup_error}")


def convert_gene_ids_to_symbols(dataset_chunk_iterator, mygene_client, gene_column_prefix="ENSG"):
    """
    Convert gene IDs to gene symbols using MyGene API in a memory-efficient way.
    Only calls MyGene API once for all gene columns from the first chunk.

    Args:
        dataset_chunk_iterator: Iterator yielding DataFrame chunks
        mygene_client: MyGene client instance
        gene_column_prefix (str): Prefix to identify gene columns

    Returns:
        pd.DataFrame: Complete dataset with gene symbols as column names
    """
    print("Converting gene IDs to symbols (chunked processing)...")

    # Get first chunk to determine gene columns and create mapping
    try:
        first_chunk = next(dataset_chunk_iterator)
        print(f"Processing first chunk: {first_chunk.shape}")
    except StopIteration:
        print("❌ Error: Dataset is empty")
        return pd.DataFrame()

    # Identify gene columns from first chunk
    gene_columns = [col for col in first_chunk.columns
                   if isinstance(col, str) and col.startswith(gene_column_prefix)]

    if not gene_columns:
        print("No gene columns found - returning original data")
        # Concatenate all chunks and return
        all_chunks = [first_chunk]
        for chunk in dataset_chunk_iterator:
            all_chunks.append(chunk)
        return pd.concat(all_chunks, axis=0, ignore_index=False)

    print(f"Found {len(gene_columns)} gene columns")

    # Single MyGene API call for all gene IDs
    gene_id_to_symbol = {}
    symbols_found = 0

    try:
        print("Making single MyGene API call...")
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            results = mygene_client.querymany(
                gene_columns,
                scopes='ensembl.gene',
                fields='symbol',
                species='human',
                verbose=False,
                silent=True
            )

        # Process API results
        for result in results:
            gene_id = result['query']
            if 'symbol' in result and result['symbol']:
                gene_id_to_symbol[gene_id] = result['symbol']
                symbols_found += 1
            else:
                gene_id_to_symbol[gene_id] = gene_id

        print(f"✅ Successfully converted {symbols_found}/{len(gene_columns)} genes to symbols")

    except Exception as e:
        print(f"⚠️ MyGene API error: {e}. Using original gene IDs")
        gene_id_to_symbol = {gene_id: gene_id for gene_id in gene_columns}

    # Create final column mapping (preserve non-gene columns like 'condition')
    final_column_mapping = {}
    for col in first_chunk.columns:
        if col == 'condition':
            final_column_mapping[col] = col
        elif col in gene_id_to_symbol:
            final_column_mapping[col] = gene_id_to_symbol[col]
        else:
            final_column_mapping[col] = col

    print("Applying gene symbol mapping to all chunks...")

    # Process first chunk
    renamed_first_chunk = first_chunk.rename(columns=final_column_mapping)
    processed_chunks = [renamed_first_chunk]
    chunk_count = 1

    # Process remaining chunks with same mapping
    for chunk in dataset_chunk_iterator:
        renamed_chunk = chunk.rename(columns=final_column_mapping)
        processed_chunks.append(renamed_chunk)
        chunk_count += 1

        if chunk_count % 3 == 0:
            print(f"  ✓ Processed {chunk_count} chunks...")

    print(f"Concatenating {chunk_count} processed chunks...")
    final_dataset = pd.concat(processed_chunks, axis=0, ignore_index=False)

    print(f"✅ Final dataset shape: {final_dataset.shape}")
    print(f"   Symbols converted: {symbols_found}/{len(gene_columns)}")

    return final_dataset


def add_condition_labels_to_chunks(chunk_iterator, condition_label, dataset_name):
    """
    Add condition labels to dataset chunks.

    Args:
        chunk_iterator: Iterator yielding DataFrame chunks
        condition_label (int): Binary label (0 for healthy, 1 for unhealthy)
        dataset_name (str): Name for logging purposes

    Returns:
        list: List of labeled DataFrame chunks
    """
    print(f"Adding label '{condition_label}' to {dataset_name} dataset...")

    labeled_chunks = []
    for chunk in chunk_iterator:
        chunk['condition'] = condition_label
        labeled_chunks.append(chunk)

    print(f"✅ Completed {len(labeled_chunks)} {dataset_name} chunks")
    return labeled_chunks


def merge_labeled_datasets(healthy_chunks, unhealthy_chunks):
    """
    Merge healthy and unhealthy dataset chunks into one DataFrame.

    Args:
        healthy_chunks (list): List of healthy DataFrame chunks
        unhealthy_chunks (list): List of unhealthy DataFrame chunks

    Returns:
        pd.DataFrame: Merged dataset
    """
    print("Merging datasets...")

    all_chunks = healthy_chunks + unhealthy_chunks
    merged_dataset = pd.concat(all_chunks, axis=0, ignore_index=False)

    print(f"✅ Merged dataset: {len(merged_dataset)} samples, {len(merged_dataset.columns)} features")
    print(f"   Healthy (0): {(merged_dataset['condition'] == 0).sum()}")
    print(f"   Unhealthy (1): {(merged_dataset['condition'] == 1).sum()}")

    return merged_dataset


def clean_duplicate_nans(chunk_iterator):
    """
    Processes an iterator of DataFrame chunks to drop duplicates and NaNs.
    Yields cleaned DataFrame chunks one at a time.
    """
    print("Cleaning chunks by dropping NaNs and duplicates...", flush=True)

    chunks_processed = 0
    for i, chunk_df in enumerate(chunk_iterator):
        original_rows = len(chunk_df)

        chunk_df = chunk_df.dropna()

        if len(chunk_df) < original_rows:
            print(f"    Chunk {i+1}: Dropped {original_rows - len(chunk_df)} rows with null values.")

        if chunk_df.empty:
            print(f"    Chunk {i+1}: Empty after dropping NaNs, skipping...", flush=True)
            continue

        rows_before_dedup = len(chunk_df)
        chunk_df = chunk_df.drop_duplicates()

        if rows_before_dedup > len(chunk_df):
            print(f"    Chunk {i+1}: Dropped {rows_before_dedup - len(chunk_df)} duplicate rows.")

        if not chunk_df.empty:
            chunks_processed += 1
            print(f"    Chunk {i+1}: Yielded chunk with shape {chunk_df.shape}")
            yield chunk_df
        else:
            print(f"    Chunk {i+1}: Empty after deduplication, skipping...")

    print(f"Finished processing. {chunks_processed} non-empty chunks processed.")


def rename_index(chunk_iterator, index_name):
    """
    Rename the index of DataFrame chunks.
    """
    for chunk in chunk_iterator:
        chunk.index.name = index_name
        yield chunk


def filter_rows(chunk_iterator):
    """
    Generator function that filters out rows where the sum of all numeric values equals 0.
    Excludes the sample_id column from sum calculation.
    """

    print("Filtering rows where sum of all numeric values equals 0...")

    total_rows_processed = 0
    total_rows_removed = 0
    chunk_count = 0

    for chunk_df in chunk_iterator:
        if chunk_df.empty:
            continue

        chunk_count += 1
        original_rows = len(chunk_df)

        numeric_columns = [col for col in chunk_df.columns if col != 'sample_id']

        if len(numeric_columns) == 0:
            print(f"    Warning: Chunk {chunk_count} has no numeric columns, keeping all rows")
            yield chunk_df
            continue

        try:
            row_sums = chunk_df[numeric_columns].sum(axis=1)
            filtered_chunk = chunk_df[row_sums != 0]
        except Exception as e:
            print(f"    Error in chunk {chunk_count}: {e}")
            print(f"    Column types: {chunk_df.dtypes}")
            raise

        rows_removed = original_rows - len(filtered_chunk)
        total_rows_processed += original_rows
        total_rows_removed += rows_removed

        if rows_removed > 0:
            print(f"    Chunk {chunk_count}: Removed {rows_removed}/{original_rows} rows with zero sum")

        yield filtered_chunk

    print(f"Filtering complete: {total_rows_removed}/{total_rows_processed} rows removed")


def check_metadata(healthy_generator, unhealthy_generator, merged_metadata_generator):
    healthy_dataset_line_count = 0
    unhealthy_dataset_line_count = 0
    metadata_line_count = 0

    # Count rows in healthy dataset generator
    for chunk in healthy_generator:
        healthy_dataset_line_count += len(chunk)

    # Count rows in unhealthy dataset generator
    for chunk in unhealthy_generator:
        unhealthy_dataset_line_count += len(chunk)

    # Count rows in metadata generator
    for chunk in merged_metadata_generator:
        metadata_line_count += len(chunk)

    total = healthy_dataset_line_count + unhealthy_dataset_line_count

    if (total != metadata_line_count):
        raise Exception(f"Healthy line count: {healthy_dataset_line_count} | Unhealthy line count: {unhealthy_dataset_line_count} | Total line count: {total} | Metadata line count: {metadata_line_count}")
    else:
        print("All good")
        print(f"Stats: Healthy line count: {healthy_dataset_line_count} | Unhealthy line count: {unhealthy_dataset_line_count} | Total line count: {total} | Metadata line count: {metadata_line_count}")


def prepare_metadata_generator(aligned_pair_iterator, output_metadata_path="data/merged_metadata.pq"):
    """
    Generator that creates metadata from aligned chunk pairs and saves to parquet.

    Args:
        aligned_chunk_pairs_generator: Generator yielding (healthy_chunk, unhealthy_chunk) tuples
        output_metadata_path: Path to save metadata parquet file

    Yields:
        tuple: (healthy_chunk, unhealthy_chunk) - passes through the original chunks
    """

    all_metadata = []

    for healthy_chunk, unhealthy_chunk in aligned_pair_iterator:
        chunk_metadata = []

        if healthy_chunk is not None:
            healthy_sample_ids = [str(sid) for sid in healthy_chunk.iloc[:, 0].tolist()]
            healthy_metadata = [{'sample_id': sid, 'condition': 'healthy'} for sid in healthy_sample_ids]
            chunk_metadata.extend(healthy_metadata)

        if unhealthy_chunk is not None:
            unhealthy_sample_ids = [str(sid) for sid in unhealthy_chunk.iloc[:, 0].tolist()]
            unhealthy_metadata = [{'sample_id': sid, 'condition': 'unhealthy'} for sid in unhealthy_sample_ids]
            chunk_metadata.extend(unhealthy_metadata)

        all_metadata.extend(chunk_metadata)

        # Yield the chunks unchanged (pass-through)
        yield healthy_chunk, unhealthy_chunk
        del healthy_chunk, unhealthy_chunk

    # Save metadata at the end
    if all_metadata:
        metadata_df = pd.DataFrame(all_metadata)
        metadata_df.to_parquet(output_metadata_path, index=False)
        print(f"Metadata saved to {output_metadata_path} with {len(metadata_df)} records")
    else:
        print("No metadata to save")

    # Clean up
    if 'all_metadata' in locals() or 'all_metadata' in globals():
        del all_metadata
    if 'metadata_df' in locals() or 'metadata_df' in globals():
        del metadata_df
    if 'aligned_pair_iterator' in locals() or 'aligned_pair_iterator' in globals():
        del aligned_pair_iterator
    if 'chunk_metadata' in locals() or 'chunk_metadata' in globals():
        del chunk_metadata
    if 'healthy_metadata' in locals() or 'healthy_metadata' in globals():
        del healthy_metadata
    if 'unhealthy_metadata' in locals() or 'unhealthy_metadata' in globals():
        del unhealthy_metadata
    gc.collect()

def get_healthy_whole_blood_samples(metadata_path: str) -> list[str]:
    """
    Get healthy whole blood samples from the metadata.

    Args:
        metadata_path (str): Path to the metadata CSV file
        gtex_blood_ids (list[str]): List of GTEx whole blood sample IDs

    Returns:
        list[str]: List of healthy whole blood sample IDs
    """
    metadata_df = pd.read_csv(metadata_path, index_col=0)

    # Get all samples with 'Whole Blood' SMTSD
    whole_blood_samples = metadata_df[metadata_df['SMTSD'] == 'Whole Blood']
    # Filter on RNA integrity (SMRIN) to remove low quality samples
    healthy_samples = whole_blood_samples["SMRIN"] >= 7.0

    # Return SAMPIDs of healthy whole blood samples
    healthy_whole_blood_samples = healthy_samples.index.tolist()
    return healthy_whole_blood_samples

def align_gene_columns_generator(healthy_chunk_generator, unhealthy_chunk_generator,
                                gene_column_prefix="ENSG"):
    """
    Generator that aligns gene columns between healthy and unhealthy datasets.
    Handles datasets of different sizes by yielding None for exhausted iterators.

    Args:
        healthy_chunk_generator: Generator yielding healthy dataset chunks
        unhealthy_chunk_generator: Generator yielding unhealthy dataset chunks
        gene_column_prefix (str): Prefix to identify gene columns (default: "ENSG")

    Yields:
        tuple: (aligned_healthy_chunk, aligned_unhealthy_chunk) for each chunk pair
               None is yielded for exhausted datasets
    """

    print("Starting gene column alignment...")

    healthy_iter = iter(healthy_chunk_generator)
    unhealthy_iter = iter(unhealthy_chunk_generator)

    try:
        first_healthy_chunk = next(healthy_iter)
        first_unhealthy_chunk = next(unhealthy_iter)
    except StopIteration:
        print("ERROR: One or both datasets are empty!")
        return

    print("\nProcessing gene columns and stripping version suffixes...")
    def extract_gene_info(chunk, dataset_name):
        """Extract gene columns and create mapping from original_column to base_id"""
        print(f"Processing {dataset_name} dataset...")

        original_to_base = {}
        base_gene_ids = set()
        total_gene_columns = 0
        duplicate_count = 0

        for column in chunk.columns:
            if isinstance(column, str) and column.startswith(gene_column_prefix):
                total_gene_columns += 1
                base_gene_id = column.split('.')[0]

                # Handle duplicates (keep first occurrence)
                if base_gene_id not in base_gene_ids:
                    original_to_base[column] = base_gene_id
                    base_gene_ids.add(base_gene_id)
                else:
                    duplicate_count += 1
                    print(f"Warning: Duplicate base gene {base_gene_id} found, skipping {column}")

        print(f"Total gene columns found: {total_gene_columns}")
        print(f"Unique base gene IDs: {len(base_gene_ids)}")
        if duplicate_count > 0:
            print(f"Duplicates removed: {duplicate_count}")

        return original_to_base, base_gene_ids

    healthy_rename_mapping, healthy_base_genes = extract_gene_info(first_healthy_chunk, "HEALTHY")
    unhealthy_rename_mapping, unhealthy_base_genes = extract_gene_info(first_unhealthy_chunk, "UNHEALTHY")

    common_base_genes = healthy_base_genes & unhealthy_base_genes

    if not common_base_genes:
        print("ERROR: No common genes found between datasets!")
        return

    print(f"Common genes: {len(common_base_genes)}")
    print(f"Genes exclusive to healthy dataset: {len(healthy_base_genes - common_base_genes)}")
    print(f"Genes exclusive to unhealthy dataset: {len(unhealthy_base_genes - common_base_genes)}")

    renamed_first_healthy = first_healthy_chunk.rename(columns=healthy_rename_mapping)
    renamed_first_unhealthy = first_unhealthy_chunk.rename(columns=unhealthy_rename_mapping)

    healthy_non_gene_cols = [col for col in renamed_first_healthy.columns
                           if not (isinstance(col, str) and col.startswith(gene_column_prefix))]
    unhealthy_non_gene_cols = [col for col in renamed_first_unhealthy.columns
                             if not (isinstance(col, str) and col.startswith(gene_column_prefix))]

    # Validate that non-gene columns match
    if set(healthy_non_gene_cols) != set(unhealthy_non_gene_cols):
        print("Warning: Non-gene columns differ between datasets!")
        print(f"Healthy only: {set(healthy_non_gene_cols) - set(unhealthy_non_gene_cols)}")
        print(f"Unhealthy only: {set(unhealthy_non_gene_cols) - set(healthy_non_gene_cols)}")

    non_gene_cols = healthy_non_gene_cols
    common_gene_base_ids = sorted(common_base_genes)
    final_columns = non_gene_cols + common_gene_base_ids

    print("\nProcessing and yielding aligned chunks...")
    def process_chunk(chunk, rename_mapping):
        """Process a single chunk: rename and select columns"""
        if chunk is None:
            return None

        renamed_chunk = chunk.rename(columns=rename_mapping)
        aligned_chunk = renamed_chunk[final_columns].copy()

        return aligned_chunk

    aligned_first_healthy = process_chunk(first_healthy_chunk, healthy_rename_mapping)
    aligned_first_unhealthy = process_chunk(first_unhealthy_chunk, unhealthy_rename_mapping)

    yield aligned_first_healthy, aligned_first_unhealthy

    chunk_count = 1
    healthy_exhausted = False
    unhealthy_exhausted = False

    while not (healthy_exhausted and unhealthy_exhausted):
        chunk_count += 1

        healthy_chunk = None
        if not healthy_exhausted:
            try:
                healthy_chunk = next(healthy_iter)
            except StopIteration:
                healthy_exhausted = True
                print(f"Healthy dataset exhausted after {chunk_count-1} chunks")

        unhealthy_chunk = None
        if not unhealthy_exhausted:
            try:
                unhealthy_chunk = next(unhealthy_iter)
            except StopIteration:
                unhealthy_exhausted = True
                print(f"Unhealthy dataset exhausted after {chunk_count-1} chunks")

        if healthy_chunk is None and unhealthy_chunk is None:
            break

        aligned_healthy = process_chunk(healthy_chunk, healthy_rename_mapping)
        aligned_unhealthy = process_chunk(unhealthy_chunk, unhealthy_rename_mapping)

        yield aligned_healthy, aligned_unhealthy

    print(f"\nAlignment complete!")

    # Clean up
    if 'healthy_iter' in locals():
        del healthy_iter
    if 'unhealthy_iter' in locals():
        del unhealthy_iter
    if 'first_healthy_chunk' in locals():
        del first_healthy_chunk
    if 'first_unhealthy_chunk' in locals():
        del first_unhealthy_chunk
    gc.collect()


def merge_datasets_generator(aligned_pair_iterator):
    """
    Generator that merges healthy and unhealthy dataset chunks into single merged chunks.

    Args:
        aligned_chunk_pairs_generator: Generator yielding (healthy_chunk, unhealthy_chunk) tuples
        chunk_size (int, optional): Target size for merged chunks. If None, merges each pair directly.

    Yields:
        pandas.DataFrame: Merged chunks containing both healthy and unhealthy data
    """

    for healthy_chunk, unhealthy_chunk in aligned_pair_iterator:
        chunks_to_merge = []

        if healthy_chunk is not None:
            chunks_to_merge.append(healthy_chunk)

        if unhealthy_chunk is not None:
            chunks_to_merge.append(unhealthy_chunk)

        if chunks_to_merge:
            # Concatenate available chunks
            merged_chunk = pd.concat(chunks_to_merge)
            yield merged_chunk
            del merged_chunk

    # Clean up
    if 'aligned_pair_iterator' in locals():
        del aligned_pair_iterator
    if 'healthy_chunk' in locals():
        del healthy_chunk
    if 'unhealthy_chunk' in locals():
        del unhealthy_chunk
    gc.collect()


def set_index_column(chunk_generator, column_name, drop=True):
    """
    Generator function that sets a specified column as the index for each DataFrame chunk.

    Args:
        chunk_generator: An iterable that yields pandas DataFrame chunks
        column_name: Name of the column to set as index
        drop: Whether to drop the column after setting it as index (default: True)

    Yields:
        pd.DataFrame: Chunks with the specified column set as index
    """
    for chunk in chunk_generator:
        if chunk.empty:
            yield chunk
            continue

        if column_name not in chunk.columns:
            yield chunk
            continue

        chunk_with_index = chunk.set_index(column_name, drop=drop)

        yield chunk_with_index


def get_healthy_whole_blood_samples(metadata_path: str) -> list[str]:
    """
    Get healthy whole blood samples from the metadata.

    Args:
        metadata_path (str): Path to the metadata CSV file
        gtex_blood_ids (list[str]): List of GTEx whole blood sample IDs

    Returns:
        list[str]: List of healthy whole blood sample IDs
    """
    metadata_df = pd.read_csv(metadata_path, index_col=0)

    # Get all samples with 'Whole Blood' SMTSD
    whole_blood_samples = metadata_df[metadata_df['SMTSD'] == 'Whole Blood']
    # Filter on RNA integrity (SMRIN) to remove low quality samples
    healthy_samples = whole_blood_samples["SMRIN"] >= 7.0

    # Return SAMPIDs of healthy whole blood samples
    healthy_whole_blood_samples = healthy_samples.index.tolist()
    return healthy_whole_blood_samples

# Preprocessing Unhealthy Dataset

- Load `.csv` file as `pandas.DataFrame`
- Set name of index as `sample_id` (Had no index name)
- Remove duplicates and NaNs by dropping rows containing them
- Transpose dataset
- Set index to `'sample_id'` column
- Remove genes that have a total sum of 0 recorded across all patients
- Transpose dataset
- Set index to `'sample_id'` column
- Save preprocessed dataset as `.pq`

In [None]:
unhealthy_dataset_file = 'data/rna_seq_unstranded.csv'
unhealthy_output_file = 'data/unhealthy_data_preprocessed.pq'
chunk_size = 8000

chunk_iterator = load_csv_in_chunks(unhealthy_dataset_file, header=0, skiprows=[1, 2], index_col=0)
print(f"Starting preprocessing for Unhealthy Dataset: {unhealthy_dataset_file}")

chunk_iterator = rename_index(chunk_iterator, 'sample_id')
chunk_iterator = clean_duplicate_nans(chunk_iterator)
chunk_iterator = transpose_dataframe_chunks(chunk_iterator, output_batch_size=chunk_size) # Also converts dtypes to uint32
chunk_iterator = set_index_column(chunk_iterator, 'sample_id')
chunk_iterator = filter_rows(chunk_iterator)
chunk_iterator = transpose_dataframe_chunks(chunk_iterator, output_batch_size=chunk_size)
chunk_iterator = set_index_column(chunk_iterator, 'sample_id')

save_data_as_parquet(chunk_iterator, unhealthy_output_file)

# Clean up
gc.collect()

### Display basic unhealthy dataset info

In [None]:
unhealthy_iterator = load_parquet_in_chunks('data/unhealthy_data_preprocessed.pq', chunk_size=chunk_size)

if unhealthy_iterator:
    first_chunk = next(unhealthy_iterator)

    # Basic info
    print(f"First chunk shape: {first_chunk.shape}")
    print(f"Columns: {list(first_chunk.columns)}")
    print(f"Data types:\n{first_chunk.dtypes}")
    print(f"Memory usage: {first_chunk.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print("\nFirst 5 rows:")
    display(first_chunk.head())

# Preprocess healthy dataset
## Convert .gct to .pq (parquet)

- Load `.gct` file as `pandas.DataFrame` using `load_csv_in_chunks`
- Store `pandas.DatgaFrame` as `.pq` file

In [None]:
dataset_file = 'data/GTEx_Analysis_2022-06-06_v10_RNASeQCv2.4.2_gene_reads.gct'
output_file = 'data/healthy_data.pq'
chunk_size = 3000

gct_chunk_iterator = load_csv_in_chunks(
    file_path=dataset_file,
    chunk_size=chunk_size,
    sep='\t',
    skiprows=2,
    header=0,
    index_col=0,
)

if gct_chunk_iterator:
    print("Saving GCT dataset as pickle...")

    save_data_as_parquet(
        chunk_iterator=gct_chunk_iterator,
        output_parquet_path=output_file
    )
else:
    print(f"Failed to load {dataset_file} using load_csv_in_chunks.")

# Clean up
if 'gct_chunk_iterator' in locals() or 'gct_chunk_iterator' in globals():
    del gct_chunk_iterator
if 'dataset_file' in locals() or 'dataset_file' in globals():
    del dataset_file
if 'output_file' in locals() or 'output_file' in globals():
    del output_file
if 'read_chunk_size' in locals() or 'read_chunk_size' in globals():
    del chunk_size

gc.collect()

- Load dataset from `.pq` file
- Rename index to `'gene_id'`
- Remove genes that have a total sum of 0 recorded across all patients
- Transpose dataset
- Set index to `'sample_id'` column
- Remove duplicates and NaNs by dropping rows containing them
- Drop samples that don't use whole blood samples
- Save preprocessed dataset as `.pq`

In [None]:
healthy_data_input = 'data/healthy_data.pq'
healthy_data_output = 'data/healthy_data_preprocessed.pq'
feature_id_col = 'Name'
exclude_cols = ['Description']
chunk_size = 3000

# metadata_path = 'data/gtex_blood_samples_metadata.csv'
# whole_blood_samples = get_healthy_whole_blood_samples(metadata_path)
# print(f"Found {len(whole_blood_samples)} healthy whole blood samples in metadata")

whole_blood_samples = ['GTEX-1117F-0005-SM-HL9SH', 'GTEX-1122O-0526-SM-5N9DM', 'GTEX-1122O-0826-SM-5GICV', 'GTEX-117YX-0526-SM-5EGJH', 'GTEX-11DYG-0011-R10b-SM-DNZZO']

chunk_iterator = load_parquet_in_chunks(
    file_path=healthy_data_input,
    chunk_size=chunk_size
)

chunk_iterator = rename_index(chunk_iterator, 'gene_id')
chunk_iterator = drop_dataframe_chunks(chunk_generator=chunk_iterator, drop_columns=['Description'])
chunk_iterator = filter_rows(chunk_iterator)
chunk_iterator = transpose_dataframe_chunks(chunk_generator=chunk_iterator, output_batch_size=chunk_size)
chunk_iterator = set_index_column(chunk_iterator, 'sample_id', drop=True)
chunk_iterator = clean_duplicate_nans(chunk_iterator)
chunk_iterator = keep_dataframe_chunks(chunk_iterator, keep_rows=whole_blood_samples)

save_data_as_parquet(
    chunk_iterator=chunk_iterator,
    output_parquet_path=healthy_data_output
)

# Clean up
if 'chunk_iterator' in locals() or 'chunk_iterator' in globals():
    del chunk_iterator
if 'healthy_data_input' in locals() or 'healthy_data_input' in globals():
    del healthy_data_input
if 'healthy_data_output' in locals() or 'healthy_data_output' in globals():
    del healthy_data_output
if 'feature_id_col' in locals() or 'feature_id_col' in globals():
    del feature_id_col
if 'exclude_cols' in locals() or 'exclude_cols' in globals():
    del exclude_cols
if 'chunk_size' in locals() or 'chunk_size' in globals():
    del chunk_size
gc.collect()

### Display basic healthy dataset info

In [None]:
unhealthy_iterator = load_parquet_in_chunks('data/healthy_data_preprocessed.pq', chunk_size=5000)

if unhealthy_iterator:
    first_chunk = next(unhealthy_iterator)

    # Basic info
    print(f"First chunk shape: {first_chunk.shape}")
    print(f"Columns: {list(first_chunk.columns)}")
    print(f"Data types:\n{first_chunk.dtypes}")
    print(f"Memory usage: {first_chunk.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print("\nFirst 5 rows:")
    display(first_chunk.head())

# Prepare and merge datasets

- Load datasets from `.pq` files
- Align datasets by removing gene versioning suffix from Gene IDs, removing duplicate genes, keeping the common (intersection) genes, sorting the common genes in same order
- Generate custom metadata (sample_id, condition) based on the datasets
- Merge datasets into single dataset file
- Save merged dataset as `.pq` file

In [None]:
healthy_data_path = 'data/healthy_data_preprocessed.pq'
unhealthy_data_path = 'data/unhealthy_data_preprocessed.pq'
chunk_size = 3000


healthy_iterator = load_parquet_in_chunks(healthy_data_path, chunk_size=chunk_size)
unhealthy_iterator = load_parquet_in_chunks(unhealthy_data_path, chunk_size=chunk_size)

aligned_pair_iterator = align_gene_columns_generator(healthy_iterator, unhealthy_iterator)
aligned_pair_iterator = prepare_metadata_generator(aligned_pair_iterator, output_metadata_path="data/merged_metadata.pq")
merged_chunk_iterator = merge_datasets_generator(aligned_pair_iterator)

save_data_as_parquet(chunk_iterator=merged_chunk_iterator, output_parquet_path='data/merged_dataset.pq')

# Clean up
if 'healthy_iterator' in locals() or 'healthy_iterator' in globals():
    del healthy_iterator
if 'unhealthy_iterator' in locals() or 'unhealthy_iterator' in globals():
    del unhealthy_iterator
if 'aligned_pair_iterator' in locals() or 'aligned_pair_iterator' in globals():
    del aligned_pair_iterator
if 'merged_chunk_iterator' in locals() or 'merged_chunk_iterator' in globals():
    del merged_chunk_iterator

gc.collect()

### Display basic merged dataset info

In [None]:
healthy_data_path = 'data/healthy_data_preprocessed.csv'
unhealthy_data_path = 'data/unhealthy_data_preprocessed.csv'
output_healthy_path = 'data/healthy_data_aligned.csv'
output_unhealthy_path = 'data/unhealthy_data_aligned.csv'
chunk_size = 2000

merged_iterator = load_parquet_in_chunks('data/merged_dataset.pq', chunk_size=5000)

if merged_iterator:
    first_chunk = next(merged_iterator)

    # Basic info
    print(f"First chunk shape: {first_chunk.shape}")
    print(f"Columns: {list(first_chunk.columns)}")
    print(f"Data types:\n{first_chunk.dtypes}")
    print(f"Memory usage: {first_chunk.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print("\nFirst 5 rows:")
    display(first_chunk.head())

# Clean up
if 'merged_iterator' in locals() or 'merged_iterator' in globals():
    del merged_iterator
if 'first_chunk' in locals() or 'first_chunk' in globals():
    del first_chunk

gc.collect()

### Verify metadata

Double check if generated metadata is correct

In [None]:
healthy_iterator = load_parquet_in_chunks('data/healthy_data_preprocessed.pq', chunk_size=5000)
unhealthy_iterator = load_parquet_in_chunks('data/unhealthy_data_preprocessed.pq', chunk_size=5000)
metadata_iterator = load_parquet_in_chunks('data/merged_metadata.pq', chunk_size=5000)

check_metadata(healthy_generator=healthy_iterator, unhealthy_generator=unhealthy_iterator, merged_metadata_generator=metadata_iterator)