In [None]:
import os
import pandas as pd
import mygene
import csv
import random

# MyGene API test

Manual MyGene API test for checking if data is correct

In [None]:
# Check mygene for specific gene IDs
def check_mygene(gene_ids):
    mg = mygene.MyGeneInfo()
    results = mg.querymany(gene_ids, scopes='ensemblgene', fields='symbol', species='human')
    # return {result['query']: result.get('symbol', None) for result in results}
    return results

genes = ["ENSG00000223972","ENSG00000227232","ENSG00000278267","ENSG00000243485","ENSG00000237613"]
genes2 = ["ENSG00000000003","ENSG00000000005","ENSG00000000419","ENSG00000000457","ENSG00000000460"]
gene_results = check_mygene(genes)
gene_results2 = check_mygene(genes2)

# Print the results
print("Gene symbols found:")
for res in gene_results:
    print(res)
    
print("\n-================-")    

for res in gene_results2:
    print(res)
    
# Use getgenes to fetch gene information
# def getgenes(gene_ids):
#     mg = mygene.MyGeneInfo()
#     results = mg.getgenes(gene_ids, fields='symbol')
#     return results

# Example usage of getgenes
# gene_info = getgenes(genes)
# print("Gene information fetched:")
# for gene in gene_info:
#     print(f"Gene ID: {gene['query']}, Symbol: {gene.get('symbol', 'N/A')}")

# Test ordering of columns while preserving sample data.

Testing if ordering the columns preserves the sample data correctly.

In [None]:
test_df = pd.DataFrame({
    'sample_id': ['patient1', 'patient2', 'patient3'],
    'ENSG00000227232.5': [40, 50, 60],
    'ENSG00000000003.15': [100, 200, 300],
    'ENSG00000278267.1': [7, 8, 9]
})

print("=== ORIGINAL DATAFRAME ===")
print(test_df)

gene_column_prefix = "ENSG"
original_to_base = {}
base_gene_ids = set()

for column in test_df.columns:
    if isinstance(column, str) and column.startswith(gene_column_prefix):
        base_gene_id = column.split('.')[0]
        if base_gene_id not in base_gene_ids:
            original_to_base[column] = base_gene_id
            base_gene_ids.add(base_gene_id)

renamed_df = test_df.rename(columns=original_to_base)

non_gene_cols = [col for col in renamed_df.columns 
                if not (isinstance(col, str) and col.startswith(gene_column_prefix))]
common_gene_base_ids = sorted(base_gene_ids)
final_columns = non_gene_cols + common_gene_base_ids
reordered_df = renamed_df[final_columns].copy()

print("\n=== FINAL REORDERED DATAFRAME ===")
print(reordered_df)

# Convert .gct to .csv

In [None]:
dataset_file = 'data/GTEx_Analysis_2022-06-06_v10_RNASeQCv2.4.2_gene_reads.gct'
output_file = 'data/healthy_data.csv'
chunk_size = 8000
first_chunk = True

for chunk in pd.read_csv(dataset_file, sep='\t', skiprows=2, chunksize=chunk_size):
    if first_chunk:
        # For the first chunk, write with header
        chunk.to_csv(output_file, index=False, mode='w')
        first_chunk = False
    else:
        # For subsequent chunks, append without header
        chunk.to_csv(output_file, index=False, mode='a', header=False)
    print(f"Processed and wrote a chunk of {len(chunk)} rows to {output_file}")

print(f"Successfully converted {dataset_file} to {output_file} by processing in chunks.")

# Delete the original GCT file to save space
if os.path.exists(dataset_file):
    os.remove(dataset_file)
    print(f"Deleted the original GCT file: {dataset_file}")

# Transpose dataset

Flip dataset so that the rows become the columns and vice versa


In [None]:
input_csv_path = 'data/healthy_data_processed.csv'
output_transposed_csv_path = 'data/healthy_data_transposed.csv'

print(f"Starting transposition of {input_csv_path} to {output_transposed_csv_path}...")

# 1. Get gene IDs (from the 'Name' column, these will be the new header columns)
# Reads the first column 'Name' which contains gene Ensembl IDs
gene_ids_series = pd.read_csv(input_csv_path, usecols=['Name']).iloc[:, 0]
gene_ids_list = gene_ids_series.tolist()
print(f"Read {len(gene_ids_list)} gene IDs to be used as columns.")

# 2. Get sample IDs (these are the current column headers, excluding 'Name' and 'Description')
# Read only the header row of the input CSV to get column names
header_df = pd.read_csv(input_csv_path, nrows=0)
original_column_names = header_df.columns.tolist()
# The first two columns are 'Name' and 'Description', the rest are sample IDs
sample_ids_list = original_column_names[2:] # Assuming 'Name' is first, 'Description' is second
print(f"Found {len(sample_ids_list)} sample IDs to be used as rows.")

# 3. Write the new transposed CSV
with open(output_transposed_csv_path, 'w', newline='') as outfile:
    writer = csv.writer(outfile)

    # Write the header for the transposed file: 'sample_id' followed by all gene IDs
    transposed_header = ['sample_id'] + gene_ids_list
    writer.writerow(transposed_header)
    print("Written header to output file.")

    # 4. Iterate through sample IDs (original columns) in chunks for memory efficiency
    sample_chunk_size = 3000  # Adjust based on memory; number of samples to process at once
    num_samples = len(sample_ids_list)

    for i in range(0, num_samples, sample_chunk_size):
        current_sample_ids_chunk = sample_ids_list[i:i + sample_chunk_size]
        
        if not current_sample_ids_chunk:
            continue

        chunk_df = pd.read_csv(input_csv_path, usecols=current_sample_ids_chunk)

        for sample_id in current_sample_ids_chunk:
            sample_values_list = chunk_df[sample_id].tolist()
            
            output_row = [sample_id] + sample_values_list
            writer.writerow(output_row)
        
        print(f"Processed and wrote samples from index {i} to {min(i + sample_chunk_size - 1, num_samples - 1)}")

