# Imports and modular functions

In [None]:
from multiprocessing import Lock, Process
import os
import pandas as pd
import mygene
import csv
import pickle
import random
import gc

def load_csv_in_chunks(file_path, chunk_size=8000, **kwargs):
    """
    Loads a CSV file in chunks to avoid memory issues.

    Args:
        file_path (str): The path to the CSV file.
        chunk_size (int): The number of rows per chunk.
        **kwargs: Additional keyword arguments to pass to pd.read_csv()
                  (e.g., sep=',', header=0, index_col=None, usecols=None).

    Returns:
        A pandas TextFileReader object (iterator) that yields DataFrame chunks
        if the file exists, otherwise None.
    """
    
    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None
    
    print(f"Preparing to load {file_path} in chunks of size {chunk_size}...")
    try:
        chunk_iterator = pd.read_csv(file_path, chunksize=chunk_size, **kwargs)
        return chunk_iterator
    except Exception as e:
        print(f"Error loading CSV file {file_path}: {e}")
        return None


def load_pickle_in_chunks(file_path, chunk_size=8000):
    """
    Loads a pickle file containing a pandas DataFrame in chunks to avoid memory issues.

    Args:
        file_path (str): The path to the pickle file.
        chunk_size (int): The number of rows per chunk.

    Yields:
        pandas.DataFrame: DataFrame chunks of the specified size.
    """
    
    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None
    
    print(f"Preparing to load {file_path} in chunks of size {chunk_size}...")
    try:
        df = pd.read_pickle(file_path)
        print(f"Loaded DataFrame with {len(df)} rows. Yielding chunks...", flush=True)
        
        # Yield chunks of the DataFrame
        for start_idx in range(0, len(df), chunk_size):
            end_idx = min(start_idx + chunk_size, len(df))
            chunk = df.iloc[start_idx:end_idx].copy()
            yield chunk
            
    except Exception as e:
        print(f"Error loading pickle file {file_path}: {e}")
        return None


def save_data_to_csv(data_to_save, output_file_path, delete_original_path=False, index=False, **kwargs):
    """
    Saves a pandas DataFrame or an iterator of DataFrame chunks to a CSV file.
    Optionally deletes the original file after successful saving.

    Args:
        data_to_save (pd.DataFrame or iterator): DataFrame or iterator yielding DataFrames.
        output_file_path (str): Path to save the new CSV file.
        delete_original_path (str, optional): Path to the original file to delete. Defaults to False.
        index (bool, optional): Whether to write DataFrame index. Defaults to False.
        **kwargs: Additional keyword arguments to pass to df.to_csv().
    """
    mode = kwargs.pop('mode', 'w')
    header = kwargs.pop('header', True)

    try:
        if isinstance(data_to_save, pd.DataFrame):
            print(f"Saving DataFrame to {output_file_path}...")
            data_to_save.to_csv(output_file_path, index=index, mode=mode, header=header, **kwargs)
        else:
            print(f"Saving data in chunks to {output_file_path}...")
            for i, chunk_df in enumerate(data_to_save):
                chunk_mode = 'w' if i == 0 else 'a'
                chunk_header = header if i == 0 else False
                
                chunk_df.to_csv(output_file_path, index=index, mode=chunk_mode, header=chunk_header, **kwargs)
                
                if i == 0:
                    print(f"Written first chunk to {output_file_path}")
                else:
                    print(f"Appended chunk {i+1} to {output_file_path}")
        
        print(f"Successfully saved data to {output_file_path}")
        
        # Handle file deletion if requested
        if delete_original_path and delete_original_path != output_file_path:
            if os.path.exists(delete_original_path):
                os.remove(delete_original_path)
                print(f"Successfully deleted original file: {delete_original_path}")
            else:
                print(f"Warning: Original file not found at {delete_original_path}")
        elif delete_original_path == output_file_path:
            print(f"Warning: Original and output paths are the same. Original file not deleted.")

    except Exception as e:
        print(f"Error saving data to {output_file_path}: {e}")
        if os.path.exists(output_file_path):
            print(f"Partial data might have been written to {output_file_path}.")
        raise


