# Imports and modular functions

In [None]:
import os
import pandas as pd
import mygene
import csv
import random
import gc


def load_csv_in_chunks(file_path, chunk_size=8000, **kwargs):
    """
    Loads a CSV file in chunks to avoid memory issues.

    Args:
        file_path (str): The path to the CSV file.
        chunk_size (int): The number of rows per chunk.
        **kwargs: Additional keyword arguments to pass to pd.read_csv()
                  (e.g., sep=',', header=0, index_col=None, usecols=None).

    Returns:
        A pandas TextFileReader object (iterator) that yields DataFrame chunks
        if the file exists, otherwise None.
    """
    
    if not os.path.exists(file_path):
        print(f"Error: File not found at {file_path}")
        return None
    
    print(f"Preparing to load {file_path} in chunks of size {chunk_size}...")
    try:
        chunk_iterator = pd.read_csv(file_path, chunksize=chunk_size, **kwargs)
        return chunk_iterator
    except Exception as e:
        print(f"Error loading CSV file {file_path}: {e}")
        return None


def save_data_to_csv(data_to_save, output_file_path, delete_original_path=False, index=False, **kwargs):
    """
    Saves a pandas DataFrame or an iterator of DataFrame chunks to a CSV file.
    Optionally deletes the original file after successful saving.

    Args:
        data_to_save (pd.DataFrame or iterator): DataFrame or iterator yielding DataFrames.
        output_file_path (str): Path to save the new CSV file.
        delete_original_path (str, optional): Path to the original file to delete. Defaults to False.
        index (bool, optional): Whether to write DataFrame index. Defaults to False.
        **kwargs: Additional keyword arguments to pass to df.to_csv().
    """
    mode = kwargs.pop('mode', 'w')
    header = kwargs.pop('header', True)

    try:
        if isinstance(data_to_save, pd.DataFrame):
            print(f"Saving DataFrame to {output_file_path}...")
            data_to_save.to_csv(output_file_path, index=index, mode=mode, header=header, **kwargs)
        else:
            print(f"Saving data in chunks to {output_file_path}...")
            for i, chunk_df in enumerate(data_to_save):
                chunk_mode = 'w' if i == 0 else 'a'
                chunk_header = header if i == 0 else False
                
                chunk_df.to_csv(output_file_path, index=index, mode=chunk_mode, header=chunk_header, **kwargs)
                
                if i == 0:
                    print(f"Written first chunk to {output_file_path}")
                else:
                    print(f"Appended chunk {i+1} to {output_file_path}")
        
        print(f"Successfully saved data to {output_file_path}")
        
        # Handle file deletion if requested
        if delete_original_path and delete_original_path != output_file_path:
            if os.path.exists(delete_original_path):
                os.remove(delete_original_path)
                print(f"Successfully deleted original file: {delete_original_path}")
            else:
                print(f"Warning: Original file not found at {delete_original_path}")
        elif delete_original_path == output_file_path:
            print(f"Warning: Original and output paths are the same. Original file not deleted.")

    except Exception as e:
        print(f"Error saving data to {output_file_path}: {e}")
        if os.path.exists(output_file_path):
            print(f"Partial data might have been written to {output_file_path}.")
        raise