print(f"Successfully transposed {input_csv_path} to {output_transposed_csv_path}")

# Delete the original GCT file to save space
if os.path.exists(input_csv_path):
    os.remove(input_csv_path)
    print(f"Deleted the original GCT file: {input_csv_path}")

# Load CSV in chunks

Use this method to load the csv datasets in chunks so that it doesn't crash the devcontainer :)

Also use this to get the `chunk_iterator` for the dataset and use that in combination with the other methods created

In [None]:
def load_csv_in_chunks(file_path, chunk_size=8000, **kwargs):
    """
    Loads a CSV file in chunks to avoid memory issues.

    Args:
        file_path (str): The path to the CSV file.
        chunk_size (int): The number of rows per chunk.
        **kwargs: Additional keyword arguments to pass to pd.read_csv()
                  (e.g., sep=',', header=0, index_col=None, usecols=None).

    Returns:
        A pandas TextFileReader object (iterator) that yields DataFrame chunks
        if the file exists, otherwise None.
    """
    
    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None
    
    print(f"Preparing to load {file_path} in chunks of size {chunk_size}...")
    try:
        chunk_iterator = pd.read_csv(file_path, chunksize=chunk_size, **kwargs)
        return chunk_iterator
    except Exception as e:
        print(f"Error loading CSV file {file_path}: {e}")
        return None

# Save data to new csv file

Use this to save the preprocessed data as a new csv file, can be a pandas `dataframe` or a `chunk_iterator`.

Optional: Can also delete old csv file. Can be usefull for final dataset when we know everything is setup properly. Otherwise, keep old dataset so that we work in a sort of non-destructive way.

In [None]:
def save_data_to_csv(data_to_save, output_file_path, delete_original_path=False, index=False, **kwargs):
    """
    Saves a pandas DataFrame or an iterator of DataFrame chunks to a CSV file.
    Optionally deletes the original file after successful saving.

    Args:
        data_to_save (pd.DataFrame or iterator): DataFrame or iterator yielding DataFrames.
        output_file_path (str): Path to save the new CSV file.
        delete_original_path (str, optional): Path to the original file to delete. Defaults to False.
        index (bool, optional): Whether to write DataFrame index. Defaults to False.
        **kwargs: Additional keyword arguments to pass to df.to_csv().
    """
    mode = kwargs.pop('mode', 'w')
    header = kwargs.pop('header', True)

    try:
        if isinstance(data_to_save, pd.DataFrame):
            print(f"Saving DataFrame to {output_file_path}...")
            data_to_save.to_csv(output_file_path, index=index, mode=mode, header=header, **kwargs)
        else:
            print(f"Saving data in chunks to {output_file_path}...")
            for i, chunk_df in enumerate(data_to_save):
                chunk_mode = 'w' if i == 0 else 'a'
                chunk_header = header if i == 0 else False
                
                chunk_df.to_csv(output_file_path, index=index, mode=chunk_mode, header=chunk_header, **kwargs)
                
                if i == 0:
                    print(f"Written first chunk to {output_file_path}")
                else:
                    print(f"Appended chunk {i+1} to {output_file_path}")
        
        print(f"Successfully saved data to {output_file_path}")
        
        # Handle file deletion if requested
        if delete_original_path and delete_original_path != output_file_path:
            if os.path.exists(delete_original_path):
                os.remove(delete_original_path)
                print(f"Successfully deleted original file: {delete_original_path}")
            else:
                print(f"Warning: Original file not found at {delete_original_path}")
        elif delete_original_path == output_file_path:
            print(f"Warning: Original and output paths are the same. Original file not deleted.")

    except Exception as e:
        print(f"Error saving data to {output_file_path}: {e}")
        if os.path.exists(output_file_path):
            print(f"Partial data might have been written to {output_file_path}.")
        raise

# Modify Dataframe

Use this method to perform different dataframe manipulations in a memory safe way.


`drop_row`: Drops a row by its index label.


`column_header`: Renames a column.


`index_header`: Sets/renames the index name.

In [None]:
def modify_dataframe_element(df_chunk, new_value, row_label=None, old_name=None, element_type=None):
    """
    Modifies DataFrame elements in a memory-safe way.
    
    Args:
        df_chunk (pd.DataFrame): The DataFrame chunk to modify.
        new_value (any): The new value or name to set.
        row_label (any, optional): Row index label for 'drop_row' operation.
        old_name (any, optional): Current column/index name for reference.
        element_type (str): Type of modification:
                           'drop_row': Drops row by index label
                           'column_header': Renames column 
                           'index_header': Sets/renames index name
    
    Returns:
        pd.DataFrame: Modified DataFrame chunk.
    """
    
    if element_type == 'drop_row':
        if row_label is None:
            print("Error: 'row_label' required for 'drop_row' operation.")
            return df_chunk
        if row_label not in df_chunk.index:
            print(f"Warning: Row '{row_label}' not found. No changes made.")
            return df_chunk
        print(f"Dropping row: {row_label}")
        return df_chunk.drop(index=row_label)
    
    elif element_type == 'column_header':
        if old_name is None:
            print("Error: 'old_name' required for 'column_header' operation.")
            return df_chunk
        if old_name not in df_chunk.columns:
            print(f"Warning: Column '{old_name}' not found. No changes made.")
            return df_chunk
        print(f"Renaming column: '{old_name}' → '{new_value}'")
        df_chunk.rename(columns={old_name: new_value}, inplace=True)
        return df_chunk
    
    elif element_type == 'index_header':
        current_name = df_chunk.index.name
        print(f"Setting index name: '{current_name}' → '{new_value}'")
        df_chunk.index.name = new_value
        return df_chunk
    
    else:
        print(f"Error: Invalid element_type '{element_type}'. Use: 'drop_row', 'column_header', or 'index_header'.")
        return df_chunk