def save_data_as_pickle(chunk_iterator, output_pickle_path, process_func=None):
    """
    Process CSV chunks and save the final result as a pickle file.
    Uses pandas' built-in to_pickle() method for better performance.
    
    Args:
        chunk_iterator: Iterator yielding DataFrame chunks from load_csv_in_chunks
        output_pickle_path (str): Path to save the pickle file
        process_func (callable, optional): Function to apply to each chunk
    """
    print(f"Processing chunks and saving to {output_pickle_path}...")

    processed_chunks = []
    for i, chunk in enumerate(chunk_iterator):
        if process_func:
            chunk = process_func(chunk)
        
        if chunk.empty:
            continue
            
        processed_chunks.append(chunk)
    
    if processed_chunks:
        print("Concatenating all processed chunks...")
        final_df = pd.concat(processed_chunks, axis=0, ignore_index=False)
        
        print(f"Saving final dataset to {output_pickle_path}...")
        final_df.to_pickle(output_pickle_path)
        
        print(f"✅ Successfully saved {final_df.shape[0]} rows, {final_df.shape[1]} columns to pickle")
        
        del processed_chunks, final_df
        gc.collect()
        
    else:
        print("No data to save!")


def transpose_csv(input_csv, output_csv, feature_id_col='sample_id', exclude_cols=[], 
                           batch_size=1000):
    """
    Optimized version of transpose_csv that reduces file reads and improves performance.
    
    Args:
        input_csv (str): Path to the input CSV file
        output_csv (str): Path where the transposed CSV will be saved
        feature_id_col (str): Name of the column containing feature IDs (e.g., gene names)
        exclude_cols (list): List of column names to exclude from transposition
        batch_size (int): Number of samples to process in each batch
    """
    print(f"Reading header and identifying sample columns...")
    with open(input_csv, 'r') as f:
        header = next(csv.reader(f))
    
    # Find indices for sample columns and the feature_id column
    sample_cols = [i for i, col in enumerate(header) if col not in [feature_id_col] + exclude_cols]
    sample_ids = [header[i] for i in sample_cols]
    feature_id_idx = header.index(feature_id_col)
    print(f"    Found {len(sample_cols)} samples to transpose.")
    
    # First pass: collect gene IDs
    print(f"Reading gene IDs from column '{feature_id_col}'...")
    gene_ids = []
    with open(input_csv, 'r') as f:
        reader = csv.reader(f)
        next(reader)  # Skip header
        for row in reader:
            if row and len(row) > feature_id_idx:
                gene_ids.append(row[feature_id_idx])
    
    print(f"    Found {len(gene_ids)} genes!")
    
    # Create output file with header
    with open(output_csv, 'w', newline='') as out_file:
        writer = csv.writer(out_file)
        writer.writerow(['sample_id'] + gene_ids)
    
    # Process samples in batches to balance memory usage and speed
    print(f"Writing data to {output_csv} in batches of {batch_size} samples...")
    for batch_start in range(0, len(sample_cols), batch_size):
        batch_end = min(batch_start + batch_size, len(sample_cols))
        current_batch_cols = sample_cols[batch_start:batch_end]
        current_batch_ids = sample_ids[batch_start:batch_end]
        
        print(f"    Processing batch {batch_start//batch_size + 1}/{(len(sample_cols)-1)//batch_size + 1} "
              f"(samples {batch_start+1}-{batch_end})...")
        
        # For each batch, read the input file once and extract all columns for this batch
        sample_values = [[] for _ in range(len(current_batch_cols))]
        
        with open(input_csv, 'r') as f:
            reader = csv.reader(f)
            next(reader)  # Skip header
            for row in reader:
                if row:
                    # For each gene row, extract values for all samples in current batch
                    for i, col_idx in enumerate(current_batch_cols):
                        if col_idx < len(row):
                            sample_values[i].append(row[col_idx])
                        else:
                            sample_values[i].append("")
        
        # Now write each sample in the batch to the output file
        with open(output_csv, 'a', newline='') as out_file:
            writer = csv.writer(out_file)
            for i, sample_id in enumerate(current_batch_ids):
                writer.writerow([sample_id] + sample_values[i])
        
        # Force garbage collection after writing each batch
        sample_values = None
        gc.collect()
    
    print(f"🥶✅ Transposed {len(sample_ids)} samples × {len(gene_ids)} genes!")