def transpose_csv(input_csv, output_csv, feature_id_col='Name', exclude_cols=['Description'], 
                           batch_size=1000):
    """
    Optimized version of transpose_csv that reduces file reads and improves performance.
    
    Args:
        input_csv (str): Path to the input CSV file
        output_csv (str): Path where the transposed CSV will be saved
        feature_id_col (str): Name of the column containing feature IDs (e.g., gene names)
        exclude_cols (list): List of column names to exclude from transposition
        batch_size (int): Number of samples to process in each batch
    """
    print(f"Reading header and identifying sample columns...")
    with open(input_csv, 'r') as f:
        header = next(csv.reader(f))
    
    # Find indices for sample columns and the feature_id column
    sample_cols = [i for i, col in enumerate(header) if col not in [feature_id_col] + exclude_cols]
    sample_ids = [header[i] for i in sample_cols]
    feature_id_idx = header.index(feature_id_col)
    print(f"    Found {len(sample_cols)} samples to transpose.")
    
    # First pass: collect gene IDs
    print(f"Reading gene IDs from column '{feature_id_col}'...")
    gene_ids = []
    with open(input_csv, 'r') as f:
        reader = csv.reader(f)
        next(reader)  # Skip header
        for row in reader:
            if row and len(row) > feature_id_idx:
                gene_ids.append(row[feature_id_idx])
    
    print(f"    Found {len(gene_ids)} genes!")
    
    # Create output file with header
    with open(output_csv, 'w', newline='') as out_file:
        writer = csv.writer(out_file)
        writer.writerow(['sample_id'] + gene_ids)
    
    # Process samples in batches to balance memory usage and speed
    print(f"Writing data to {output_csv} in batches of {batch_size} samples...")
    for batch_start in range(0, len(sample_cols), batch_size):
        batch_end = min(batch_start + batch_size, len(sample_cols))
        current_batch_cols = sample_cols[batch_start:batch_end]
        current_batch_ids = sample_ids[batch_start:batch_end]
        
        print(f"    Processing batch {batch_start//batch_size + 1}/{(len(sample_cols)-1)//batch_size + 1} "
              f"(samples {batch_start+1}-{batch_end})...")
        
        # For each batch, read the input file once and extract all columns for this batch
        sample_values = [[] for _ in range(len(current_batch_cols))]
        
        with open(input_csv, 'r') as f:
            reader = csv.reader(f)
            next(reader)  # Skip header
            for row in reader:
                if row:
                    # For each gene row, extract values for all samples in current batch
                    for i, col_idx in enumerate(current_batch_cols):
                        if col_idx < len(row):
                            sample_values[i].append(row[col_idx])
                        else:
                            sample_values[i].append("")
        
        # Now write each sample in the batch to the output file
        with open(output_csv, 'a', newline='') as out_file:
            writer = csv.writer(out_file)
            for i, sample_id in enumerate(current_batch_ids):
                writer.writerow([sample_id] + sample_values[i])
        
        # Force garbage collection after writing each batch
        sample_values = None
        gc.collect()
    
    print(f"🥶✅ Transposed {len(sample_ids)} samples × {len(gene_ids)} genes!")


def convert_gene_ids_to_symbols(dataset_chunk_iterator, mygene_client, gene_column_prefix="ENSG"):
    """
    Convert gene IDs to gene symbols using MyGene API in a memory-efficient way.
    Only calls MyGene API once for all gene columns from the first chunk.
    
    Args:
        dataset_chunk_iterator: Iterator yielding DataFrame chunks
        mygene_client: MyGene client instance
        gene_column_prefix (str): Prefix to identify gene columns
    
    Returns:
        pd.DataFrame: Complete dataset with gene symbols as column names
    """
    print("Converting gene IDs to symbols (chunked processing)...")
    
    # Get first chunk to determine gene columns and create mapping
    try:
        first_chunk = next(dataset_chunk_iterator)
        print(f"Processing first chunk: {first_chunk.shape}")
    except StopIteration:
        print("❌ Error: Dataset is empty")
        return pd.DataFrame()
    
    # Identify gene columns from first chunk
    gene_columns = [col for col in first_chunk.columns 
                   if isinstance(col, str) and col.startswith(gene_column_prefix)]
    
    if not gene_columns:
        print("No gene columns found - returning original data")
        # Concatenate all chunks and return
        all_chunks = [first_chunk]
        for chunk in dataset_chunk_iterator:
            all_chunks.append(chunk)
        return pd.concat(all_chunks, axis=0, ignore_index=False)
    
    print(f"Found {len(gene_columns)} gene columns")
    
    # Single MyGene API call for all gene IDs
    gene_id_to_symbol = {}
    symbols_found = 0
    
    try:
        print("Making single MyGene API call...")
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            results = mygene_client.querymany(
                gene_columns,
                scopes='ensembl.gene',
                fields='symbol',
                species='human',
                verbose=False,
                silent=True
            )
        
        # Process API results
        for result in results:
            gene_id = result['query']
            if 'symbol' in result and result['symbol']:
                gene_id_to_symbol[gene_id] = result['symbol']
                symbols_found += 1
            else:
                gene_id_to_symbol[gene_id] = gene_id
        
        print(f"✅ Successfully converted {symbols_found}/{len(gene_columns)} genes to symbols")
        
    except Exception as e:
        print(f"⚠️ MyGene API error: {e}. Using original gene IDs")
        gene_id_to_symbol = {gene_id: gene_id for gene_id in gene_columns}
    
    # Create final column mapping (preserve non-gene columns like 'condition')
    final_column_mapping = {}
    for col in first_chunk.columns:
        if col == 'condition':
            final_column_mapping[col] = col
        elif col in gene_id_to_symbol:
            final_column_mapping[col] = gene_id_to_symbol[col]
        else:
            final_column_mapping[col] = col
    
    print("Applying gene symbol mapping to all chunks...")
    
    # Process first chunk
    renamed_first_chunk = first_chunk.rename(columns=final_column_mapping)
    processed_chunks = [renamed_first_chunk]
    chunk_count = 1
    
    # Process remaining chunks with same mapping
    for chunk in dataset_chunk_iterator:
        renamed_chunk = chunk.rename(columns=final_column_mapping)
        processed_chunks.append(renamed_chunk)
        chunk_count += 1
        
        if chunk_count % 3 == 0:
            print(f"  ✓ Processed {chunk_count} chunks...")
    
    print(f"Concatenating {chunk_count} processed chunks...")
    final_dataset = pd.concat(processed_chunks, axis=0, ignore_index=False)
    
    print(f"✅ Final dataset shape: {final_dataset.shape}")
    print(f"   Symbols converted: {symbols_found}/{len(gene_columns)}")
    
    return final_dataset