# Convert gene IDs to corrosponding symbol

Use this to remove the version suffix from the gene IDs (.XX) and convert the gene IDs to the symbols.

We could keep using the gene IDs without the version suffix, but in general the symbols are better for readibility and should be more consistent between GENCODE versions.

Also removes duplicate genes found in the dataset, only keeping the first and removing the columns of the others.

In [None]:
def align_gene_columns_simple(healthy_chunk_iterator, unhealthy_chunk_iterator, 
                              gene_column_prefix="ENSG"):
    """
    Aligns gene columns between healthy and unhealthy datasets.
    
    Args:
        healthy_chunk_iterator: Iterator for healthy dataset chunks
        unhealthy_chunk_iterator: Iterator for unhealthy dataset chunks
        gene_column_prefix (str): Prefix to identify gene columns (default: "ENSG")
    
    Returns:
        tuple: (aligned_healthy_chunks, aligned_unhealthy_chunks, alignment_info)
    """
    
    print("Starting gene column alignment...")
    
    print("\nLoading first chunks from both datasets...")
    try:
        first_healthy_chunk = next(healthy_chunk_iterator)
        print(f"   ✓ Loaded healthy dataset first chunk: {first_healthy_chunk.shape}")
        
        first_unhealthy_chunk = next(unhealthy_chunk_iterator)
        print(f"   ✓ Loaded unhealthy dataset first chunk: {first_unhealthy_chunk.shape}")
        
    except StopIteration:
        print("ERROR: One or both datasets are empty!")
        return [], [], {}
    

    print("\nProcessing gene columns and stripping version suffixes...")
    def extract_gene_info(chunk, dataset_name):
        """Extract gene columns and create mapping from original_column to base_id"""

        print(f"Processing {dataset_name} dataset...")

        original_to_base = {}
        base_gene_ids = set()
        total_gene_columns = 0
        duplicate_count = 0
        
        for column in chunk.columns:
            if isinstance(column, str) and column.startswith(gene_column_prefix):
                total_gene_columns += 1
                base_gene_id = column.split('.')[0]
                
                # Handle duplicates (keep first occurrence)
                if base_gene_id not in base_gene_ids:
                    original_to_base[column] = base_gene_id
                    base_gene_ids.add(base_gene_id)
                else:
                    duplicate_count += 1
                    print(f"Warning: Duplicate base gene {base_gene_id} found, skipping {column}")
        
        print(f"Total gene columns found: {total_gene_columns}")
        print(f"Unique base gene IDs: {len(base_gene_ids)}")
        if duplicate_count > 0:
            print(f"Duplicates removed: {duplicate_count}")

        return original_to_base, base_gene_ids
    
    # Process both datasets
    healthy_rename_mapping, healthy_base_genes = extract_gene_info(first_healthy_chunk, "HEALTHY")
    unhealthy_rename_mapping, unhealthy_base_genes = extract_gene_info(first_unhealthy_chunk, "UNHEALTHY")

    print(f"\nSummary:")
    print(f"   Healthy dataset: {len(healthy_base_genes)} unique genes")
    print(f"   Unhealthy dataset: {len(unhealthy_base_genes)} unique genes")

    print("\nFinding common genes between datasets...")
    common_base_genes = healthy_base_genes & unhealthy_base_genes

    if not common_base_genes:
        print("ERROR: No common genes found between datasets!")
        return [], [], {}

    print(f"Common genes: {len(common_base_genes)}")
    print(f"Genes exclusive to healthy dataset: {len(healthy_base_genes - common_base_genes)}")
    print(f"Genes exclusive to unhealthy dataset: {len(unhealthy_base_genes - common_base_genes)}")

    print("\nProcessing all chunks and renaming to base gene IDs...")

    def process_all_chunks(chunk_iterator, first_chunk, rename_mapping, common_genes, dataset_name):
        """Process all chunks: rename gene columns to base IDs and keep only common genes"""
        print(f"   Processing {dataset_name} dataset chunks...")
        
        # Rename columns in first chunk
        renamed_first_chunk = first_chunk.rename(columns=rename_mapping)
        
        # Build final column list
        non_gene_cols = [col for col in renamed_first_chunk.columns 
                        if not (isinstance(col, str) and col.startswith(gene_column_prefix))]
        common_gene_base_ids = sorted(common_genes)  # Sorted base gene IDs
        final_columns = non_gene_cols + common_gene_base_ids
        
        print(f"      Non-gene columns: {len(non_gene_cols)}")
        print(f"      Common gene columns (base IDs): {len(common_gene_base_ids)}")
        print(f"      Total columns to keep: {len(final_columns)}")
        
        # Keep only desired columns
        aligned_chunks = [renamed_first_chunk[final_columns].copy()]
        chunk_count = 1
        print(f"✓ Processed chunk 1 - shape: {aligned_chunks[0].shape}")

        # Process remaining chunks
        for chunk in chunk_iterator:
            renamed_chunk = chunk.rename(columns=rename_mapping)
            aligned_chunks.append(renamed_chunk[final_columns].copy())
            chunk_count += 1
            
            if chunk_count % 3 == 0:
                print(f"✓ Processed {chunk_count} chunks so far...")

        print(f"{dataset_name} processing complete: {chunk_count} chunks processed")
        return aligned_chunks

    # Process both datasets
    print(f"\nProcessing HEALTHY dataset...")
    aligned_healthy = process_all_chunks(healthy_chunk_iterator, first_healthy_chunk,
                                    healthy_rename_mapping, common_base_genes, "HEALTHY")

    print(f"\nProcessing UNHEALTHY dataset...")
    aligned_unhealthy = process_all_chunks(unhealthy_chunk_iterator, first_unhealthy_chunk,
                                        unhealthy_rename_mapping, common_base_genes, "UNHEALTHY")

    alignment_info = {
        'total_healthy_genes_original': len(healthy_rename_mapping),
        'total_unhealthy_genes_original': len(unhealthy_rename_mapping),
        'common_genes': len(common_base_genes),
        'genes_only_in_healthy': len(healthy_base_genes - common_base_genes),
        'genes_only_in_unhealthy': len(unhealthy_base_genes - common_base_genes),
        'healthy_chunks_processed': len(aligned_healthy),
        'unhealthy_chunks_processed': len(aligned_unhealthy),
        'final_healthy_columns': len([col for col in aligned_healthy[0].columns if not isinstance(col, str) or not col.startswith(gene_column_prefix)]) + len(common_base_genes),
        'final_unhealthy_columns': len([col for col in aligned_unhealthy[0].columns if not isinstance(col, str) or not col.startswith(gene_column_prefix)]) + len(common_base_genes)
    }

    print("\nAlignment complete!")
    print("=" * 25)
    print(f"Final Summary:")
    print(f"Original genes - Healthy: {alignment_info['total_healthy_genes_original']}")
    print(f"Original genes - Unhealthy: {alignment_info['total_unhealthy_genes_original']}")
    print(f"Common genes kept: {alignment_info['common_genes']}")
    print(f"Healthy chunks processed: {alignment_info['healthy_chunks_processed']}")
    print(f"Unhealthy chunks processed: {alignment_info['unhealthy_chunks_processed']}")
    print("=" * 25)

    return aligned_healthy, aligned_unhealthy, alignment_info