def convert_gene_ids_to_symbols(dataset_chunk_iterator, mygene_client, gene_column_prefix="ENSG"):
    """
    Convert gene IDs to gene symbols using MyGene API in a memory-efficient way.
    Only calls MyGene API once for all gene columns from the first chunk.
    
    Args:
        dataset_chunk_iterator: Iterator yielding DataFrame chunks
        mygene_client: MyGene client instance
        gene_column_prefix (str): Prefix to identify gene columns
    
    Returns:
        pd.DataFrame: Complete dataset with gene symbols as column names
    """
    print("Converting gene IDs to symbols (chunked processing)...")
    
    # Get first chunk to determine gene columns and create mapping
    try:
        first_chunk = next(dataset_chunk_iterator)
        print(f"Processing first chunk: {first_chunk.shape}")
    except StopIteration:
        print("❌ Error: Dataset is empty")
        return pd.DataFrame()
    
    # Identify gene columns from first chunk
    gene_columns = [col for col in first_chunk.columns 
                   if isinstance(col, str) and col.startswith(gene_column_prefix)]
    
    if not gene_columns:
        print("No gene columns found - returning original data")
        # Concatenate all chunks and return
        all_chunks = [first_chunk]
        for chunk in dataset_chunk_iterator:
            all_chunks.append(chunk)
        return pd.concat(all_chunks, axis=0, ignore_index=False)
    
    print(f"Found {len(gene_columns)} gene columns")
    
    # Single MyGene API call for all gene IDs
    gene_id_to_symbol = {}
    symbols_found = 0
    
    try:
        print("Making single MyGene API call...")
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            results = mygene_client.querymany(
                gene_columns,
                scopes='ensembl.gene',
                fields='symbol',
                species='human',
                verbose=False,
                silent=True
            )
        
        # Process API results
        for result in results:
            gene_id = result['query']
            if 'symbol' in result and result['symbol']:
                gene_id_to_symbol[gene_id] = result['symbol']
                symbols_found += 1
            else:
                gene_id_to_symbol[gene_id] = gene_id
        
        print(f"✅ Successfully converted {symbols_found}/{len(gene_columns)} genes to symbols")
        
    except Exception as e:
        print(f"⚠️ MyGene API error: {e}. Using original gene IDs")
        gene_id_to_symbol = {gene_id: gene_id for gene_id in gene_columns}
    
    # Create final column mapping (preserve non-gene columns like 'condition')
    final_column_mapping = {}
    for col in first_chunk.columns:
        if col == 'condition':
            final_column_mapping[col] = col
        elif col in gene_id_to_symbol:
            final_column_mapping[col] = gene_id_to_symbol[col]
        else:
            final_column_mapping[col] = col
    
    print("Applying gene symbol mapping to all chunks...")
    
    # Process first chunk
    renamed_first_chunk = first_chunk.rename(columns=final_column_mapping)
    processed_chunks = [renamed_first_chunk]
    chunk_count = 1
    
    # Process remaining chunks with same mapping
    for chunk in dataset_chunk_iterator:
        renamed_chunk = chunk.rename(columns=final_column_mapping)
        processed_chunks.append(renamed_chunk)
        chunk_count += 1
        
        if chunk_count % 3 == 0:
            print(f"  ✓ Processed {chunk_count} chunks...")
    
    print(f"Concatenating {chunk_count} processed chunks...")
    final_dataset = pd.concat(processed_chunks, axis=0, ignore_index=False)
    
    print(f"✅ Final dataset shape: {final_dataset.shape}")
    print(f"   Symbols converted: {symbols_found}/{len(gene_columns)}")
    
    return final_dataset


def add_condition_labels_to_chunks(chunk_iterator, condition_label, dataset_name):
    """
    Add condition labels to dataset chunks.
    
    Args:
        chunk_iterator: Iterator yielding DataFrame chunks
        condition_label (int): Binary label (0 for healthy, 1 for unhealthy)
        dataset_name (str): Name for logging purposes
    
    Returns:
        list: List of labeled DataFrame chunks
    """
    print(f"Adding label '{condition_label}' to {dataset_name} dataset...")
    
    labeled_chunks = []
    for chunk in chunk_iterator:
        chunk['condition'] = condition_label
        labeled_chunks.append(chunk)
    
    print(f"✅ Completed {len(labeled_chunks)} {dataset_name} chunks")
    return labeled_chunks


def merge_labeled_datasets(healthy_chunks, unhealthy_chunks):
    """
    Merge healthy and unhealthy dataset chunks into one DataFrame.
    
    Args:
        healthy_chunks (list): List of healthy DataFrame chunks
        unhealthy_chunks (list): List of unhealthy DataFrame chunks
    
    Returns:
        pd.DataFrame: Merged dataset
    """
    print("Merging datasets...")
    
    all_chunks = healthy_chunks + unhealthy_chunks
    merged_dataset = pd.concat(all_chunks, axis=0, ignore_index=False)
    
    print(f"✅ Merged dataset: {len(merged_dataset)} samples, {len(merged_dataset.columns)} features")
    print(f"   Healthy (0): {(merged_dataset['condition'] == 0).sum()}")
    print(f"   Unhealthy (1): {(merged_dataset['condition'] == 1).sum()}")
    
    return merged_dataset