def add_condition_labels_to_chunks(chunk_iterator, condition_label, dataset_name):
    """
    Add condition labels to dataset chunks.
    
    Args:
        chunk_iterator: Iterator yielding DataFrame chunks
        condition_label (int): Binary label (0 for healthy, 1 for unhealthy)
        dataset_name (str): Name for logging purposes
    
    Returns:
        list: List of labeled DataFrame chunks
    """
    print(f"Adding label '{condition_label}' to {dataset_name} dataset...")
    
    labeled_chunks = []
    for chunk in chunk_iterator:
        chunk['condition'] = condition_label
        labeled_chunks.append(chunk)
    
    print(f"✅ Completed {len(labeled_chunks)} {dataset_name} chunks")
    return labeled_chunks


def merge_labeled_datasets(healthy_chunks, unhealthy_chunks):
    """
    Merge healthy and unhealthy dataset chunks into one DataFrame.
    
    Args:
        healthy_chunks (list): List of healthy DataFrame chunks
        unhealthy_chunks (list): List of unhealthy DataFrame chunks
    
    Returns:
        pd.DataFrame: Merged dataset
    """
    print("Merging datasets...")
    
    all_chunks = healthy_chunks + unhealthy_chunks
    merged_dataset = pd.concat(all_chunks, axis=0, ignore_index=False)
    
    print(f"✅ Merged dataset: {len(merged_dataset)} samples, {len(merged_dataset.columns)} features")
    print(f"   Healthy (0): {(merged_dataset['condition'] == 0).sum()}")
    print(f"   Unhealthy (1): {(merged_dataset['condition'] == 1).sum()}")
    
    return merged_dataset


def clean_duplicate_nans(chunk_iterator):
    """
    Processes an iterator of DataFrame chunks to drop duplicates and NaNs.
    Returns a list of cleaned DataFrame chunks.
    """
    cleaned_chunks_list = []
    for i, chunk_df in enumerate(chunk_iterator):
        original_rows = len(chunk_df)
        chunk_df.drop_duplicates(inplace=True)
        if original_rows > len(chunk_df):
            print(f"        Chunk {i+1}: Dropped {original_rows - len(chunk_df)} duplicate data rows.")
        original_rows = len(chunk_df)
        chunk_df.dropna(inplace=True)
        if original_rows > len(chunk_df):
            print(f"        Chunk {i+1}: Dropped {original_rows - len(chunk_df)} rows with null values.")
        if not chunk_df.empty:
            cleaned_chunks_list.append(chunk_df)
        else:
            del chunk_df
        gc.collect()
    print(f"    Finished dup/NaN cleaning. {len(cleaned_chunks_list)} non-empty chunks remaining.")
    return cleaned_chunks_list