In [None]:
def align_gene_columns_between_datasets(healthy_chunk_iterator, unhealthy_chunk_iterator, 
                                       gene_column_prefix="ENSG"):
    """
    Aligns gene columns between healthy and unhealthy datasets by keeping only
    genes present in both datasets. Handles GENCODE version differences.
    
    Args:
        healthy_chunk_iterator: Iterator yielding healthy DataFrame chunks
        unhealthy_chunk_iterator: Iterator yielding unhealthy DataFrame chunks  
        gene_column_prefix (str): Prefix to identify gene columns (e.g., "ENSG")
    
    Returns:
        tuple: (aligned_healthy_chunks, aligned_unhealthy_chunks, alignment_info)
    """
    
    print("Starting gene column alignment...")
    
    # Get first chunks
    try:
        first_healthy_chunk = next(healthy_chunk_iterator)
        first_unhealthy_chunk = next(unhealthy_chunk_iterator)
    except StopIteration:
        print("Error: One or both datasets are empty")
        return [], [], {}
    
    def extract_and_rename_gene_columns(chunk, dataset_name):
        """Extract gene columns, strip version suffixes, and rename columns"""
        gene_mapping = {}  # Maps base_gene_id -> original_column_name
        column_rename_mapping = {}  # Maps original_column_name -> base_gene_id
        
        for column in chunk.columns:
            if isinstance(column, str) and column.startswith(gene_column_prefix):
                base_gene_id = column.split('.')[0]  # Remove version suffix
                
                if base_gene_id not in gene_mapping:
                    gene_mapping[base_gene_id] = column
                    column_rename_mapping[column] = base_gene_id  # NEW: Map to base ID
                else:
                    print(f"Warning: Duplicate gene {base_gene_id} in {dataset_name} dataset")
        
        return gene_mapping, column_rename_mapping
    
    healthy_genes, healthy_rename_map = extract_and_rename_gene_columns(first_healthy_chunk, "healthy")
    unhealthy_genes, unhealthy_rename_map = extract_and_rename_gene_columns(first_unhealthy_chunk, "unhealthy")
    
    print(f"Healthy dataset: {len(healthy_genes)} unique genes")
    print(f"Unhealthy dataset: {len(unhealthy_genes)} unique genes")
    
    # Find common genes (intersection)
    common_genes = set(healthy_genes.keys()) & set(unhealthy_genes.keys())
    
    
    if not common_genes:
        print("Error: No common genes found between datasets!")
        return [], [], {}
    
    print(f"Common genes found: {len(common_genes)}")
    
    def process_chunks_with_renaming(chunk_iterator, first_chunk, rename_map, common_genes, dataset_name):
        """Process chunks: rename columns to base gene IDs and keep only common genes"""
        # Rename columns in first chunk
        renamed_chunk = first_chunk.rename(columns=rename_map)
        
        # Keep non-gene columns + common gene columns
        non_gene_cols = [col for col in renamed_chunk.columns 
                        if not (isinstance(col, str) and col.startswith(gene_column_prefix))]
        common_gene_cols = [gene_id for gene_id in sorted(common_genes)]
        final_columns = non_gene_cols + common_gene_cols
        
        aligned_chunks = [renamed_chunk[final_columns].copy()]
        chunk_count = 1
        
        for chunk in chunk_iterator:
            renamed_chunk = chunk.rename(columns=rename_map)
            aligned_chunks.append(renamed_chunk[final_columns].copy())
            chunk_count += 1
            
            if chunk_count % 5 == 0:
                print(f"  Processed {chunk_count} {dataset_name} chunks...")
        
        print(f"  Completed {chunk_count} {dataset_name} chunks")
        return aligned_chunks
    
    print("Processing healthy dataset chunks...")
    aligned_healthy = process_chunks_with_renaming(
        healthy_chunk_iterator, first_healthy_chunk, healthy_rename_map, common_genes, "healthy"
    )
    
    print("Processing unhealthy dataset chunks...")
    aligned_unhealthy = process_chunks_with_renaming(
        unhealthy_chunk_iterator, first_unhealthy_chunk, unhealthy_rename_map, common_genes, "unhealthy"
    )
    
    # Generate summary statistics
    alignment_info = {
        'total_healthy_genes': len(healthy_genes),
        'total_unhealthy_genes': len(unhealthy_genes),
        'common_genes': len(common_genes),
        'genes_only_in_healthy': len(healthy_genes) - len(common_genes),
        'genes_only_in_unhealthy': len(unhealthy_genes) - len(common_genes),
        'healthy_chunks_processed': len(aligned_healthy),
        'unhealthy_chunks_processed': len(aligned_unhealthy)
    }
    
    print("\n=== Gene Alignment Summary ===")
    print(f"✓ Common genes: {alignment_info['common_genes']}")
    print(f"• Healthy-only genes: {alignment_info['genes_only_in_healthy']}")
    print(f"• Unhealthy-only genes: {alignment_info['genes_only_in_unhealthy']}")
    print("=============================\n")
    
    return aligned_healthy, aligned_unhealthy, alignment_info