def clean_duplicate_nans(chunk_iterator):
    """
    Processes an iterator of DataFrame chunks to drop duplicates and NaNs.
    Yields cleaned DataFrame chunks one at a time.
    """
    print("Cleaning chunks by dropping NaNs and duplicates...", flush=True)

    chunks_processed = 0
    for i, chunk_df in enumerate(chunk_iterator):
        original_rows = len(chunk_df)

        chunk_df = chunk_df.dropna()

        if len(chunk_df) < original_rows:
            print(f"    Chunk {i+1}: Dropped {original_rows - len(chunk_df)} rows with null values.")

        if chunk_df.empty:
            print(f"    Chunk {i+1}: Empty after dropping NaNs, skipping...", flush=True)
            continue

        rows_before_dedup = len(chunk_df)
        chunk_df = chunk_df.drop_duplicates()

        if rows_before_dedup > len(chunk_df):
            print(f"    Chunk {i+1}: Dropped {rows_before_dedup - len(chunk_df)} duplicate rows.")

        if not chunk_df.empty:
            chunks_processed += 1
            print(f"    Chunk {i+1}: Yielded chunk with shape {chunk_df.shape}")
            yield chunk_df
        else:
            print(f"    Chunk {i+1}: Empty after deduplication, skipping...")

    print(f"Finished processing. {chunks_processed} non-empty chunks processed.")


def rename_index(chunk_iterator, index_name):
    """
    Rename the index of DataFrame chunks.
    """
    for chunk in chunk_iterator:
        chunk.index.name = index_name
        yield chunk


def filter_rows(chunk_iterator):
    """
    Processes an iterator of DataFrame chunks to filter out rows where all data columns are zero.
    Yields filtered DataFrame chunks one at a time.
    """
    print("Filtering rows where sum of all values equals 0...", flush=True)
    total_rows_removed = 0
    chunk_count = 0
    
    for chunk_df in chunk_iterator:
        if chunk_df.empty:
            continue

        chunk_count += 1
        original_rows = len(chunk_df)
        
        row_sums = chunk_df.sum(axis=1)
        filtered_chunk = chunk_df[row_sums != 0]
        
        rows_removed = original_rows - len(filtered_chunk)
        total_rows_removed += rows_removed
        
        if rows_removed > 0:
            print(f"  Chunk {chunk_count}: Removed {rows_removed} zero-sum rows ({len(filtered_chunk)} remaining)")
        else:
            print(f"  Chunk {chunk_count}: No zero-sum rows found ({len(filtered_chunk)} rows)")
        
        if not filtered_chunk.empty:
            yield filtered_chunk
    
    print(f"Filtering complete: {total_rows_removed} total zero-sum rows removed across {chunk_count} chunks")


def prepare_metadata(healthy_dataset_path: str, unhealthy_dataset_path: str) -> list[str] | None:
    output_path = "data/merged_metadata_simple.csv"

    with open(output_path, 'w', newline='') as csvfile:
        print("Creating metadata file")
        writer = csv.writer(csvfile, delimiter=",")
        writer.writerow(['sampid', 'condition'])

    def do_process(dataset_path: str, type_index: str, lock):
        print(f"Starting process on extracting metadata from: {dataset_path}")
        with open(dataset_path, mode='r') as f:
            reader = csv.reader(f, delimiter=",")

            # Skip header
            next(reader)

            metadata = [row[0] for row in reader]

            with lock:
                with open(output_path, mode="a", newline='') as f:
                    writer = csv.writer(f, delimiter=',')

                    for row in metadata:
                        writer.writerow([row, type_index])

    lock = Lock()

    p1 = Process(target=do_process, args=(healthy_dataset_path, "healthy", lock))
    p2 = Process(target=do_process, args=(unhealthy_dataset_path, "unhealthy", lock))

    p1.start()
    p2.start()

    p1.join()
    p2.join()


def check_metadata(healthy_dataset_path: str, unhealthy_dataset_path: str, merged_metadata_path: str) -> None:
    print(f"Checking metadata...")
    healthy_dataset_line_count = 0
    unhealthy_dataset_line_count = 0

    with open(healthy_dataset_path, 'r') as file1:
        reader = csv.reader(file1)
        healthy_dataset_line_count = sum(1 for row in reader) -1

        with open(unhealthy_dataset_path, 'r') as file2:
            reader = csv.reader(file2)
            unhealthy_dataset_line_count = sum(1 for row in reader) - 1

            total = healthy_dataset_line_count + unhealthy_dataset_line_count

            with open(merged_metadata_path, 'r') as file3:
                reader = csv.reader(file3)
                metadata_line_count = sum(1 for row in reader) - 1

                if (total != metadata_line_count):
                    raise Exception(f"Healthy line count: {healthy_dataset_line_count} | Unhealthy line count: {unhealthy_dataset_line_count} | Total line count: {total} | Metadata line count: {metadata_line_count}")
                else:
                    print("All good")
                    print(f"Stats: Healthy line count: {healthy_dataset_line_count} | Unhealthy line count: {unhealthy_dataset_line_count} | Total line count: {total} | Metadata line count: {metadata_line_count}")