def drop_column(chunk_iterator, column_name_to_drop):
    """
    Processes an iterator of DataFrame chunks to drop a specified column.
    Returns a list of processed DataFrame chunks.
    """
    processed_chunks_list = []
    print(f"    Attempting to drop column: '{column_name_to_drop}' from all chunks...")
    for i, chunk_df in enumerate(chunk_iterator):
        if column_name_to_drop in chunk_df.columns:
            chunk_df.drop(columns=[column_name_to_drop], inplace=True)
            if i == 0:
                 print(f"        Dropped '{column_name_to_drop}' column (first occurrence in chunk {i+1}).")
        elif i == 0:
            print(f"        Warning: Column '{column_name_to_drop}' not found in the first chunk ({i+1}).")
        if not chunk_df.empty:
            processed_chunks_list.append(chunk_df)
        else:
            del chunk_df
        gc.collect()
    print(f"    Finished column drop attempt. {len(processed_chunks_list)} chunks remaining.")
    return processed_chunks_list


def rename_index(chunk_iterator, new_index_name):
    """
    Processes an iterator of DataFrame chunks to rename the index.
    Returns a list of processed DataFrame chunks.
    """
    processed_chunks_list = []
    print(f"    Attempting to rename index to: '{new_index_name}' for all chunks...")
    for i, chunk_df in enumerate(chunk_iterator):
        if chunk_df.index.name != new_index_name:
            chunk_df.index.name = new_index_name
            if i == 0:
                print(f"        Renamed index to '{new_index_name}' (first occurrence in chunk {i+1}).")
        if not chunk_df.empty:
            processed_chunks_list.append(chunk_df)
        else:
            del chunk_df
        gc.collect()
    print(f"    Finished index renaming. {len(processed_chunks_list)} chunks remaining.")
    return processed_chunks_list


def filter_rows(chunk_iterator):
    """
    Processes an iterator of DataFrame chunks to filter out rows where all data columns are zero.
    Returns a list of filtered DataFrame chunks.
    """
    filtered_chunks_list = []
    print(f"    Attempting to filter rows with all zero counts from all chunks...")
    for i, chunk_df in enumerate(chunk_iterator):
        if not chunk_df.empty:
            original_rows = len(chunk_df)
            chunk_df_filtered = chunk_df.loc[(chunk_df != 0).any(axis=1)]
            if original_rows > len(chunk_df_filtered):
                print(f"        Chunk {i+1}: Dropped {original_rows - len(chunk_df_filtered)} rows with all zero counts.")
            if not chunk_df_filtered.empty:
                filtered_chunks_list.append(chunk_df_filtered)
            else:
                del chunk_df
                if chunk_df_filtered is not chunk_df:
                    del chunk_df_filtered
        gc.collect()
    print(f"    Finished zero-row filtering. {len(filtered_chunks_list)} non-empty chunks remaining.")
    return filtered_chunks_list

# Initial Preprocessing Unhealthy Dataset

# Initial Preprocessing Healthy Dataset
## Convert .gct to .csv

In [None]:
dataset_file = 'data/GTEx_Analysis_2022-06-06_v10_RNASeQCv2.4.2_gene_reads.gct'
output_file = 'data/healthy_data.csv'
chunk_size = 8000

gct_chunk_iterator = load_csv_in_chunks(
    file_path=dataset_file,
    chunk_size=chunk_size,
    sep='\t',
    skiprows=2
)

if gct_chunk_iterator:
    save_data_to_csv(
        data_to_save=gct_chunk_iterator,
        output_file_path=output_file,
        delete_original_path=dataset_file,
        index=False
    )
else:
    print(f"Failed to load {dataset_file} using load_csv_in_chunks.")
    
# Clean up
if 'gct_chunk_iterator' in locals() or 'gct_chunk_iterator' in globals():
    del gct_chunk_iterator
if 'dataset_file' in locals() or 'dataset_file' in globals():
    del dataset_file
if 'output_file' in locals() or 'output_file' in globals():
    del output_file
if 'read_chunk_size' in locals() or 'read_chunk_size' in globals():
    del chunk_size
    
gc.collect()

## Transpose healthy dataset

Run this seperately. After it is done transposing, restart python kernel, otherwise it will most likely crash if you try to run the next cell (If that cell would use a memory intensive task).