In [None]:
def preprocess_gene_columns(chunk_iterator, mygene_client, id_column_prefix="ENSG"):
    """
    Convert gene IDs to symbols using corrected MyGene API call.
    """
    first_chunk = next(chunk_iterator)
    original_columns = first_chunk.columns.tolist()
    
    # Find gene columns (should already be base IDs without version suffixes)
    gene_columns = [col for col in original_columns 
                   if isinstance(col, str) and col.startswith(id_column_prefix)]
    
    if not gene_columns:
        print("No gene columns found for symbol conversion")
        return pd.concat([first_chunk] + list(chunk_iterator), ignore_index=False)
    
    print(f"Converting {len(gene_columns)} gene IDs to symbols...")
    gene_symbols = {}
    
    try:
        results = mygene_client.querymany(
            gene_columns, 
            scopes='ensembl.gene',
            fields='symbol', 
            species='human', 
            verbose=False
        )
        
        # Process results
        for result in results:
            if 'symbol' in result and result['symbol']:
                gene_symbols[result['query']] = result['symbol']
            else:
                # Keep original ID if no symbol found
                gene_symbols[result['query']] = result['query']
                print(f"No symbol found for {result['query']}, keeping original ID")
                
        print(f"Successfully converted {len([v for v in gene_symbols.values() if not v.startswith('ENSG')])} genes to symbols")
        
    except Exception as e:
        print(f"MyGene API error: {e}")
        print("Using original gene IDs as fallback")
        gene_symbols = {gid: gid for gid in gene_columns}
    
    # Build column mapping
    column_mapping = {}
    for col in original_columns:
        if col in gene_columns:
            column_mapping[col] = gene_symbols.get(col, col)
        else:
            column_mapping[col] = col
    
    # Apply mapping and combine chunks
    print("Applying symbol mapping to all chunks...")
    all_chunks = [first_chunk.rename(columns=column_mapping)]
    
    for chunk in chunk_iterator:
        all_chunks.append(chunk.rename(columns=column_mapping))
    
    final_df = pd.concat(all_chunks, ignore_index=False)
    print(f"Final dataset shape: {final_df.shape}")
    
    return final_df

In [None]:
def add_condition_labels_to_chunks(chunk_iterator, condition_label, dataset_name):
    """
    Add condition labels to dataset chunks.
    
    Args:
        chunk_iterator: Iterator yielding DataFrame chunks
        condition_label (int): Binary label (0 for healthy, 1 for unhealthy)
        dataset_name (str): Name for logging purposes
    
    Returns:
        list: List of labeled DataFrame chunks
    """
    print(f"Adding label '{condition_label}' to {dataset_name} dataset...")
    
    labeled_chunks = []
    for chunk in chunk_iterator:
        chunk['condition'] = condition_label
        labeled_chunks.append(chunk)
    
    print(f"✅ Completed {len(labeled_chunks)} {dataset_name} chunks")
    return labeled_chunks

def merge_labeled_datasets(healthy_chunks, unhealthy_chunks):
    """
    Merge healthy and unhealthy dataset chunks into one DataFrame.
    
    Args:
        healthy_chunks (list): List of healthy DataFrame chunks
        unhealthy_chunks (list): List of unhealthy DataFrame chunks
    
    Returns:
        pd.DataFrame: Merged dataset
    """
    print("Merging datasets...")
    
    all_chunks = healthy_chunks + unhealthy_chunks
    merged_dataset = pd.concat(all_chunks, axis=0, ignore_index=False)
    
    print(f"✅ Merged dataset: {len(merged_dataset)} samples, {len(merged_dataset.columns)} features")
    print(f"   Healthy (0): {(merged_dataset['condition'] == 0).sum()}")
    print(f"   Unhealthy (1): {(merged_dataset['condition'] == 1).sum()}")
    
    return merged_dataset