def get_healthy_whole_blood_samples(metadata_path: str) -> list[str]:
    """
    Get healthy whole blood samples from the metadata.

    Args:
        metadata_path (str): Path to the metadata CSV file
        gtex_blood_ids (list[str]): List of GTEx whole blood sample IDs

    Returns:
        list[str]: List of healthy whole blood sample IDs
    """
    metadata_df = pd.read_csv(metadata_path, index_col=0)

    # Get all samples with 'Whole Blood' SMTSD
    whole_blood_samples = metadata_df[metadata_df['SMTSD'] == 'Whole Blood']
    # Filter on RNA integrity (SMRIN) to remove low quality samples
    healthy_samples = whole_blood_samples["SMRIN"] >= 7.0

    # Return SAMPIDs of healthy whole blood samples
    healthy_whole_blood_samples = healthy_samples.index.tolist()
    return healthy_whole_blood_samples

In [None]:
def align_gene_columns_simple(healthy_chunk_iterator, unhealthy_chunk_iterator, 
                              gene_column_prefix="ENSG"):
    """
    Aligns gene columns between healthy and unhealthy datasets.
    
    Args:
        healthy_chunk_iterator: Iterator for healthy dataset chunks
        unhealthy_chunk_iterator: Iterator for unhealthy dataset chunks
        gene_column_prefix (str): Prefix to identify gene columns (default: "ENSG")
    
    Returns:
        tuple: (aligned_healthy_chunks, aligned_unhealthy_chunks, alignment_info)
    """
    
    print("Starting gene column alignment...")
    
    print("\nLoading first chunks from both datasets...")
    try:
        first_healthy_chunk = next(healthy_chunk_iterator)
        print(f"   ✓ Loaded healthy dataset first chunk: {first_healthy_chunk.shape}")
        
        first_unhealthy_chunk = next(unhealthy_chunk_iterator)
        print(f"   ✓ Loaded unhealthy dataset first chunk: {first_unhealthy_chunk.shape}")
        
    except StopIteration:
        print("ERROR: One or both datasets are empty!")
        return [], [], {}
    

    print("\nProcessing gene columns and stripping version suffixes...")
    def extract_gene_info(chunk, dataset_name):
        """Extract gene columns and create mapping from original_column to base_id"""

        print(f"Processing {dataset_name} dataset...")

        original_to_base = {}
        base_gene_ids = set()
        total_gene_columns = 0
        duplicate_count = 0
        
        for column in chunk.columns:
            if isinstance(column, str) and column.startswith(gene_column_prefix):
                total_gene_columns += 1
                base_gene_id = column.split('.')[0]
                
                # Handle duplicates (keep first occurrence)
                if base_gene_id not in base_gene_ids:
                    original_to_base[column] = base_gene_id
                    base_gene_ids.add(base_gene_id)
                else:
                    duplicate_count += 1
                    print(f"Warning: Duplicate base gene {base_gene_id} found, skipping {column}")
        
        print(f"Total gene columns found: {total_gene_columns}")
        print(f"Unique base gene IDs: {len(base_gene_ids)}")
        if duplicate_count > 0:
            print(f"Duplicates removed: {duplicate_count}")

        return original_to_base, base_gene_ids
    
    # Process both datasets
    healthy_rename_mapping, healthy_base_genes = extract_gene_info(first_healthy_chunk, "HEALTHY")
    unhealthy_rename_mapping, unhealthy_base_genes = extract_gene_info(first_unhealthy_chunk, "UNHEALTHY")

    print(f"\nSummary:")
    print(f"   Healthy dataset: {len(healthy_base_genes)} unique genes")
    print(f"   Unhealthy dataset: {len(unhealthy_base_genes)} unique genes")

    print("\nFinding common genes between datasets...")
    common_base_genes = healthy_base_genes & unhealthy_base_genes

    if not common_base_genes:
        print("ERROR: No common genes found between datasets!")
        return [], [], {}

    print(f"Common genes: {len(common_base_genes)}")
    print(f"Genes exclusive to healthy dataset: {len(healthy_base_genes - common_base_genes)}")
    print(f"Genes exclusive to unhealthy dataset: {len(unhealthy_base_genes - common_base_genes)}")

    print("\nProcessing all chunks and renaming to base gene IDs...")

    def process_all_chunks(chunk_iterator, first_chunk, rename_mapping, common_genes, dataset_name):
        """Process all chunks: rename gene columns to base IDs and keep only common genes"""
        print(f"   Processing {dataset_name} dataset chunks...")
        
        # Rename columns in first chunk
        renamed_first_chunk = first_chunk.rename(columns=rename_mapping)
        
        # Build final column list
        non_gene_cols = [col for col in renamed_first_chunk.columns 
                        if not (isinstance(col, str) and col.startswith(gene_column_prefix))]
        common_gene_base_ids = sorted(common_genes)  # Sorted base gene IDs
        final_columns = non_gene_cols + common_gene_base_ids
        
        print(f"      Non-gene columns: {len(non_gene_cols)}")
        print(f"      Common gene columns (base IDs): {len(common_gene_base_ids)}")
        print(f"      Total columns to keep: {len(final_columns)}")
        
        # Keep only desired columns
        aligned_chunks = [renamed_first_chunk[final_columns].copy()]
        chunk_count = 1
        print(f"✓ Processed chunk 1 - shape: {aligned_chunks[0].shape}")

        # Process remaining chunks
        for chunk in chunk_iterator:
            renamed_chunk = chunk.rename(columns=rename_mapping)
            aligned_chunks.append(renamed_chunk[final_columns].copy())
            chunk_count += 1
            
            if chunk_count % 3 == 0:
                print(f"✓ Processed {chunk_count} chunks so far...")

        print(f"{dataset_name} processing complete: {chunk_count} chunks processed")
        return aligned_chunks

    # Process both datasets
    print(f"\nProcessing HEALTHY dataset...")
    aligned_healthy = process_all_chunks(healthy_chunk_iterator, first_healthy_chunk,
                                    healthy_rename_mapping, common_base_genes, "HEALTHY")

    print(f"\nProcessing UNHEALTHY dataset...")
    aligned_unhealthy = process_all_chunks(unhealthy_chunk_iterator, first_unhealthy_chunk,
                                        unhealthy_rename_mapping, common_base_genes, "UNHEALTHY")

    return aligned_healthy, aligned_unhealthy