In [None]:
healthy_data_input = 'data/healthy_data.csv'
healthy_data_output = 'data/healthy_data_transposed.csv'
feature_id_col = 'Name'
exclude_cols = ['Description']
batch_size = 3000

transpose_csv(
    input_csv=healthy_data_input,
    output_csv=healthy_data_output,
    feature_id_col=feature_id_col,
    exclude_cols=exclude_cols,
    batch_size=batch_size
)

## Load and clean dataset

In [None]:
healthy_dataset_file = 'data/healthy_data_transposed.csv'
healthy_output_file = 'data/healthy_data_preprocessed.csv'
config_chunk_size = 1000

chunk_iterator = load_csv_in_chunks(file_path=healthy_dataset_file, chunk_size=config_chunk_size, index_col=0)
print(f"Starting preprocessing for Healthy Dataset: {healthy_dataset_file}")

chunk_iterator = clean_duplicate_nans(chunk_iterator)
gc.collect()

print(f"Dropping ")
chunk_iterator = drop_column(chunk_iterator, column_name_to_drop='Unnamed: 0')
gc.collect()

chunk_iterator = rename_index(chunk_iterator, new_index_name='gene_id')
gc.collect()

chunk_iterator = filter_rows(chunk_iterator)
gc.collect()

save_data_to_csv(
    data_to_save=chunk_iterator,
    output_file_path=healthy_output_file,
    index=True
)

if 'healthy_dataset_file' in locals(): del healthy_dataset_file
if 'healthy_output_file' in locals(): del healthy_output_file
if 'config_chunk_size' in locals(): del config_chunk_size
if 'chunk_iterator' in locals(): del chunk_iterator
gc.collect()

In [None]:
healthy_dataset_file = 'data/healthy_data_transposed.csv'
healthy_output_file = 'data/healthy_data_preprocessed_1.csv'
config_chunk_size = 500

chunk_iterator = load_csv_in_chunks(file_path=healthy_dataset_file, chunk_size=config_chunk_size, index_col=0)
print(f"Starting preprocessing for Healthy Dataset: {healthy_dataset_file}")

print(f"Cleaning duplicate NaNs and filtering rows...")
chunk_iterator = clean_duplicate_nans(chunk_iterator)

print("saving cleaned data to CSV...")
save_data_to_csv(
    data_to_save=chunk_iterator,
    output_file_path=healthy_output_file,
    index=True
)

In [None]:
healthy_dataset_file = 'data/healthy_data_preprocessed_1.csv'
healthy_output_file = 'data/healthy_data_preprocessed_2.csv'
config_chunk_size = 1000

chunk_iterator = load_csv_in_chunks(file_path=healthy_dataset_file, chunk_size=config_chunk_size, index_col=0)
print(f"Starting preprocessing for Healthy Dataset: {healthy_dataset_file}")

chunk_iterator = drop_column(chunk_iterator, column_name_to_drop='Unnamed: 0')

save_data_to_csv(
    data_to_save=chunk_iterator,
    output_file_path=healthy_output_file,
    index=True
)

In [None]:
healthy_dataset_file = 'data/healthy_data_preprocessed_2.csv'
healthy_output_file = 'data/healthy_data_preprocessed_3.csv'
config_chunk_size = 1000

chunk_iterator = load_csv_in_chunks(file_path=healthy_dataset_file, chunk_size=config_chunk_size, index_col=0)
print(f"Starting preprocessing for Healthy Dataset: {healthy_dataset_file}")

chunk_iterator = rename_index(chunk_iterator, new_index_name='gene_id')

save_data_to_csv(
    data_to_save=chunk_iterator,
    output_file_path=healthy_output_file,
    index=True
)

In [None]:
healthy_dataset_file = 'data/healthy_data_preprocessed_3.csv'
healthy_output_file = 'data/healthy_data_preprocessed_4.csv'
config_chunk_size = 1000

chunk_iterator = load_csv_in_chunks(file_path=healthy_dataset_file, chunk_size=config_chunk_size, index_col=0)
print(f"Starting preprocessing for Healthy Dataset: {healthy_dataset_file}")

chunk_iterator = filter_rows(chunk_iterator)

save_data_to_csv(
    data_to_save=chunk_iterator,
    output_file_path=healthy_output_file,
    index=True
)

In [None]:
# 