def convert_gene_ids_to_symbols(dataset_chunk_iterator, mygene_client, gene_column_prefix="ENSG"):
    """
    Convert gene IDs to gene symbols using MyGene API in a memory-efficient way.
    Only calls MyGene API once for all gene columns from the first chunk.
    
    Args:
        dataset_chunk_iterator: Iterator yielding DataFrame chunks
        mygene_client: MyGene client instance
        gene_column_prefix (str): Prefix to identify gene columns
    
    Returns:
        pd.DataFrame: Complete dataset with gene symbols as column names
    """
    print("Converting gene IDs to symbols (chunked processing)...")
    
    # Get first chunk to determine gene columns and create mapping
    try:
        first_chunk = next(dataset_chunk_iterator)
        print(f"Processing first chunk: {first_chunk.shape}")
    except StopIteration:
        print("❌ Error: Dataset is empty")
        return pd.DataFrame()
    
    # Identify gene columns from first chunk
    gene_columns = [col for col in first_chunk.columns 
                   if isinstance(col, str) and col.startswith(gene_column_prefix)]
    
    if not gene_columns:
        print("No gene columns found - returning original data")
        # Concatenate all chunks and return
        all_chunks = [first_chunk]
        for chunk in dataset_chunk_iterator:
            all_chunks.append(chunk)
        return pd.concat(all_chunks, axis=0, ignore_index=False)
    
    print(f"Found {len(gene_columns)} gene columns")
    
    # Single MyGene API call for all gene IDs
    gene_id_to_symbol = {}
    symbols_found = 0
    
    try:
        print("Making single MyGene API call...")
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            results = mygene_client.querymany(
                gene_columns,
                scopes='ensembl.gene',
                fields='symbol',
                species='human',
                verbose=False,
                silent=True
            )
        
        # Process API results
        for result in results:
            gene_id = result['query']
            if 'symbol' in result and result['symbol']:
                gene_id_to_symbol[gene_id] = result['symbol']
                symbols_found += 1
            else:
                gene_id_to_symbol[gene_id] = gene_id
        
        print(f"✅ Successfully converted {symbols_found}/{len(gene_columns)} genes to symbols")
        
    except Exception as e:
        print(f"⚠️ MyGene API error: {e}. Using original gene IDs")
        gene_id_to_symbol = {gene_id: gene_id for gene_id in gene_columns}
    
    # Create final column mapping (preserve non-gene columns like 'condition')
    final_column_mapping = {}
    for col in first_chunk.columns:
        if col == 'condition':
            final_column_mapping[col] = col
        elif col in gene_id_to_symbol:
            final_column_mapping[col] = gene_id_to_symbol[col]
        else:
            final_column_mapping[col] = col
    
    print("Applying gene symbol mapping to all chunks...")
    
    # Process first chunk
    renamed_first_chunk = first_chunk.rename(columns=final_column_mapping)
    processed_chunks = [renamed_first_chunk]
    chunk_count = 1
    
    # Process remaining chunks with same mapping
    for chunk in dataset_chunk_iterator:
        renamed_chunk = chunk.rename(columns=final_column_mapping)
        processed_chunks.append(renamed_chunk)
        chunk_count += 1
        
        if chunk_count % 3 == 0:
            print(f"  ✓ Processed {chunk_count} chunks...")
    
    print(f"Concatenating {chunk_count} processed chunks...")
    final_dataset = pd.concat(processed_chunks, axis=0, ignore_index=False)
    
    print(f"✅ Final dataset shape: {final_dataset.shape}")
    print(f"   Symbols converted: {symbols_found}/{len(gene_columns)}")
    
    return final_dataset

# Actual preprocessing code

In [None]:
# Setup paths for the healthy and unhealthy datasets
unhealthy_data_path = 'data/rna_seq_unstranded.csv'
output_unhealthy_processed_path = 'data/unhealthy_data_processed.csv'

healthy_data_path = 'data/healthy_data_preprocessed.csv'
output_healthy_processed_path = 'data/healthy_data_processed.csv'

mg = mygene.MyGeneInfo() # Essential for process_gene_columns

## Unhealthy Preprocessing

In [None]:
# Load the unhealthy dataset in chunks
df_unhealthy_chunk_iterator = load_csv_in_chunks(
    unhealthy_data_path,
    header=0,         # Use the first row (Ensembl IDs) as header
    skiprows=[1, 2],  # Skip the `gene_name` and `gene_type` rows of the original file
    index_col=0       # `Sample IDs` becomes the index
)

# Process the unhealthy dataset chunks
if df_unhealthy_chunk_iterator:
    processed_unhealthy_chunks = []
    for i, df_chunk in enumerate(df_unhealthy_chunk_iterator):
        if df_chunk.index.name != 'sample_id':
            df_chunk = modify_dataframe_element(df_chunk, new_value='sample_id', old_name=df_chunk.index.name, element_type='index_header')
        processed_unhealthy_chunks.append(df_chunk)
    
    # if processed_unhealthy_chunks:
    #     print("Efficiently processing unhealthy dataset gene columns with single MyGene API call...")
    #     df_unhealthy_processed = preprocess_gene_columns(
    #         iter(processed_unhealthy_chunks), 
    #         mg, 
    #         id_column_prefix="ENSG"
    #     )
        
    #     print(f"Unhealthy dataset processing complete. Final shape: {df_unhealthy_processed.shape}")
    save_data_to_csv(processed_unhealthy_chunks, output_unhealthy_processed_path, index=True)
    # else:
    #     print("No unhealthy data chunks were processed.")
else:
    print(f"Failed to load unhealthy data chunks from {unhealthy_data_path}. Check file path and format.")

## Healthy Preprocessing

### Transpose the healthy dataset

In [None]:
# Transpose the healthy dataset
current_input_csv_path = 'data/healthy_data_preprocessed.csv'
current_output_transposed_csv_path = 'data/healthy_data_transformed.csv'
feature_id_col_name = 'Name'
columns_to_exclude = ['Description']

print(f"Starting transposition of {current_input_csv_path} to {current_output_transposed_csv_path} using preferred logic...")

try:
    gene_ids_series = pd.read_csv(current_input_csv_path, usecols=[feature_id_col_name]).iloc[:, 0]
    gene_ids_list = gene_ids_series.tolist()
    print(f"Read {len(gene_ids_list)} gene IDs to be used as columns.")
except ValueError as e:
    print(f"Error reading gene IDs: {e}. Ensure '{feature_id_col_name}' exists in '{current_input_csv_path}'.")
    gene_ids_list = []