def merge_datasets(healthy_chunk_iterator, unhealthy_chunk_iterator, output_path, index=True):
    """
    Merges healthy and unhealthy dataset chunks directly into a single CSV file
    to be memory-efficient.
    
    Args:
        healthy_chunk_iterator (iterator): Iterator yielding healthy DataFrame chunks.
        unhealthy_chunk_iterator (iterator): Iterator yielding unhealthy DataFrame chunks.
        output_path (str): Path to save the merged CSV file.
        index (bool): Whether to write DataFrame index to CSV. Defaults to True.
    
    Returns:
        str: The path to the merged output file if successful, None otherwise.
    """
    print(f"Merging datasets directly to {output_path}...")
    
    if os.path.exists(output_path):
        os.remove(output_path)

    is_first_chunk_written_to_file = False

    try:
        if healthy_chunk_iterator:
            print("Processing healthy dataset chunks...")
            for i, chunk in enumerate(healthy_chunk_iterator):
                if not is_first_chunk_written_to_file:
                    chunk.to_csv(output_path, mode='w', header=True, index=index)
                    print(f"  Written first healthy chunk (chunk {i+1}) to {output_path}")
                    is_first_chunk_written_to_file = True
                else:
                    chunk.to_csv(output_path, mode='a', header=False, index=index)
                    print(f"  Appended healthy chunk (chunk {i+1}) to {output_path}")
        else:
            print("Warning: Healthy chunk iterator is None or empty.")

        if unhealthy_chunk_iterator:
            print("Processing unhealthy dataset chunks...")
            for i, chunk in enumerate(unhealthy_chunk_iterator):
                if not is_first_chunk_written_to_file:
                    chunk.to_csv(output_path, mode='w', header=True, index=index)
                    print(f"  Written first unhealthy chunk (chunk {i+1}) as first overall chunk to {output_path}")
                    is_first_chunk_written_to_file = True
                else:
                    chunk.to_csv(output_path, mode='a', header=False, index=index)
                    print(f"  Appended unhealthy chunk (chunk {i+1}) to {output_path}")
        else:
            print("Warning: Unhealthy chunk iterator is None or empty.")
            
        if not is_first_chunk_written_to_file:
            print(f"No data chunks were processed. Output file {output_path} may not have been created or is empty.")
            return None
        else:
            print(f"✅ Merged dataset saved to {output_path}")
            return output_path

    except Exception as e:
        print(f"Error during merging datasets to {output_path}: {e}")
        return None