if gene_ids_list:
    header_df = pd.read_csv(current_input_csv_path, nrows=0)
    original_column_names = header_df.columns.tolist()

    sample_ids_list = [
        col for col in original_column_names
        if col not in [feature_id_col_name] + columns_to_exclude
    ]
    print(f"Found {len(sample_ids_list)} sample IDs to be used as rows.")

    if sample_ids_list:
        with open(current_output_transposed_csv_path, 'w', newline='') as outfile:
            writer = csv.writer(outfile)
            transposed_header = ['sample_id'] + gene_ids_list
            writer.writerow(transposed_header)
            print("Written header to output file.")

            sample_processing_batch_size = 500
            num_samples = len(sample_ids_list)

            for i in range(0, num_samples, sample_processing_batch_size):
                current_sample_ids_batch = sample_ids_list[i:i + sample_processing_batch_size]
                
                if not current_sample_ids_batch:
                    continue
                
                try:
                    batch_df = pd.read_csv(current_input_csv_path, usecols=current_sample_ids_batch)

                    for sample_id_in_batch in current_sample_ids_batch:
                        sample_values_list = batch_df[sample_id_in_batch].tolist()
                        
                        output_row = [sample_id_in_batch] + sample_values_list
                        writer.writerow(output_row)
                    
                    print(f"Processed and wrote samples from index {i} to {min(i + sample_processing_batch_size - 1, num_samples - 1)}")
                except Exception as e_batch:
                    print(f"Error processing batch of samples starting at index {i}: {e_batch}")
                    print(f"Problematic sample IDs in batch might be: {current_sample_ids_batch[:5]}")


        print(f"Successfully transposed {current_input_csv_path} to {current_output_transposed_csv_path}")

        df_healthy_transposed_sample = pd.read_csv(current_output_transposed_csv_path, nrows=5)
        print(df_healthy_transposed_sample.head())
    else:
        print(f"No sample IDs found to process in {current_input_csv_path} after excluding specified columns.")
else:
    if os.path.exists(current_input_csv_path): # Only print if input file was valid but no gene_ids
        print(f"No gene IDs found in column '{feature_id_col_name}' in {current_input_csv_path}. Transposition aborted.")

### Reduce the samples of the healthy dataset

In [None]:
input_for_sampling_path = 'data/healthy_data_transposed.csv'
output_sampled_path = 'data/healthy_data_transposed_sampled.csv'
num_samples_to_keep = 1000
sampling_read_chunk_size = 1000

print(f"Starting sampling of {input_for_sampling_path} to select {num_samples_to_keep} samples.")

# Reduce the sample size for the healthy dataset to be more equivalent to the unhealthy dataset
if not os.path.exists(input_for_sampling_path):
    print(f"Error: Input file for sampling not found: {input_for_sampling_path}")
else:
    try:
        # 1. Get all sample IDs from the 'sample_id' column of the transformed file
        print("Reading all sample IDs from the transformed healthy dataset...")
        # The first column of healthy_data_transposed.csv is 'sample_id'
        all_sample_ids_df = pd.read_csv(input_for_sampling_path, usecols=['sample_id'])
        all_sample_ids = all_sample_ids_df['sample_id'].tolist()
        print(f"Found {len(all_sample_ids)} total samples in {input_for_sampling_path}.")

        if not all_sample_ids:
            print("Error: No sample IDs found. Cannot proceed with sampling.")
        elif len(all_sample_ids) < num_samples_to_keep:
            print(f"Warning: Total samples ({len(all_sample_ids)}) is less than desired samples ({num_samples_to_keep}). Using all available samples.")
            selected_sample_ids = all_sample_ids
        else:
            # 2. Randomly select N sample IDs
            random.seed(5) # for reproducibility
            selected_sample_ids = random.sample(all_sample_ids, num_samples_to_keep)
            print(f"Randomly selected {len(selected_sample_ids)} sample IDs.")

        if selected_sample_ids:
            # 3. Create the new CSV with only the selected samples
            print(f"Writing {len(selected_sample_ids)} selected samples to {output_sampled_path}...")
            
            # The transposed file has 'sample_id' as the first column, which becomes the index
            chunk_iterator_for_sampling = load_csv_in_chunks(
                input_for_sampling_path,
                chunk_size=sampling_read_chunk_size,
                index_col='sample_id' 
            )

            if chunk_iterator_for_sampling:
                first_save_chunk = True
                for i_chunk, df_chunk_original in enumerate(chunk_iterator_for_sampling):
                    # Filter the chunk to keep only rows whose index (sample_id) is in selected_sample_ids
                    df_chunk_filtered = df_chunk_original[df_chunk_original.index.isin(selected_sample_ids)]

                    if not df_chunk_filtered.empty:
                        if first_save_chunk:
                            save_data_to_csv(df_chunk_filtered, output_sampled_path, index=True, header=True, mode='w')
                            first_save_chunk = False
                        else:
                            save_data_to_csv(df_chunk_filtered, output_sampled_path, index=True, header=False, mode='a')
                        print(f"  Processed original chunk {i_chunk+1}, wrote {len(df_chunk_filtered)} sampled rows to {output_sampled_path}.")
                print(f"Finished writing sampled data to {output_sampled_path}")
            else:
                print(f"Failed to load chunks from {input_for_sampling_path} for sampling.")
        else:
            print("No samples were selected. Output file will not be created.")
            
    except Exception as e:
        print(f"An error occurred during sampling: {e}")

### Rename index name

In [None]:
# # Replace the existing healthy data processing cell with this more efficient version
# sampled_healthy_data_input_path = 'data/healthy_data_transformed_sampled.csv'
# processing_chunk_size = 1000

# print(f"Loading sampled healthy data from: {sampled_healthy_data_input_path}")
# df_healthy_chunk_iterator = load_csv_in_chunks(
#     sampled_healthy_data_input_path,
#     header=0,
#     index_col=0,
#     chunk_size=processing_chunk_size
# )

# # Process the healthy dataset efficiently with single MyGene API call
# if df_healthy_chunk_iterator:
#     # First ensure all chunks have proper index name
#     processed_chunks = []
#     for i, df_chunk in enumerate(df_healthy_chunk_iterator):
#         if df_chunk.index.name != 'sample_id':
#             df_chunk = modify_dataframe_element(df_chunk, new_value='sample_id', old_name=df_chunk.index.name, element_type='index_header')
#         processed_chunks.append(df_chunk)
    
#     # Save the processed healthy data
#     save_data_to_csv(df_healthy_processed, output_healthy_processed_path, index=True)
# else:
#     print(f"Failed to load healthy data chunks from {sampled_healthy_data_input_path}. Check file path and format.")

### Prepare datasets for merge

Remove version suffix from gene IDs.