# Initial Preprocessing Unhealthy Dataset

In [None]:
unhealthy_dataset_file = 'data/rna_seq_unstranded.csv'
unhealthy_output_file = 'data/unhealthy_data_preprocessed.pkl'
chunk_size = 8000

chunk_iterator = load_csv_in_chunks(unhealthy_dataset_file, header=0, skiprows=[1, 2], index_col=0)
print(f"Starting preprocessing for Unhealthy Dataset: {unhealthy_dataset_file}")

# Convert int64 columns to uint32

chunk_iterator = rename_index(chunk_iterator, 'sample_id')
chunk_iterator = clean_duplicate_nans(chunk_iterator)

save_data_as_pickle(chunk_iterator, unhealthy_output_file)
# save_data_to_csv(data_to_save=chunk_iterator, output_file_path=unhealthy_output_file, index=True)

In [None]:
unhealthy_temp_file = 'data/unhealthy_data_temp.csv'
unhealthy_output_file = 'data/unhealthy_data_preprocessed.csv'
chunk_size = 8000

transpose_csv(input_csv=unhealthy_output_file, output_csv=unhealthy_temp_file, batch_size=chunk_size)

In [None]:
unhealthy_chunk_iterator = load_csv_in_chunks(file_path=unhealthy_output_file, chunk_size=chunk_size, index_col=0, low_memory=False, header=0)
unhealthy_chunk_iterator = filter_rows(unhealthy_chunk_iterator)

save_data_to_csv(data_to_save=unhealthy_chunk_iterator, output_file_path=unhealthy_output_file, index=True)

# Clean up
if os.path.exists(unhealthy_temp_file):
    os.remove(unhealthy_temp_file)
    print(f"Deleted temporary file: {unhealthy_temp_file}")

gc.collect()

# Initial Preprocessing Healthy Dataset
## Convert .gct to .csv

In [None]:
dataset_file = 'data/GTEx_Analysis_2022-06-06_v10_RNASeQCv2.4.2_gene_reads.gct'
output_file = 'data/healthy_data.csv'
chunk_size = 8000

gct_chunk_iterator = load_csv_in_chunks(
    file_path=dataset_file,
    chunk_size=chunk_size,
    sep='\t',
    skiprows=2
)

if gct_chunk_iterator:
    save_data_to_csv(
        data_to_save=gct_chunk_iterator,
        output_file_path=output_file,
        delete_original_path=dataset_file,
        index=False
    )
else:
    print(f"Failed to load {dataset_file} using load_csv_in_chunks.")
    
# Clean up
if 'gct_chunk_iterator' in locals() or 'gct_chunk_iterator' in globals():
    del gct_chunk_iterator
if 'dataset_file' in locals() or 'dataset_file' in globals():
    del dataset_file
if 'output_file' in locals() or 'output_file' in globals():
    del output_file
if 'read_chunk_size' in locals() or 'read_chunk_size' in globals():
    del chunk_size
    
gc.collect()

## Transpose healthy dataset

In [None]:
healthy_data_input = 'data/healthy_data.csv'
healthy_data_output = 'data/healthy_data_transposed.csv'
feature_id_col = 'Name'
exclude_cols = ['Description']
batch_size = 3000

transpose_csv(
    input_csv=healthy_data_input,
    output_csv=healthy_data_output,
    feature_id_col=feature_id_col,
    exclude_cols=exclude_cols,
    batch_size=batch_size
)

## Load and clean dataset

In [None]:
healthy_dataset_file = 'data/healthy_data_transposed.csv'
healthy_output_file = 'data/healthy_data_preprocessed.csv'
chunk_size = 5000

chunk_iterator = load_csv_in_chunks(file_path=healthy_dataset_file, chunk_size=chunk_size, index_col=0)
print(f"Starting preprocessing for Healthy Dataset: {healthy_dataset_file}")

chunk_iterator = clean_duplicate_nans(chunk_iterator)

save_data_to_csv(
    data_to_save=chunk_iterator,
    output_file_path=healthy_output_file,
    index=True
)

gc.collect()

## Drop genes with only zero values

In [None]:
healthy_data_input = 'data/healthy_data_preprocessed.csv'
healthy_dataset_temp = 'data/healthy_data_temp.csv'
healthy_output_file = 'data/healthy_data_final.csv'
chunk_size = 5000