Process datasets so that they only have the genes that are present in both datasets.

Translate the gene IDs to their corresponding symbols for interpretability.

In [None]:
healthy_data_path = 'data/healthy_data_transposed_sampled.csv'
unhealthy_data_path = 'data/unhealthy_data_processed.csv'
output_healthy_path = 'data/healthy_data_aligned.csv'
output_unhealthy_path = 'data/unhealthy_data_aligned.csv'
chunk_size = 1000

print("🚀 STARTING DATASET ALIGNMENT PROCESS")

# Load datasets
print("\n📂 LOADING DATASETS...")
print(f"   Healthy data: {healthy_data_path}")
print(f"   Unhealthy data: {unhealthy_data_path}")
print(f"   Chunk size: {chunk_size}")

healthy_chunks = load_csv_in_chunks(healthy_data_path, chunk_size=chunk_size, index_col=0)
unhealthy_chunks = load_csv_in_chunks(unhealthy_data_path, chunk_size=chunk_size, index_col=0)

if healthy_chunks is None or unhealthy_chunks is None:
    print("❌ ERROR: Failed to load one or both datasets")
else:
    print("✅ Both datasets loaded successfully")
    
    # Align gene columns (strips version suffixes and keeps intersection only)
    print("\n🔄 STARTING GENE ALIGNMENT...")
    aligned_healthy, aligned_unhealthy, alignment_info = align_gene_columns_simple(
        healthy_chunks, unhealthy_chunks
    )
    
    if aligned_healthy and aligned_unhealthy:
        # Save results
        print("\n💾 SAVING ALIGNED DATASETS...")
        print(f"   Saving healthy dataset to: {output_healthy_path}")
        save_data_to_csv(aligned_healthy, output_healthy_path, index=True)
        
        print(f"   Saving unhealthy dataset to: {output_unhealthy_path}")
        # save_data_to_csv(aligned_unhealthy, output_unhealthy_path, index=True)
        
        # Final verification
        print("\n🔍 FINAL VERIFICATION...")
        print("   Loading saved files to verify structure...")
        
        # Check first few rows of each saved file
        healthy_check = pd.read_csv(output_healthy_path, nrows=3, index_col=0)
        # unhealthy_check = pd.read_csv(output_unhealthy_path, nrows=3, index_col=0)
        
        print(f"   ✅ Healthy dataset saved - Shape: {healthy_check.shape}")
        print(f"      Sample columns: {list(healthy_check.columns[:5])}...")
        
        # print(f"   ✅ Unhealthy dataset saved - Shape: {unhealthy_check.shape}")  
        # print(f"      Sample columns: {list(unhealthy_check.columns[:5])}...")
        
        print("\n🎉 ALIGNMENT PROCESS COMPLETED SUCCESSFULLY!")
        print("   Both datasets now contain only common genes with version suffixes stripped.")
        
    else:
        print("ERROR: Gene alignment failed")

In [None]:
print("🚀 FINAL DATASET PREPARATION")
print("=" * 40)

# Configuration
healthy_aligned_path = 'data/healthy_data_aligned.csv'
unhealthy_aligned_path = 'data/unhealthy_data_aligned.csv'
output_healthy_path = 'data/healthy_data_labeled.csv'
output_unhealthy_path = 'data/unhealthy_data_labeled.csv'
output_merged_path = 'data/merged_dataset.csv'
final_path = 'data/final_dataset.csv'
chunk_size = 1000

# Load datasets
healthy_chunks = load_csv_in_chunks(healthy_aligned_path, chunk_size=chunk_size, index_col=0)
unhealthy_chunks = load_csv_in_chunks(unhealthy_aligned_path, chunk_size=chunk_size, index_col=0)

if healthy_chunks is None or unhealthy_chunks is None:
    print("❌ ERROR: Failed to load one or both datasets")
else:
    print("✅ Both datasets loaded successfully")

# Label datasets
healthy_labeled = add_condition_labels_to_chunks(healthy_chunks, 0, 'HEALTHY')
unhealthy_labeled = add_condition_labels_to_chunks(unhealthy_chunks, 1, 'UNHEALTHY')

print("Saving labeled datasets...")
save_data_to_csv(healthy_labeled, output_healthy_path, index=True)
# save_data_to_csv(unhealthy_labeled, output_unhealthy_path, index=True)

# # Merge datasets
# merged_dataset = merge_labeled_datasets(healthy_labeled, unhealthy_labeled)

# print(f"   Saving merged dataset to: {output_merged_path}")
# save_data_to_csv(merged_dataset, output_merged_path, index=True)

# merged_dataset_chunks = load_csv_in_chunks(output_merged_path, chunk_size=chunk_size, index_col=0)
# if merged_dataset_chunks is not None:
#     # Step 2: Convert gene IDs to symbols
#     print("\n🧬 CONVERTING GENE IDS TO SYMBOLS")
#     print("-" * 35)

#     final_dataset = convert_gene_ids_to_symbols(merged_dataset_chunks, mg)

#     # Step 3: Save final dataset
#     print("\n💾 SAVING FINAL DATASET")
#     print("-" * 25)
    
#     print(f"Saving to: {final_path}")
#     try:
#         save_data_to_csv(final_dataset, final_path, index=True)
#         print(f"✅ Successfully saved!")
        
#         # Final summary
#         print(f"\n📊 FINAL DATASET SUMMARY")
#         print("-" * 25)
#         print(f"Shape: {final_dataset.shape}")
#         print(f"Binary labels: {final_dataset['condition'].value_counts().sort_index().to_dict()}")
#         print(f"Sample columns: {list(final_dataset.columns[:5])}...")
        
#         print("\n🎉 DATASET READY FOR ML TRAINING!")
        
#     except Exception as e:
#         print(f"❌ Error saving: {e}")
# else:
#     print("❌ Cannot proceed - failed to load merged dataset chunks for gene ID conversion.")