transpose_csv(input_csv=healthy_data_input, output_csv=healthy_dataset_temp, batch_size=chunk_size)

In [None]:
healthy_chunk_iterator = load_csv_in_chunks(file_path=healthy_dataset_temp, chunk_size=chunk_size, index_col=0, low_memory=False, header=0)
healthy_chunk_iterator = filter_rows(healthy_chunk_iterator)

save_data_to_csv(data_to_save=healthy_chunk_iterator, output_file_path=healthy_output_file, index=True)

# Clean up temporary files
if os.path.exists(healthy_dataset_temp):
    os.remove(healthy_dataset_temp)
    print(f"Deleted temporary file: {healthy_dataset_temp}")

In [None]:
healthy_data_input = 'data/healthy_data_preprocessed.csv'
batch_size = 3000

transpose_csv(input_csv='data/healthy_data_final.csv', output_csv=healthy_data_input, batch_size=batch_size)

if os.path.exists(healthy_output_file):
    os.remove(healthy_output_file)
    print(f"Deleted temporary file: {healthy_output_file}")

# Prepare datasets for merge

In [None]:
healthy_data_path = 'data/healthy_data_preprocessed.csv'
unhealthy_data_path = 'data/unhealthy_data_preprocessed.csv'
output_healthy_path = 'data/healthy_data_aligned.csv'
output_unhealthy_path = 'data/unhealthy_data_aligned.csv'
metadata_path = 'data/gtex_blood_samples_metadata.csv'
chunk_size = 1000

whole_blood_samples = get_healthy_whole_blood_samples(metadata_path)
print(f"Found {len(whole_blood_samples)} healthy whole blood samples in metadata")

# Load datasets
print("\nLoading datasets...")
print(f"   Healthy data: {healthy_data_path}")
print(f"   Unhealthy data: {unhealthy_data_path}")
print(f"   Chunk size: {chunk_size}")

healthy_chunks = load_csv_in_chunks(healthy_data_path, chunk_size=chunk_size, index_col=0)
unhealthy_chunks = load_csv_in_chunks(unhealthy_data_path, chunk_size=chunk_size, index_col=0)

if healthy_chunks is None or unhealthy_chunks is None:
    print("ERROR: Failed to load one or both datasets")
else:
    print("Both datasets loaded successfully")

    # Align gene columns (strips version suffixes and keeps intersection only)
    print("\nStarting gene alignment...")
    aligned_healthy, aligned_unhealthy = align_gene_columns_simple(
        healthy_chunks, unhealthy_chunks
    )

    if aligned_healthy and aligned_unhealthy:
        print(f"   Saving healthy dataset to: {output_healthy_path}")
        save_data_to_csv(aligned_healthy, output_healthy_path, index=True)

        print(f"   Saving unhealthy dataset to: {output_unhealthy_path}")
        save_data_to_csv(aligned_unhealthy, output_unhealthy_path, index=True)

        print("Gene alignment completed successfully!")
    else:
        print("ERROR: Gene alignment failed")

In [None]:
# Metadata preparation
healthy_path = 'data/healthy_data_aligned.csv'
unhealthy_path = 'data/unhealthy_data_aligned.csv'
metadata_path = 'data/merged_metadata_simple.csv'
prepare_metadata(healthy_dataset_path=healthy_path, unhealthy_dataset_path=unhealthy_path)

# Metadata check
check_metadata(healthy_dataset_path=healthy_path, unhealthy_dataset_path=unhealthy_path, merged_metadata_path=metadata_path)

In [None]:
# Merge datasets
healthy_data_path = 'data/healthy_data_aligned.csv'
unhealthy_data_path = 'data/unhealthy_data_aligned.csv'
merged_data_path = 'data/merged_dataset.csv'
healthy_chunk_iterator = load_csv_in_chunks(healthy_data_path, chunk_size=2000, index_col=0)
unhealthy_chunk_iterator = load_csv_in_chunks(unhealthy_data_path, chunk_size=2000, index_col=0)

merged_dataset = merge_datasets(healthy_chunk_iterator, unhealthy_chunk_iterator, merged_data_path)

# Preprocessing steps after PyDESeq2

- Possibly transpose, depends on results from PyDESeq2
- Load as dataframe
- Add labels
- Convert gene IDs to corrosponding symbols

In [None]:
dataset = ''
chunk_size = 5000

healthy_chunk_iterator = load_csv_in_chunks(dataset, chunk_size=chunk_size, index_col=0, header=0)

